Skip to content

Commit 3727e5a

Browse files
Merge remote-tracking branch 'origin/feature/checkpoint' into feature/checkpoint
2 parents a397d11 + 63ff8c4 commit 3727e5a

File tree

3 files changed

+5
-7
lines changed

3 files changed

+5
-7
lines changed

dlio_benchmark/configs/workload/megatron_deepspeed.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
model: unet3d
1+
# 8 node run with 4 GPUs per node and TPSIZE=4 and PPSIZE=8
2+
model: megatron_deepspeed
23

34
framework: pytorch
45

dlio_benchmark/data_generator/hdf5_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def generate(self):
4444
"""
4545
super().generate()
4646
np.random.seed(10)
47-
samples_per_iter=max(1, int(32*1024*1024/self._args.record_length))
47+
samples_per_iter=max(1, int(self._args.generation_buffer_size/self._args.record_length))
4848
record_labels = [0] * self.num_samples
4949
for i in dlp.iter(range(self.my_rank, int(self.total_files_to_generate), self.comm_size)):
5050
progress(i, self.total_files_to_generate, "Generating HDF5 Data")

dlio_benchmark/framework/tf_framework.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,8 @@ def __init__(self, profiling):
5858
if self.args.checkpoint_type == CheckpointLocationType.RANK_ZERO:
5959
rank_to_checkpoint = 0
6060
if rank_to_checkpoint == self.args.my_rank:
61-
num_ranks = 1
62-
if self.args.checkpoint_type == CheckpointLocationType.RANK_ZERO:
63-
num_ranks = self.args.comm_size
6461
if self.args.model_size > 0:
65-
self.model_state = {"a": self._get_tensor(self.args.model_size*num_ranks)}
62+
self.model_state = {"a": self._get_tensor(self.args.model_size)}
6663
self.optimization_state = None
6764
if len(self.args.optimization_groups) > 0:
6865
self.optimization_state = dict()
@@ -78,7 +75,7 @@ def __init__(self, profiling):
7875
self.layer_state = dict()
7976
for index, state in enumerate(self.args.layer_parameters):
8077
if state > 0:
81-
self.layer_state[str(index)] = self._get_tensor(state*num_ranks)
78+
self.layer_state[str(index)] = self._get_tensor(state)
8279

8380
def _get_tensor(self, size):
8481
return tf.random.uniform((int(size / 4),), maxval=100, dtype=tf.dtypes.int32)

0 commit comments

Comments
 (0)