Skip to content

Commit 1e9eda4

Browse files
Copilotzsy056
authored andcommitted
Initial plan
Add ADLS Gen2 storage backend support Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Change ADLS Gen2 namespace type to HIERARCHICAL Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Implement ADLS Gen2 API calls using Azure SDK Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Fix code review issues: remove unused import and fix bare except clauses Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Add ADLS Gen2 PyTorch checkpointing and fixes Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Remove unused ctypes import from pytorch_adls_checkpointing Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Add ADLS Gen2 support to CI workflow Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Add workflow_dispatch trigger to ci.yml Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Fix Microsoft repository 403 errors in CI Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Implement ADLS Gen2 test suite Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Fix ADLS Gen2 generator support for NPY/NPZ formats Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Move DataLakeServiceClient import to module level for test patching Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Fix ADLS test configuration and add DefaultAzureCredential patching Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Remove redundant comment from test for consistency Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Fix ADLS Gen2 URI parsing in all storage methods Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Refactor URI parsing to use helper method and reduce duplication Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Add debug logging to investigate test failure Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Add debug logging to create_node as well Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Add comprehensive debug logging to ADLS Gen2 storage Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Fix get_uri to avoid double URI prefixing Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Fix mock get_paths to return full paths matching Azure SDK behavior Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Move AzStorageCheckpoint import to module level for test patching Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Fix test_adls_subset to match S3 pattern - remove incorrect checkpoint mode Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Implement proper AzStorageCheckpoint mock with writer/reader context managers Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Remove redundant AzStorageCheckpoint patches from checkpoint tests Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Fix ADLS checkpoint test: always apply MockAzStorageCheckpoint patch Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> fix adlsgen2 tests
1 parent 8b280cb commit 1e9eda4

File tree

13 files changed

+1799
-4
lines changed

13 files changed

+1799
-4
lines changed

.github/workflows/ci.yml

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ on:
44
pull_request:
55
branches: [main, dev]
66
push:
7+
workflow_dispatch:
78

