File tree Expand file tree Collapse file tree 6 files changed +36
-6
lines changed
Expand file tree Collapse file tree 6 files changed +36
-6
lines changed Original file line number Diff line number Diff line change @@ -212,7 +212,12 @@ def _setup_distributed(self) -> None:
212212 self ._set_world_ranks ()
213213 self ._process_group_backend = self ._get_process_group_backend ()
214214 assert self .cluster_environment is not None
215- _init_dist_connection (self .cluster_environment , self ._process_group_backend , timeout = self ._timeout )
215+ _init_dist_connection (
216+ self .cluster_environment ,
217+ self ._process_group_backend ,
218+ timeout = self ._timeout ,
219+ device_id = self .root_device if self .root_device .type != "cpu" else None ,
220+ )
216221
217222 def _get_process_group_backend (self ) -> str :
218223 return self ._process_group_backend or _get_default_process_group_backend_for_device (self .root_device )
Original file line number Diff line number Diff line change @@ -663,7 +663,12 @@ def _setup_distributed(self) -> None:
663663 self ._set_world_ranks ()
664664 self ._process_group_backend = self ._get_process_group_backend ()
665665 assert self .cluster_environment is not None
666- _init_dist_connection (self .cluster_environment , self ._process_group_backend , timeout = self ._timeout )
666+ _init_dist_connection (
667+ self .cluster_environment ,
668+ self ._process_group_backend ,
669+ timeout = self ._timeout ,
670+ device_id = self .root_device if self .root_device .type != "cpu" else None ,
671+ )
667672
668673 def _get_process_group_backend (self ) -> str :
669674 return self ._process_group_backend or _get_default_process_group_backend_for_device (self .root_device )
Original file line number Diff line number Diff line change @@ -302,7 +302,12 @@ def _setup_distributed(self) -> None:
302302 self ._set_world_ranks ()
303303 self ._process_group_backend = self ._get_process_group_backend ()
304304 assert self .cluster_environment is not None
305- _init_dist_connection (self .cluster_environment , self ._process_group_backend , timeout = self ._timeout )
305+ _init_dist_connection (
306+ self .cluster_environment ,
307+ self ._process_group_backend ,
308+ timeout = self ._timeout ,
309+ device_id = self .root_device if self .root_device .type != "cpu" else None ,
310+ )
306311
307312 def _get_process_group_backend (self ) -> str :
308313 return self ._process_group_backend or _get_default_process_group_backend_for_device (self .root_device )
Original file line number Diff line number Diff line change @@ -200,7 +200,12 @@ def setup_distributed(self) -> None:
200200 self .set_world_ranks ()
201201 self ._process_group_backend = self ._get_process_group_backend ()
202202 assert self .cluster_environment is not None
203- _init_dist_connection (self .cluster_environment , self ._process_group_backend , timeout = self ._timeout )
203+ _init_dist_connection (
204+ self .cluster_environment ,
205+ self ._process_group_backend ,
206+ timeout = self ._timeout ,
207+ device_id = self .root_device if self .root_device .type != "cpu" else None ,
208+ )
204209
205210 def _get_process_group_backend (self ) -> str :
206211 return self ._process_group_backend or _get_default_process_group_backend_for_device (self .root_device )
Original file line number Diff line number Diff line change @@ -260,7 +260,12 @@ def setup_environment(self) -> None:
260260
261261 self ._process_group_backend = self ._get_process_group_backend ()
262262 assert self .cluster_environment is not None
263- _init_dist_connection (self .cluster_environment , self ._process_group_backend , timeout = self ._timeout )
263+ _init_dist_connection (
264+ self .cluster_environment ,
265+ self ._process_group_backend ,
266+ timeout = self ._timeout ,
267+ device_id = self .root_device if self .root_device .type != "cpu" else None ,
268+ )
264269
265270 # if 'device_mesh' in the `kwargs` is provided as a tuple, update it into the `DeviceMesh` object here
266271 if isinstance (self .kwargs .get ("device_mesh" ), tuple ):
Original file line number Diff line number Diff line change @@ -350,7 +350,12 @@ def _setup_distributed(self) -> None:
350350 self .set_world_ranks ()
351351 self ._process_group_backend = self ._get_process_group_backend ()
352352 assert self .cluster_environment is not None
353- _init_dist_connection (self .cluster_environment , self ._process_group_backend , timeout = self ._timeout )
353+ _init_dist_connection (
354+ self .cluster_environment ,
355+ self ._process_group_backend ,
356+ timeout = self ._timeout ,
357+ device_id = self .root_device if self .root_device .type != "cpu" else None ,
358+ )
354359
355360 def _get_process_group_backend (self ) -> str :
356361 return self ._process_group_backend or _get_default_process_group_backend_for_device (self .root_device )
You can’t perform that action at this time.
0 commit comments