| 
30 | 30 |         PIPELINE_REPLICATED_PARAMETER_PATTERNS,  | 
31 | 31 |         TP_REPLICATED_PARAMETER_PATTERNS,  | 
32 | 32 |         PARAMETER_WITH_ROW_PARALLELISM_PATTERNS,  | 
 | 33 | +        PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0,  | 
33 | 34 |     )  | 
34 | 35 |     DS_UNIVERSAL_CHECKPOINT_INFO = True   | 
35 | 36 | except ImportError:  | 
@@ -338,36 +339,88 @@ def _logits_helper(embedding, lm_output):  | 
338 | 339 |                          activation_checkpoint_interval=interval,  | 
339 | 340 |                          partition_method='type:transformer')  | 
340 | 341 | 
 
  | 
 | 342 | +    @staticmethod  | 
 | 343 | +    def _get_vocab_param_patterns():  | 
 | 344 | +        args = get_args()  | 
 | 345 | +        if args.untie_embeddings_and_output_weights:  | 
 | 346 | +            patterns = [  | 
 | 347 | +                r"\d+.word_embeddings.weight",  | 
 | 348 | +                r"\d+.lm_head.weight"  | 
 | 349 | +            ]  | 
 | 350 | +        else:  | 
 | 351 | +            patterns = [  | 
 | 352 | +                r"tied_modules.embed.word_embeddings.weight"  | 
 | 353 | +            ]  | 
 | 354 | +        return patterns  | 
 | 355 | + | 
 | 356 | +    def _get_pp_replicated_param_patterns(self):  | 
 | 357 | +        args = get_args()  | 
 | 358 | +        if args.untie_embeddings_and_output_weights:  | 
 | 359 | +            return []  | 
 | 360 | +        patterns = self._get_vocab_param_patterns()  | 
 | 361 | +        if args.add_position_embedding:  | 
 | 362 | +            patterns.append(r"tied_modules.embed.position_embeddings.weight")  | 
 | 363 | +        return patterns  | 
 | 364 | + | 
 | 365 | +    @staticmethod  | 
 | 366 | +    def _get_tp_replicated_param_patterns():  | 
 | 367 | +        args = get_args()  | 
 | 368 | +        patterns = [  | 
 | 369 | +            r"\d+.input_layernorm.weight",  | 
 | 370 | +            r"\d+.post_attention_layernorm.weight",  | 
 | 371 | +            r"\d+.weight",  | 
 | 372 | +        ]  | 
 | 373 | +        if args.add_position_embedding:  | 
 | 374 | +            patterns.append(r"tied_modules.embed.position_embeddings.weight")  | 
 | 375 | +        if args.add_bias_linear:  | 
 | 376 | +            patterns.extend([  | 
 | 377 | +                r"\d+.self_attention.dense.bias",  | 
 | 378 | +                r"\d+.mlp.dense_4h_to_h.bias",  | 
 | 379 | +            ])  | 
 | 380 | +        if args.normalization == 'layernorm':  | 
 | 381 | +            patterns.extend([  | 
 | 382 | +                r"\d+.input_layernorm.bias",  | 
 | 383 | +                r"\d+.post_attention_layernorm.bias",  | 
 | 384 | +                r"\d+.bias",  | 
 | 385 | +            ])  | 
 | 386 | +        return patterns  | 
 | 387 | + | 
 | 388 | +    @staticmethod  | 
 | 389 | +    def _get_row_parallel_param_patterns():  | 
 | 390 | +        return [  | 
 | 391 | +            r"\d+.mlp.dense_4h_to_h.weight",  | 
 | 392 | +            r"\d+.self_attention.dense.weight",  | 
 | 393 | +        ]  | 
 | 394 | + | 
 | 395 | +    @staticmethod  | 
 | 396 | +    def _get_swiglu_col_parallel_param_patterns():  | 
 | 397 | +        args = get_args()  | 
 | 398 | +        if not args.swiglu:  | 
 | 399 | +            return []  | 
 | 400 | +        patterns = [  | 
 | 401 | +            r"\d+.mlp.dense_h_to_4h.weight",  | 
 | 402 | +        ]  | 
 | 403 | +        if args.add_bias_linear:  | 
 | 404 | +            patterns.append(r"\d+.mlp.dense_h_to_4h.bias")  | 
 | 405 | +        return patterns  | 
 | 406 | + | 
 | 407 | + | 
341 | 408 |     def universal_checkpoint_info(self):  | 
342 | 409 |         info = dict()  | 
343 | 410 |         if DS_UNIVERSAL_CHECKPOINT_INFO:  | 
344 | 411 |             # Vocabulary parameters (embeddings) that require special handling due to padding.  | 
345 |  | -            info[VOCABULARY_PARAMETER_PATTERNS] = [  | 
346 |  | -                r"tied_modules.embed.word_embeddings.weight"  | 
347 |  | -            ]  | 
 | 412 | +            info[VOCABULARY_PARAMETER_PATTERNS] = self._get_vocab_param_patterns()  | 
348 | 413 | 
 
  | 
349 | 414 |             # Replicated (shared) parameters on the pipeline dimension  | 
350 |  | -            info[PIPELINE_REPLICATED_PARAMETER_PATTERNS] = [  | 
351 |  | -                r"tied_modules.embed.word_embeddings.weight",  | 
352 |  | -                r"tied_modules.embed.position_embeddings.weight"  | 
353 |  | -            ]  | 
 | 415 | +            info[PIPELINE_REPLICATED_PARAMETER_PATTERNS] = self._get_pp_replicated_param_patterns()  | 
354 | 416 | 
 
  | 
355 | 417 |             # Parameter slices that should be averaged not concatenated.  | 
356 |  | -            info[TP_REPLICATED_PARAMETER_PATTERNS] = [  | 
357 |  | -                r"tied_modules.embed.position_embeddings.weight",  | 
358 |  | -                r"\d+.input_layernorm.weight",  | 
359 |  | -                r"\d+.input_layernorm.bias",  | 
360 |  | -                r"\d+.post_attention_layernorm.weight",  | 
361 |  | -                r"\d+.post_attention_layernorm.bias",  | 
362 |  | -                r"\d+.self_attention.dense.bias",  | 
363 |  | -                r"\d+.mlp.dense_4h_to_h.bias",  | 
364 |  | -                r"\d+.weight",  | 
365 |  | -                r"\d+.bias",  | 
366 |  | -            ]  | 
 | 418 | +            info[TP_REPLICATED_PARAMETER_PATTERNS] = self._get_tp_replicated_param_patterns()  | 
367 | 419 | 
 
  | 
368 | 420 |             # Parameter that are sliced on the row dimension  | 
369 |  | -            info[PARAMETER_WITH_ROW_PARALLELISM_PATTERNS] = [  | 
370 |  | -                r"\d+.mlp.dense_4h_to_h.weight",  | 
371 |  | -                r"\d+.self_attention.dense.weight",  | 
372 |  | -            ]  | 
 | 421 | +            info[PARAMETER_WITH_ROW_PARALLELISM_PATTERNS] = self._get_row_parallel_param_patterns()  | 
 | 422 | + | 
 | 423 | +            # SWIGLU parameters are first sliced on dim=0 to tp slices  | 
 | 424 | +            # Then, each tp slice is chunked into 2 to create the linear layers L1, L2 used for silu(L1(x)) * L2(x))  | 
 | 425 | +            info[PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0] = self._get_swiglu_col_parallel_param_patterns()  | 
373 | 426 |         return info  | 
0 commit comments