-
Notifications
You must be signed in to change notification settings - Fork 58
S3 support with Boto3 for Pytorch NPZ data a generator and dataloader #264
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
7c88b48
dc768a4
7842e1d
d2bf507
411dbce
b5ee362
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| """ | ||
| Copyright (c) 2025 Dell Inc, or its subsidiaries. | ||
| Copyright (c) 2024, UChicago Argonne, LLC | ||
| All Rights Reserved | ||
|
|
||
|
|
@@ -20,11 +21,13 @@ | |
|
|
||
| import logging | ||
| import numpy as np | ||
| import io | ||
|
|
||
| from dlio_benchmark.utils.utility import progress, utcnow | ||
| from dlio_benchmark.utils.utility import Profile | ||
| from shutil import copyfile | ||
| from dlio_benchmark.common.constants import MODULE_DATA_GENERATOR | ||
| from dlio_benchmark.common.enumerations import StorageType | ||
|
|
||
| dlp = Profile(MODULE_DATA_GENERATOR) | ||
|
|
||
|
|
@@ -51,8 +54,19 @@ def generate(self): | |
| out_path_spec = self.storage.get_uri(self._file_list[i]) | ||
| progress(i+1, self.total_files_to_generate, "Generating NPZ Data") | ||
| prev_out_spec = out_path_spec | ||
| if self.compression != Compression.ZIP: | ||
| np.savez(out_path_spec, x=records, y=record_labels) | ||
|
|
||
|
||
| if self._args.storage_type == StorageType.S3: | ||
| buffer = io.BytesIO() | ||
| if self.compression != Compression.ZIP: | ||
| np.savez(buffer, x=records, y=record_labels) | ||
| else: | ||
| np.savez_compressed(buffer, x=records, y=record_labels) | ||
| self.storage.put_data(out_path_spec, buffer) | ||
|
|
||
|
|
||
| else: | ||
| np.savez_compressed(out_path_spec, x=records, y=record_labels) | ||
| if self.compression != Compression.ZIP: | ||
| np.savez(out_path_spec, x=records, y=record_labels) | ||
| else: | ||
| np.savez_compressed(out_path_spec, x=records, y=record_labels) | ||
| np.random.seed() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| """ | ||
| Copyright (c) 2025 Dell Inc, or its subsidiaries. | ||
|
||
| Copyright (c) 2024, UChicago Argonne, LLC | ||
| All Rights Reserved | ||
|
|
||
|
|
@@ -16,9 +17,11 @@ | |
| """ | ||
| import numpy as np | ||
|
|
||
| import io | ||
| from dlio_benchmark.common.constants import MODULE_DATA_READER | ||
| from dlio_benchmark.reader.reader_handler import FormatReader | ||
| from dlio_benchmark.utils.utility import Profile | ||
| from dlio_benchmark.storage.s3_storage import S3PytorchStorage | ||
|
|
||
| dlp = Profile(MODULE_DATA_READER) | ||
|
|
||
|
|
@@ -34,6 +37,12 @@ def __init__(self, dataset_type, thread_index, epoch): | |
|
|
||
| @dlp.log | ||
| def open(self, filename): | ||
| if isinstance(self.storage, S3PytorchStorage): | ||
|
||
| print(filename) | ||
| data = self.storage.get_data(filename) | ||
| image = io.BytesIO(data) | ||
| return np.load(image, allow_pickle=True)["x"] | ||
|
|
||
| super().open(filename) | ||
| return np.load(filename, allow_pickle=True)['x'] | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| """ | ||
| Copyright (c) 2025 Dell Inc, or its subsidiaries. | ||
| Copyright (c) 2024, UChicago Argonne, LLC | ||
| All Rights Reserved | ||
|
|
||
|
|
@@ -44,7 +45,8 @@ def __init__(self, dataset_type, thread_index): | |
| f"Loading {self.__class__.__qualname__} reader on thread {self.thread_index} from rank {self._args.my_rank}") | ||
| self.dataset_type = dataset_type | ||
| self.open_file_map = {} | ||
|
|
||
| self.storage = StorageFactory().get_storage(self._args.storage_type, self._args.storage_root, | ||
|
||
| self._args.framework) | ||
| if FormatReader.read_images is None: | ||
| FormatReader.read_images = 0 | ||
| self.step = 1 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| """ | ||
| Copyright (c) 2025 Dell Inc, or its subsidiaries. | ||
|
||
| Copyright (c) 2024, UChicago Argonne, LLC | ||
| All Rights Reserved | ||
|
|
||
|
|
@@ -20,11 +21,92 @@ | |
| from dlio_benchmark.storage.storage_handler import DataStorage, Namespace | ||
| from dlio_benchmark.common.enumerations import NamespaceType, MetadataType | ||
| import os | ||
| import boto3 | ||
| from botocore.exceptions import ClientError | ||
|
|
||
| from dlio_benchmark.utils.utility import Profile | ||
|
|
||
| dlp = Profile(MODULE_STORAGE) | ||
|
|
||
| class S3PytorchStorage(DataStorage): | ||
| """ | ||
| PyTorch Storage APIs for creating files. | ||
| It uses Boto3 client to read and write data | ||
| """ | ||
|
|
||
| @dlp.log_init | ||
| def __init__(self, namespace, framework=None): | ||
| super().__init__(framework) | ||
| self.namespace = Namespace(namespace, NamespaceType.FLAT) | ||
| self.s3_client = boto3.client('s3') | ||
|
|
||
|
|
||
| @dlp.log | ||
| def get_uri(self, id): | ||
| return id | ||
|
|
||
| @dlp.log | ||
| def create_namespace(self, exist_ok=False): | ||
| # Assume the S3 bucket is exist | ||
| return True | ||
|
|
||
| @dlp.log | ||
| def get_namespace(self): | ||
| return self.get_node(self.namespace.name) | ||
|
|
||
| @dlp.log | ||
| def create_node(self, id, exist_ok=False): | ||
| return super().create_node(self.get_uri(id), exist_ok) | ||
|
|
||
| @dlp.log | ||
| def get_node(self, id=""): | ||
| return super().get_node(self.get_uri(id)) | ||
|
|
||
| @dlp.log | ||
| def walk_node(self, id, use_pattern=False): | ||
| return self.list_objects(self.namespace.name, id) | ||
|
|
||
| @dlp.log | ||
| def put_data(self, id, data, offset=None, length=None): | ||
| self.s3_client.put_object(Bucket=self.namespace.name, Key=id, Body=data.getvalue()) | ||
| return None | ||
|
|
||
| @dlp.log | ||
| def get_data(self, id, offset=None, length=None): | ||
| obj_name = os.path.relpath(id) | ||
| if offset: | ||
| byte_range = f"bytes={offset}-{offset + length - 1}" | ||
| return self.s3_client.get_object(Bucket=self.namespace.name, Key=id, Range=byte_range)['Body'].read() | ||
| else: | ||
| return self.s3_client.get_object(Bucket=self.namespace.name, Key=obj_name)['Body'].read() | ||
|
|
||
|
|
||
| @dlp.log | ||
| def list_objects(self, bucket_name, prefix=None): | ||
| params = {'Bucket': bucket_name} | ||
| if prefix: | ||
| params['Prefix'] = prefix | ||
| paths = [] | ||
| try: | ||
| ## Need to implement pagination | ||
| response = self.s3_client.list_objects_v2(**params) | ||
|
|
||
| if 'Contents' in response: | ||
| for key in response['Contents']: | ||
| paths.append(key['Key'][len(prefix)+1:]) | ||
| except self.s3_client.exceptions.NoSuchBucket: | ||
| print(f"Bucket '{bucket_name}' does not exist.") | ||
|
|
||
| return paths | ||
|
|
||
|
|
||
| @dlp.log | ||
| def delete_node(self, id): | ||
| return super().delete_node(self.get_uri(id)) | ||
|
|
||
| def get_basename(self, id): | ||
| return os.path.basename(id) | ||
|
|
||
|
|
||
| class S3Storage(DataStorage): | ||
| """ | ||
|
|
@@ -73,4 +155,4 @@ def get_data(self, id, data, offset=None, length=None): | |
| return super().get_data(self.get_uri(id), data, offset, length) | ||
|
|
||
| def get_basename(self, id): | ||
| return os.path.basename(id) | ||
| return os.path.basename(id) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,3 +16,4 @@ tensorflow>=2.11.0 | |
| torch>=2.2.0 | ||
| torchaudio | ||
| torchvision | ||
| boto3 | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We cannot have companies copyright here.