Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions dlio_benchmark/data_generator/generator_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@

from dlio_benchmark.common.enumerations import FormatType
from dlio_benchmark.common.error_code import ErrorCodes

from dlio_benchmark.common.enumerations import StorageType


class GeneratorFactory(object):
def __init__(self):
pass

@staticmethod
def get_generator(type):
def get_generator(type, storage_type):
if type == FormatType.TFRECORD:
from dlio_benchmark.data_generator.tf_generator import TFRecordGenerator
return TFRecordGenerator()
Expand All @@ -36,8 +36,12 @@ def get_generator(type):
from dlio_benchmark.data_generator.csv_generator import CSVGenerator
return CSVGenerator()
elif type == FormatType.NPZ:
from dlio_benchmark.data_generator.npz_generator import NPZGenerator
return NPZGenerator()
if storage_type == StorageType.S3:
from dlio_benchmark.data_generator.npz_s3_generator import NPZS3Generator
return NPZS3Generator()
else:
from dlio_benchmark.data_generator.npz_generator import NPZGenerator
return NPZGenerator()
elif type == FormatType.NPY:
from dlio_benchmark.data_generator.npy_generator import NPYGenerator
return NPYGenerator()
Expand All @@ -54,4 +58,4 @@ def get_generator(type):
from dlio_benchmark.data_generator.indexed_binary_generator import IndexedBinaryGenerator
return IndexedBinaryGenerator()
else:
raise Exception(str(ErrorCodes.EC1001))
raise Exception(str(ErrorCodes.EC1001))
1 change: 1 addition & 0 deletions dlio_benchmark/data_generator/npz_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import logging
import numpy as np
import io

from dlio_benchmark.utils.utility import progress, utcnow
from dlio_benchmark.utils.utility import Profile
Expand Down
64 changes: 64 additions & 0 deletions dlio_benchmark/data_generator/npz_s3_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
Copyright (c) 2025 Dell Inc, or its subsidiaries.
Copyright (c) 2024, UChicago Argonne, LLC
All Rights Reserved

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from dlio_benchmark.common.enumerations import Compression
from dlio_benchmark.data_generator.data_generator import DataGenerator

import logging
import numpy as np
import io

from dlio_benchmark.utils.utility import progress, utcnow
from dlio_benchmark.utils.utility import Profile
from shutil import copyfile
from dlio_benchmark.common.constants import MODULE_DATA_GENERATOR
from dlio_benchmark.storage.storage_factory import StorageFactory

dlp = Profile(MODULE_DATA_GENERATOR)

"""
Generator for creating data in NPZ format.
"""
class NPZS3Generator(DataGenerator):
def __init__(self):
super().__init__()
self.storage = StorageFactory().get_storage(self._args.storage_type, self._args.storage_root, self._args.framework)

@dlp.log
def generate(self):
"""
Generator for creating data in NPZ format of 3d dataset.
"""
super().generate()
np.random.seed(10)
record_labels = [0] * self.num_samples
dim = self.get_dimension(self.total_files_to_generate)
for i in dlp.iter(range(self.my_rank, int(self.total_files_to_generate), self.comm_size)):
dim1 = dim[2*i]
dim2 = dim[2*i+1]
records = np.random.randint(255, size=(dim1, dim2, self.num_samples), dtype=np.uint8)
out_path_spec = self.storage.get_uri(self._file_list[i])
progress(i+1, self.total_files_to_generate, "Generating NPZ Data")
prev_out_spec = out_path_spec
buffer = io.BytesIO()
if self.compression != Compression.ZIP:
np.savez(buffer, x=records, y=record_labels)
else:
np.savez_compressed(buffer, x=records, y=record_labels)
self.storage.put_data(out_path_spec, buffer)
np.random.seed()
2 changes: 1 addition & 1 deletion dlio_benchmark/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __init__(self, cfg):
self.profiler = ProfilerFactory().get_profiler(self.args.profiler)

