Skip to content

Commit ef816b6

Browse files
committed
Revert "refactor: defer DeepSpeed import and logging configuration until needed"
This reverts commit 59dda02.
1 parent 6c1554a commit ef816b6

File tree

1 file changed

+4
-16
lines changed

1 file changed

+4
-16
lines changed

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -282,10 +282,10 @@ def __init__(
282282
sub_group_size=sub_group_size,
283283
)
284284

285+
import deepspeed
286+
285287
self._config_initialized = False
286-
# Defer importing and configuring DeepSpeed logging until it is actually needed.
287-
# Store the desired logging level to be applied on first use.
288-
self._logging_level = logging_level
288+
deepspeed.utils.logging.logger.setLevel(logging_level)
289289

290290
self.remote_device = remote_device
291291
self.load_full_weights = load_full_weights
@@ -374,8 +374,6 @@ def module_sharded_context(self) -> AbstractContextManager:
374374

375375
import deepspeed
376376

377-
deepspeed.utils.logging.logger.setLevel(self._logging_level)
378-
379377
assert self._config_initialized
380378
return deepspeed.zero.Init(
381379
enabled=self.zero_stage_3,
@@ -603,8 +601,6 @@ def _initialize_engine(
603601
"""
604602
import deepspeed
605603

606-
deepspeed.utils.logging.logger.setLevel(self._logging_level)
607-
608604
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
609605
deepspeed_engine, deepspeed_optimizer, _, deepspeed_scheduler = deepspeed.initialize(
610606
args=argparse.Namespace(device_rank=self.root_device.index),
@@ -632,20 +628,14 @@ def _setup_distributed(self) -> None:
632628
_validate_device_index_selection(self.parallel_devices)
633629
reset_seed()
634630
self._set_world_ranks()
635-
# Avoid initializing DeepSpeed distributed for single-process runs. This also avoids importing
636-
# DeepSpeed in environments where it may not be fully functional (e.g., missing nvcc),
637-
# while still allowing configuration and dataloader setup logic to run.
638-
if self.world_size > 1:
639-
self._init_deepspeed_distributed()
631+
self._init_deepspeed_distributed()
640632
if not self._config_initialized:
641633
self._format_config()
642634
self._config_initialized = True
643635

644636
def _init_deepspeed_distributed(self) -> None:
645637
import deepspeed
646638

647-
deepspeed.utils.logging.logger.setLevel(self._logging_level)
648-
649639
assert self.cluster_environment is not None
650640
if platform.system() != "Windows":
651641
# do not set env variables on windows, allow deepspeed to control setup
@@ -671,8 +661,6 @@ def _set_node_environment_variables(self) -> None:
671661
def _set_deepspeed_activation_checkpointing(self) -> None:
672662
import deepspeed
673663

674-
deepspeed.utils.logging.logger.setLevel(self._logging_level)
675-
676664
assert isinstance(self.config, dict)
677665
if self.config.get("activation_checkpointing"):
678666
checkpoint_config = self.config["activation_checkpointing"]

0 commit comments

Comments
 (0)