Skip to content

Commit 9aca79f

Browse files
salccsayakpaul
andauthored
Replace transformers.deepspeed with transformers.integrations.deepspeed (#9281)
to avoid "FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations" Co-authored-by: Sayak Paul <[email protected]>
1 parent bbcf2a8 commit 9aca79f

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/diffusers/training_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -418,11 +418,11 @@ def step(self, parameters: Iterable[torch.nn.Parameter]):
418418
one_minus_decay = 1 - decay
419419

420420
context_manager = contextlib.nullcontext
421-
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
421+
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
422422
import deepspeed
423423

424424
if self.foreach:
425-
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
425+
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
426426
context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None)
427427

428428
with context_manager():
@@ -444,7 +444,7 @@ def step(self, parameters: Iterable[torch.nn.Parameter]):
444444

445445
else:
446446
for s_param, param in zip(self.shadow_params, parameters):
447-
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
447+
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
448448
context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
449449

450450
with context_manager():

0 commit comments

Comments
 (0)