|
33 | 33 | ) |
34 | 34 | DS_UNIVERSAL_CHECKPOINT_INFO = True |
35 | 35 | except ImportError: |
36 | | - DS_UNIVERSAL_CHECKPOINT_INFO = None |
| 36 | + DS_UNIVERSAL_CHECKPOINT_INFO = False |
| 37 | + |
37 | 38 |
|
38 | 39 | def post_language_model_processing(lm_output, labels, logit_weights, |
39 | 40 | parallel_output, |
@@ -341,37 +342,37 @@ def _logits_helper(embedding, lm_output): |
341 | 342 |
|
342 | 343 | def universal_checkpoint_info(self): |
343 | 344 | info = dict() |
| 345 | + if DS_UNIVERSAL_CHECKPOINT_INFO: |
| 346 | + # Vocabulary parameters (embeddings) that require special handling due to padding. |
| 347 | + info[VOCABULARY_PARAMETER_PATTERNS] = [ |
| 348 | + r"tied_modules.embed.word_embeddings.weight" |
| 349 | + ] |
| 350 | + |
| 351 | + # Replicated (shared) parameters on the pipeline dimension |
| 352 | + info[PIPELINE_REPLICATED_PARAMETER_PATTERNS] = [ |
| 353 | + r"tied_modules.embed.word_embeddings.weight", |
| 354 | + r"tied_modules.embed.position_embeddings.weight" |
| 355 | + ] |
| 356 | + |
| 357 | + # Parameter slices that should be averaged not concatenated. |
| 358 | + info[TP_REPLICATED_PARAMETER_PATTERNS] = [ |
| 359 | + r"tied_modules.embed.word_embeddings.norm.weight", |
| 360 | + r"tied_modules.embed.word_embeddings.norm.bias", |
| 361 | + r"tied_modules.embed.position_embeddings.weight", |
| 362 | + r"\d+.input_layernorm.weight", |
| 363 | + r"\d+.input_layernorm.bias", |
| 364 | + r"\d+.post_attention_layernorm.weight", |
| 365 | + r"\d+.post_attention_layernorm.bias", |
| 366 | + r"\d+.self_attention.dense.bias", |
| 367 | + r"\d+.mlp.dense_4h_to_h.bias", |
| 368 | + r"\d+.weight", |
| 369 | + r"\d+.bias", |
| 370 | + ] |
344 | 371 |
|
345 | | - # Vocabulary parameters (embeddings) that require special handling due to padding. |
346 | | - info[VOCABULARY_PARAMETER_PATTERNS] = [ |
347 | | - r"tied_modules.embed.word_embeddings.weight" |
348 | | - ] |
349 | | - |
350 | | - # Replicated (shared) parameters on the pipeline dimension |
351 | | - info[PIPELINE_REPLICATED_PARAMETER_PATTERNS] = [ |
352 | | - r"tied_modules.embed.word_embeddings.weight", |
353 | | - r"tied_modules.embed.position_embeddings.weight" |
354 | | - ] |
355 | | - |
356 | | - # Parameter slices that should be averaged not concatenated. |
357 | | - info[TP_REPLICATED_PARAMETER_PATTERNS] = [ |
358 | | - r"tied_modules.embed.word_embeddings.norm.weight", |
359 | | - r"tied_modules.embed.word_embeddings.norm.bias", |
360 | | - r"tied_modules.embed.position_embeddings.weight", |
361 | | - r"\d+.input_layernorm.weight", |
362 | | - r"\d+.input_layernorm.bias", |
363 | | - r"\d+.post_attention_layernorm.weight", |
364 | | - r"\d+.post_attention_layernorm.bias", |
365 | | - r"\d+.self_attention.dense.bias", |
366 | | - r"\d+.mlp.dense_4h_to_h.bias", |
367 | | - r"\d+.weight", |
368 | | - r"\d+.bias", |
369 | | - ] |
370 | | - |
371 | | - # Parameter that are sliced on the row dimension |
372 | | - info[PARAMETER_WITH_ROW_PARALLELISM_PATTERNS] = [ |
373 | | - r"\d+.mlp.dense_4h_to_h.weight", |
374 | | - r"\d+.mlp.self_attention.dense.weight", |
375 | | - ] |
| 372 | + # Parameter that are sliced on the row dimension |
| 373 | + info[PARAMETER_WITH_ROW_PARALLELISM_PATTERNS] = [ |
| 374 | + r"\d+.mlp.dense_4h_to_h.weight", |
| 375 | + r"\d+.mlp.self_attention.dense.weight", |
| 376 | + ] |
376 | 377 |
|
377 | 378 | return info |
0 commit comments