Skip to content

Commit ddc92ff

Browse files
refactor enum for better naming
1 parent fc42c86 commit ddc92ff

File tree

6 files changed

+22
-21
lines changed

6 files changed

+22
-21
lines changed

dlio_benchmark/common/enumerations.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717

1818
from enum import Enum
1919

20-
class CheckpointType(Enum):
20+
class CheckpointLocationType(Enum):
2121
"""
2222
Different types of underlying storage
2323
"""
24-
COLLECTIVE = 'collective'
25-
INDEPENDENT = 'independent'
24+
RANK_ZERO = 'rank_zero'
25+
ALL_RANKS = 'all_ranks'
2626

2727
def __str__(self):
2828
return self.value

dlio_benchmark/configs/workload/megatron_deepspeed.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ checkpoint:
2929
checkpoint_folder: checkpoints/megatron-deepspeed
3030
steps_between_checkpoints: 1000
3131
model_size: 30102
32-
type: independent
32+
type: all_ranks
3333
optimization_groups: [1009254400, 865075200, 793600]
3434
num_layers: 44
3535
layer_parameters: [129761280, 20971520]

dlio_benchmark/framework/tf_framework.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from dlio_benchmark.profiler.profiler_factory import ProfilerFactory
3030
from dlio_benchmark.storage.storage_factory import StorageFactory
3131
from dlio_benchmark.common.enumerations import FrameworkType, Profiler, FormatType, DatasetType, MetadataType, \
32-
DataLoaderType, CheckpointType
32+
DataLoaderType, CheckpointLocationType
3333

3434
import tensorflow as tf
3535
from tensorflow.python.framework import errors
@@ -55,11 +55,11 @@ def __init__(self, profiling):
5555
self.reader_handler = None
5656
self.model_state = None
5757
rank_to_checkpoint = self.args.my_rank
58-
if self.args.checkpoint_type == CheckpointType.COLLECTIVE:
58+
if self.args.checkpoint_type == CheckpointLocationType.RANK_ZERO:
5959
rank_to_checkpoint = 0
6060
if rank_to_checkpoint == self.args.my_rank:
6161
num_ranks = 1
62-
if self.args.checkpoint_type == CheckpointType.COLLECTIVE:
62+
if self.args.checkpoint_type == CheckpointLocationType.RANK_ZERO:
6363
num_ranks = self.args.comm_size
6464
if self.args.model_size > 0:
6565
self.model_state = {"a": self._get_tensor(self.args.model_size*num_ranks)}
@@ -121,7 +121,7 @@ def checkpoint(self, epoch, step_number):
121121
"""
122122
my_rank = DLIOMPI.get_instance().rank()
123123
rank_to_checkpoint = my_rank
124-
if self.args.checkpoint_type == CheckpointType.COLLECTIVE:
124+
if self.args.checkpoint_type == CheckpointLocationType.RANK_ZERO:
125125
rank_to_checkpoint = 0
126126
if rank_to_checkpoint == my_rank:
127127
if self.model_state:

dlio_benchmark/framework/torch_framework.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"""
1717

1818
from dlio_benchmark.common.error_code import ErrorCodes
19-
from dlio_benchmark.common.enumerations import FormatType, FrameworkType, DatasetType, DataLoaderType, CheckpointType
19+
from dlio_benchmark.common.enumerations import FormatType, FrameworkType, DatasetType, DataLoaderType, CheckpointLocationType
2020
from dlio_benchmark.data_loader.data_loader_factory import DataLoaderFactory
2121
from dlio_benchmark.framework.framework import Framework, DummyTraceObject
2222
from dlio_benchmark.common.constants import MODULE_AI_FRAMEWORK
@@ -62,7 +62,7 @@ def __init__(self, profiling):
6262
self.profiling = profiling
6363
self.reader_handler = None
6464
rank_to_checkpoint = self.args.my_rank
65-
if self.args.checkpoint_type == CheckpointType.COLLECTIVE:
65+
if self.args.checkpoint_type == CheckpointLocationType.RANK_ZERO:
6666
rank_to_checkpoint = 0
6767
if rank_to_checkpoint == self.args.my_rank:
6868
self.model_state = None
@@ -120,7 +120,7 @@ def trace_object(self, string, step, r):
120120
def checkpoint(self, epoch, step_number):
121121