if self.args.generate_data:
self.data_generator = GeneratorFactory.get_generator(self.args.format)
self.data_generator = GeneratorFactory.get_generator(self.args.format, self.args.storage_type)
# Checkpointing support
self.do_checkpoint = self.args.do_checkpoint
self.steps_between_checkpoints = self.args.steps_between_checkpoints
Expand Down
2 changes: 1 addition & 1 deletion dlio_benchmark/reader/npz_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
"""
import numpy as np

import io
from dlio_benchmark.common.constants import MODULE_DATA_READER
from dlio_benchmark.reader.reader_handler import FormatReader
from dlio_benchmark.utils.utility import Profile

dlp = Profile(MODULE_DATA_READER)


class NPZReader(FormatReader):
"""
Reader for NPZ files
Expand Down
72 changes: 72 additions & 0 deletions dlio_benchmark/reader/npz_s3_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
Copyright (c) 2024, UChicago Argonne, LLC
All Rights Reserved

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np

import io
from dlio_benchmark.common.constants import MODULE_DATA_READER
from dlio_benchmark.reader.reader_handler import FormatReader
from dlio_benchmark.utils.utility import Profile
from dlio_benchmark.storage.storage_factory import StorageFactory

dlp = Profile(MODULE_DATA_READER)


class NPZS3Reader(FormatReader):
"""
Reader for NPZ files using S3 protocol
"""

@dlp.log_init
def __init__(self, dataset_type, thread_index, epoch):
super().__init__(dataset_type, thread_index)
self.storage = StorageFactory().get_storage(self._args.storage_type, self._args.storage_root, self._args.framework)

@dlp.log
def open(self, filename):
data = self.storage.get_data(filename)
image = io.BytesIO(data)
return np.load(image, allow_pickle=True)["x"]

@dlp.log
def close(self, filename):
super().close(filename)

@dlp.log
def get_sample(self, filename, sample_index):
super().get_sample(filename, sample_index)
image = self.open_file_map[filename][..., sample_index]
dlp.update(image_size=image.nbytes)

def next(self):
for batch in super().next():
yield batch

@dlp.log
def read_index(self, image_idx, step):
dlp.update(step=step)
return super().read_index(image_idx, step)

@dlp.log
def finalize(self):
return super().finalize()

def is_index_based(self):
return True

def is_iterator_based(self):
return True

7 changes: 5 additions & 2 deletions dlio_benchmark/reader/reader_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from dlio_benchmark.common.enumerations import FormatType, DataLoaderType
from dlio_benchmark.common.error_code import ErrorCodes

from dlio_benchmark.common.enumerations import StorageType

class ReaderFactory(object):
def __init__(self):
Expand Down Expand Up @@ -61,6 +61,9 @@ def get_reader(type, dataset_type, thread_index, epoch_number):
elif type == FormatType.NPZ:
if _args.data_loader == DataLoaderType.NATIVE_DALI:
raise Exception("Loading data of %s format is not supported without framework data loader; please use npy format instead." %type)
elif _args.storage_type == StorageType.S3:
from dlio_benchmark.reader.npz_s3_reader import NPZS3Reader
return NPZS3Reader(dataset_type, thread_index, epoch_number)
else:
from dlio_benchmark.reader.npz_reader import NPZReader
return NPZReader(dataset_type, thread_index, epoch_number)
Expand All @@ -82,4 +85,4 @@ def get_reader(type, dataset_type, thread_index, epoch_number):
return SyntheticReader(dataset_type, thread_index, epoch_number)

else:
raise Exception("Loading data of %s format is not supported without framework data loader" %type)
raise Exception("Loading data of %s format is not supported without framework data loader" %type)
1 change: 0 additions & 1 deletion dlio_benchmark/reader/reader_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(self, dataset_type, thread_index):
f"Loading {self.__class__.__qualname__} reader on thread {self.thread_index} from rank {self._args.my_rank}")
self.dataset_type = dataset_type
self.open_file_map = {}

