diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f7f40729..5146733c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,6 +4,7 @@ on: pull_request: branches: [main, dev] push: + workflow_dispatch: jobs: build-and-test: @@ -59,6 +60,8 @@ jobs: key: ${{ matrix.venv }}-gcc${{ matrix.gcc }}-python${{ matrix.python }}-${{ hashFiles('requirements.txt', 'setup.py') }} - name: Install system dependencies run: | + sudo rm -f /etc/apt/sources.list.d/microsoft-prod.list + sudo rm -f /etc/apt/sources.list.d/azure-cli.list sudo apt update sudo apt-get install -y $CC $CXX libc6 git sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev python3-dev @@ -381,3 +384,57 @@ jobs: mpirun -np 1 pytest -k test_aistore_multi_threads[pytorch-0] -v mpirun -np 1 pytest -k test_aistore_multi_threads[pytorch-1] -v mpirun -np 1 pytest -k test_aistore_multi_threads[pytorch-2] -v + + # ADLS Gen2-specific setup and tests + - name: Install ADLS Gen2 dependencies + run: | + source ${VENV_PATH}/bin/activate + pip install .[adls] + - name: test_adls_gen_data + run: | + source ${VENV_PATH}/bin/activate + mpirun -np 1 pytest -k test_adls_gen_data[npy-pytorch] -v + mpirun -np 1 pytest -k test_adls_gen_data[npz-pytorch] -v + - name: test_adls_train + run: | + source ${VENV_PATH}/bin/activate + mpirun -np 1 pytest -k test_adls_train[npy-pytorch-pytorch-True] -v + mpirun -np 1 pytest -k test_adls_train[npz-pytorch-pytorch-True] -v + mpirun -np 1 pytest -k test_adls_train[npy-pytorch-pytorch-False] -v + mpirun -np 1 pytest -k test_adls_train[npz-pytorch-pytorch-False] -v + - name: test_adls_eval + run: | + source ${VENV_PATH}/bin/activate + mpirun -np 1 pytest -k test_adls_eval -v + - name: test_adls_multi_threads + run: | + source ${VENV_PATH}/bin/activate + mpirun -np 1 pytest -k test_adls_multi_threads[pytorch-0] -v + mpirun -np 1 pytest -k test_adls_multi_threads[pytorch-1] -v + mpirun -np 1 pytest -k test_adls_multi_threads[pytorch-2] -v + - name: test_adls_pytorch_multiprocessing_context + run: | + source ${VENV_PATH}/bin/activate + mpirun -np 1 pytest -k test_adls_pytorch_multiprocessing_context[0-None] -v + mpirun -np 1 pytest -k test_adls_pytorch_multiprocessing_context[1-fork] -v + - name: test_adls_subset + run: | + source ${VENV_PATH}/bin/activate + mpirun -np 1 pytest -k test_adls_subset -v + - name: test_adls_checkpoint_epoch + run: | + source ${VENV_PATH}/bin/activate + mpirun -np 1 pytest -k test_adls_checkpoint_epoch[pytorch-1024-optimizers0-2-layer_params0-0-True] -v + mpirun -np 1 pytest -k test_adls_checkpoint_epoch[pytorch-1024-optimizers1-2-layer_params1-3-True] -v + mpirun -np 1 pytest -k test_adls_checkpoint_epoch[pytorch-1024-optimizers2-1-layer_params2-0-True] -v + mpirun -np 1 pytest -k test_adls_checkpoint_epoch[pytorch-1024-optimizers3-2-layer_params3-0-False] -v + mpirun -np 1 pytest -k test_adls_checkpoint_epoch[pytorch-1024-optimizers4-2-layer_params4-3-False] -v + mpirun -np 1 pytest -k test_adls_checkpoint_epoch[pytorch-1024-optimizers5-1-layer_params5-0-False] -v + - name: test_adls_checkpoint_ksm_config + run: | + source ${VENV_PATH}/bin/activate + mpirun -np 1 pytest -k test_adls_checkpoint_ksm_config -v + - name: test_adls_checkpoint_step + run: | + source ${VENV_PATH}/bin/activate + mpirun -np 1 pytest -k test_adls_checkpoint_step -v diff --git a/dlio_benchmark/checkpointing/checkpointing_factory.py b/dlio_benchmark/checkpointing/checkpointing_factory.py index 845dccb1..588a751c 100644 --- a/dlio_benchmark/checkpointing/checkpointing_factory.py +++ b/dlio_benchmark/checkpointing/checkpointing_factory.py @@ -42,5 +42,8 @@ def get_mechanism(checkpoint_mechanism_type): elif checkpoint_mechanism_type == CheckpointMechanismType.PT_S3_SAVE: from dlio_benchmark.checkpointing.pytorch_s3_checkpointing import PyTorchS3Checkpointing return PyTorchS3Checkpointing.get_instance() + elif checkpoint_mechanism_type == CheckpointMechanismType.PT_ADLS_SAVE: + from dlio_benchmark.checkpointing.pytorch_adls_checkpointing import PyTorchADLSCheckpointing + return PyTorchADLSCheckpointing.get_instance() else: raise Exception(str(ErrorCodes.EC1005)) diff --git a/dlio_benchmark/checkpointing/pytorch_adls_checkpointing.py b/dlio_benchmark/checkpointing/pytorch_adls_checkpointing.py new file mode 100644 index 00000000..f568addc --- /dev/null +++ b/dlio_benchmark/checkpointing/pytorch_adls_checkpointing.py @@ -0,0 +1,279 @@ +""" + Copyright (c) 2025, UChicago Argonne, LLC + All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +from datetime import datetime, timedelta, timezone +from urllib.parse import parse_qs, urlencode, urlparse, urlunparse + +import torch +from dlio_benchmark.checkpointing.base_checkpointing import BaseCheckpointing +from dlio_benchmark.checkpointing.pytorch_checkpointing import PyTorchCheckpointing +from dlio_benchmark.utils.utility import Profile, dft_ai + +from dlio_benchmark.common.constants import MODULE_CHECKPOINT + +dlp = Profile(MODULE_CHECKPOINT) + +# Import BlobIO at module level to allow test patching +try: + from azstoragetorch.io import BlobIO +except ImportError: + BlobIO = None + +try: + from azure.storage.blob import ContainerSasPermissions, generate_container_sas +except ImportError: + ContainerSasPermissions = None + generate_container_sas = None + +class PyTorchADLSCheckpointing(PyTorchCheckpointing): + __instance = None + + @staticmethod + def get_instance(): + """ Static access method. """ + if PyTorchADLSCheckpointing.__instance is None: + PyTorchADLSCheckpointing.__instance = PyTorchADLSCheckpointing() + return PyTorchADLSCheckpointing.__instance + + @dft_ai.checkpoint.init + def __init__(self): + BaseCheckpointing.__init__(self, "ptadls") + + # Check if BlobIO is available + if BlobIO is None: + raise ImportError( + "azstoragetorch is required for ADLS Gen2 checkpointing support. " + "Install with: pip install 'azstoragetorch>=0.1.0'" + ) + + # Access config values from self.args (inherited from BaseCheckpointing) + storage_options = getattr(self.args, "storage_options", {}) or {} + self._checkpoint_folder = self.args.checkpoint_folder + self._account_name = None + self._account_key = None + self._shared_access_signature = None + self._container_sas_tokens = {} + + if not isinstance(storage_options, dict): + storage_options = dict(storage_options) + + self._container_sas_ttl = self._get_duration_option( + storage_options, + "container_sas_ttl", + self.args.adls_container_sas_ttl, + ) + self._container_sas_refresh_margin = self._get_duration_option( + storage_options, + "sas_refresh_margin", + self.args.adls_sas_refresh_margin, + ) + + # Support both connection string and account URL authentication + connection_string = storage_options.get("connection_string") + account_url = storage_options.get("account_url") + account_name = storage_options.get("account_name") + + if connection_string: + # Parse connection string and generate SAS-based blob URLs for BlobIO. + self._load_connection_string(connection_string) + elif account_url: + # Use account URL and derive account name for SAS-backed checkpoint URLs. + self._account_name = self._extract_account_name_from_url(account_url) + elif account_name: + # Use explicit account name for SAS-backed checkpoint URLs. + self._account_name = account_name + else: + raise ValueError( + "ADLS Gen2 checkpointing requires authentication configuration. " + "Provide 'connection_string', 'account_url', or 'account_name' in storage_options." + ) + + if self._account_name is None: + self._account_name = self._extract_account_name_from_abfs(self._checkpoint_folder) + + if self._account_name is None: + raise ValueError( + "Unable to determine ADLS account name for checkpointing. " + "Provide storage_options.account_name/account_url or use canonical ABFS checkpoint URI." + ) + + def _get_duration_option(self, storage_options, option_name, default_value): + value = storage_options.get(option_name) + if value is None: + return default_value + + if isinstance(value, timedelta): + return value + + if isinstance(value, (int, float)): + return timedelta(seconds=float(value)) + + if isinstance(value, str): + normalized = value.strip().lower() + if not normalized: + return default_value + suffix_multipliers = { + "s": 1, + "m": 60, + "h": 3600, + "d": 86400, + } + if normalized[-1] in suffix_multipliers: + amount = float(normalized[:-1]) + return timedelta(seconds=amount * suffix_multipliers[normalized[-1]]) + return timedelta(seconds=float(normalized)) + + raise ValueError( + f"Invalid duration for storage_options.{option_name}: {value!r}. " + "Use seconds or a string with suffix s, m, h, or d." + ) + + def _load_connection_string(self, connection_string): + parts = {} + for segment in connection_string.split(';'): + if '=' in segment: + key, value = segment.split('=', 1) + parts[key] = value + + self._account_name = parts.get("AccountName") + self._account_key = parts.get("AccountKey") + self._shared_access_signature = parts.get("SharedAccessSignature") + + def _extract_account_name_from_url(self, account_url): + parsed = urlparse(account_url) + host = parsed.netloc + if not host: + return None + return host.split('.')[0] + + def _extract_account_name_from_abfs(self, uri): + parsed = urlparse(uri) + if parsed.scheme != "abfs" or '@' not in parsed.netloc: + return None + _, account_fqdn = parsed.netloc.split('@', 1) + return account_fqdn.split('.')[0] + + def _to_blob_url(self, checkpoint_name, for_write): + parsed = urlparse(checkpoint_name) + + if parsed.scheme == "https": + blob_url = checkpoint_name + elif parsed.scheme == "abfs": + if '@' not in parsed.netloc: + raise ValueError( + "Invalid ABFS checkpoint path. Expected format: " + "abfs://@.dfs.core.windows.net/" + ) + file_system, account_fqdn = parsed.netloc.split('@', 1) + account_name = account_fqdn.split('.')[0] + blob_path = parsed.path.lstrip('/') + blob_url = f"https://{account_name}.blob.core.windows.net/{file_system}/{blob_path}" + else: + raise ValueError( + f"Unsupported checkpoint URI '{checkpoint_name}'. Expected abfs:// or https://" + ) + + if self._shared_access_signature: + return self._append_query(blob_url, self._shared_access_signature) + + if self._account_key: + if generate_container_sas is None or ContainerSasPermissions is None: + raise ImportError( + "azure-storage-blob is required for connection-string-based ADLS checkpointing." + ) + blob_parsed = urlparse(blob_url) + path_parts = blob_parsed.path.lstrip('/').split('/', 1) + if len(path_parts) != 2: + raise ValueError(f"Invalid blob URL for checkpointing: {blob_url}") + container_name, _ = path_parts + token = self._get_container_sas(container_name) + return self._append_query(blob_url, token) + + return blob_url + + def _get_container_sas(self, container_name): + cache_entry = self._container_sas_tokens.get(container_name) + now = datetime.now(timezone.utc) + refresh_margin = self._container_sas_refresh_margin + + if isinstance(cache_entry, dict): + token = cache_entry.get("token") + expires_at = cache_entry.get("expires_at") + if token and expires_at and (expires_at - now) > refresh_margin: + return token + + ttl = self._container_sas_ttl + expiry = now + ttl + + token = generate_container_sas( + account_name=self._account_name, + container_name=container_name, + account_key=self._account_key, + permission=ContainerSasPermissions( + read=True, + write=True, + create=True, + add=True, + list=True, + ), + expiry=expiry, + ) + self._container_sas_tokens[container_name] = { + "token": token, + "expires_at": expiry, + } + return token + + def _append_query(self, url, query_string): + parsed = urlparse(url) + existing = parse_qs(parsed.query, keep_blank_values=True) + incoming = parse_qs(query_string.lstrip('?'), keep_blank_values=True) + for key, values in incoming.items(): + existing[key] = values + merged_query = urlencode(existing, doseq=True) + return urlunparse(parsed._replace(query=merged_query)) + + @dft_ai.checkpoint.capture + def save_state(self, suffix, state, fsync = False): + name = self.get_name(suffix) + blob_url = self._to_blob_url(name, for_write=True) + # Save checkpoint to ADLS using azstoragetorch BlobIO + with BlobIO(blob_url, "wb", credential=None) as writer: + torch.save(state, writer) + + @dft_ai.checkpoint.restart + def load_state(self, suffix, state): + name = self.get_name(suffix) + blob_url = self._to_blob_url(name, for_write=False) + state = dict() # clear up + # Load checkpoint from ADLS using azstoragetorch BlobIO + with BlobIO(blob_url, "rb", credential=None) as reader: + state = torch.load(reader) + self.logger.debug(f"checkpoint state loaded: {state}") + assert(len(state.keys())>0) + + @dlp.log + def save_checkpoint(self, epoch, step_number): + super().save_checkpoint(epoch, step_number) + + @dlp.log + def load_checkpoint(self, epoch, step_number): + super().load_checkpoint(epoch, step_number) + + @dlp.log + def finalize(self): + super().finalize() + diff --git a/dlio_benchmark/common/enumerations.py b/dlio_benchmark/common/enumerations.py index 2c61475d..41934217 100644 --- a/dlio_benchmark/common/enumerations.py +++ b/dlio_benchmark/common/enumerations.py @@ -27,6 +27,7 @@ class CheckpointMechanismType(Enum): TF_SAVE = 'tf_save' PT_SAVE = 'pt_save' PT_S3_SAVE = 'pt_s3_save' + PT_ADLS_SAVE = 'pt_adls_save' def __str__(self): return self.value @@ -59,6 +60,7 @@ class StorageType(Enum): PARALLEL_FS = 'parallel_fs' S3 = 's3' AISTORE = 'aistore' + ADLS_GEN2 = 'adls_gen2' def __str__(self): return self.value @@ -70,6 +72,7 @@ class MetadataType(Enum): FILE = 'file' DIRECTORY = 'directory' S3_OBJECT = 's3_object' + ADLS_OBJECT = 'adls_object' def __str__(self): return self.value diff --git a/dlio_benchmark/configs/workload/unet3d_a100_adlsgen2.yaml b/dlio_benchmark/configs/workload/unet3d_a100_adlsgen2.yaml new file mode 100644 index 00000000..3dc3441a --- /dev/null +++ b/dlio_benchmark/configs/workload/unet3d_a100_adlsgen2.yaml @@ -0,0 +1,46 @@ +model: + name: unet3d + type: cnn + model_size: 499153191 + +framework: pytorch + +workflow: + generate_data: True + train: True + checkpoint: False + +dataset: + data_folder: abfs://dliobenchmark@dlio.dfs.core.windows.net + format: npz + num_files_train: 168 + num_samples_per_file: 1 + record_length_bytes: 146600628 + record_length_bytes_stdev: 0 + record_length_bytes_resize: 2097152 + +storage: + storage_type: adls_gen2 + storage_root: dliobenchmark + storage_options: + connection_string: DefaultEndpointsProtocol=https;AccountName=;AccountKey=;EndpointSuffix=core.windows.net + +reader: + data_loader: pytorch + batch_size: 7 + read_threads: 4 + file_shuffle: seed + sample_shuffle: seed + +train: + epochs: 5 + computation_time: 0.636 + +checkpoint: + checkpoint_mechanism: pt_adls_save + checkpoint_folder: checkpoints/unet3d + checkpoint_after_epoch: 5 + epochs_between_checkpoints: 2 + +metric: + au: 0.90 diff --git a/dlio_benchmark/configs/workload/unet3d_h100_adlsgen2.yaml b/dlio_benchmark/configs/workload/unet3d_h100_adlsgen2.yaml new file mode 100644 index 00000000..c515bdf1 --- /dev/null +++ b/dlio_benchmark/configs/workload/unet3d_h100_adlsgen2.yaml @@ -0,0 +1,46 @@ +model: + name: unet3d + type: cnn + model_size: 499153191 + +framework: pytorch + +workflow: + generate_data: True + train: True + checkpoint: False + +dataset: + data_folder: abfs://dliobenchmark@dlio.dfs.core.windows.net + format: npz + num_files_train: 168 + num_samples_per_file: 1 + record_length_bytes: 146600628 + record_length_bytes_stdev: 0 + record_length_bytes_resize: 2097152 + +storage: + storage_type: adls_gen2 + storage_root: dliobenchmark + storage_options: + connection_string: DefaultEndpointsProtocol=https;AccountName=;AccountKey=;EndpointSuffix=core.windows.net + +reader: + data_loader: pytorch + batch_size: 7 + read_threads: 4 + file_shuffle: seed + sample_shuffle: seed + +train: + epochs: 7 + computation_time: 0.323 + +checkpoint: + checkpoint_mechanism: pt_adls_save + checkpoint_folder: checkpoints/unet3d + checkpoint_after_epoch: 5 + epochs_between_checkpoints: 2 + +metric: + au: 0.90 diff --git a/dlio_benchmark/configs/workload/unet3d_v100_adlsgen2.yaml b/dlio_benchmark/configs/workload/unet3d_v100_adlsgen2.yaml new file mode 100644 index 00000000..01595411 --- /dev/null +++ b/dlio_benchmark/configs/workload/unet3d_v100_adlsgen2.yaml @@ -0,0 +1,43 @@ +model: + name: unet3d + type: cnn + model_size: 499153191 + +framework: pytorch + +workflow: + generate_data: True + train: True + checkpoint: False + +dataset: + data_folder: abfs://dliobenchmark@dlio.dfs.core.windows.net + format: npy + num_files_train: 168 + num_samples_per_file: 1 + record_length_bytes: 146600628 + record_length_bytes_stdev: 0 + record_length_bytes_resize: 2097152 + +storage: + storage_type: adls_gen2 + storage_root: dliobenchmark + storage_options: + connection_string: DefaultEndpointsProtocol=https;AccountName=;AccountKey=;EndpointSuffix=core.windows.net + +reader: + data_loader: pytorch + batch_size: 4 + read_threads: 4 + file_shuffle: seed + sample_shuffle: seed + +train: + epochs: 5 + computation_time: 1.3604 + +checkpoint: + checkpoint_mechanism: pt_adls_save + checkpoint_folder: checkpoints/unet3d + checkpoint_after_epoch: 5 + epochs_between_checkpoints: 2 diff --git a/dlio_benchmark/reader/reader_factory.py b/dlio_benchmark/reader/reader_factory.py index abcbbd14..8bf75803 100644 --- a/dlio_benchmark/reader/reader_factory.py +++ b/dlio_benchmark/reader/reader_factory.py @@ -67,8 +67,7 @@ def get_reader(type, dataset_type, thread_index, epoch_number): if _args.odirect == True: from dlio_benchmark.reader.npy_reader_odirect import NPYReaderODirect return NPYReaderODirect(dataset_type, thread_index, epoch_number) - # Use S3 readers for both S3 and AIStore - elif _args.storage_type in (StorageType.S3, StorageType.AISTORE): + elif _args.storage_type in (StorageType.S3, StorageType.AISTORE, StorageType.ADLS_GEN2): from dlio_benchmark.reader.npy_reader_s3 import NPYReaderS3 return NPYReaderS3(dataset_type, thread_index, epoch_number) else: @@ -81,8 +80,7 @@ def get_reader(type, dataset_type, thread_index, epoch_number): if _args.odirect == True: from dlio_benchmark.reader.npz_reader_odirect import NPZReaderODIRECT return NPZReaderODIRECT(dataset_type, thread_index, epoch_number) - # Use S3 readers for both S3 and AIStore - elif _args.storage_type in (StorageType.S3, StorageType.AISTORE): + elif _args.storage_type in (StorageType.S3, StorageType.AISTORE, StorageType.ADLS_GEN2): from dlio_benchmark.reader.npz_reader_s3 import NPZReaderS3 return NPZReaderS3(dataset_type, thread_index, epoch_number) else: diff --git a/dlio_benchmark/storage/adls_gen2_storage.py b/dlio_benchmark/storage/adls_gen2_storage.py new file mode 100644 index 00000000..afdb2a43 --- /dev/null +++ b/dlio_benchmark/storage/adls_gen2_storage.py @@ -0,0 +1,374 @@ +""" + Copyright (c) 2025, UChicago Argonne, LLC + All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +from time import time +from urllib.parse import urlparse + +from dlio_benchmark.common.constants import MODULE_STORAGE +from dlio_benchmark.storage.storage_handler import DataStorage, Namespace +from dlio_benchmark.common.enumerations import NamespaceType, MetadataType +import os + +from dlio_benchmark.utils.utility import Profile + +# Import Azure SDK libraries at module level for patching in tests +try: + from azure.storage.filedatalake import DataLakeServiceClient + from azure.identity import DefaultAzureCredential +except ImportError: + DataLakeServiceClient = None + DefaultAzureCredential = None + +dlp = Profile(MODULE_STORAGE) + + +class ADLSGen2Storage(DataStorage): + """ + Storage APIs for ADLS Gen2 (Azure Data Lake Storage Gen2). + Uses Azure Data Lake Storage Gen2 Python SDK to interact with Azure storage. + """ + + @dlp.log_init + def __init__(self, namespace, framework=None): + super().__init__(framework) + self.namespace = Namespace(namespace, NamespaceType.HIERARCHICAL) + self.container_name, self.account_fqdn, self.base_path = self._parse_namespace(namespace) + + # Check if Azure SDK libraries are available + if DataLakeServiceClient is None: + raise ImportError( + "Azure Storage libraries are required for ADLS Gen2 support. " + "Install with: pip install azure-storage-file-datalake azure-identity" + ) + + # Import exception types locally as they're only used in this class + from azure.core.exceptions import ResourceNotFoundError, ResourceExistsError + + # Store exception types for use in methods + self.ResourceNotFoundError = ResourceNotFoundError + self.ResourceExistsError = ResourceExistsError + + # Get storage configuration from args + storage_options = getattr(self._args, "storage_options", {}) or {} + if not isinstance(storage_options, dict): + storage_options = dict(storage_options) + + # Support both connection string and account URL authentication + connection_string = storage_options.get("connection_string") + account_url = storage_options.get("account_url") + account_name = storage_options.get("account_name") + + if self.account_fqdn is None and account_name: + self.account_fqdn = f"{account_name}.dfs.core.windows.net" + elif self.account_fqdn is None and account_url: + self.account_fqdn = urlparse(account_url).netloc + + if connection_string: + # Use connection string authentication + self.service_client = DataLakeServiceClient.from_connection_string(connection_string) + elif account_url: + # Use account URL with default credential + credential = DefaultAzureCredential() + self.service_client = DataLakeServiceClient(account_url=account_url, credential=credential) + elif account_name: + # Construct account URL from account name + account_url = f"https://{account_name}.dfs.core.windows.net" + credential = DefaultAzureCredential() + self.service_client = DataLakeServiceClient(account_url=account_url, credential=credential) + else: + raise ValueError( + "ADLS Gen2 requires authentication configuration. " + "Provide 'connection_string', 'account_url', or 'account_name' in storage_options." + ) + + # Get or create file system client for the namespace (container) + self.file_system_client = self.service_client.get_file_system_client(file_system=self.container_name) + + def _parse_namespace(self, namespace): + parsed = urlparse(namespace) + if parsed.scheme == "abfs": + netloc = parsed.netloc + if "@" not in netloc: + raise ValueError( + "Invalid ABFS namespace URI. Expected format: " + "abfs://@.dfs.core.windows.net[/]" + ) + container_name, account_fqdn = netloc.split("@", 1) + base_path = parsed.path.lstrip('/').rstrip('/') + return container_name, account_fqdn, base_path + + namespace = namespace.strip('/') + if '/' in namespace: + container_name, base_path = namespace.split('/', 1) + return container_name, None, base_path.strip('/') + return namespace, None, "" + + def _build_abfs_uri(self, path): + if self.account_fqdn: + if path: + return f"abfs://{self.container_name}@{self.account_fqdn}/{path}" + return f"abfs://{self.container_name}@{self.account_fqdn}" + + if path: + return f"abfs://{self.container_name}/{path}" + return f"abfs://{self.container_name}" + + @dlp.log + def get_uri(self, id): + # If id is already a full URI, return as-is + # Otherwise, construct the URI + if id.startswith("abfs://"): + return id + path = self._resolve_path(id) + return self._build_abfs_uri(path) + + def _resolve_path(self, uri): + """ + Resolve URI or relative path to a path inside the ADLS container. + """ + parsed = urlparse(uri) + if parsed.scheme == 'abfs': + if "@" not in parsed.netloc: + raise ValueError( + "Invalid ABFS URI. Expected format: " + "abfs://@.dfs.core.windows.net//" + ) + return parsed.path.lstrip('/') + + relative_path = uri.lstrip('/') + if not self.base_path: + return relative_path + + if not relative_path: + return self.base_path + + if relative_path == self.base_path or relative_path.startswith(f"{self.base_path}/"): + return relative_path + + return f"{self.base_path}/{relative_path}" + + @dlp.log + def create_namespace(self, exist_ok=False): + """ + Create the file system (container) for ADLS Gen2. + """ + try: + self.file_system_client.create_file_system() + return True + except self.ResourceExistsError: + if exist_ok: + return True + raise + except Exception as e: + print(f"Error creating namespace '{self.namespace.name}': {e}") + return False + + @dlp.log + def get_namespace(self): + """ + Get the namespace (file system/container) information. + """ + try: + properties = self.file_system_client.get_file_system_properties() + return MetadataType.DIRECTORY + except self.ResourceNotFoundError: + return None + + @dlp.log + def create_node(self, id, exist_ok=False): + """ + Create a directory in ADLS Gen2. + """ + try: + dir_path = self._resolve_path(id) + if not dir_path: + return True + directory_client = self.file_system_client.get_directory_client(dir_path) + directory_client.create_directory() + return True + except self.ResourceExistsError: + if exist_ok: + return True + raise + except Exception as e: + print(f"Error creating node '{id}': {e}") + return False + + @dlp.log + def get_node(self, id=""): + """ + Get metadata about a path (file or directory). + """ + if not id or id == "": + return self.get_namespace() + + node_path = self._resolve_path(id) + try: + file_client = self.file_system_client.get_file_client(node_path) + properties = file_client.get_file_properties() + metadata = properties.get("metadata") or {} + is_directory = str(metadata.get("hdi_isfolder", "")).lower() == "true" + if is_directory: + return MetadataType.DIRECTORY + return MetadataType.FILE + except self.ResourceNotFoundError: + return None + except Exception: + return None + + @dlp.log + def walk_node(self, id, use_pattern=False): + """ + List files and directories under a path. + """ + try: + dir_path = self._resolve_path(id) + if not use_pattern: + # List all items in the directory + paths = self.file_system_client.get_paths(path=dir_path, recursive=False) + result = [] + prefix_len = len(dir_path.rstrip('/') + '/') if dir_path else 0 + + for path in paths: + path_name = path.name + # Get only immediate children (not nested) + if prefix_len > 0: + relative_path = path_name[prefix_len:] + else: + relative_path = path_name + + # Only include immediate children (no slashes in relative path) + if '/' not in relative_path: + result.append(relative_path) + + return result + else: + # Pattern matching for file extensions + format_ext = dir_path.split(".")[-1] + if format_ext != format_ext.lower(): + raise Exception(f"Unknown file format {format_ext}") + + search_path = os.path.dirname(dir_path) + while any(token in search_path for token in ["*", "?", "["]): + search_path = os.path.dirname(search_path) + + # List files matching the pattern + paths = self.file_system_client.get_paths(path=search_path) + result = [] + + # Match files with both lowercase and uppercase extensions + lower_pattern = dir_path + upper_pattern = dir_path.replace(format_ext, format_ext.upper()) + + for path in paths: + path_name = path.name + if (path_name.endswith(format_ext) or + path_name.endswith(format_ext.upper())): + result.append(self.get_uri(path_name)) + + return result + except Exception as e: + print(f"Error walking node '{id}': {e}") + return [] + + @dlp.log + def delete_node(self, id): + """ + Delete a file or directory from ADLS Gen2. + """ + try: + file_path = self._resolve_path(id) + file_client = self.file_system_client.get_file_client(file_path) + file_client.delete_file() + return True + except Exception as e: + print(f"Error deleting node '{id}': {e}") + return False + + @dlp.log + def put_data(self, id, data, offset=None, length=None): + """ + Upload data to a file in ADLS Gen2. + """ + try: + file_path = self._resolve_path(id) + file_client = self.file_system_client.get_file_client(file_path) + + # Handle different data types + if hasattr(data, 'getvalue'): + # BytesIO or StringIO object + data_bytes = data.getvalue() + elif isinstance(data, bytes): + data_bytes = data + elif isinstance(data, str): + data_bytes = data.encode('utf-8') + else: + data_bytes = str(data).encode('utf-8') + + if offset is not None and length is not None: + # Partial write - append to existing file + file_client.append_data(data_bytes, offset=offset, length=length) + file_client.flush_data(offset + length) + else: + # Full write - create/overwrite file + file_client.create_file() + file_client.upload_data(data_bytes, overwrite=True) + + return True + except Exception as e: + print(f"Error putting data to '{id}': {e}") + return False + + @dlp.log + def get_data(self, id, data, offset=None, length=None): + """ + Download data from a file in ADLS Gen2. + """ + try: + file_path = self._resolve_path(id) + file_client = self.file_system_client.get_file_client(file_path) + + if offset is not None and length is not None: + # Partial read + download_stream = file_client.download_file(offset=offset, length=length) + else: + # Full read + download_stream = file_client.download_file() + + return download_stream.readall() + except Exception as e: + print(f"Error getting data from '{id}': {e}") + return None + + @dlp.log + def isfile(self, id): + """ + Check if the path is a file. + """ + try: + file_path = self._resolve_path(id) + file_client = self.file_system_client.get_file_client(file_path) + properties = file_client.get_file_properties() + metadata = properties.get("metadata") or {} + is_directory = str(metadata.get("hdi_isfolder", "")).lower() == "true" + return not is_directory + except self.ResourceNotFoundError: + return False + except Exception: + return False + + def get_basename(self, id): + return os.path.basename(id) diff --git a/dlio_benchmark/storage/storage_factory.py b/dlio_benchmark/storage/storage_factory.py index e346187c..0145031d 100644 --- a/dlio_benchmark/storage/storage_factory.py +++ b/dlio_benchmark/storage/storage_factory.py @@ -16,6 +16,7 @@ """ from dlio_benchmark.storage.file_storage import FileStorage from dlio_benchmark.storage.s3_storage import S3Storage +from dlio_benchmark.storage.adls_gen2_storage import ADLSGen2Storage from dlio_benchmark.common.enumerations import StorageType from dlio_benchmark.common.error_code import ErrorCodes @@ -48,5 +49,7 @@ def get_storage(storage_type, namespace, framework=None): from dlio_benchmark.storage.s3_torch_storage import S3PyTorchConnectorStorage return S3PyTorchConnectorStorage(namespace, framework) return S3Storage(namespace, framework) + elif storage_type == StorageType.ADLS_GEN2: + return ADLSGen2Storage(namespace, framework) else: raise Exception(str(ErrorCodes.EC1001)) diff --git a/dlio_benchmark/utils/config.py b/dlio_benchmark/utils/config.py index 15a1071d..4dc1738e 100644 --- a/dlio_benchmark/utils/config.py +++ b/dlio_benchmark/utils/config.py @@ -17,6 +17,7 @@ import importlib import inspect import hydra +from datetime import timedelta import logging @@ -196,6 +197,9 @@ class ConfigArguments: s3_region: str = "us-east-1" s3_force_path_style = False s3_max_attempts: int = 5 + # adls gen2 defaults + adls_container_sas_ttl = timedelta(hours=1) + adls_sas_refresh_margin = timedelta(minutes=5) def __init__(self): """ Virtually private constructor. """ @@ -377,6 +381,25 @@ def validate(self): "AIStore with NPZ requires dlio_benchmark.reader.npz_reader_s3.NPZReaderS3" ) + # ADLS Gen2 specific checks (uses S3 generators/readers) + if self.storage_type == StorageType.ADLS_GEN2 and self.framework == FrameworkType.PYTORCH: + if self.format not in (FormatType.NPZ, FormatType.NPY): + raise Exception(f"For ADLS Gen2 using PyTorch framework, only NPZ or NPY formats are supported. Got format {self.format}") + storage_options = self.storage_options or {} + if not any([ + storage_options.get("connection_string"), + storage_options.get("account_url"), + storage_options.get("account_name"), + ]): + raise Exception( + "ADLS Gen2 requires authentication configuration. " + "Provide 'connection_string', 'account_url', or 'account_name' in storage_options." + ) + if self.do_checkpoint == True and self.checkpoint_mechanism != CheckpointMechanismType.PT_ADLS_SAVE: + raise Exception( + f"For ADLS Gen2 checkpointing using PyTorch framework, invalid mechanism type supported. Got mechanism type as {self.checkpoint_mechanism}" + ) + # S3 specific checks if self.storage_type == StorageType.S3 and self.framework == FrameworkType.PYTORCH: if self.format not in (FormatType.NPZ, FormatType.NPY): @@ -671,10 +694,17 @@ def GetConfig(args, key): elif keys[1] == "storage_root": value = args.storage_root elif keys[1] == "storage_options" and len(keys) > 2: - if args.storage_type == "s3": - option_key = keys[2] + option_key = keys[2] + if args.storage_type == StorageType.S3: if option_key in ["access_key_id", "secret_access_key", "endpoint_url", "region", "s3_force_path_style", "s3_max_attempts"]: - value = config["storage"].get("storage_options", {}).get(option_key) + value = (args.storage_options or {}).get(option_key) + elif args.storage_type == StorageType.ADLS_GEN2: + if option_key in ["connection_string", "account_url", "account_name", "container_sas_ttl", "sas_refresh_margin"]: + value = (args.storage_options or {}).get(option_key) + if value is None and option_key == "container_sas_ttl": + value = args.adls_container_sas_ttl + elif value is None and option_key == "sas_refresh_margin": + value = args.adls_sas_refresh_margin if len(keys) > 1 and keys[0] == "dataset": if keys[1] == "record_length_bytes": @@ -792,6 +822,8 @@ def GetConfig(args, key): value = args.steps_between_checkpoints elif keys[1] == "type": value = args.checkpoint_type + elif keys[1] == "checkpoint_mechanism": + value = args.checkpoint_mechanism elif keys[1] == 'mode': value = args.checkpoint_mode elif keys[1] == "checkpoint_mechanism_classname": @@ -1074,6 +1106,8 @@ def LoadConfig(args, config): args.steps_between_checkpoints = config['checkpoint']['steps_between_checkpoints'] if 'type' in config['checkpoint']: args.checkpoint_type = CheckpointLocationType(config['checkpoint']['type']) + if 'checkpoint_mechanism' in config['checkpoint']: + args.checkpoint_mechanism = CheckpointMechanismType(config['checkpoint']['checkpoint_mechanism']) if 'checkpoint_mechanism_classname' in config['checkpoint']: args.checkpoint_mechanism_classname = config['checkpoint']['checkpoint_mechanism_classname'] if 'fsync' in config['checkpoint']: diff --git a/setup.py b/setup.py index 9a69fc92..75c2da07 100644 --- a/setup.py +++ b/setup.py @@ -51,6 +51,11 @@ "aistore": [ "aistore", ], + "adls": [ + "azure-storage-file-datalake>=12.0.0", + "azure-identity>=1.12.0", + "azstoragetorch>=0.1.0", + ], } here = pathlib.Path(__file__).parent.resolve() diff --git a/tests/dlio_adls_benchmark_test.py b/tests/dlio_adls_benchmark_test.py new file mode 100644 index 00000000..45cca886 --- /dev/null +++ b/tests/dlio_adls_benchmark_test.py @@ -0,0 +1,938 @@ +""" + Copyright (c) 2025, UChicago Argonne, LLC + All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +#!/usr/bin/env python +from hydra import initialize, initialize_config_dir, compose +from omegaconf import OmegaConf +import unittest +from datetime import datetime, timedelta, timezone +import uuid +import io +from io import BytesIO +import glob +from mpi4py import MPI +from tests.utils import TEST_TIMEOUT_SECONDS + +comm = MPI.COMM_WORLD + +import pytest +import time +import subprocess +import logging +import os +from dlio_benchmark.utils.config import ConfigArguments, GetConfig, LoadConfig +from dlio_benchmark.utils.utility import DLIOMPI +import dlio_benchmark + +from unittest.mock import patch, MagicMock +try: + from azstoragetorch.io import BlobIO +except ImportError as e: + BlobIO = None +from urllib.parse import urlparse + +config_dir=os.path.dirname(dlio_benchmark.__file__)+"/configs/" + +logging.basicConfig( + level=logging.INFO, + handlers=[ + logging.FileHandler("dlio_benchmark_test.log", mode="a", encoding='utf-8'), + logging.StreamHandler() + ], format='[%(levelname)s] %(message)s [%(pathname)s:%(lineno)d]' + # logging's max timestamp resolution is msecs, we will pass in usecs in the message +) + +from dlio_benchmark.main import DLIOBenchmark, set_dftracer_initialize, set_dftracer_finalize + +def finalize(): + # DLIOMPI.get_instance().finalize() + pass + +def clean_adls(mock_file_system_client, prefixes: list[str]) -> None: + """Clean up mock ADLS Gen2 storage after tests""" + comm.Barrier() + if comm.rank == 0: + for prefix in prefixes: + # Get all paths starting with prefix + keys = [k for k in mock_file_system_client.storage.keys() if k.startswith(prefix)] + for key in keys: + del mock_file_system_client.storage[key] + comm.Barrier() + +def get_adls_prefixes_from_uri(uri: str, subdirs=("train", "valid")): + parsed = urlparse(uri) + base_prefix = parsed.path.lstrip("/") + return [f"{base_prefix}/{subdir}" for subdir in subdirs] + +def run_benchmark(cfg, verify=True): + comm.Barrier() + t0 = time.time() + ConfigArguments.reset() + benchmark = DLIOBenchmark(cfg["workload"]) + benchmark.initialize() + benchmark.run() + benchmark.finalize() + t1 = time.time() + if (comm.rank==0): + logging.info("Time for the benchmark: %.10f" %(t1-t0)) + if (verify): + assert(len(glob.glob(benchmark.output_folder+"./*_output.json"))==benchmark.comm_size) + return benchmark + +class MockADLSFileClient: + """Mock Azure Data Lake Storage Gen2 file client""" + def __init__(self, file_system_client, file_path): + self.file_system_client = file_system_client + self.file_path = file_path + self.storage = file_system_client.storage + + def create_file(self): + """Create a file (no-op for mock)""" + pass + + def upload_data(self, data, overwrite=True): + """Upload data to the file""" + if isinstance(data, bytes): + self.storage[self.file_path] = data + elif isinstance(data, str): + self.storage[self.file_path] = data.encode('utf-8') + else: + self.storage[self.file_path] = bytes(data) + + def download_file(self, offset=None, length=None): + """Download file data""" + data = self.storage.get(self.file_path, b"") + if offset is not None and length is not None: + return MockDownloadStream(data[offset:offset+length]) + return MockDownloadStream(data) + + def delete_file(self): + """Delete a file""" + if self.file_path in self.storage: + del self.storage[self.file_path] + + def get_file_properties(self): + """Get file properties""" + if self.file_path in self.storage: + return {'is_directory': False} + raise Exception(f"File not found: {self.file_path}") + +class MockDownloadStream: + """Mock download stream""" + def __init__(self, data): + self.data = data + + def readall(self): + return self.data + +class MockADLSDirectoryClient: + """Mock Azure Data Lake Storage Gen2 directory client""" + def __init__(self, file_system_client, directory_path): + self.file_system_client = file_system_client + self.directory_path = directory_path + self.storage = file_system_client.storage + + def create_directory(self): + """Create a directory (mark in storage)""" + # Store directory marker + if not self.directory_path.endswith('/'): + dir_key = self.directory_path + '/' + else: + dir_key = self.directory_path + self.storage[dir_key] = b"" + + def delete_directory(self): + """Delete a directory and all its contents""" + prefix = self.directory_path if self.directory_path.endswith('/') else self.directory_path + '/' + keys_to_delete = [k for k in self.storage.keys() if k.startswith(prefix)] + for key in keys_to_delete: + del self.storage[key] + + def get_directory_properties(self): + """Get directory properties""" + return {'is_directory': True} + +class MockPathItem: + """Mock path item returned by get_paths""" + def __init__(self, name, is_directory=False): + self.name = name + self.is_directory = is_directory + +class MockADLSFileSystemClient: + """Mock Azure Data Lake Storage Gen2 file system client""" + def __init__(self, file_system_name): + self.file_system_name = file_system_name + self.storage = {} + + def create_file_system(self): + """Create file system (no-op for mock)""" + pass + + def get_file_system_properties(self): + """Get file system properties""" + return {'name': self.file_system_name} + + def get_file_client(self, file_path): + """Get a file client""" + return MockADLSFileClient(self, file_path) + + def get_directory_client(self, directory_path): + """Get a directory client""" + return MockADLSDirectoryClient(self, directory_path) + + def get_paths(self, path="", recursive=True): + """List paths under a given path""" + prefix = path if path.endswith('/') or path == "" else path + '/' + if path == "": + # List all items + paths = [] + seen = set() + for key in self.storage.keys(): + if key: # Skip empty keys + # Get the top-level name + first_part = key.split('/')[0] + if first_part not in seen: + seen.add(first_part) + is_dir = '/' in key[len(first_part):] + # Return full path from root (matching Azure SDK behavior) + paths.append(MockPathItem(first_part, is_directory=is_dir)) + return paths + else: + # List items under specific path + paths = [] + seen = set() + for key in self.storage.keys(): + if key.startswith(prefix): + # Get relative path + relative = key[len(prefix):] + if relative: + # Get first component + first_part = relative.split('/')[0] + if first_part and first_part not in seen: + seen.add(first_part) + # Check if this is a directory + full_path = prefix + first_part + is_dir = any(k.startswith(full_path + '/') for k in self.storage.keys()) + # Return full path from root (matching Azure SDK behavior) + paths.append(MockPathItem(full_path, is_directory=is_dir)) + return paths + +class MockDataLakeServiceClient: + """Mock Azure Data Lake Service Client""" + def __init__(self, account_url=None, credential=None): + self.account_url = account_url + self.credential = credential + self._file_systems = {} + + @classmethod + def from_connection_string(cls, connection_string): + """Create from connection string""" + return cls(account_url="mock_url") + + def get_file_system_client(self, file_system): + """Get or create a file system client""" + if file_system not in self._file_systems: + self._file_systems[file_system] = MockADLSFileSystemClient(file_system) + return self._file_systems[file_system] + +@pytest.fixture +def setup_test_env(): + DLIOMPI.get_instance().initialize() + if comm.rank == 0: + now = datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f") + storage_root = f"adls-test-container-{now}-{str(uuid.uuid4())}" + storage_type = "adls_gen2" + else: + storage_root = None + storage_type = None + + storage_root = comm.bcast(storage_root, root=0) + storage_type = comm.bcast(storage_type, root=0) + + # Create mock ADLS Gen2 service client + if comm.rank == 0: + mock_service_client = MockDataLakeServiceClient() + mock_file_system_client = mock_service_client.get_file_system_client(storage_root) + mock_file_system_client.create_file_system() + # Initialize with a marker file + mock_file_system_client.storage["init.txt"] = b"container initialized" + mock_storage = mock_file_system_client.storage + else: + mock_storage = None + mock_service_client = MockDataLakeServiceClient() + mock_file_system_client = mock_service_client.get_file_system_client(storage_root) + + # Broadcast the mock_storage dictionary to all ranks + mock_storage = comm.bcast(mock_storage, root=0) + mock_file_system_client.storage = mock_storage + + adls_overrides = [ + f"++workload.storage.storage_type={storage_type}", + f"++workload.storage.storage_root={storage_root}", + f"++workload.dataset.data_folder=abfs://{storage_root}@test.dfs.core.windows.net", + "++workload.storage.storage_options.account_name=test", + "++workload.dataset.num_subfolders_train=0", + "++workload.dataset.num_subfolders_eval=0" + ] + + comm.Barrier() + yield storage_root, storage_type, mock_file_system_client, adls_overrides + comm.Barrier() + +@pytest.fixture +def patch_adls_checkpoint(setup_test_env): + storage_root, storage_type, mock_file_system_client, adls_overrides = setup_test_env + adls_overrides += [f"++workload.checkpoint.checkpoint_folder=abfs://{storage_root}@test.dfs.core.windows.net/checkpoints"] + checkpoint_storage = {} + + from dlio_benchmark.checkpointing.pytorch_adls_checkpointing import PyTorchADLSCheckpointing + PyTorchADLSCheckpointing._PyTorchADLSCheckpointing__instance = None + + class MockBlobIO: + """Mock BlobIO for testing""" + def __init__(self, blob_url, mode, credential=None, **kwargs): + self.blob_url = blob_url + self.mode = mode + self.credential = credential + self._mock_storage = checkpoint_storage + self._buffer = None + + def __enter__(self): + if self.mode == "wb": + self._buffer = io.BytesIO() + return self._buffer + data = self._mock_storage.get(self.blob_url, b'') + return io.BytesIO(data) + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is None and self.mode == "wb" and self._buffer is not None: + self._mock_storage[self.blob_url] = self._buffer.getvalue() + return False + + # Always mock BlobIO for tests (whether azstoragetorch is installed or not) + with patch("dlio_benchmark.checkpointing.pytorch_adls_checkpointing.BlobIO", MockBlobIO): + mock_file_system_client.checkpoint_storage = checkpoint_storage + yield setup_test_env + +@pytest.mark.timeout(TEST_TIMEOUT_SECONDS, method="thread") +@pytest.mark.parametrize("fmt, framework", [("npy", "pytorch"), ("npz", "pytorch")]) +def test_adls_gen_data(setup_test_env, fmt, framework) -> None: + storage_root, storage_type, mock_file_system_client, adls_overrides = setup_test_env + + # Patch both DataLakeServiceClient and DefaultAzureCredential + with patch("dlio_benchmark.storage.adls_gen2_storage.DataLakeServiceClient") as mock_service, \ + patch("dlio_benchmark.storage.adls_gen2_storage.DefaultAzureCredential") as mock_cred: + mock_instance = MagicMock() + mock_instance.get_file_system_client.return_value = mock_file_system_client + mock_service.return_value = mock_instance + mock_service.from_connection_string.return_value = mock_instance + mock_cred.return_value = MagicMock() + + if (comm.rank == 0): + logging.info("") + logging.info("=" * 80) + logging.info(f" DLIO test for generating {fmt} dataset on ADLS Gen2") + logging.info("=" * 80) + with initialize_config_dir(version_base=None, config_dir=config_dir): + cfg = compose(config_name='config', overrides=adls_overrides + [f'++workload.framework={framework}', + f'++workload.reader.data_loader={framework}', + '++workload.workflow.train=False', + '++workload.workflow.generate_data=True', + f"++workload.dataset.format={fmt}", + "++workload.dataset.num_files_train=8", + "++workload.dataset.num_files_eval=8"]) + benchmark = run_benchmark(cfg, verify=False) + + # Verify files were created + train_keys = [k for k in mock_file_system_client.storage.keys() if k.startswith("train/") and k.endswith(f".{fmt}")] + valid_keys = [k for k in mock_file_system_client.storage.keys() if k.startswith("valid/") and k.endswith(f".{fmt}")] + assert len(train_keys) == cfg.workload.dataset.num_files_train, f"Expected {cfg.workload.dataset.num_files_train} train files, got {len(train_keys)}" + assert len(valid_keys) == cfg.workload.dataset.num_files_eval, f"Expected {cfg.workload.dataset.num_files_eval} valid files, got {len(valid_keys)}" + + # Clean up mock ADLS after test + clean_adls(mock_file_system_client, ["train/", "valid/"]) + finalize() + + +@pytest.mark.timeout(TEST_TIMEOUT_SECONDS, method="thread") +def test_adls_rejects_unsupported_format(setup_test_env) -> None: + storage_root, storage_type, mock_file_system_client, adls_overrides = setup_test_env + + with patch("dlio_benchmark.storage.adls_gen2_storage.DataLakeServiceClient") as mock_service, \ + patch("dlio_benchmark.storage.adls_gen2_storage.DefaultAzureCredential") as mock_cred: + mock_instance = MagicMock() + mock_instance.get_file_system_client.return_value = mock_file_system_client + mock_service.return_value = mock_instance + mock_service.from_connection_string.return_value = mock_instance + mock_cred.return_value = MagicMock() + + with initialize_config_dir(version_base=None, config_dir=config_dir): + cfg = compose(config_name='config', overrides=adls_overrides + [ + '++workload.framework=pytorch', + '++workload.reader.data_loader=pytorch', + '++workload.workflow.train=False', + '++workload.workflow.generate_data=True', + '++workload.dataset.format=jpeg' + ]) + + with pytest.raises( + Exception, + match="For ADLS Gen2 using PyTorch framework, only NPZ or NPY formats are supported", + ): + DLIOBenchmark(cfg['workload']) + +@pytest.mark.timeout(TEST_TIMEOUT_SECONDS, method="thread") +def test_adls_subset(setup_test_env) -> None: + storage_root, storage_type, mock_file_system_client, adls_overrides = setup_test_env + + with patch("dlio_benchmark.storage.adls_gen2_storage.DataLakeServiceClient") as mock_service, \ + patch("dlio_benchmark.storage.adls_gen2_storage.DefaultAzureCredential") as mock_cred: + mock_instance = MagicMock() + mock_instance.get_file_system_client.return_value = mock_file_system_client + mock_service.return_value = mock_instance + mock_service.from_connection_string.return_value = mock_instance + mock_cred.return_value = MagicMock() + + if comm.rank == 0: + logging.info("") + logging.info("=" * 80) + logging.info(f" DLIO training test for subset on ADLS Gen2") + logging.info("=" * 80) + with initialize_config_dir(version_base=None, config_dir=config_dir): + set_dftracer_finalize(False) + # Generate data + cfg = compose(config_name='config', overrides=adls_overrides + [ + '++workload.workflow.train=False', + '++workload.workflow.generate_data=True']) + benchmark = run_benchmark(cfg, verify=False) + + # Train on subset + set_dftracer_initialize(False) + cfg = compose(config_name='config', overrides=adls_overrides + [ + '++workload.workflow.train=True', + '++workload.workflow.generate_data=False', + '++workload.dataset.num_files_train=8', + '++workload.train.computation_time=0.01']) + benchmark = run_benchmark(cfg, verify=True) + + # Clean up + clean_adls(mock_file_system_client, ["train/", "valid/"]) + finalize() + +@pytest.mark.timeout(TEST_TIMEOUT_SECONDS, method="thread") +def test_adls_eval(setup_test_env) -> None: + storage_root, storage_type, mock_file_system_client, adls_overrides = setup_test_env + + with patch("dlio_benchmark.storage.adls_gen2_storage.DataLakeServiceClient") as mock_service, \ + patch("dlio_benchmark.storage.adls_gen2_storage.DefaultAzureCredential") as mock_cred: + mock_instance = MagicMock() + mock_instance.get_file_system_client.return_value = mock_file_system_client + mock_service.return_value = mock_instance + mock_service.from_connection_string.return_value = mock_instance + mock_cred.return_value = MagicMock() + + if comm.rank == 0: + logging.info("") + logging.info("=" * 80) + logging.info(f" DLIO evaluation test on ADLS Gen2") + logging.info("=" * 80) + with initialize_config_dir(version_base=None, config_dir=config_dir): + cfg = compose(config_name='config', overrides=adls_overrides + ['++workload.workflow.train=True', + '++workload.workflow.generate_data=True', + 'workload.train.computation_time=0.01', + 'workload.evaluation.eval_time=0.005', + '++workload.train.epochs=4', + '++workload.workflow.evaluation=True']) + benchmark = run_benchmark(cfg) + + # Clean up + clean_adls(mock_file_system_client, ["train/", "valid/"]) + finalize() + +@pytest.mark.timeout(TEST_TIMEOUT_SECONDS, method="thread") +@pytest.mark.parametrize("framework, nt", [("pytorch", 0), ("pytorch", 1), ("pytorch", 2)]) +def test_adls_multi_threads(setup_test_env, framework, nt) -> None: + storage_root, storage_type, mock_file_system_client, adls_overrides = setup_test_env + + with patch("dlio_benchmark.storage.adls_gen2_storage.DataLakeServiceClient") as mock_service, \ + patch("dlio_benchmark.storage.adls_gen2_storage.DefaultAzureCredential") as mock_cred: + mock_instance = MagicMock() + mock_instance.get_file_system_client.return_value = mock_file_system_client + mock_service.return_value = mock_instance + mock_service.from_connection_string.return_value = mock_instance + mock_cred.return_value = MagicMock() + + if comm.rank == 0: + logging.info("") + logging.info("=" * 80) + logging.info(f" DLIO multi-threaded test on ADLS Gen2: {framework} with {nt} threads") + logging.info("=" * 80) + with initialize_config_dir(version_base=None, config_dir=config_dir): + cfg = compose(config_name='config', overrides=adls_overrides + ['++workload.workflow.train=True', + '++workload.workflow.generate_data=True', + f'++workload.framework={framework}', + f'++workload.reader.data_loader={framework}', + '++workload.dataset.format=npz', + '++workload.train.computation_time=0.01', + '++workload.evaluation.eval_time=0.005', + '++workload.train.epochs=1', + '++workload.dataset.num_files_train=8', + f'++workload.reader.read_threads={nt}']) + benchmark = run_benchmark(cfg) + + # Clean up + clean_adls(mock_file_system_client, ["train/", "valid/"]) + finalize() + +@pytest.mark.timeout(TEST_TIMEOUT_SECONDS, method="thread") +@pytest.mark.parametrize("nt, context", [(0, None), (1, "fork")]) +def test_adls_pytorch_multiprocessing_context(setup_test_env, nt, context, monkeypatch) -> None: + storage_root, storage_type, mock_file_system_client, adls_overrides = setup_test_env + + with patch("dlio_benchmark.storage.adls_gen2_storage.DataLakeServiceClient") as mock_service, \ + patch("dlio_benchmark.storage.adls_gen2_storage.DefaultAzureCredential") as mock_cred: + mock_instance = MagicMock() + mock_instance.get_file_system_client.return_value = mock_file_system_client + mock_service.return_value = mock_instance + mock_service.from_connection_string.return_value = mock_instance + mock_cred.return_value = MagicMock() + + if comm.rank == 0: + logging.info("") + logging.info("=" * 80) + logging.info(f" DLIO PyTorch multiprocessing context test on ADLS Gen2: threads={nt}, context={context}") + logging.info("=" * 80) + + overrides = adls_overrides + ['++workload.workflow.train=True', + '++workload.workflow.generate_data=True', + '++workload.framework=pytorch', + '++workload.reader.data_loader=pytorch', + '++workload.dataset.format=npz', + '++workload.train.computation_time=0.01', + '++workload.evaluation.eval_time=0.005', + '++workload.train.epochs=1', + '++workload.dataset.num_files_train=8', + f'++workload.reader.read_threads={nt}'] + + if context is not None: + overrides.append(f'++workload.reader.multiprocessing_context={context}') + + with initialize_config_dir(version_base=None, config_dir=config_dir): + cfg = compose(config_name='config', overrides=overrides) + benchmark = run_benchmark(cfg) + + # Clean up + clean_adls(mock_file_system_client, ["train/", "valid/"]) + finalize() + +@pytest.mark.timeout(TEST_TIMEOUT_SECONDS, method="thread") +@pytest.mark.parametrize("fmt, framework, dataloader, is_even", [ + ("npy", "pytorch", "pytorch", True), + ("npz", "pytorch", "pytorch", True), + ("npy", "pytorch", "pytorch", False), + ("npz", "pytorch", "pytorch", False) +]) +def test_adls_train(setup_test_env, fmt, framework, dataloader, is_even) -> None: + storage_root, storage_type, mock_file_system_client, adls_overrides = setup_test_env + if is_even: + num_files = 16 + else: + num_files = 17 + + with patch("dlio_benchmark.storage.adls_gen2_storage.DataLakeServiceClient") as mock_service, \ + patch("dlio_benchmark.storage.adls_gen2_storage.DefaultAzureCredential") as mock_cred: + mock_instance = MagicMock() + mock_instance.get_file_system_client.return_value = mock_file_system_client + mock_service.return_value = mock_instance + mock_service.from_connection_string.return_value = mock_instance + mock_cred.return_value = MagicMock() + + if comm.rank == 0: + logging.info("") + logging.info("=" * 80) + logging.info(f" DLIO training test on ADLS Gen2: Generating data for {fmt} format") + logging.info("=" * 80) + with initialize_config_dir(version_base=None, config_dir=config_dir): + cfg = compose(config_name='config', overrides=adls_overrides + ['++workload.workflow.train=True', + '++workload.workflow.generate_data=True', + f"++workload.framework={framework}", + f"++workload.reader.data_loader={dataloader}", + f"++workload.dataset.format={fmt}", + 'workload.train.computation_time=0.01', + 'workload.evaluation.eval_time=0.005', + '++workload.train.epochs=1', + f'++workload.dataset.num_files_train={num_files}', + '++workload.reader.read_threads=1']) + benchmark = run_benchmark(cfg) + + # Clean up + clean_adls(mock_file_system_client, ["train/", "valid/"]) + finalize() + +@pytest.mark.timeout(TEST_TIMEOUT_SECONDS, method="thread") +@pytest.mark.parametrize("framework, model_size, optimizers, num_layers, layer_params, zero_stage, randomize", [ + ("pytorch", 1024, [1024, 128], 2, [16], 0, True), + ("pytorch", 1024, [1024, 128], 2, [16], 3, True), + ("pytorch", 1024, [128], 1, [16], 0, True), + ("pytorch", 1024, [1024, 128], 2, [16], 0, False), + ("pytorch", 1024, [1024, 128], 2, [16], 3, False), + ("pytorch", 1024, [128], 1, [16], 0, False) +]) +def test_adls_checkpoint_epoch(patch_adls_checkpoint, framework, model_size, optimizers, num_layers, layer_params, zero_stage, randomize) -> None: + storage_root, storage_type, mock_file_system_client, adls_overrides = patch_adls_checkpoint + + with patch("dlio_benchmark.storage.adls_gen2_storage.DataLakeServiceClient") as mock_service, \ + patch("dlio_benchmark.storage.adls_gen2_storage.DefaultAzureCredential") as mock_cred: + mock_instance = MagicMock() + mock_instance.get_file_system_client.return_value = mock_file_system_client + mock_service.return_value = mock_instance + mock_service.from_connection_string.return_value = mock_instance + mock_cred.return_value = MagicMock() + + if comm.rank == 0: + logging.info("") + logging.info("=" * 80) + logging.info(f" DLIO test for checkpointing at the end of epochs on ADLS Gen2") + logging.info("=" * 80) + + with initialize_config_dir(version_base=None, config_dir=config_dir): + epochs = 8 + epoch_per_ckp = 2 + cfg = compose(config_name='config', + overrides=adls_overrides + [ + f'++workload.framework={framework}', + f'++workload.reader.data_loader={framework}', + '++workload.workflow.train=True', + '++workload.workflow.generate_data=True', + f'++workload.checkpoint.randomize_tensor={randomize}', + '++workload.train.computation_time=0.01', + '++workload.evaluation.eval_time=0.005', + f'++workload.train.epochs={epochs}', + '++workload.workflow.checkpoint=True', + '++workload.checkpoint.checkpoint_mechanism=pt_adls_save', + f'++workload.checkpoint.epochs_between_checkpoints={epoch_per_ckp}', + f'++workload.model.model_size={model_size}', + f'++workload.model.optimization_groups={optimizers}', + f'++workload.model.num_layers={num_layers}', + f'++workload.model.parallelism.zero_stage={zero_stage}', + f'++workload.model.layer_parameters={layer_params}', + f'++workload.model.parallelism.tensor={comm.size}' + ]) + ConfigArguments.reset() + benchmark = DLIOBenchmark(cfg['workload']) + benchmark.initialize() + benchmark.run() + benchmark.finalize() + + checkpoint_keys = list(mock_file_system_client.checkpoint_storage.keys()) + n = 0 + if len(layer_params) > 0: + n = num_layers + nranks = comm.size + num_model_files = 1 + num_optimizer_files = 1 + num_layer_files = 1 + files_per_checkpoint = (num_model_files + num_optimizer_files + num_layer_files) * nranks + if framework == "pytorch": + num_check_files = epochs / epoch_per_ckp * files_per_checkpoint + assert (len(checkpoint_keys) == num_check_files), f"files produced are {len(checkpoint_keys)} {num_check_files} {checkpoint_keys}" + + # Clean up + clean_adls(mock_file_system_client, ["checkpoints/"]) + finalize() + +@pytest.mark.timeout(TEST_TIMEOUT_SECONDS, method="thread") +def test_adls_checkpoint_step(patch_adls_checkpoint) -> None: + storage_root, storage_type, mock_file_system_client, adls_overrides = patch_adls_checkpoint + + with patch("dlio_benchmark.storage.adls_gen2_storage.DataLakeServiceClient") as mock_service, \ + patch("dlio_benchmark.storage.adls_gen2_storage.DefaultAzureCredential") as mock_cred: + mock_instance = MagicMock() + mock_instance.get_file_system_client.return_value = mock_file_system_client + mock_service.return_value = mock_instance + mock_service.from_connection_string.return_value = mock_instance + mock_cred.return_value = MagicMock() + + if comm.rank == 0: + logging.info("") + logging.info("=" * 80) + logging.info(f" DLIO test for checkpointing at the end of steps on ADLS Gen2") + logging.info("=" * 80) + + with initialize_config_dir(version_base=None, config_dir=config_dir): + cfg = compose(config_name='config', + overrides=adls_overrides + [ + '++workload.workflow.train=True', + '++workload.workflow.generate_data=True', + '++workload.train.computation_time=0.01', + '++workload.evaluation.eval_time=0.005', + '++workload.train.epochs=8', + '++workload.workflow.checkpoint=True', + '++workload.checkpoint.checkpoint_mechanism=pt_adls_save', + '++workload.checkpoint.steps_between_checkpoints=2' + ]) + ConfigArguments.reset() + benchmark = DLIOBenchmark(cfg['workload']) + benchmark.initialize() + benchmark.run() + benchmark.finalize() + + dataset = cfg['workload']['dataset'] + nstep = dataset.num_files_train * dataset.num_samples_per_file // cfg['workload']['reader'].batch_size // benchmark.comm_size + ncheckpoints = nstep // 2 * 8 + checkpoint_keys = list(mock_file_system_client.checkpoint_storage.keys()) + assert (len(checkpoint_keys) == ncheckpoints) + + # Clean up + clean_adls(mock_file_system_client, ["checkpoints/"]) + finalize() + +@pytest.mark.timeout(TEST_TIMEOUT_SECONDS, method="thread") +def test_adls_checkpoint_ksm_config(patch_adls_checkpoint) -> None: + storage_root, storage_type, mock_file_system_client, adls_overrides = patch_adls_checkpoint + + with patch("dlio_benchmark.storage.adls_gen2_storage.DataLakeServiceClient") as mock_service, \ + patch("dlio_benchmark.storage.adls_gen2_storage.DefaultAzureCredential") as mock_cred: + mock_instance = MagicMock() + mock_instance.get_file_system_client.return_value = mock_file_system_client + mock_service.return_value = mock_instance + mock_service.from_connection_string.return_value = mock_instance + mock_cred.return_value = MagicMock() + + if comm.rank == 0: + logging.info("") + logging.info("=" * 80) + logging.info(" DLIO test for KSM config on ADLS Gen2") + logging.info("=" * 80) + + # Test Case 1: KSM enabled with defaults + logging.info("Testing KSM enabled with defaults...") + with initialize_config_dir(version_base=None, config_dir=config_dir): + cfg = compose(config_name='config', + overrides=adls_overrides + [ + '++workload.workflow.checkpoint=True', + '++workload.checkpoint.checkpoint_mechanism=pt_adls_save', + '++workload.checkpoint.ksm={}', + '++workload.workflow.generate_data=False', + '++workload.workflow.train=False', + '++workload.checkpoint.num_checkpoints_write=1', + '++workload.checkpoint.num_checkpoints_read=1', + '++workload.checkpoint.randomize_tensor=False' + ]) + ConfigArguments.reset() + benchmark = DLIOBenchmark(cfg['workload']) + benchmark.initialize() + + args = ConfigArguments.get_instance() + assert args.ksm_init is True, "[Test Case 1 Failed] ksm_init should be True when ksm section is present" + assert args.ksm_madv_mergeable_id == 12 + assert args.ksm_high_ram_trigger == 30.0 + assert args.ksm_low_ram_exit == 15.0 + assert args.ksm_await_time == 200 + logging.info("[Test Case 1 Passed]") + + # Test Case 2: KSM enabled with overrides + logging.info("Testing KSM enabled with overrides...") + with initialize_config_dir(version_base=None, config_dir=config_dir): + cfg = compose(config_name='config', + overrides=adls_overrides + [ + '++workload.workflow.checkpoint=True', + '++workload.checkpoint.checkpoint_mechanism=pt_adls_save', + '++workload.checkpoint.ksm.high_ram_trigger=25.5', + '++workload.checkpoint.ksm.await_time=100', + '++workload.workflow.generate_data=False', + '++workload.workflow.train=False', + '++workload.checkpoint.num_checkpoints_write=1', + '++workload.checkpoint.num_checkpoints_read=1', + '++workload.checkpoint.randomize_tensor=False' + ]) + ConfigArguments.reset() + benchmark = DLIOBenchmark(cfg['workload']) + benchmark.initialize() + + args = ConfigArguments.get_instance() + assert args.ksm_init is True + assert args.ksm_high_ram_trigger == 25.5 + assert args.ksm_await_time == 100 + logging.info("[Test Case 2 Passed]") + + # Test Case 3: KSM disabled + logging.info("Testing KSM disabled...") + with initialize_config_dir(version_base=None, config_dir=config_dir): + cfg = compose(config_name='config', + overrides=adls_overrides + [ + '++workload.workflow.checkpoint=True', + '++workload.checkpoint.checkpoint_mechanism=pt_adls_save', + '++workload.workflow.generate_data=False', + '++workload.workflow.train=False', + '++workload.checkpoint.num_checkpoints_write=1', + '++workload.checkpoint.num_checkpoints_read=1', + '++workload.checkpoint.randomize_tensor=False' + ]) + ConfigArguments.reset() + benchmark = DLIOBenchmark(cfg['workload']) + benchmark.initialize() + + args = ConfigArguments.get_instance() + assert args.ksm_init is False + logging.info("[Test Case 3 Passed]") + + # Clean up + clean_adls(mock_file_system_client, ["checkpoints/"]) + finalize() + +if __name__ == '__main__': + unittest.main() + + +def test_adls_checkpoint_uses_cached_container_sas(monkeypatch): + from dlio_benchmark.checkpointing import pytorch_adls_checkpointing as module + from dlio_benchmark.checkpointing.pytorch_adls_checkpointing import PyTorchADLSCheckpointing + + mock_generate_container_sas = MagicMock(return_value="sig=container-token") + monkeypatch.setattr(module, "generate_container_sas", mock_generate_container_sas) + + checkpoint = PyTorchADLSCheckpointing.__new__(PyTorchADLSCheckpointing) + checkpoint._account_name = "testacct" + checkpoint._account_key = "testkey" + checkpoint._shared_access_signature = None + checkpoint._container_sas_tokens = {} + checkpoint._container_sas_refresh_margin = timedelta(minutes=5) + checkpoint._container_sas_ttl = timedelta(hours=1) + + first = checkpoint._to_blob_url( + "abfs://cont@testacct.dfs.core.windows.net/checkpoints/model.pt", for_write=True + ) + second = checkpoint._to_blob_url( + "abfs://cont@testacct.dfs.core.windows.net/checkpoints/optim.pt", for_write=True + ) + third = checkpoint._to_blob_url( + "abfs://cont@testacct.dfs.core.windows.net/checkpoints/model.pt", for_write=False + ) + + assert "sig=container-token" in first + assert "sig=container-token" in second + assert "sig=container-token" in third + assert mock_generate_container_sas.call_count == 1 + + +def test_adls_checkpoint_container_sas_cached_per_container(monkeypatch): + from dlio_benchmark.checkpointing import pytorch_adls_checkpointing as module + from dlio_benchmark.checkpointing.pytorch_adls_checkpointing import PyTorchADLSCheckpointing + + mock_generate_container_sas = MagicMock(return_value="sig=container-token") + monkeypatch.setattr(module, "generate_container_sas", mock_generate_container_sas) + + checkpoint = PyTorchADLSCheckpointing.__new__(PyTorchADLSCheckpointing) + checkpoint._account_name = "testacct" + checkpoint._account_key = "testkey" + checkpoint._shared_access_signature = None + checkpoint._container_sas_tokens = {} + checkpoint._container_sas_refresh_margin = timedelta(minutes=5) + checkpoint._container_sas_ttl = timedelta(hours=1) + + checkpoint._to_blob_url("abfs://conta@testacct.dfs.core.windows.net/checkpoints/model.pt", for_write=True) + checkpoint._to_blob_url("abfs://contb@testacct.dfs.core.windows.net/checkpoints/model.pt", for_write=True) + + assert mock_generate_container_sas.call_count == 2 + + +def test_adls_checkpoint_container_sas_refreshes_near_expiry(monkeypatch): + from dlio_benchmark.checkpointing import pytorch_adls_checkpointing as module + from dlio_benchmark.checkpointing.pytorch_adls_checkpointing import PyTorchADLSCheckpointing + + mock_generate_container_sas = MagicMock(return_value="sig=fresh-token") + monkeypatch.setattr(module, "generate_container_sas", mock_generate_container_sas) + + checkpoint = PyTorchADLSCheckpointing.__new__(PyTorchADLSCheckpointing) + checkpoint._account_name = "testacct" + checkpoint._account_key = "testkey" + checkpoint._shared_access_signature = None + checkpoint._container_sas_refresh_margin = timedelta(minutes=5) + checkpoint._container_sas_ttl = timedelta(hours=1) + checkpoint._container_sas_tokens = { + "cont": { + "token": "sig=stale-token", + "expires_at": datetime.now(timezone.utc) + timedelta(minutes=1), + } + } + + url = checkpoint._to_blob_url("abfs://cont@testacct.dfs.core.windows.net/checkpoints/model.pt", for_write=True) + + assert "sig=fresh-token" in url + assert mock_generate_container_sas.call_count == 1 + + +def test_adls_checkpoint_duration_options_are_configurable(): + from dlio_benchmark.checkpointing.pytorch_adls_checkpointing import PyTorchADLSCheckpointing + + checkpoint = PyTorchADLSCheckpointing.__new__(PyTorchADLSCheckpointing) + + assert checkpoint._get_duration_option({}, "container_sas_ttl", timedelta(hours=1)) == timedelta(hours=1) + assert checkpoint._get_duration_option({"container_sas_ttl": 90}, "container_sas_ttl", timedelta(hours=1)) == timedelta(seconds=90) + assert checkpoint._get_duration_option({"container_sas_ttl": "2h"}, "container_sas_ttl", timedelta(hours=1)) == timedelta(hours=2) + assert checkpoint._get_duration_option({"sas_refresh_margin": "7.5m"}, "sas_refresh_margin", timedelta(minutes=5)) == timedelta(minutes=7.5) + + +def test_adls_get_config_reads_storage_options(): + ConfigArguments.reset() + args = ConfigArguments.get_instance() + + cfg = OmegaConf.create({ + "storage": { + "storage_type": "adls_gen2", + "storage_options": { + "account_name": "testacct", + }, + }, + }) + LoadConfig(args, cfg) + + assert GetConfig(args, "storage.storage_options.account_name") == "testacct" + assert GetConfig(args, "storage.storage_options.container_sas_ttl") == "1:00:00" + assert GetConfig(args, "storage.storage_options.sas_refresh_margin") == "0:05:00" + + +def test_adls_checkpoint_uses_none_blobio_credential(monkeypatch): + from dlio_benchmark.checkpointing import pytorch_adls_checkpointing as module + from dlio_benchmark.checkpointing.pytorch_adls_checkpointing import PyTorchADLSCheckpointing + + blobio_calls = [] + + class MockBlobIO: + def __init__(self, blob_url, mode, credential=None): + blobio_calls.append({"blob_url": blob_url, "mode": mode, "credential": credential}) + + def __enter__(self): + from io import BytesIO + return BytesIO() + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(module, "BlobIO", MockBlobIO) + monkeypatch.setattr(module.torch, "save", lambda state, writer: None) + + checkpoint = PyTorchADLSCheckpointing.__new__(PyTorchADLSCheckpointing) + checkpoint._shared_access_signature = None + checkpoint._account_key = None + checkpoint.get_name = lambda suffix: "https://testacct.blob.core.windows.net/cont/checkpoints/model.pt?sig=abc" + + checkpoint.save_state("model.pt", {"step": 1}) + + assert len(blobio_calls) == 1 + assert blobio_calls[0]["credential"] is None