|
| 1 | +import copy |
1 | 2 | import os
|
2 | 3 | from typing import Callable, Optional, Union
|
3 | 4 |
|
@@ -74,6 +75,24 @@ def new_from_pretrained(
|
74 | 75 | subfolder = kwargs.pop("subfolder", "")
|
75 | 76 | commit_hash = kwargs.pop("_commit_hash", None)
|
76 | 77 | variant = kwargs.pop("variant", None)
|
| 78 | + |
| 79 | + kwargs.pop("state_dict", None) |
| 80 | + kwargs.pop("from_tf", False) |
| 81 | + kwargs.pop("from_flax", False) |
| 82 | + kwargs.pop("output_loading_info", False) |
| 83 | + kwargs.pop("trust_remote_code", None) |
| 84 | + kwargs.pop("low_cpu_mem_usage", None) |
| 85 | + kwargs.pop("device_map", None) |
| 86 | + kwargs.pop("max_memory", None) |
| 87 | + kwargs.pop("offload_folder", None) |
| 88 | + kwargs.pop("offload_state_dict", False) |
| 89 | + kwargs.pop("load_in_8bit", False) |
| 90 | + kwargs.pop("load_in_4bit", False) |
| 91 | + kwargs.pop("quantization_config", None) |
| 92 | + kwargs.pop("adapter_kwargs", {}) |
| 93 | + kwargs.pop("adapter_name", "default") |
| 94 | + kwargs.pop("use_flash_attention_2", False) |
| 95 | + |
77 | 96 | use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
|
78 | 97 |
|
79 | 98 | if len(kwargs) > 0:
|
@@ -108,6 +127,10 @@ def new_from_pretrained(
|
108 | 127 | **kwargs,
|
109 | 128 | )
|
110 | 129 | else:
|
| 130 | + config = copy.deepcopy(config) |
| 131 | + kwarg_attn_imp = kwargs.pop("attn_implementation", None) |
| 132 | + if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp: |
| 133 | + config._attn_implementation = kwarg_attn_imp |
111 | 134 | model_kwargs = kwargs
|
112 | 135 |
|
113 | 136 | if commit_hash is None:
|
|
0 commit comments