Skip to content

Commit 9d83c6d

Browse files
authored
[lazy] fix lazy cls init (#5720)
* fix * fix * fix * fix * fix * remove kernel intall * rebase revert fix * fix * fix
1 parent 2011b13 commit 9d83c6d

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

.github/workflows/build_on_pr.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ jobs:
140140
141141
- name: Install Colossal-AI
142142
run: |
143-
BUILD_EXT=1 pip install -v -e .
143+
pip install -v -e .
144144
pip install -r requirements/requirements-test.txt
145145
146146
- name: Store Colossal-AI Cache

colossalai/lazy/pretrained.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import os
23
from typing import Callable, Optional, Union
34

@@ -74,6 +75,24 @@ def new_from_pretrained(
7475
subfolder = kwargs.pop("subfolder", "")
7576
commit_hash = kwargs.pop("_commit_hash", None)
7677
variant = kwargs.pop("variant", None)
78+
79+
kwargs.pop("state_dict", None)
80+
kwargs.pop("from_tf", False)
81+
kwargs.pop("from_flax", False)
82+
kwargs.pop("output_loading_info", False)
83+
kwargs.pop("trust_remote_code", None)
84+
kwargs.pop("low_cpu_mem_usage", None)
85+
kwargs.pop("device_map", None)
86+
kwargs.pop("max_memory", None)
87+
kwargs.pop("offload_folder", None)
88+
kwargs.pop("offload_state_dict", False)
89+
kwargs.pop("load_in_8bit", False)
90+
kwargs.pop("load_in_4bit", False)
91+
kwargs.pop("quantization_config", None)
92+
kwargs.pop("adapter_kwargs", {})
93+
kwargs.pop("adapter_name", "default")
94+
kwargs.pop("use_flash_attention_2", False)
95+
7796
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
7897

7998
if len(kwargs) > 0:
@@ -108,6 +127,10 @@ def new_from_pretrained(
108127
**kwargs,
109128
)
110129
else:
130+
config = copy.deepcopy(config)
131+
kwarg_attn_imp = kwargs.pop("attn_implementation", None)
132+
if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp:
133+
config._attn_implementation = kwarg_attn_imp
111134
model_kwargs = kwargs
112135

113136
if commit_hash is None:

0 commit comments

Comments
 (0)