89
jobs:
910
build-and-test:
@@ -59,6 +60,8 @@ jobs:
5960
key: ${{ matrix.venv }}-gcc${{ matrix.gcc }}-python${{ matrix.python }}-${{ hashFiles('requirements.txt', 'setup.py') }}
6061
- name: Install system dependencies
6162
run: |
63+
sudo rm -f /etc/apt/sources.list.d/microsoft-prod.list
64+
sudo rm -f /etc/apt/sources.list.d/azure-cli.list
6265
sudo apt update
6366
sudo apt-get install -y $CC $CXX libc6 git
6467
sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev python3-dev
@@ -381,3 +384,57 @@ jobs:
381384
mpirun -np 1 pytest -k test_aistore_multi_threads[pytorch-0] -v
382385
mpirun -np 1 pytest -k test_aistore_multi_threads[pytorch-1] -v
383386
mpirun -np 1 pytest -k test_aistore_multi_threads[pytorch-2] -v
387+
388+
# ADLS Gen2-specific setup and tests
389+
- name: Install ADLS Gen2 dependencies
390+
run: |
391+
source ${VENV_PATH}/bin/activate
392+
pip install .[adls]
393+
- name: test_adls_gen_data
394+
run: |
395+
source ${VENV_PATH}/bin/activate
396+
mpirun -np 1 pytest -k test_adls_gen_data[npy-pytorch] -v
397+
mpirun -np 1 pytest -k test_adls_gen_data[npz-pytorch] -v
398+
- name: test_adls_train
399+
run: |
400+
source ${VENV_PATH}/bin/activate
401+
mpirun -np 1 pytest -k test_adls_train[npy-pytorch-pytorch-True] -v
402+
mpirun -np 1 pytest -k test_adls_train[npz-pytorch-pytorch-True] -v
403+
mpirun -np 1 pytest -k test_adls_train[npy-pytorch-pytorch-False] -v
404+
mpirun -np 1 pytest -k test_adls_train[npz-pytorch-pytorch-False] -v
405+
- name: test_adls_eval
406+
run: |
407+
source ${VENV_PATH}/bin/activate
408+
mpirun -np 1 pytest -k test_adls_eval -v
409+
- name: test_adls_multi_threads
410+
run: |
411+
source ${VENV_PATH}/bin/activate
412+
mpirun -np 1 pytest -k test_adls_multi_threads[pytorch-0] -v
413+
mpirun -np 1 pytest -k test_adls_multi_threads[pytorch-1] -v
414+
mpirun -np 1 pytest -k test_adls_multi_threads[pytorch-2] -v
415+
- name: test_adls_pytorch_multiprocessing_context
416+
run: |
417+
source ${VENV_PATH}/bin/activate
418+
mpirun -np 1 pytest -k test_adls_pytorch_multiprocessing_context[0-None] -v
419+
mpirun -np 1 pytest -k test_adls_pytorch_multiprocessing_context[1-fork] -v
420+
- name: test_adls_subset
421+
run: |
422+
source ${VENV_PATH}/bin/activate
423+
mpirun -np 1 pytest -k test_adls_subset -v
424+
- name: test_adls_checkpoint_epoch
425+
run: |
426+
source ${VENV_PATH}/bin/activate
427+
mpirun -np 1 pytest -k test_adls_checkpoint_epoch[pytorch-1024-optimizers0-2-layer_params0-0-True] -v
428+
mpirun -np 1 pytest -k test_adls_checkpoint_epoch[pytorch-1024-optimizers1-2-layer_params1-3-True] -v
429+
mpirun -np 1 pytest -k test_adls_checkpoint_epoch[pytorch-1024-optimizers2-1-layer_params2-0-True] -v
430+
mpirun -np 1 pytest -k test_adls_checkpoint_epoch[pytorch-1024-optimizers3-2-layer_params3-0-False] -v
431+
mpirun -np 1 pytest -k test_adls_checkpoint_epoch[pytorch-1024-optimizers4-2-layer_params4-3-False] -v
432+
mpirun -np 1 pytest -k test_adls_checkpoint_epoch[pytorch-1024-optimizers5-1-layer_params5-0-False] -v
433+
- name: test_adls_checkpoint_ksm_config
434+
run: |
435+
source ${VENV_PATH}/bin/activate
436+
mpirun -np 1 pytest -k test_adls_checkpoint_ksm_config -v
437+
- name: test_adls_checkpoint_step
438+
run: |
439+
source ${VENV_PATH}/bin/activate
440+
mpirun -np 1 pytest -k test_adls_checkpoint_step -v

