@@ -954,13 +954,17 @@ def _create_fp32_partitions(self):
954954 unpinned_fp32_buffer = torch .empty (num_elements , device = self .device , dtype = torch .float )
955955 self ._swap_in_sub_group_to_flat_buffer (unpinned_fp32_buffer , i )
956956 self .fp32_partitioned_groups_flat .append (unpinned_fp32_buffer )
957+ elif self .offload_optimizer :
958+ self .fp32_partitioned_groups_flat .append (self .fp16_partitioned_groups_flat [i ].to (
959+ self .subgroup_to_device [i ]).clone ().float ().detach ())
960+ elif self .fp16_partitioned_groups_flat [i ].dtype == torch .float32 :
961+ # When torch autocast is enabled, weights in the provided model (and thus groups in the so-called
962+ # "fp16" partitioned groups) are already in and updated using fp32. In such cases we don't need
963+ # another copy of the weights.
964+ self .fp32_partitioned_groups_flat .append (self .fp16_partitioned_groups_flat [i ])
957965 else :
958- if self .offload_optimizer :
959- self .fp32_partitioned_groups_flat .append (self .fp16_partitioned_groups_flat [i ].to (
960- self .subgroup_to_device [i ]).clone ().float ().detach ())
961- else :
962- self .fp32_partitioned_groups_flat .append (self .fp16_partitioned_groups_flat [i ].to (
963- self .device ).clone ().float ().detach ())
966+ self .fp32_partitioned_groups_flat .append (self .fp16_partitioned_groups_flat [i ].to (
967+ self .device ).clone ().float ().detach ())
964968 self .fp32_partitioned_groups_flat [i ].ds_id = ds_id
965969
966970 self .fp32_partitioned_groups_flat [i ].requires_grad = True # keep this in case internal optimizer uses it
@@ -2114,6 +2118,8 @@ def _post_step(self, timer_names):
21142118 @instrument_w_nvtx
21152119 def _reassign_or_swap_out_partitioned_parameters (self , sub_group_id ):
21162120 if self .fp16_partitioned_groups_flat [sub_group_id ] is not None :
2121+ # When torch autocast is enabled, groups in fp16_partitioned_groups are in fp32 already and those in
2122+ # fp32_partitioned_groups are aliases. Calling tensor.data.copy_ will not trigger any copy in that case.
21172123 self .fp16_partitioned_groups_flat [sub_group_id ].data .copy_ (
21182124 self .fp32_partitioned_groups_flat [sub_group_id ].data )
21192125
0 commit comments