File tree Expand file tree Collapse file tree 1 file changed +3
-4
lines changed Expand file tree Collapse file tree 1 file changed +3
-4
lines changed Original file line number Diff line number Diff line change @@ -163,10 +163,9 @@ def main(**kwargs):
163
163
wandb_run .config .update (peft_config )
164
164
model .print_trainable_parameters ()
165
165
166
-
167
- hsdp_device_mesh = None
166
+ hsdp_device_mesh_plan = None
168
167
if fsdp_config .hsdp and fsdp_config .sharding_strategy == ShardingStrategy .HYBRID_SHARD :
169
- hsdp_device_mesh = hsdp_device_mesh (replica_group_size = fsdp_config .replica_group_size , sharding_group_size = fsdp_config .sharding_group_size )
168
+ hsdp_device_mesh_plan = hsdp_device_mesh (replica_group_size = fsdp_config .replica_group_size , sharding_group_size = fsdp_config .sharding_group_size )
170
169
print ("HSDP device mesh is ready" )
171
170
172
171
#setting up FSDP if enable_fsdp is enabled
@@ -189,7 +188,7 @@ def main(**kwargs):
189
188
cpu_offload = CPUOffload (offload_params = True ) if fsdp_config .fsdp_cpu_offload else None ,
190
189
mixed_precision = mixed_precision_policy if not fsdp_config .pure_bf16 else None ,
191
190
sharding_strategy = fsdp_config .sharding_strategy ,
192
- device_mesh = hsdp_device_mesh ,
191
+ device_mesh = hsdp_device_mesh_plan ,
193
192
device_id = device_id ,
194
193
limit_all_gathers = True ,
195
194
sync_module_states = train_config .low_cpu_fsdp ,
You can’t perform that action at this time.
0 commit comments