@@ -282,10 +282,10 @@ def __init__(
282282 sub_group_size = sub_group_size ,
283283 )
284284
285+ import deepspeed
286+
285287 self ._config_initialized = False
286- # Defer importing and configuring DeepSpeed logging until it is actually needed.
287- # Store the desired logging level to be applied on first use.
288- self ._logging_level = logging_level
288+ deepspeed .utils .logging .logger .setLevel (logging_level )
289289
290290 self .remote_device = remote_device
291291 self .load_full_weights = load_full_weights
@@ -374,8 +374,6 @@ def module_sharded_context(self) -> AbstractContextManager:
374374
375375 import deepspeed
376376
377- deepspeed .utils .logging .logger .setLevel (self ._logging_level )
378-
379377 assert self ._config_initialized
380378 return deepspeed .zero .Init (
381379 enabled = self .zero_stage_3 ,
@@ -603,8 +601,6 @@ def _initialize_engine(
603601 """
604602 import deepspeed
605603
606- deepspeed .utils .logging .logger .setLevel (self ._logging_level )
607-
608604 model_parameters = filter (lambda p : p .requires_grad , model .parameters ())
609605 deepspeed_engine , deepspeed_optimizer , _ , deepspeed_scheduler = deepspeed .initialize (
610606 args = argparse .Namespace (device_rank = self .root_device .index ),
@@ -632,20 +628,14 @@ def _setup_distributed(self) -> None:
632628 _validate_device_index_selection (self .parallel_devices )
633629 reset_seed ()
634630 self ._set_world_ranks ()
635- # Avoid initializing DeepSpeed distributed for single-process runs. This also avoids importing
636- # DeepSpeed in environments where it may not be fully functional (e.g., missing nvcc),
637- # while still allowing configuration and dataloader setup logic to run.
638- if self .world_size > 1 :
639- self ._init_deepspeed_distributed ()
631+ self ._init_deepspeed_distributed ()
640632 if not self ._config_initialized :
641633 self ._format_config ()
642634 self ._config_initialized = True
643635
644636 def _init_deepspeed_distributed (self ) -> None :
645637 import deepspeed
646638
647- deepspeed .utils .logging .logger .setLevel (self ._logging_level )
648-
649639 assert self .cluster_environment is not None
650640 if platform .system () != "Windows" :
651641 # do not set env variables on windows, allow deepspeed to control setup
@@ -671,8 +661,6 @@ def _set_node_environment_variables(self) -> None:
671661 def _set_deepspeed_activation_checkpointing (self ) -> None :
672662 import deepspeed
673663
674- deepspeed .utils .logging .logger .setLevel (self ._logging_level )
675-
676664 assert isinstance (self .config , dict )
677665 if self .config .get ("activation_checkpointing" ):
678666 checkpoint_config = self .config ["activation_checkpointing" ]
0 commit comments