122122
rank_to_checkpoint = DLIOMPI.get_instance().rank()
123-
if self.args.checkpoint_type == CheckpointType.COLLECTIVE:
123+
if self.args.checkpoint_type == CheckpointLocationType.RANK_ZERO:
124124
rank_to_checkpoint = 0
125125
if rank_to_checkpoint == DLIOMPI.get_instance().rank():
126126
my_rank = DLIOMPI.get_instance().rank()

dlio_benchmark/utils/config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626

2727
from dlio_benchmark.common.constants import MODULE_CONFIG
2828
from dlio_benchmark.common.enumerations import StorageType, FormatType, Shuffle, ReadType, FileAccess, Compression, \
29-
FrameworkType, DataLoaderType, Profiler, DatasetType, DataLoaderSampler, CheckpointType
29+
FrameworkType, \
30+
DataLoaderType, Profiler, DatasetType, DataLoaderSampler, CheckpointLocationType
3031
from dlio_benchmark.utils.utility import DLIOMPI
3132
from dataclasses import dataclass
3233
import math
@@ -99,7 +100,7 @@ class ConfigArguments:
99100
eval_time_stdev: float = 0.0
100101
eval_after_epoch: int = 1
101102
epochs_between_evals: int = 1
102-
checkpoint_type: CheckpointType = CheckpointType.COLLECTIVE
103+
checkpoint_type: CheckpointLocationType = CheckpointLocationType.RANK_ZERO
103104
model_size: int = 10240
104105
optimization_groups: ClassVar[List[int]] = []
105106
num_layers: int = 1
@@ -453,7 +454,7 @@ def LoadConfig(args, config):
453454
if 'steps_between_checkpoints' in config['checkpoint']:
454455
args.steps_between_checkpoints = config['checkpoint']['steps_between_checkpoints']
455456
if 'type' in config['checkpoint']:
456-
args.checkpoint_type = CheckpointType(config['checkpoint']['type'])
457+
args.checkpoint_type = CheckpointLocationType(config['checkpoint']['type'])
457458
if 'model_size' in config['checkpoint']:
458459
args.model_size = config['checkpoint']['model_size']
459460
if 'optimization_groups' in config['checkpoint']:

tests/dlio_benchmark_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -209,12 +209,12 @@ def test_iostat_profiling() -> None:
209209
clean()
210210

211211
@pytest.mark.timeout(60, method="thread")
212-
@pytest.mark.parametrize("framework, model_size, optimizers, num_layers, layer_params, type", [("tensorflow", 1024, [1024, 128], 2, [16], "independent"),
213-
("pytorch", 1024, [1024, 128], 2, [16], "independent"),
214-
("tensorflow", 1024, [1024, 128], 2, [16], "collective"),
215-
("pytorch", 1024, [1024, 128], 2, [16], "collective"),
216-
("tensorflow", 1024, [128], 1, [], "independent"),
217-
("pytorch", 1024, [128], 1, [], "independent")])
212+
@pytest.mark.parametrize("framework, model_size, optimizers, num_layers, layer_params, type", [("tensorflow", 1024, [1024, 128], 2, [16], "all_ranks"),
213+
("pytorch", 1024, [1024, 128], 2, [16], "all_ranks"),
214+
("tensorflow", 1024, [1024, 128], 2, [16], "rank_zero"),
215+
("pytorch", 1024, [1024, 128], 2, [16], "rank_zero"),
216+
("tensorflow", 1024, [128], 1, [], "all_ranks"),
217+
("pytorch", 1024, [128], 1, [], "all_ranks")])
218218
def test_checkpoint_epoch(framework, model_size, optimizers, num_layers, layer_params, type) -> None:
219219
clean()
220220
if comm.rank == 0:
@@ -249,7 +249,7 @@ def test_checkpoint_epoch(framework, model_size, optimizers, num_layers, layer_p
249249
if len(layer_params) > 0:
250250
n = num_layers
251251
nranks = 1
252-
if type == "independent":
252+
if type == "all_ranks":
253253
nranks = comm.size
254254
if framework == "tensorflow":
255255
num_check_files = 8 / 2 * (2 + 2 + 2*n) * nranks + 1

0 commit comments

Comments
 (0)