Skip to content

Commit b3f4427

Browse files
Merge remote-tracking branch 'origin/feature/checkpoint' into feature/checkpoint
2 parents a397d11 + 63ff8c4 commit b3f4427

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

dlio_benchmark/configs/workload/megatron_deepspeed.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
model: unet3d
1+
# 8 node run with 4 GPUs per node and TPSIZE=4 and PPSIZE=8
2+
model: megatron_deepspeed
23

34
framework: pytorch
45

dlio_benchmark/framework/tf_framework.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,8 @@ def __init__(self, profiling):
5858
if self.args.checkpoint_type == CheckpointLocationType.RANK_ZERO:
5959
rank_to_checkpoint = 0
6060
if rank_to_checkpoint == self.args.my_rank:
61-
num_ranks = 1
62-
if self.args.checkpoint_type == CheckpointLocationType.RANK_ZERO:
63-
num_ranks = self.args.comm_size
6461
if self.args.model_size > 0:
65-
self.model_state = {"a": self._get_tensor(self.args.model_size*num_ranks)}
62+
self.model_state = {"a": self._get_tensor(self.args.model_size)}
6663
self.optimization_state = None
6764
if len(self.args.optimization_groups) > 0:
6865
self.optimization_state = dict()

0 commit comments

Comments
 (0)