if FormatReader.read_images is None:
FormatReader.read_images = 0
self.step = 1
Expand Down
83 changes: 82 additions & 1 deletion dlio_benchmark/storage/s3_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,92 @@
from dlio_benchmark.storage.storage_handler import DataStorage, Namespace
from dlio_benchmark.common.enumerations import NamespaceType, MetadataType
import os
import boto3
from botocore.exceptions import ClientError

from dlio_benchmark.utils.utility import Profile

dlp = Profile(MODULE_STORAGE)

class S3PytorchStorage(DataStorage):
"""
PyTorch Storage APIs for creating files.
It uses Boto3 client to read and write data
"""

@dlp.log_init
def __init__(self, namespace, framework=None):
super().__init__(framework)
self.namespace = Namespace(namespace, NamespaceType.FLAT)
self.s3_client = boto3.client('s3')


@dlp.log
def get_uri(self, id):
return id

@dlp.log
def create_namespace(self, exist_ok=False):
# Assume the S3 bucket is exist
return True

@dlp.log
def get_namespace(self):
return self.get_node(self.namespace.name)

@dlp.log
def create_node(self, id, exist_ok=False):
return super().create_node(self.get_uri(id), exist_ok)

@dlp.log
def get_node(self, id=""):
return super().get_node(self.get_uri(id))

@dlp.log
def walk_node(self, id, use_pattern=False):
return self.list_objects(self.namespace.name, id)

@dlp.log
def put_data(self, id, data, offset=None, length=None):
self.s3_client.put_object(Bucket=self.namespace.name, Key=id, Body=data.getvalue())
return None

@dlp.log
def get_data(self, id, offset=None, length=None):
obj_name = os.path.relpath(id)
if offset:
byte_range = f"bytes={offset}-{offset + length - 1}"
return self.s3_client.get_object(Bucket=self.namespace.name, Key=id, Range=byte_range)['Body'].read()
else:
return self.s3_client.get_object(Bucket=self.namespace.name, Key=obj_name)['Body'].read()


@dlp.log
def list_objects(self, bucket_name, prefix=None):
params = {'Bucket': bucket_name}
if prefix:
params['Prefix'] = prefix
paths = []
try:
## Need to implement pagination
response = self.s3_client.list_objects_v2(**params)

if 'Contents' in response:
for key in response['Contents']:
paths.append(key['Key'][len(prefix)+1:])
except self.s3_client.exceptions.NoSuchBucket:
print(f"Bucket '{bucket_name}' does not exist.")

return paths


@dlp.log
def delete_node(self, id):
return super().delete_node(self.get_uri(id))

def get_basename(self, id):
return os.path.basename(id)


class S3Storage(DataStorage):
"""
Expand Down Expand Up @@ -73,4 +154,4 @@ def get_data(self, id, data, offset=None, length=None):
return super().get_data(self.get_uri(id), data, offset, length)

def get_basename(self, id):
return os.path.basename(id)
return os.path.basename(id)
7 changes: 6 additions & 1 deletion dlio_benchmark/storage/storage_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from dlio_benchmark.storage.s3_storage import S3Storage
from dlio_benchmark.common.enumerations import StorageType
from dlio_benchmark.common.error_code import ErrorCodes
from dlio_benchmark.common.enumerations import FrameworkType
from dlio_benchmark.storage.s3_storage import S3PytorchStorage

class StorageFactory(object):
def __init__(self):
Expand All @@ -28,6 +30,9 @@ def get_storage(storage_type, namespace, framework=None):
if storage_type == StorageType.LOCAL_FS:
return FileStorage(namespace, framework)
elif storage_type == StorageType.S3:
return S3Storage(namespace, framework)
if framework == FrameworkType.PYTORCH:
return S3PytorchStorage(namespace, framework)
else:
return S3Storage(namespace, framework)
else:
raise Exception(str(ErrorCodes.EC1001))
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ tensorflow>=2.11.0
torch>=2.2.0
torchaudio
torchvision
boto3
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"pandas>=1.5.1",
"psutil>=5.9.8",
"pydftracer==1.0.8",
"boto3",
]
x86_deps = [
f"hydra-core>={HYDRA_VERSION}",
Expand Down
Loading