Skip to content
14 changes: 13 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -291,4 +291,16 @@ jobs:
run: |
source ${VENV_PATH}/bin/activate
rm -rf output data checkpoints
mpirun -np 2 ${DLIO_EXEC} workload=llama_8b_zero3 ++workload.model.parallelism.data=1024 ++workload.checkpoint.mode=subset
mpirun -np 2 ${DLIO_EXEC} workload=llama_8b_zero3 ++workload.model.parallelism.data=1024 ++workload.checkpoint.mode=subset
- name: test_model_comms
run: |
source ${VENV_PATH}/bin/activate
rm -rf output data checkpoints
mpirun -np 2 pytest -k test_resnet_model_with_comms_enabled -v
rm -rf data
- name: test_model_compute
run: |
source ${VENV_PATH}/bin/activate
rm -rf output data checkpoints
mpirun -np 2 pytest -k test_resnet_model_with_compute_enabled -v
rm -rf data
1 change: 1 addition & 0 deletions dlio_benchmark/checkpointing/base_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class BaseCheckpointing(ABC):

def __init__(self, ext):
#TODO(Huihuo): Add support for checkpointing rng states for transformer type of architecture
#TODO: Consider actual model instances - Model.SLEEP is default
self.ext = ext
self.args = ConfigArguments.get_instance()
self.checkpoint_storage = StorageFactory().get_storage(self.args.storage_type, self.args.checkpoint_folder,
Expand Down
21 changes: 21 additions & 0 deletions dlio_benchmark/common/enumerations.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,27 @@ class FrameworkType(Enum):

def __str__(self):
return self.value

class Model(Enum):
"""
Different Model Architectures
"""
RESNET = 'resnet50'
UNET= 'unet3d'
BERT = 'bert'
SLEEP = 'sleep'
DEFAULT = 'default'

def __str__(self):
return self.value

class Loss(Enum):
"""
Loss functions for models
"""
MSE = 'mse'
CE = 'cross_entropy'
NONE = 'none'

class ComputationType(Enum):
"""
Expand Down
4 changes: 2 additions & 2 deletions dlio_benchmark/configs/workload/resnet50_a100.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ workflow:
train: True

dataset:
num_files_train: 1024
num_samples_per_file: 1251
num_files_train: 2
num_samples_per_file: 800
record_length_bytes: 114660.07
record_length_bytes_resize: 150528
data_folder: data/resnet50
Expand Down
7 changes: 4 additions & 3 deletions dlio_benchmark/framework/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,10 @@ def stop_framework_profiler(self):
@abstractmethod
def trace_object(self, string, step, r):
pass

def model(epoch, batch, computation_time):
sleep(computation_time)

@abstractmethod
def model(self, epoch, batch, computation_time):
pass

@abstractmethod
def compute(self, batch, epoch_number, step, computation_time):
Expand Down
8 changes: 4 additions & 4 deletions dlio_benchmark/framework/framework_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
limitations under the License.
"""

from dlio_benchmark.common.enumerations import FrameworkType
from dlio_benchmark.common.enumerations import FrameworkType, Model
from dlio_benchmark.common.error_code import ErrorCodes


Expand All @@ -24,12 +24,12 @@ def __init__(self):
pass

@staticmethod
def get_framework(framework_type, profiling):
def get_framework(framework_type, profiling, model: Model = Model.SLEEP, communication: bool = False):
if framework_type == FrameworkType.TENSORFLOW:
from dlio_benchmark.framework.tf_framework import TFFramework
return TFFramework.get_instance(profiling)
return TFFramework.get_instance(profiling, model, communication)
elif framework_type == FrameworkType.PYTORCH:
from dlio_benchmark.framework.torch_framework import TorchFramework
return TorchFramework.get_instance(profiling)
return TorchFramework.get_instance(profiling, model, communication)
else:
raise Exception(str(ErrorCodes.EC1001))
25 changes: 20 additions & 5 deletions dlio_benchmark/framework/tf_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
from time import time, sleep
from dlio_benchmark.common.constants import MODULE_AI_FRAMEWORK
from dlio_benchmark.data_loader.data_loader_factory import DataLoaderFactory
from dlio_benchmark.model.model_factory import ModelFactory
from dlio_benchmark.utils.utility import utcnow, DLIOMPI
from dlio_benchmark.utils.utility import Profile, sleep
from dlio_benchmark.common.error_code import ErrorCodes
from dlio_benchmark.framework.framework import Framework
from dlio_benchmark.reader.reader_factory import ReaderFactory
from dlio_benchmark.profiler.profiler_factory import ProfilerFactory
from dlio_benchmark.storage.storage_factory import StorageFactory
from dlio_benchmark.common.enumerations import FrameworkType, Profiler, FormatType, DatasetType, MetadataType, \
from dlio_benchmark.common.enumerations import FrameworkType, Model, Profiler, FormatType, DatasetType, MetadataType, \
DataLoaderType

import tensorflow as tf
Expand All @@ -43,15 +44,19 @@ class TFFramework(Framework):
__instance = None

@dlp.log_init
def __init__(self, profiling):
def __init__(self, profiling, model: Model = Model.SLEEP, communication: bool = False):
super().__init__()
self.profiling = profiling
self._model = ModelFactory.create_model(FrameworkType.TENSORFLOW, model, communication, gpu_id=DLIOMPI.get_instance().local_rank())
# TODO: Temporary fix, need to separate the iostat profiler (needed for report gen) and the others
if profiling:
if self.args.profiler != Profiler.IOSTAT:
self.tensorboard = ProfilerFactory.get_profiler(Profiler.NONE)
else:
self.tensorboard = ProfilerFactory.get_profiler(Profiler.TENSORBOARD)


# self.model = DDP(model)
self.reader_handler = None

@dlp.log
Expand All @@ -64,10 +69,10 @@ def get_type(self):
return FrameworkType.TENSORFLOW

@staticmethod
def get_instance(profiling):
def get_instance(profiling, model: Model = Model.SLEEP, communication: bool = False):
""" Static access method. """
if TFFramework.__instance is None:
TFFramework.__instance = TFFramework(profiling)
TFFramework.__instance = TFFramework(profiling, model, communication)
return TFFramework.__instance

@dlp.log
Expand All @@ -87,9 +92,18 @@ def trace_object(self, string, step, r):

@dlp.log
def compute(self, batch, epoch_number, step, computation_time):
return self.model(batch, computation_time)
return self.model(epoch_number, batch, computation_time)
# tf.function(self.model)(epoch_number, step, computation_time)


def model(self, epoch, batch, computation_time):
if self._model is None:
sleep(computation_time
)
else:
self._model.compute(batch)


@dlp.log
def get_loader(self, dataset_type=DatasetType.TRAIN):
if dataset_type == DatasetType.TRAIN:
Expand Down Expand Up @@ -145,3 +159,4 @@ def get_data(self, id, data, offset=None, length=None):
@dlp.log
def isfile(self, id):
return tf.io.gfile.exists(id) and not tf.io.gfile.isdir(id)

57 changes: 37 additions & 20 deletions dlio_benchmark/framework/torch_framework.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,36 @@
"""
Copyright (c) 2025, UChicago Argonne, LLC
All Rights Reserved
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Copyright (c) 2025, UChicago Argonne, LLC
All Rights Reserved

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from dlio_benchmark.common.error_code import ErrorCodes
from dlio_benchmark.common.enumerations import FormatType, FrameworkType, DatasetType, DataLoaderType
from dlio_benchmark.common.enumerations import (
FormatType,
FrameworkType,
DatasetType,
DataLoaderType,
Model,
)
from dlio_benchmark.data_loader.data_loader_factory import DataLoaderFactory
from dlio_benchmark.framework.framework import Framework, DummyTraceObject
from dlio_benchmark.common.constants import MODULE_AI_FRAMEWORK
import os
import torch
import functools
import logging
from dlio_benchmark.model.model_factory import ModelFactory
from dlio_benchmark.utils.utility import utcnow, DLIOMPI
from dlio_benchmark.utils.utility import Profile

Expand Down Expand Up @@ -58,10 +65,12 @@ class TorchFramework(Framework):
__instance = None

@dlp.log_init
def __init__(self, profiling):
def __init__(self, profiling, model: Model = Model.SLEEP, communication: bool = False):
super().__init__()
self.profiling = profiling
self.reader_handler = None
# TODO: Check if we need to add config for gpu, use local_rank or 0 always maybe
self._model = ModelFactory.create_model(FrameworkType.PYTORCH, model, communication, gpu_id=DLIOMPI.get_instance().local_rank())

@dlp.log
def init_loader(self, format_type, epoch=0, data_loader=None):
Expand All @@ -74,10 +83,10 @@ def get_type(self):
return FrameworkType.PYTORCH

@staticmethod
def get_instance(profiling):
""" Static access method. """
def get_instance(profiling, model: Model = Model.SLEEP, communication: bool = False):
"""Static access method."""
if TorchFramework.__instance is None:
TorchFramework.__instance = TorchFramework(profiling)
TorchFramework.__instance = TorchFramework(profiling, model, communication)
return TorchFramework.__instance

@dlp.log
Expand All @@ -94,7 +103,15 @@ def trace_object(self, string, step, r):

@dlp.log
def compute(self, batch, epoch_number, step, computation_time):
return self.model(batch, computation_time)
return self.model(epoch_number, batch, computation_time)

def model(self, epoch, batch, computation_time):
if self._model is None:
print("sleeping")
sleep(computation_time)
else:
print("Using model to compute")
self._model.compute(batch)

@dlp.log
def get_loader(self, dataset_type=DatasetType.TRAIN):
Expand Down
20 changes: 13 additions & 7 deletions dlio_benchmark/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from dlio_benchmark.utils.statscounter import StatsCounter
from hydra.core.config_store import ConfigStore
from dlio_benchmark.utils.config import LoadConfig, ConfigArguments, GetConfig
from dlio_benchmark.common.enumerations import Profiler, DatasetType, StorageType, MetadataType, FormatType
from dlio_benchmark.common.enumerations import Model, Profiler, DatasetType, StorageType, MetadataType, FormatType
from dlio_benchmark.profiler.profiler_factory import ProfilerFactory
from dlio_benchmark.framework.framework_factory import FrameworkFactory
from dlio_benchmark.data_generator.generator_factory import GeneratorFactory
Expand Down Expand Up @@ -73,10 +73,8 @@ def __init__(self, cfg):
global dftracer, dftracer_initialize, dftracer_finalize

t0 = time()
self.args = ConfigArguments.get_instance()
self.args : ConfigArguments = ConfigArguments.get_instance() # type: ignore
LoadConfig(self.args, cfg)
self.storage = StorageFactory().get_storage(self.args.storage_type, self.args.storage_root,
self.args.framework)

self.output_folder = self.args.output_folder
os.makedirs(self.args.output_folder, mode=0o755, exist_ok=True)
Expand All @@ -85,10 +83,18 @@ def __init__(self, cfg):
self.comm_size = self.args.comm_size = DLIOMPI.get_instance().size()
self.data_folder = self.args.data_folder
self.storage_root = self.args.storage_root
try:
model_enum = Model(self.args.model)
except:
model_enum = Model.DEFAULT
if not self.args.compute:
model_enum = Model.DEFAULT
self.framework = FrameworkFactory().get_framework(self.args.framework,
self.args.do_profiling, model_enum, self.args.communication)
self.storage = StorageFactory().get_storage(self.args.storage_type, self.args.storage_root,
self.args.framework)
if self.args.storage_root:
self.storage.create_namespace(exist_ok=True)
self.framework = FrameworkFactory().get_framework(self.args.framework,
self.args.do_profiling)

# Delete previous logfile
if self.my_rank == 0:
Expand Down Expand Up @@ -126,10 +132,10 @@ def __init__(self, cfg):
self.num_subfolders_eval = self.args.num_subfolders_eval
self.num_samples = self.args.num_samples_per_file
self.total_training_steps = self.args.total_training_steps
self.computation_time = self.args.computation_time

self.epochs = self.args.epochs
self.batch_size = self.args.batch_size
self.computation_time = self.args.computation_time

if self.do_profiling:
self.profiler = ProfilerFactory().get_profiler(self.args.profiler)
Expand Down
2 changes: 2 additions & 0 deletions dlio_benchmark/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from dlio_benchmark.model.model_factory import ModelFactory
from dlio_benchmark.model.model import UnifiedModel
Loading