Skip to content

Commit ebecbb9

Browse files
eljandoubiBernardZach
authored andcommitted
Skip DeepSpeed ZeRO Stage 3 model initialization when bnb (huggingface#34395)
* Skip DeepSpeed ZeRO Stage 3 model initialization when it is intended to be quantized. * Propagate the quantization state using a context manager * make fixup
1 parent 3618189 commit ebecbb9

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

src/transformers/modeling_utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@
136136

137137

138138
_init_weights = True
139+
_is_quantized = False
139140

140141

141142
def is_fsdp_enabled():
@@ -213,6 +214,16 @@ def _skip_init(*args, **kwargs):
213214
setattr(torch.nn.init, name, init_func)
214215

215216

217+
@contextmanager
218+
def set_quantized_state():
219+
global _is_quantized
220+
_is_quantized = True
221+
try:
222+
yield
223+
finally:
224+
_is_quantized = False
225+
226+
216227
def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
217228
try:
218229
return next(parameter.parameters()).device
@@ -1531,7 +1542,7 @@ def _from_config(cls, config, **kwargs):
15311542
torch_dtype=torch_dtype,
15321543
)
15331544

1534-
if is_deepspeed_zero3_enabled():
1545+
if is_deepspeed_zero3_enabled() and not _is_quantized:
15351546
import deepspeed
15361547

15371548
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
@@ -4086,6 +4097,9 @@ def from_pretrained(
40864097
)
40874098
init_contexts.append(init_empty_weights())
40884099

4100+
if is_deepspeed_zero3_enabled() and is_quantized:
4101+
init_contexts.append(set_quantized_state())
4102+
40894103
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
40904104
if not getattr(config, "_attn_implementation_autoset", False):
40914105
config = cls._autoset_attn_implementation(

0 commit comments

Comments
 (0)