File tree Expand file tree Collapse file tree 1 file changed +6
-5
lines changed Expand file tree Collapse file tree 1 file changed +6
-5
lines changed Original file line number Diff line number Diff line change 2424if is_transformers_available ():
2525 import transformers
2626
27+ if transformers .integrations .deepspeed .is_deepspeed_zero3_enabled ():
28+ import deepspeed
29+
2730if is_peft_available ():
2831 from peft import set_peft_model_state_dict
2932
@@ -442,15 +445,13 @@ def step(self, parameters: Iterable[torch.nn.Parameter]):
442445 self .cur_decay_value = decay
443446 one_minus_decay = 1 - decay
444447
445- context_manager = contextlib .nullcontext
446- if is_transformers_available () and transformers .integrations .deepspeed .is_deepspeed_zero3_enabled ():
447- import deepspeed
448+ context_manager = contextlib .nullcontext ()
448449
449450 if self .foreach :
450451 if is_transformers_available () and transformers .integrations .deepspeed .is_deepspeed_zero3_enabled ():
451452 context_manager = deepspeed .zero .GatheredParameters (parameters , modifier_rank = None )
452453
453- with context_manager () :
454+ with context_manager :
454455 params_grad = [param for param in parameters if param .requires_grad ]
455456 s_params_grad = [
456457 s_param for s_param , param in zip (self .shadow_params , parameters ) if param .requires_grad
@@ -472,7 +473,7 @@ def step(self, parameters: Iterable[torch.nn.Parameter]):
472473 if is_transformers_available () and transformers .integrations .deepspeed .is_deepspeed_zero3_enabled ():
473474 context_manager = deepspeed .zero .GatheredParameters (param , modifier_rank = None )
474475
475- with context_manager () :
476+ with context_manager :
476477 if param .requires_grad :
477478 s_param .sub_ (one_minus_decay * (s_param - param ))
478479 else :
You can’t perform that action at this time.
0 commit comments