dlio_benchmark/checkpointing/checkpointing_factory.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,8 @@ def get_mechanism(checkpoint_mechanism_type):
4242
elif checkpoint_mechanism_type == CheckpointMechanismType.PT_S3_SAVE:
4343
from dlio_benchmark.checkpointing.pytorch_s3_checkpointing import PyTorchS3Checkpointing
4444
return PyTorchS3Checkpointing.get_instance()
45+
elif checkpoint_mechanism_type == CheckpointMechanismType.PT_ADLS_SAVE:
46+
from dlio_benchmark.checkpointing.pytorch_adls_checkpointing import PyTorchADLSCheckpointing
47+
return PyTorchADLSCheckpointing.get_instance()
4548
else:
4649
raise Exception(str(ErrorCodes.EC1005))
Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
"""
2+
Copyright (c) 2025, UChicago Argonne, LLC
3+
All Rights Reserved
4+
5+
Licensed under the Apache License, Version 2.0 (the "License");
6+
you may not use this file except in compliance with the License.
7+
You may obtain a copy of the License at
8+
9+
http://www.apache.org/licenses/LICENSE-2.0
10+
11+
Unless required by applicable law or agreed to in writing, software
12+
distributed under the License is distributed on an "AS IS" BASIS,
13+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
See the License for the specific language governing permissions and
15+
limitations under the License.
16+
"""
17+
from datetime import datetime, timedelta, timezone
18+
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
19+
20+
import torch
21+
from dlio_benchmark.checkpointing.base_checkpointing import BaseCheckpointing
22+
from dlio_benchmark.checkpointing.pytorch_checkpointing import PyTorchCheckpointing
23+
from dlio_benchmark.utils.utility import Profile, dft_ai
24+
25+
from dlio_benchmark.common.constants import MODULE_CHECKPOINT
26+
27+
dlp = Profile(MODULE_CHECKPOINT)
28+
29+
DEFAULT_CONTAINER_SAS_TTL = timedelta(hours=1)
30+
DEFAULT_CONTAINER_SAS_REFRESH_MARGIN = timedelta(minutes=5)
31+
32+
# Import BlobIO at module level to allow test patching
33+
try:
34+
from azstoragetorch.io import BlobIO
35+
except ImportError:
36+
BlobIO = None
37+
38+
try:
39+
from azure.storage.blob import ContainerSasPermissions, generate_container_sas
40+
except ImportError:
41+
ContainerSasPermissions = None
42+
generate_container_sas = None
43+
44+
class PyTorchADLSCheckpointing(PyTorchCheckpointing):
45+
__instance = None
46+
47+
@staticmethod
48+
def get_instance():
49+
""" Static access method. """
50+
if PyTorchADLSCheckpointing.__instance is None:
51+
PyTorchADLSCheckpointing.__instance = PyTorchADLSCheckpointing()
52+
return PyTorchADLSCheckpointing.__instance
53+
54+
@dft_ai.checkpoint.init
55+
def __init__(self):
56+
BaseCheckpointing.__init__(self, "ptadls")
57+
58+
# Check if BlobIO is available
59+
if BlobIO is None:
60+
raise ImportError(
61+
"azstoragetorch is required for ADLS Gen2 checkpointing support. "
62+
"Install with: pip install 'azstoragetorch>=0.1.0'"
63+
)
64+
65+
# Access config values from self.args (inherited from BaseCheckpointing)
66+
storage_options = getattr(self.args, "storage_options", {}) or {}
67+
self._checkpoint_folder = self.args.checkpoint_folder
68+
self._account_name = None
69+
self._account_key = None
70+
self._shared_access_signature = None
71+
self._container_sas_tokens = {}
72+
73+
if not isinstance(storage_options, dict):
74+
storage_options = dict(storage_options)
75+
76+
self._container_sas_ttl = self._get_duration_option(
77+
storage_options,
78+
"container_sas_ttl",
79+
DEFAULT_CONTAINER_SAS_TTL,
80+
)
81+
self._container_sas_refresh_margin = self._get_duration_option(
82+
storage_options,
83+
"sas_refresh_margin",
84+
DEFAULT_CONTAINER_SAS_REFRESH_MARGIN,
85+
)
86+
87+
# Support both connection string and account URL authentication
88+
connection_string = storage_options.get("connection_string")
89+
account_url = storage_options.get("account_url")
90+
account_name = storage_options.get("account_name")
91+
92+
if connection_string:
93+
# Parse connection string and generate SAS-based blob URLs for BlobIO.
94+
self._load_connection_string(connection_string)
95+
elif account_url:
96+
# Use account URL and derive account name for SAS-backed checkpoint URLs.
97+
self._account_name = self._extract_account_name_from_url(account_url)
98+
elif account_name:
99+
# Use explicit account name for SAS-backed checkpoint URLs.
100+
self._account_name = account_name
101+
else:
102+
raise ValueError(
103+
"ADLS Gen2 checkpointing requires authentication configuration. "
104+
"Provide 'connection_string', 'account_url', or 'account_name' in storage_options."
105+
)
106+
107+
if self._account_name is None:
108+
self._account_name = self._extract_account_name_from_abfs(self._checkpoint_folder)
109+
110+
if self._account_name is None:
111+
raise ValueError(
112+
"Unable to determine ADLS account name for checkpointing. "
113+
"Provide storage_options.account_name/account_url or use canonical ABFS checkpoint URI."
114+
)
115+
116+
def _get_duration_option(self, storage_options, option_name, default_value):
117+
value = storage_options.get(option_name)
118+
if value is None:
119+
return default_value
120+
121+
if isinstance(value, timedelta):
122+
return value
123+
124+
if isinstance(value, (int, float)):
125+
return timedelta(seconds=float(value))
126+
127+
if isinstance(value, str):
128+
normalized = value.strip().lower()
129+
if not normalized:
130+
return default_value
131+
suffix_multipliers = {
132+
"s": 1,
133+
"m": 60,
134+
"h": 3600,
135+
"d": 86400,
136+
}
137+
if normalized[-1] in suffix_multipliers:
138+
amount = float(normalized[:-1])
139+
return timedelta(seconds=amount * suffix_multipliers[normalized[-1]])
140+
return timedelta(seconds=float(normalized))
141+
142+
raise ValueError(
143+
f"Invalid duration for storage_options.{option_name}: {value!r}. "
144+
"Use seconds or a string with suffix s, m, h, or d."
145+
)
146+
147+
def _load_connection_string(self, connection_string):
148+
parts = {}
149+
for segment in connection_string.split(';'):
150+
if '=' in segment:
151+
key, value = segment.split('=', 1)
152+
parts[key] = value
153+
154+
self._account_name = parts.get("AccountName")
155+
self._account_key = parts.get("AccountKey")
156+
self._shared_access_signature = parts.get("SharedAccessSignature")
157+
158+
def _extract_account_name_from_url(self, account_url):
159+
parsed = urlparse(account_url)
160+
host = parsed.netloc
161+
if not host:
162+
return None
163+
return host.split('.')[0]
164+
165+
def _extract_account_name_from_abfs(self, uri):
166+
parsed = urlparse(uri)
167+
if parsed.scheme != "abfs" or '@' not in parsed.netloc:
168+
return None
169+
_, account_fqdn = parsed.netloc.split('@', 1)
170+
return account_fqdn.split('.')[0]
171+
172+
def _to_blob_url(self, checkpoint_name, for_write):
173+
parsed = urlparse(checkpoint_name)
174+
175+
if parsed.scheme == "https":
176+
blob_url = checkpoint_name
177+
elif parsed.scheme == "abfs":
178+
if '@' not in parsed.netloc:
179+
raise ValueError(
180+
"Invalid ABFS checkpoint path. Expected format: "
181+
"abfs://<file_system>@<account>.dfs.core.windows.net/<path>"
182+
)
183+
file_system, account_fqdn = parsed.netloc.split('@', 1)
184+
account_name = account_fqdn.split('.')[0]
185+
blob_path = parsed.path.lstrip('/')
186+
blob_url = f"https://{account_name}.blob.core.windows.net/{file_system}/{blob_path}"
187+
else:
188+
raise ValueError(
189+
f"Unsupported checkpoint URI '{checkpoint_name}'. Expected abfs:// or https://"
190+
)
191+
192+
if self._shared_access_signature:
193+
return self._append_query(blob_url, self._shared_access_signature)
194+
195+
if self._account_key:
196+
if generate_container_sas is None or ContainerSasPermissions is None:
197+
raise ImportError(
198+
"azure-storage-blob is required for connection-string-based ADLS checkpointing."
199+
)
200+
blob_parsed = urlparse(blob_url)
201+
path_parts = blob_parsed.path.lstrip('/').split('/', 1)
202+
if len(path_parts) != 2:
203+
raise ValueError(f"Invalid blob URL for checkpointing: {blob_url}")
204+
container_name, _ = path_parts
205+
token = self._get_container_sas(container_name)
206+
return self._append_query(blob_url, token)
207+
208+
return blob_url
209+
210+
def _get_container_sas(self, container_name):
211+
cache_entry = self._container_sas_tokens.get(container_name)
212+
now = datetime.now(timezone.utc)
213+
refresh_margin = getattr(
214+
self,
215+
"_container_sas_refresh_margin",
216+
DEFAULT_CONTAINER_SAS_REFRESH_MARGIN,
217+
)
218+
219+
if isinstance(cache_entry, dict):
220+
token = cache_entry.get("token")
221+
expires_at = cache_entry.get("expires_at")
222+
if token and expires_at and (expires_at - now) > refresh_margin:
223+
return token
224+
225+
ttl = getattr(self, "_container_sas_ttl", DEFAULT_CONTAINER_SAS_TTL)
226+
expiry = now + ttl
227+
228+
token = generate_container_sas(
229+
account_name=self._account_name,
230+
container_name=container_name,
231+
account_key=self._account_key,
232+
permission=ContainerSasPermissions(
233+
read=True,
234+
write=True,
235+
create=True,
236+
add=True,
237+
list=True,
238+
),
239+
expiry=expiry,
240+
)
241+
self._container_sas_tokens[container_name] = {
242+
"token": token,
243+
"expires_at": expiry,
244+
}
245+
return token
246+
247+
def _append_query(self, url, query_string):
248+
parsed = urlparse(url)
249+
existing = parse_qs(parsed.query, keep_blank_values=True)
250+
incoming = parse_qs(query_string.lstrip('?'), keep_blank_values=True)
251+
for key, values in incoming.items():
252+
existing[key] = values
253+
merged_query = urlencode(existing, doseq=True)
254+
return urlunparse(parsed._replace(query=merged_query))
255+
256+
@dft_ai.checkpoint.capture
257+
def save_state(self, suffix, state, fsync = False):
258+
name = self.get_name(suffix)
259+
blob_url = self._to_blob_url(name, for_write=True)
260+
# Save checkpoint to ADLS using azstoragetorch BlobIO
261+
with BlobIO(blob_url, "wb", credential=None) as writer:
262+
torch.save(state, writer)
263+
264+
@dft_ai.checkpoint.restart
265+
def load_state(self, suffix, state):
266+
name = self.get_name(suffix)
267+
blob_url = self._to_blob_url(name, for_write=False)
268+
state = dict() # clear up
269+
# Load checkpoint from ADLS using azstoragetorch BlobIO
270+
with BlobIO(blob_url, "rb", credential=None) as reader:
271+
state = torch.load(reader)
272+
self.logger.debug(f"checkpoint state loaded: {state}")
273+
assert(len(state.keys())>0)
274+
275+
@dlp.log
276+
def save_checkpoint(self, epoch, step_number):
277+
super().save_checkpoint(epoch, step_number)
278+
279+
@dlp.log
280+
def load_checkpoint(self, epoch, step_number):
281+
super().load_checkpoint(epoch, step_number)
282+
283+
@dlp.log
284+
def finalize(self):
285+
super().finalize()
286+

dlio_benchmark/common/enumerations.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class CheckpointMechanismType(Enum):
2727
TF_SAVE = 'tf_save'
2828
PT_SAVE = 'pt_save'
2929
PT_S3_SAVE = 'pt_s3_save'
30+
PT_ADLS_SAVE = 'pt_adls_save'
3031

3132
def __str__(self):
3233
return self.value
@@ -59,6 +60,7 @@ class StorageType(Enum):
5960
PARALLEL_FS = 'parallel_fs'
6061
S3 = 's3'
6162
AISTORE = 'aistore'
63+
ADLS_GEN2 = 'adls_gen2'
6264

6365
def __str__(self):
6466
return self.value
@@ -70,6 +72,7 @@ class MetadataType(Enum):
7072
FILE = 'file'
7173
DIRECTORY = 'directory'
7274
S3_OBJECT = 's3_object'
75+
ADLS_OBJECT = 'adls_object'
7376

7477
def __str__(self):
7578
return self.value

0 commit comments

Comments
 (0)