Skip to content

Commit 8e1ff89

Browse files
tjruwasemrwyattii
andauthored
Universal Checkpoint: BC for older DeepSpeed (bigscience-workshop#271)
* Enable universal ckpting * Update run scripts * Address PR feedback * Remove line * Fix white lines * Remove redudant changes * Apply to gpt_model only * Code cleanup * Code cleanup * Update training.py Co-authored-by: Michael Wyatt <[email protected]> * Update training.py Co-authored-by: Michael Wyatt <[email protected]> * Log loss_scale only valid for fp16 * Add README and bf16 scripts * Visualization docsts * Support older DS * Handle uni_ckpt import error * Revert changes --------- Co-authored-by: Michael Wyatt <[email protected]>
1 parent ad0e1fd commit 8e1ff89

File tree

2 files changed

+35
-34
lines changed

2 files changed

+35
-34
lines changed

examples_deepspeed/universal_checkpointing/ds_config.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"train_batch_size" : 16,
3-
"train_micro_batch_size_per_gpu": 1,
3+
"train_micro_batch_size_per_gpu": 16,
44
"steps_per_print": 1,
55

66
"zero_optimization": {
@@ -11,7 +11,7 @@
1111
"enabled": true
1212
},
1313

14-
"data_types": {
14+
"data_types": {
1515
"grad_accum_dtype": "fp32"
1616
},
1717

megatron/model/gpt_model.py

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
)
3434
DS_UNIVERSAL_CHECKPOINT_INFO = True
3535
except ImportError:
36-
DS_UNIVERSAL_CHECKPOINT_INFO = None
36+
DS_UNIVERSAL_CHECKPOINT_INFO = False
37+
3738

3839
def post_language_model_processing(lm_output, labels, logit_weights,
3940
parallel_output,
@@ -341,37 +342,37 @@ def _logits_helper(embedding, lm_output):
341342

342343
def universal_checkpoint_info(self):
343344
info = dict()
345+
if DS_UNIVERSAL_CHECKPOINT_INFO:
346+
# Vocabulary parameters (embeddings) that require special handling due to padding.
347+
info[VOCABULARY_PARAMETER_PATTERNS] = [
348+
r"tied_modules.embed.word_embeddings.weight"
349+
]
350+
351+
# Replicated (shared) parameters on the pipeline dimension
352+
info[PIPELINE_REPLICATED_PARAMETER_PATTERNS] = [
353+
r"tied_modules.embed.word_embeddings.weight",
354+
r"tied_modules.embed.position_embeddings.weight"
355+
]
356+
357+
# Parameter slices that should be averaged not concatenated.
358+
info[TP_REPLICATED_PARAMETER_PATTERNS] = [
359+
r"tied_modules.embed.word_embeddings.norm.weight",
360+
r"tied_modules.embed.word_embeddings.norm.bias",
361+
r"tied_modules.embed.position_embeddings.weight",
362+
r"\d+.input_layernorm.weight",
363+
r"\d+.input_layernorm.bias",
364+
r"\d+.post_attention_layernorm.weight",
365+
r"\d+.post_attention_layernorm.bias",
366+
r"\d+.self_attention.dense.bias",
367+
r"\d+.mlp.dense_4h_to_h.bias",
368+
r"\d+.weight",
369+
r"\d+.bias",
370+
]
344371

345-
# Vocabulary parameters (embeddings) that require special handling due to padding.
346-
info[VOCABULARY_PARAMETER_PATTERNS] = [
347-
r"tied_modules.embed.word_embeddings.weight"
348-
]
349-
350-
# Replicated (shared) parameters on the pipeline dimension
351-
info[PIPELINE_REPLICATED_PARAMETER_PATTERNS] = [
352-
r"tied_modules.embed.word_embeddings.weight",
353-
r"tied_modules.embed.position_embeddings.weight"
354-
]
355-
356-
# Parameter slices that should be averaged not concatenated.
357-
info[TP_REPLICATED_PARAMETER_PATTERNS] = [
358-
r"tied_modules.embed.word_embeddings.norm.weight",
359-
r"tied_modules.embed.word_embeddings.norm.bias",
360-
r"tied_modules.embed.position_embeddings.weight",
361-
r"\d+.input_layernorm.weight",
362-
r"\d+.input_layernorm.bias",
363-
r"\d+.post_attention_layernorm.weight",
364-
r"\d+.post_attention_layernorm.bias",
365-
r"\d+.self_attention.dense.bias",
366-
r"\d+.mlp.dense_4h_to_h.bias",
367-
r"\d+.weight",
368-
r"\d+.bias",
369-
]
370-
371-
# Parameter that are sliced on the row dimension
372-
info[PARAMETER_WITH_ROW_PARALLELISM_PATTERNS] = [
373-
r"\d+.mlp.dense_4h_to_h.weight",
374-
r"\d+.mlp.self_attention.dense.weight",
375-
]
372+
# Parameter that are sliced on the row dimension
373+
info[PARAMETER_WITH_ROW_PARALLELISM_PATTERNS] = [
374+
r"\d+.mlp.dense_4h_to_h.weight",
375+
r"\d+.mlp.self_attention.dense.weight",
376+
]
376377

377378
return info

0 commit comments

Comments
 (0)