Skip to content

Commit 03796ad

Browse files
fixed checkpointing for tensors
1 parent 6221b33 commit 03796ad

File tree

3 files changed

+10
-10
lines changed

3 files changed

+10
-10
lines changed

dlio_benchmark/configs/workload/megatron_deepspeed.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ dataset:
1717
reader:
1818
data_loader: pytorch
1919
batch_size: 1024
20-
read_threads: 8
20+
read_threads: 1
2121
file_shuffle: seed
2222
sample_shuffle: seed
2323

@@ -27,7 +27,7 @@ train:
2727

2828
checkpoint:
2929
checkpoint_folder: checkpoints/megatron-deepspeed
30-
checkpoint_after_epoch: 1000
30+
steps_between_checkpoints: 1000
3131
model_size: 30102
3232
type: independent
3333
optimization_groups: [1009254400, 865075200, 793600]

dlio_benchmark/framework/tf_framework.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,13 @@ def __init__(self, profiling):
6666
self.optimization_state = None
6767
if len(self.args.optimization_groups) > 0:
6868
self.optimization_state = dict()
69-
tensor_array = []
69+
tensor_array_size = 0
7070
for index, state in enumerate(self.args.optimization_groups):
7171
if state > 0:
72-
self.optimization_state[str(index)] = {'a': self._get_tensor(state*num_ranks),
73-
'b': self._get_tensor(state*num_ranks)}
74-
tensor_array.append(self._get_tensor(state*num_ranks))
75-
self.optimization_state["combined"] = tensor_array
72+
self.optimization_state[str(index)] = {'a': self._get_tensor(state),
73+
'b': self._get_tensor(state)}
74+
tensor_array_size += state
75+
self.optimization_state["combined"] = self._get_tensor(tensor_array_size)
7676
self.layer_state = None
7777
if len(self.args.layer_parameters) > 0:
7878
self.layer_state = dict()

dlio_benchmark/framework/torch_framework.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,12 @@ def __init__(self, profiling):
7171
self.optimization_state = None
7272
if len(self.args.optimization_groups) > 0:
7373
self.optimization_state = dict()
74-
tensor_array = []
74+
tensor_array_size = 0
7575
for index, state in enumerate(self.args.optimization_groups):
7676
if state > 0:
7777
self.optimization_state[str(index)] = {'a': self._get_tensor(state), 'b': self._get_tensor(state)}
78-
tensor_array.append(self._get_tensor(state))
79-
self.optimization_state["combined"] = tensor_array
78+
tensor_array_size += state
79+
self.optimization_state["combined"] = self._get_tensor(tensor_array_size)
8080
self.layer_state = None
8181
if len(self.args.layer_parameters) > 0:
8282
self.layer_state = dict()

0 commit comments

Comments
 (0)