Skip to content

Commit 706f6e8

Browse files
authored
Deduplicate fp32 weights under torch autocast and ZeRO3 (#7651)
When torch autocast is enabled, model weights are already in fp32 and can be directly updated by the optimizer with fp32 gradients. It is a waste of accelerator memory to keep another copy, also in fp32, as the master weight. Use aliases to the so-called-"fp16" params as the master weights to save memory. It applies only when no optimizer offloading (either CPU or NVMe) or swapping mechanisms is enabled. Using https://gist.github.com/eternalNight/3c2cf8c703f1e9e7742d3b7f9e1edae3 (which enables torch autocast) as an example, the memory profile of the training startup phase is as follows: <img width="3172" height="1915" alt="Picture1" src="https://github.com/user-attachments/assets/ffd40042-3582-4c82-9072-e1fdf8d49a63" /> With this PR, the master weights no longer instantiate: <img width="2990" height="1753" alt="Picture2" src="https://github.com/user-attachments/assets/1d1d3411-0735-4bd1-8061-3e015040ce74" /> This is also true when DeepCompile is enabled: <img width="3094" height="2083" alt="Picture3" src="https://github.com/user-attachments/assets/c867d766-769a-4775-ac2a-3f1a1a723c32" /> When torch autocast is disabled, the master weights are preserved: <img width="2922" height="1471" alt="Picture4" src="https://github.com/user-attachments/assets/5097ef57-2c7a-4fd0-b0c3-717c098ec52c" /> Signed-off-by: Junjie Mao <[email protected]>
1 parent 9a012d2 commit 706f6e8

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

deepspeed/runtime/zero/stage3.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -954,13 +954,17 @@ def _create_fp32_partitions(self):
954954
unpinned_fp32_buffer = torch.empty(num_elements, device=self.device, dtype=torch.float)
955955
self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i)
956956
self.fp32_partitioned_groups_flat.append(unpinned_fp32_buffer)
957+
elif self.offload_optimizer:
958+
self.fp32_partitioned_groups_flat.append(self.fp16_partitioned_groups_flat[i].to(
959+
self.subgroup_to_device[i]).clone().float().detach())
960+
elif self.fp16_partitioned_groups_flat[i].dtype == torch.float32:
961+
# When torch autocast is enabled, weights in the provided model (and thus groups in the so-called
962+
# "fp16" partitioned groups) are already in and updated using fp32. In such cases we don't need
963+
# another copy of the weights.
964+
self.fp32_partitioned_groups_flat.append(self.fp16_partitioned_groups_flat[i])
957965
else:
958-
if self.offload_optimizer:
959-
self.fp32_partitioned_groups_flat.append(self.fp16_partitioned_groups_flat[i].to(
960-
self.subgroup_to_device[i]).clone().float().detach())
961-
else:
962-
self.fp32_partitioned_groups_flat.append(self.fp16_partitioned_groups_flat[i].to(
963-
self.device).clone().float().detach())
966+
self.fp32_partitioned_groups_flat.append(self.fp16_partitioned_groups_flat[i].to(
967+
self.device).clone().float().detach())
964968
self.fp32_partitioned_groups_flat[i].ds_id = ds_id
965969

966970
self.fp32_partitioned_groups_flat[i].requires_grad = True # keep this in case internal optimizer uses it
@@ -2114,6 +2118,8 @@ def _post_step(self, timer_names):
21142118
@instrument_w_nvtx
21152119
def _reassign_or_swap_out_partitioned_parameters(self, sub_group_id):
21162120
if self.fp16_partitioned_groups_flat[sub_group_id] is not None:
2121+
# When torch autocast is enabled, groups in fp16_partitioned_groups are in fp32 already and those in
2122+
# fp32_partitioned_groups are aliases. Calling tensor.data.copy_ will not trigger any copy in that case.
21172123
self.fp16_partitioned_groups_flat[sub_group_id].data.copy_(
21182124
self.fp32_partitioned_groups_flat[sub_group_id].data)
21192125

0 commit comments

Comments
 (0)