Skip to content

Commit ad74bb3

Browse files
committed
allow user to pass kwargs to DeepSpeedStrategy
1 parent be608fa commit ad74bb3

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def __init__(
9797
load_full_weights: bool = False,
9898
precision: Optional[Precision] = None,
9999
process_group_backend: Optional[str] = None,
100+
**kwargs: Any,
100101
) -> None:
101102
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
102103
billion parameter models. `For more information: https://pytorch-
@@ -239,6 +240,7 @@ def __init__(
239240
cluster_environment=cluster_environment,
240241
precision=precision,
241242
process_group_backend=process_group_backend,
243+
**kwargs,
242244
)
243245
self._backward_sync_control = None # DeepSpeed handles gradient accumulation internally
244246

src/lightning/pytorch/strategies/deepspeed.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def __init__(
119119
load_full_weights: bool = False,
120120
precision_plugin: Optional[Precision] = None,
121121
process_group_backend: Optional[str] = None,
122+
**kwargs: Any,
122123
) -> None:
123124
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
124125
billion parameter models. `For more information: https://pytorch-
@@ -263,6 +264,7 @@ def __init__(
263264
cluster_environment=cluster_environment,
264265
precision_plugin=precision_plugin,
265266
process_group_backend=process_group_backend,
267+
**kwargs,
266268
)
267269

268270
self.config = self._load_config(config)

0 commit comments

Comments
 (0)