Skip to content

Commit 58fa112

Browse files
committed
update error types and add property forwarding
1 parent e4d8d5e commit 58fa112

File tree

3 files changed

+10
-3
lines changed

3 files changed

+10
-3
lines changed

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,3 +476,7 @@ def config(self) -> dict[str, Any]:
476476
@config.setter
477477
def config(self, config: dict[str, Any]) -> None:
478478
self.deepspeed_impl.config = config
479+
480+
@property
481+
def load_full_weights(self) -> bool:
482+
return self.deepspeed_impl.load_full_weights

src/lightning/pytorch/strategies/deepspeed.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,3 +445,7 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
445445
@property
446446
def config(self) -> dict[str, Any]:
447447
return self.deepspeed_strategy_impl.config
448+
449+
@property
450+
def load_full_weights(self) -> bool:
451+
return self.deepspeed_strategy_impl.load_full_weights

tests/tests_pytorch/strategies/test_deepspeed.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from lightning.pytorch.loggers import CSVLogger
3434
from lightning.pytorch.plugins import DeepSpeedPrecision
3535
from lightning.pytorch.strategies.deepspeed import DeepSpeedStrategy
36-
from lightning.pytorch.utilities.exceptions import MisconfigurationException
3736
from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11
3837
from tests_pytorch.helpers.datamodules import ClassifDataModule
3938
from tests_pytorch.helpers.runif import RunIf
@@ -1079,7 +1078,7 @@ def training_step(self, batch, batch_idx):
10791078
enable_progress_bar=False,
10801079
enable_model_summary=False,
10811080
)
1082-
with pytest.raises(MisconfigurationException, match="returning `None` .* is not supported"):
1081+
with pytest.raises(ValueError, match="returning `None` .* is not supported"):
10831082
trainer.fit(model)
10841083

10851084

@@ -1158,7 +1157,7 @@ def test_deepspeed_gradient_clip_by_value(tmp_path):
11581157
enable_progress_bar=False,
11591158
enable_model_summary=False,
11601159
)
1161-
with pytest.raises(MisconfigurationException, match="does not support clipping gradients by value"):
1160+
with pytest.raises(ValueError, match="does not support clipping gradients by value"):
11621161
trainer.fit(model)
11631162

11641163

0 commit comments

Comments
 (0)