Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
pull_request:
branches: [main, dev]
push:
workflow_dispatch:

jobs:
build-and-test:
Expand Down Expand Up @@ -59,6 +60,8 @@ jobs:
key: ${{ matrix.venv }}-gcc${{ matrix.gcc }}-python${{ matrix.python }}-${{ hashFiles('requirements.txt', 'setup.py') }}
- name: Install system dependencies
run: |
sudo rm -f /etc/apt/sources.list.d/microsoft-prod.list
sudo rm -f /etc/apt/sources.list.d/azure-cli.list
sudo apt update
sudo apt-get install -y $CC $CXX libc6 git
sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev python3-dev
Expand Down Expand Up @@ -381,3 +384,57 @@ jobs:
mpirun -np 1 pytest -k test_aistore_multi_threads[pytorch-0] -v
mpirun -np 1 pytest -k test_aistore_multi_threads[pytorch-1] -v
mpirun -np 1 pytest -k test_aistore_multi_threads[pytorch-2] -v

# ADLS Gen2-specific setup and tests
- name: Install ADLS Gen2 dependencies
run: |
source ${VENV_PATH}/bin/activate
pip install .[adls]
- name: test_adls_gen_data
run: |
source ${VENV_PATH}/bin/activate
mpirun -np 1 pytest -k test_adls_gen_data[npy-pytorch] -v
mpirun -np 1 pytest -k test_adls_gen_data[npz-pytorch] -v
- name: test_adls_train
run: |
source ${VENV_PATH}/bin/activate
mpirun -np 1 pytest -k test_adls_train[npy-pytorch-pytorch-True] -v
mpirun -np 1 pytest -k test_adls_train[npz-pytorch-pytorch-True] -v
mpirun -np 1 pytest -k test_adls_train[npy-pytorch-pytorch-False] -v
mpirun -np 1 pytest -k test_adls_train[npz-pytorch-pytorch-False] -v
- name: test_adls_eval
run: |
source ${VENV_PATH}/bin/activate
mpirun -np 1 pytest -k test_adls_eval -v
- name: test_adls_multi_threads
run: |
source ${VENV_PATH}/bin/activate
mpirun -np 1 pytest -k test_adls_multi_threads[pytorch-0] -v
mpirun -np 1 pytest -k test_adls_multi_threads[pytorch-1] -v
mpirun -np 1 pytest -k test_adls_multi_threads[pytorch-2] -v
- name: test_adls_pytorch_multiprocessing_context
run: |
source ${VENV_PATH}/bin/activate
mpirun -np 1 pytest -k test_adls_pytorch_multiprocessing_context[0-None] -v
mpirun -np 1 pytest -k test_adls_pytorch_multiprocessing_context[1-fork] -v
- name: test_adls_subset
run: |
source ${VENV_PATH}/bin/activate
mpirun -np 1 pytest -k test_adls_subset -v
- name: test_adls_checkpoint_epoch
run: |
source ${VENV_PATH}/bin/activate
mpirun -np 1 pytest -k test_adls_checkpoint_epoch[pytorch-1024-optimizers0-2-layer_params0-0-True] -v
mpirun -np 1 pytest -k test_adls_checkpoint_epoch[pytorch-1024-optimizers1-2-layer_params1-3-True] -v
mpirun -np 1 pytest -k test_adls_checkpoint_epoch[pytorch-1024-optimizers2-1-layer_params2-0-True] -v
mpirun -np 1 pytest -k test_adls_checkpoint_epoch[pytorch-1024-optimizers3-2-layer_params3-0-False] -v
mpirun -np 1 pytest -k test_adls_checkpoint_epoch[pytorch-1024-optimizers4-2-layer_params4-3-False] -v
mpirun -np 1 pytest -k test_adls_checkpoint_epoch[pytorch-1024-optimizers5-1-layer_params5-0-False] -v
- name: test_adls_checkpoint_ksm_config
run: |
source ${VENV_PATH}/bin/activate
mpirun -np 1 pytest -k test_adls_checkpoint_ksm_config -v
- name: test_adls_checkpoint_step
run: |
source ${VENV_PATH}/bin/activate
mpirun -np 1 pytest -k test_adls_checkpoint_step -v
3 changes: 3 additions & 0 deletions dlio_benchmark/checkpointing/checkpointing_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,8 @@ def get_mechanism(checkpoint_mechanism_type):
elif checkpoint_mechanism_type == CheckpointMechanismType.PT_S3_SAVE:
from dlio_benchmark.checkpointing.pytorch_s3_checkpointing import PyTorchS3Checkpointing
return PyTorchS3Checkpointing.get_instance()
elif checkpoint_mechanism_type == CheckpointMechanismType.PT_ADLS_SAVE:
from dlio_benchmark.checkpointing.pytorch_adls_checkpointing import PyTorchADLSCheckpointing
return PyTorchADLSCheckpointing.get_instance()
else:
raise Exception(str(ErrorCodes.EC1005))
279 changes: 279 additions & 0 deletions dlio_benchmark/checkpointing/pytorch_adls_checkpointing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
"""
Copyright (c) 2025, UChicago Argonne, LLC
All Rights Reserved

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from datetime import datetime, timedelta, timezone
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse

import torch
from dlio_benchmark.checkpointing.base_checkpointing import BaseCheckpointing
from dlio_benchmark.checkpointing.pytorch_checkpointing import PyTorchCheckpointing
from dlio_benchmark.utils.utility import Profile, dft_ai

from dlio_benchmark.common.constants import MODULE_CHECKPOINT

dlp = Profile(MODULE_CHECKPOINT)

# Import BlobIO at module level to allow test patching
try:
from azstoragetorch.io import BlobIO
except ImportError:
BlobIO = None

try:
from azure.storage.blob import ContainerSasPermissions, generate_container_sas
except ImportError:
ContainerSasPermissions = None
generate_container_sas = None

class PyTorchADLSCheckpointing(PyTorchCheckpointing):
__instance = None

@staticmethod
def get_instance():
""" Static access method. """
if PyTorchADLSCheckpointing.__instance is None:
PyTorchADLSCheckpointing.__instance = PyTorchADLSCheckpointing()
return PyTorchADLSCheckpointing.__instance

@dft_ai.checkpoint.init
def __init__(self):
BaseCheckpointing.__init__(self, "ptadls")

# Check if BlobIO is available
if BlobIO is None:
raise ImportError(
"azstoragetorch is required for ADLS Gen2 checkpointing support. "
"Install with: pip install 'azstoragetorch>=0.1.0'"
)

# Access config values from self.args (inherited from BaseCheckpointing)
storage_options = getattr(self.args, "storage_options", {}) or {}
self._checkpoint_folder = self.args.checkpoint_folder
self._account_name = None
self._account_key = None
self._shared_access_signature = None
self._container_sas_tokens = {}

if not isinstance(storage_options, dict):
storage_options = dict(storage_options)

self._container_sas_ttl = self._get_duration_option(
storage_options,
"container_sas_ttl",
self.args.adls_container_sas_ttl,
)
self._container_sas_refresh_margin = self._get_duration_option(
storage_options,
"sas_refresh_margin",
self.args.adls_sas_refresh_margin,
)

# Support both connection string and account URL authentication
connection_string = storage_options.get("connection_string")
account_url = storage_options.get("account_url")
account_name = storage_options.get("account_name")
Comment on lines +85 to +87
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again this should be set in config.py in the ConfigurationManager class. Then used in storage class as well as here.


if connection_string:
# Parse connection string and generate SAS-based blob URLs for BlobIO.
self._load_connection_string(connection_string)
elif account_url:
# Use account URL and derive account name for SAS-backed checkpoint URLs.
self._account_name = self._extract_account_name_from_url(account_url)
elif account_name:
# Use explicit account name for SAS-backed checkpoint URLs.
self._account_name = account_name
else:
raise ValueError(
"ADLS Gen2 checkpointing requires authentication configuration. "
"Provide 'connection_string', 'account_url', or 'account_name' in storage_options."
)

if self._account_name is None:
self._account_name = self._extract_account_name_from_abfs(self._checkpoint_folder)

if self._account_name is None:
raise ValueError(
"Unable to determine ADLS account name for checkpointing. "
"Provide storage_options.account_name/account_url or use canonical ABFS checkpoint URI."
)

def _get_duration_option(self, storage_options, option_name, default_value):
value = storage_options.get(option_name)
if value is None:
return default_value

if isinstance(value, timedelta):
return value

if isinstance(value, (int, float)):
return timedelta(seconds=float(value))

if isinstance(value, str):
normalized = value.strip().lower()
if not normalized:
return default_value
suffix_multipliers = {
"s": 1,
"m": 60,
"h": 3600,
"d": 86400,
}
if normalized[-1] in suffix_multipliers:
amount = float(normalized[:-1])
return timedelta(seconds=amount * suffix_multipliers[normalized[-1]])
return timedelta(seconds=float(normalized))

raise ValueError(
f"Invalid duration for storage_options.{option_name}: {value!r}. "
"Use seconds or a string with suffix s, m, h, or d."
)

def _load_connection_string(self, connection_string):
parts = {}
for segment in connection_string.split(';'):
if '=' in segment:
key, value = segment.split('=', 1)
parts[key] = value

self._account_name = parts.get("AccountName")
self._account_key = parts.get("AccountKey")
self._shared_access_signature = parts.get("SharedAccessSignature")

def _extract_account_name_from_url(self, account_url):
parsed = urlparse(account_url)
host = parsed.netloc
if not host:
return None
return host.split('.')[0]

def _extract_account_name_from_abfs(self, uri):
parsed = urlparse(uri)
if parsed.scheme != "abfs" or '@' not in parsed.netloc:
return None
_, account_fqdn = parsed.netloc.split('@', 1)
return account_fqdn.split('.')[0]

def _to_blob_url(self, checkpoint_name, for_write):
parsed = urlparse(checkpoint_name)

if parsed.scheme == "https":
blob_url = checkpoint_name
elif parsed.scheme == "abfs":
if '@' not in parsed.netloc:
raise ValueError(
"Invalid ABFS checkpoint path. Expected format: "
"abfs://<file_system>@<account>.dfs.core.windows.net/<path>"
)
file_system, account_fqdn = parsed.netloc.split('@', 1)
account_name = account_fqdn.split('.')[0]
blob_path = parsed.path.lstrip('/')
blob_url = f"https://{account_name}.blob.core.windows.net/{file_system}/{blob_path}"
else:
raise ValueError(
f"Unsupported checkpoint URI '{checkpoint_name}'. Expected abfs:// or https://"
)

if self._shared_access_signature:
return self._append_query(blob_url, self._shared_access_signature)

if self._account_key:
if generate_container_sas is None or ContainerSasPermissions is None:
raise ImportError(
"azure-storage-blob is required for connection-string-based ADLS checkpointing."
)
blob_parsed = urlparse(blob_url)
path_parts = blob_parsed.path.lstrip('/').split('/', 1)
if len(path_parts) != 2:
raise ValueError(f"Invalid blob URL for checkpointing: {blob_url}")
container_name, _ = path_parts
token = self._get_container_sas(container_name)
return self._append_query(blob_url, token)

return blob_url

def _get_container_sas(self, container_name):
cache_entry = self._container_sas_tokens.get(container_name)
now = datetime.now(timezone.utc)
refresh_margin = self._container_sas_refresh_margin

if isinstance(cache_entry, dict):
token = cache_entry.get("token")
expires_at = cache_entry.get("expires_at")
if token and expires_at and (expires_at - now) > refresh_margin:
return token

ttl = self._container_sas_ttl
expiry = now + ttl

token = generate_container_sas(
account_name=self._account_name,
container_name=container_name,
account_key=self._account_key,
permission=ContainerSasPermissions(
read=True,
write=True,
create=True,
add=True,
list=True,
),
expiry=expiry,
)
self._container_sas_tokens[container_name] = {
"token": token,
"expires_at": expiry,
}
return token

def _append_query(self, url, query_string):
parsed = urlparse(url)
existing = parse_qs(parsed.query, keep_blank_values=True)
incoming = parse_qs(query_string.lstrip('?'), keep_blank_values=True)
for key, values in incoming.items():
existing[key] = values
merged_query = urlencode(existing, doseq=True)
return urlunparse(parsed._replace(query=merged_query))

@dft_ai.checkpoint.capture
def save_state(self, suffix, state, fsync = False):
name = self.get_name(suffix)
blob_url = self._to_blob_url(name, for_write=True)
# Save checkpoint to ADLS using azstoragetorch BlobIO
with BlobIO(blob_url, "wb", credential=None) as writer:
torch.save(state, writer)

@dft_ai.checkpoint.restart
def load_state(self, suffix, state):
name = self.get_name(suffix)
blob_url = self._to_blob_url(name, for_write=False)
state = dict() # clear up
# Load checkpoint from ADLS using azstoragetorch BlobIO
with BlobIO(blob_url, "rb", credential=None) as reader:
state = torch.load(reader)
self.logger.debug(f"checkpoint state loaded: {state}")
assert(len(state.keys())>0)

@dlp.log
def save_checkpoint(self, epoch, step_number):
super().save_checkpoint(epoch, step_number)

@dlp.log
def load_checkpoint(self, epoch, step_number):
super().load_checkpoint(epoch, step_number)

@dlp.log
def finalize(self):
super().finalize()

3 changes: 3 additions & 0 deletions dlio_benchmark/common/enumerations.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class CheckpointMechanismType(Enum):
TF_SAVE = 'tf_save'
PT_SAVE = 'pt_save'
PT_S3_SAVE = 'pt_s3_save'
PT_ADLS_SAVE = 'pt_adls_save'

def __str__(self):
return self.value
Expand Down Expand Up @@ -59,6 +60,7 @@ class StorageType(Enum):
PARALLEL_FS = 'parallel_fs'
S3 = 's3'
AISTORE = 'aistore'
ADLS_GEN2 = 'adls_gen2'

def __str__(self):
return self.value
Expand All @@ -70,6 +72,7 @@ class MetadataType(Enum):
FILE = 'file'
DIRECTORY = 'directory'
S3_OBJECT = 's3_object'
ADLS_OBJECT = 'adls_object'

def __str__(self):
return self.value
Expand Down
Loading