Skip to content

Commit 2a81718

Browse files
committed
refactor: defer DeepSpeed import and logging configuration until needed
1 parent d838098 commit 2a81718

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,10 +282,10 @@ def __init__(
282282
sub_group_size=sub_group_size,
283283
)
284284

285-
import deepspeed
286-
287285
self._config_initialized = False
288-
deepspeed.utils.logging.logger.setLevel(logging_level)
286+
# Defer importing and configuring DeepSpeed logging until it is actually needed.
287+
# Store the desired logging level to be applied on first use.
288+
self._logging_level = logging_level
289289

290290
self.remote_device = remote_device
291291
self.load_full_weights = load_full_weights
@@ -374,6 +374,8 @@ def module_sharded_context(self) -> AbstractContextManager:
374374

375375
import deepspeed
376376

377+
deepspeed.utils.logging.logger.setLevel(self._logging_level)
378+
377379
assert self._config_initialized
378380
return deepspeed.zero.Init(
379381
enabled=self.zero_stage_3,
@@ -601,6 +603,8 @@ def _initialize_engine(
601603
"""
602604
import deepspeed
603605

606+
deepspeed.utils.logging.logger.setLevel(self._logging_level)
607+
604608
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
605609
deepspeed_engine, deepspeed_optimizer, _, deepspeed_scheduler = deepspeed.initialize(
606610
args=argparse.Namespace(device_rank=self.root_device.index),
@@ -628,14 +632,20 @@ def _setup_distributed(self) -> None:
628632
_validate_device_index_selection(self.parallel_devices)
629633
reset_seed()
630634
self._set_world_ranks()
631-
self._init_deepspeed_distributed()
635+
# Avoid initializing DeepSpeed distributed for single-process runs. This also avoids importing
636+
# DeepSpeed in environments where it may not be fully functional (e.g., missing nvcc),
637+
# while still allowing configuration and dataloader setup logic to run.
638+
if self.world_size > 1:
639+
self._init_deepspeed_distributed()
632640
if not self._config_initialized:
633641
self._format_config()
634642
self._config_initialized = True
635643

636644
def _init_deepspeed_distributed(self) -> None:
637645
import deepspeed
638646

647+
deepspeed.utils.logging.logger.setLevel(self._logging_level)
648+
639649
assert self.cluster_environment is not None
640650
if platform.system() != "Windows":
641651
# do not set env variables on windows, allow deepspeed to control setup
@@ -661,6 +671,8 @@ def _set_node_environment_variables(self) -> None:
661671
def _set_deepspeed_activation_checkpointing(self) -> None:
662672
import deepspeed
663673

674+
deepspeed.utils.logging.logger.setLevel(self._logging_level)
675+
664676
assert isinstance(self.config, dict)
665677
if self.config.get("activation_checkpointing"):
666678
checkpoint_config = self.config["activation_checkpointing"]

0 commit comments

Comments
 (0)