Skip to content

Commit 8673d46

Browse files
authored
deploy elastic error handler (#258)
1 parent 58d9204 commit 8673d46

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

pretrain_gpt.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@
3232
from deepspeed.runtime.utils import see_memory_usage
3333
import os
3434

35+
try:
36+
from torch.distributed.elastic.multiprocessing.errors import record
37+
except ImportError:
38+
# noop
39+
def record(fn):
40+
return fn
3541

3642
def model_provider(pre_process=True, post_process=True):
3743
"""Build the model."""
@@ -234,7 +240,10 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
234240
print_rank_0("> finished creating GPT datasets ...")
235241
return train_ds, valid_ds, test_ds
236242

237-
238-
if __name__ == "__main__":
243+
@record
244+
def main():
239245
pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
240246
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
247+
248+
if __name__ == "__main__":
249+
main()

0 commit comments

Comments
 (0)