Skip to content

Commit ab7ad37

Browse files
Sean Narenlexierule
authored andcommitted
[FIX] Enable mixed precision in the Fully Sharded Strategy when precision=16 (#12965)
* Fix fully sharded mixed precision setter * Add CHANGELOG.md
1 parent aed5f9d commit ab7ad37

File tree

3 files changed

+7
-1
lines changed

3 files changed

+7
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2020
- Stopped `optimizer_zero_grad` from being called after IPU execution ([#12913](https://github.com/PyTorchLightning/pytorch-lightning/pull/12913))
2121
- Fixed `fuse_modules` to be qat-aware for `torch>=1.11` ([#12891](https://github.com/PyTorchLightning/pytorch-lightning/pull/12891))
2222
- Enforced eval shuffle warning only for default samplers in DataLoader ([#12653](https://github.com/PyTorchLightning/pytorch-lightning/pull/12653))
23+
- Enable mixed precision in `DDPFullyShardedStrategy` when `precision=16` ([#12965](https://github.com/PyTorchLightning/pytorch-lightning/pull/12965))
2324

2425

2526
## [1.6.2] - 2022-04-27

pytorch_lightning/strategies/fully_sharded.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def wrap_policy(*args, **kwargs):
163163
cpu_offload=self.cpu_offload,
164164
move_grads_to_cpu=self.move_grads_to_cpu,
165165
flatten_parameters=self.flatten_parameters,
166-
mixed_precision=(precision == PrecisionType.MIXED),
166+
mixed_precision=(precision in (PrecisionType.MIXED, PrecisionType.HALF)),
167167
reshard_after_forward=self.reshard_after_forward,
168168
fp32_reduce_scatter=self.fp32_reduce_scatter,
169169
compute_dtype=self.compute_dtype,

tests/strategies/test_ddp_fully_sharded_with_full_state_dict.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ def _assert_layer_fsdp_instance(self) -> None:
9090
assert self.layer.module[0].reshard_after_forward is True
9191
assert self.layer.module[2].reshard_after_forward is True
9292

93+
if isinstance(self.trainer.precision_plugin, FullyShardedNativeMixedPrecisionPlugin):
94+
assert self.layer.mixed_precision
95+
assert self.layer.module[0].mixed_precision
96+
assert self.layer.module[2].mixed_precision
97+
9398

9499
@RunIf(min_gpus=1, skip_windows=True, standalone=True, fairscale_fully_sharded=True)
95100
def test_fully_sharded_strategy_checkpoint(tmpdir):

0 commit comments

Comments
 (0)