File tree Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Original file line number Diff line number Diff line change 1010from transformers .utils .versions import require_version
1111
1212from swift .llm .argument .base_args import to_abspath
13- from swift .utils import get_logger , json_parse_to_dict
13+ from swift .utils import get_dist_setting , get_logger , json_parse_to_dict
1414
1515logger = get_logger ()
1616
@@ -160,6 +160,7 @@ class MegatronArguments(ExtraMegatronArguments):
160160
161161 # dist
162162 distributed_backend : Literal ['nccl' , 'gloo' ] = 'nccl'
163+ local_rank : Optional [int ] = None
163164 use_distributed_optimizer : bool = True
164165 tensor_model_parallel_size : int = 1
165166 pipeline_model_parallel_size : int = 1
@@ -273,6 +274,8 @@ class MegatronArguments(ExtraMegatronArguments):
273274 def _set_default (self ):
274275 if self .mlp_padding_free and self .sequence_parallel :
275276 raise ValueError ('mlp_padding_free is not compatible with sequence_parallel.' )
277+ if self .local_rank is None :
278+ self .local_rank = get_dist_setting ()[1 ]
276279 if self .lr is None :
277280 if self .train_type == 'full' :
278281 self .lr = 1e-5
You can’t perform that action at this time.
0 commit comments