Skip to content

Commit f8e5dfe

Browse files
Copilotzsy056
authored andcommitted
Initial plan
Add ADLS Gen2 storage backend support Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Change ADLS Gen2 namespace type to HIERARCHICAL Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Implement ADLS Gen2 API calls using Azure SDK Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Fix code review issues: remove unused import and fix bare except clauses Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Add ADLS Gen2 PyTorch checkpointing and fixes Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Remove unused ctypes import from pytorch_adls_checkpointing Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Add ADLS Gen2 support to CI workflow Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Add workflow_dispatch trigger to ci.yml Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Fix Microsoft repository 403 errors in CI Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Implement ADLS Gen2 test suite Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Fix ADLS Gen2 generator support for NPY/NPZ formats Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Move DataLakeServiceClient import to module level for test patching Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Fix ADLS test configuration and add DefaultAzureCredential patching Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Remove redundant comment from test for consistency Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Fix ADLS Gen2 URI parsing in all storage methods Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Refactor URI parsing to use helper method and reduce duplication Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Add debug logging to investigate test failure Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Add debug logging to create_node as well Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Add comprehensive debug logging to ADLS Gen2 storage Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Fix get_uri to avoid double URI prefixing Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Fix mock get_paths to return full paths matching Azure SDK behavior Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Move AzStorageCheckpoint import to module level for test patching Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Fix test_adls_subset to match S3 pattern - remove incorrect checkpoint mode Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Implement proper AzStorageCheckpoint mock with writer/reader context managers Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Remove redundant AzStorageCheckpoint patches from checkpoint tests Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> Fix ADLS checkpoint test: always apply MockAzStorageCheckpoint patch Co-authored-by: zsy056 <1074382+zsy056@users.noreply.github.com> fix adlsgen2 tests
1 parent 8b280cb commit f8e5dfe

File tree

13 files changed

+1836
-7
lines changed

13 files changed

+1836
-7
lines changed

.github/workflows/ci.yml

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ on:
44
pull_request:
55
branches: [main, dev]
66
push:
7+
workflow_dispatch:
78

89
jobs:
910
build-and-test:
@@ -59,6 +60,8 @@ jobs:
5960
key: ${{ matrix.venv }}-gcc${{ matrix.gcc }}-python${{ matrix.python }}-${{ hashFiles('requirements.txt', 'setup.py') }}
6061
- name: Install system dependencies
6162
run: |
63+
sudo rm -f /etc/apt/sources.list.d/microsoft-prod.list
64+
sudo rm -f /etc/apt/sources.list.d/azure-cli.list
6265
sudo apt update
6366
sudo apt-get install -y $CC $CXX libc6 git
6467
sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev python3-dev
@@ -381,3 +384,57 @@ jobs:
381384
mpirun -np 1 pytest -k test_aistore_multi_threads[pytorch-0] -v
382385
mpirun -np 1 pytest -k test_aistore_multi_threads[pytorch-1] -v
383386
mpirun -np 1 pytest -k test_aistore_multi_threads[pytorch-2] -v
387+
388+
# ADLS Gen2-specific setup and tests
389+
- name: Install ADLS Gen2 dependencies
390+
run: |
391+
source ${VENV_PATH}/bin/activate
392+
pip install .[adls]
393+
- name: test_adls_gen_data
394+
run: |
395+
source ${VENV_PATH}/bin/activate
396+
mpirun -np 1 pytest -k test_adls_gen_data[npy-pytorch] -v
397+
mpirun -np 1 pytest -k test_adls_gen_data[npz-pytorch] -v
398+
- name: test_adls_train
399+
run: |
400+
source ${VENV_PATH}/bin/activate
401+
mpirun -np 1 pytest -k test_adls_train[npy-pytorch-pytorch-True] -v
402+
mpirun -np 1 pytest -k test_adls_train[npz-pytorch-pytorch-True] -v
403+
mpirun -np 1 pytest -k test_adls_train[npy-pytorch-pytorch-False] -v
404+
mpirun -np 1 pytest -k test_adls_train[npz-pytorch-pytorch-False] -v
405+
- name: test_adls_eval
406+
run: |
407+
source ${VENV_PATH}/bin/activate
408+
mpirun -np 1 pytest -k test_adls_eval -v
409+
- name: test_adls_multi_threads
410+
run: |
411+
source ${VENV_PATH}/bin/activate
412+
mpirun -np 1 pytest -k test_adls_multi_threads[pytorch-0] -v
413+
mpirun -np 1 pytest -k test_adls_multi_threads[pytorch-1] -v
414+
mpirun -np 1 pytest -k test_adls_multi_threads[pytorch-2] -v
415+
- name: test_adls_pytorch_multiprocessing_context
416+
run: |
417+
source ${VENV_PATH}/bin/activate
418+
mpirun -np 1 pytest -k test_adls_pytorch_multiprocessing_context[0-None] -v
419+
mpirun -np 1 pytest -k test_adls_pytorch_multiprocessing_context[1-fork] -v
420+
- name: test_adls_subset
421+
run: |
422+
source ${VENV_PATH}/bin/activate
423+
mpirun -np 1 pytest -k test_adls_subset -v
424+
- name: test_adls_checkpoint_epoch
425+
run: |
426+
source ${VENV_PATH}/bin/activate
427+
mpirun -np 1 pytest -k test_adls_checkpoint_epoch[pytorch-1024-optimizers0-2-layer_params0-0-True] -v
428+
mpirun -np 1 pytest -k test_adls_checkpoint_epoch[pytorch-1024-optimizers1-2-layer_params1-3-True] -v
429+
mpirun -np 1 pytest -k test_adls_checkpoint_epoch[pytorch-1024-optimizers2-1-layer_params2-0-True] -v
430+
mpirun -np 1 pytest -k test_adls_checkpoint_epoch[pytorch-1024-optimizers3-2-layer_params3-0-False] -v
431+
mpirun -np 1 pytest -k test_adls_checkpoint_epoch[pytorch-1024-optimizers4-2-layer_params4-3-False] -v
432+
mpirun -np 1 pytest -k test_adls_checkpoint_epoch[pytorch-1024-optimizers5-1-layer_params5-0-False] -v
433+
- name: test_adls_checkpoint_ksm_config
434+
run: |
435+
source ${VENV_PATH}/bin/activate
436+
mpirun -np 1 pytest -k test_adls_checkpoint_ksm_config -v
437+
- name: test_adls_checkpoint_step
438+
run: |
439+
source ${VENV_PATH}/bin/activate
440+
mpirun -np 1 pytest -k test_adls_checkpoint_step -v

dlio_benchmark/checkpointing/checkpointing_factory.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,8 @@ def get_mechanism(checkpoint_mechanism_type):
4242
elif checkpoint_mechanism_type == CheckpointMechanismType.PT_S3_SAVE:
4343
from dlio_benchmark.checkpointing.pytorch_s3_checkpointing import PyTorchS3Checkpointing
4444
return PyTorchS3Checkpointing.get_instance()
45+
elif checkpoint_mechanism_type == CheckpointMechanismType.PT_ADLS_SAVE:
46+
from dlio_benchmark.checkpointing.pytorch_adls_checkpointing import PyTorchADLSCheckpointing
47+
return PyTorchADLSCheckpointing.get_instance()
4548
else:
4649
raise Exception(str(ErrorCodes.EC1005))
Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
"""
2+
Copyright (c) 2025, UChicago Argonne, LLC
3+
All Rights Reserved
4+
5+
Licensed under the Apache License, Version 2.0 (the "License");
6+
you may not use this file except in compliance with the License.
7+
You may obtain a copy of the License at
8+
9+
http://www.apache.org/licenses/LICENSE-2.0
10+
11+
Unless required by applicable law or agreed to in writing, software
12+
distributed under the License is distributed on an "AS IS" BASIS,
13+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
See the License for the specific language governing permissions and
15+
limitations under the License.
16+
"""
17+
from datetime import datetime, timedelta, timezone
18+
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
19+
20+
import torch
21+
from dlio_benchmark.checkpointing.base_checkpointing import BaseCheckpointing
22+
from dlio_benchmark.checkpointing.pytorch_checkpointing import PyTorchCheckpointing
23+
from dlio_benchmark.utils.utility import Profile, dft_ai
24+
25+
from dlio_benchmark.common.constants import MODULE_CHECKPOINT
26+
27+
dlp = Profile(MODULE_CHECKPOINT)
28+
29+
# Import BlobIO at module level to allow test patching
30+
try:
31+
from azstoragetorch.io import BlobIO
32+
except ImportError:
33+
BlobIO = None
34+
35+
try:
36+
from azure.storage.blob import ContainerSasPermissions, generate_container_sas
37+
except ImportError:
38+
ContainerSasPermissions = None
39+
generate_container_sas = None
40+
41+
class PyTorchADLSCheckpointing(PyTorchCheckpointing):
42+
__instance = None
43+
44+
@staticmethod
45+
def get_instance():
46+
""" Static access method. """
47+
if PyTorchADLSCheckpointing.__instance is None:
48+
PyTorchADLSCheckpointing.__instance = PyTorchADLSCheckpointing()
49+
return PyTorchADLSCheckpointing.__instance
50+
51+
@dft_ai.checkpoint.init
52+
def __init__(self):
53+
BaseCheckpointing.__init__(self, "ptadls")
54+
55+
# Check if BlobIO is available
56+
if BlobIO is None:
57+
raise ImportError(
58+
"azstoragetorch is required for ADLS Gen2 checkpointing support. "
59+
"Install with: pip install 'azstoragetorch>=0.1.0'"
60+
)
61+
62+
# Access config values from self.args (inherited from BaseCheckpointing)
63+
storage_options = getattr(self.args, "storage_options", {}) or {}
64+
self._checkpoint_folder = self.args.checkpoint_folder
65+
self._account_name = None
66+
self._account_key = None
67+
self._shared_access_signature = None
68+
self._container_sas_tokens = {}
69+
70+
if not isinstance(storage_options, dict):
71+
storage_options = dict(storage_options)
72+
73+
self._container_sas_ttl = self._get_duration_option(
74+
storage_options,
75+
"container_sas_ttl",
76+
self.args.adls_container_sas_ttl,
77+
)
78+
self._container_sas_refresh_margin = self._get_duration_option(
79+
storage_options,
80+
"sas_refresh_margin",
81+
self.args.adls_sas_refresh_margin,
82+
)
83+
84+
# Support both connection string and account URL authentication
85+
connection_string = storage_options.get("connection_string")
86+
account_url = storage_options.get("account_url")
87+
account_name = storage_options.get("account_name")
88+
89+
if connection_string:
90+
# Parse connection string and generate SAS-based blob URLs for BlobIO.
91+
self._load_connection_string(connection_string)
92+
elif account_url:
93+
# Use account URL and derive account name for SAS-backed checkpoint URLs.
94+
self._account_name = self._extract_account_name_from_url(account_url)
95+
elif account_name:
96+
# Use explicit account name for SAS-backed checkpoint URLs.
97+
self._account_name = account_name
98+
else:
99+
raise ValueError(
100+
"ADLS Gen2 checkpointing requires authentication configuration. "
101+
"Provide 'connection_string', 'account_url', or 'account_name' in storage_options."
102+
)
103+
104+
if self._account_name is None:
105+
self._account_name = self._extract_account_name_from_abfs(self._checkpoint_folder)
106+
107+
if self._account_name is None:
108+
raise ValueError(
109+
"Unable to determine ADLS account name for checkpointing. "
110+
"Provide storage_options.account_name/account_url or use canonical ABFS checkpoint URI."
111+
)
112+
113+
def _get_duration_option(self, storage_options, option_name, default_value):
114+
value = storage_options.get(option_name)
115+
if value is None:
116+
return default_value
117+
118+
if isinstance(value, timedelta):
119+
return value
120+
121+
if isinstance(value, (int, float)):
122+
return timedelta(seconds=float(value))
123+
124+
if isinstance(value, str):
125+
normalized = value.strip().lower()
126+
if not normalized:
127+
return default_value
128+
suffix_multipliers = {
129+
"s": 1,
130+
"m": 60,
131+
"h": 3600,
132+
"d": 86400,
133+
}
134+
if normalized[-1] in suffix_multipliers:
135+
amount = float(normalized[:-1])
136+
return timedelta(seconds=amount * suffix_multipliers[normalized[-1]])
137+
return timedelta(seconds=float(normalized))
138+
139+
raise ValueError(
140+
f"Invalid duration for storage_options.{option_name}: {value!r}. "
141+
"Use seconds or a string with suffix s, m, h, or d."
142+
)
143+
144+
def _load_connection_string(self, connection_string):
145+
parts = {}
146+
for segment in connection_string.split(';'):
147+
if '=' in segment:
148+
key, value = segment.split('=', 1)
149+
parts[key] = value
150+
151+
self._account_name = parts.get("AccountName")
152+
self._account_key = parts.get("AccountKey")
153+
self._shared_access_signature = parts.get("SharedAccessSignature")
154+
155+
def _extract_account_name_from_url(self, account_url):
156+
parsed = urlparse(account_url)
157+
host = parsed.netloc
158+
if not host:
159+
return None
160+
return host.split('.')[0]
161+
162+
def _extract_account_name_from_abfs(self, uri):
163+
parsed = urlparse(uri)
164+
if parsed.scheme != "abfs" or '@' not in parsed.netloc:
165+
return None
166+
_, account_fqdn = parsed.netloc.split('@', 1)
167+
return account_fqdn.split('.')[0]
168+
169+
def _to_blob_url(self, checkpoint_name, for_write):
170+
parsed = urlparse(checkpoint_name)
171+
172+
if parsed.scheme == "https":
173+
blob_url = checkpoint_name
174+
elif parsed.scheme == "abfs":
175+
if '@' not in parsed.netloc:
176+
raise ValueError(
177+
"Invalid ABFS checkpoint path. Expected format: "
178+
"abfs://<file_system>@<account>.dfs.core.windows.net/<path>"
179+
)
180+
file_system, account_fqdn = parsed.netloc.split('@', 1)
181+
account_name = account_fqdn.split('.')[0]
182+
blob_path = parsed.path.lstrip('/')
183+
blob_url = f"https://{account_name}.blob.core.windows.net/{file_system}/{blob_path}"
184+
else:
185+
raise ValueError(
186+
f"Unsupported checkpoint URI '{checkpoint_name}'. Expected abfs:// or https://"
187+
)
188+
189+
if self._shared_access_signature:
190+
return self._append_query(blob_url, self._shared_access_signature)
191+
192+
if self._account_key:
193+
if generate_container_sas is None or ContainerSasPermissions is None:
194+
raise ImportError(
195+
"azure-storage-blob is required for connection-string-based ADLS checkpointing."
196+
)
197+
blob_parsed = urlparse(blob_url)
198+
path_parts = blob_parsed.path.lstrip('/').split('/', 1)
199+
if len(path_parts) != 2:
200+
raise ValueError(f"Invalid blob URL for checkpointing: {blob_url}")
201+
container_name, _ = path_parts
202+
token = self._get_container_sas(container_name)
203+
return self._append_query(blob_url, token)
204+
205+
return blob_url
206+
207+
def _get_container_sas(self, container_name):
208+
cache_entry = self._container_sas_tokens.get(container_name)
209+
now = datetime.now(timezone.utc)
210+
refresh_margin = self._container_sas_refresh_margin
211+
212+
if isinstance(cache_entry, dict):
213+
token = cache_entry.get("token")
214+
expires_at = cache_entry.get("expires_at")
215+
if token and expires_at and (expires_at - now) > refresh_margin:
216+
return token
217+
218+
ttl = self._container_sas_ttl
219+
expiry = now + ttl
220+
221+
token = generate_container_sas(
222+
account_name=self._account_name,
223+
container_name=container_name,
224+
account_key=self._account_key,
225+
permission=ContainerSasPermissions(
226+
read=True,
227+
write=True,
228+
create=True,
229+
add=True,
230+
list=True,
231+
),
232+
expiry=expiry,
233+
)
234+
self._container_sas_tokens[container_name] = {
235+
"token": token,
236+
"expires_at": expiry,
237+
}
238+
return token
239+
240+
def _append_query(self, url, query_string):
241+
parsed = urlparse(url)
242+
existing = parse_qs(parsed.query, keep_blank_values=True)
243+
incoming = parse_qs(query_string.lstrip('?'), keep_blank_values=True)
244+
for key, values in incoming.items():
245+
existing[key] = values
246+
merged_query = urlencode(existing, doseq=True)
247+
return urlunparse(parsed._replace(query=merged_query))
248+
249+
@dft_ai.checkpoint.capture
250+
def save_state(self, suffix, state, fsync = False):
251+
name = self.get_name(suffix)
252+
blob_url = self._to_blob_url(name, for_write=True)
253+
# Save checkpoint to ADLS using azstoragetorch BlobIO
254+
with BlobIO(blob_url, "wb", credential=None) as writer:
255+
torch.save(state, writer)
256+
257+
@dft_ai.checkpoint.restart
258+
def load_state(self, suffix, state):
259+
name = self.get_name(suffix)
260+
blob_url = self._to_blob_url(name, for_write=False)
261+
state = dict() # clear up
262+
# Load checkpoint from ADLS using azstoragetorch BlobIO
263+
with BlobIO(blob_url, "rb", credential=None) as reader:
264+
state = torch.load(reader)
265+
self.logger.debug(f"checkpoint state loaded: {state}")
266+
assert(len(state.keys())>0)
267+
268+
@dlp.log
269+
def save_checkpoint(self, epoch, step_number):
270+
super().save_checkpoint(epoch, step_number)
271+
272+
@dlp.log
273+
def load_checkpoint(self, epoch, step_number):
274+
super().load_checkpoint(epoch, step_number)
275+
276+
@dlp.log
277+
def finalize(self):
278+
super().finalize()
279+

dlio_benchmark/common/enumerations.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class CheckpointMechanismType(Enum):
2727
TF_SAVE = 'tf_save'
2828
PT_SAVE = 'pt_save'
2929
PT_S3_SAVE = 'pt_s3_save'
30+
PT_ADLS_SAVE = 'pt_adls_save'
3031

3132
def __str__(self):
3233
return self.value
@@ -59,6 +60,7 @@ class StorageType(Enum):
5960
PARALLEL_FS = 'parallel_fs'
6061
S3 = 's3'
6162
AISTORE = 'aistore'
63+
ADLS_GEN2 = 'adls_gen2'
6264

6365
def __str__(self):
6466
return self.value
@@ -70,6 +72,7 @@ class MetadataType(Enum):
7072
FILE = 'file'
7173
DIRECTORY = 'directory'
7274
S3_OBJECT = 's3_object'
75+
ADLS_OBJECT = 'adls_object'
7376

7477
def __str__(self):
7578
return self.value

0 commit comments

Comments
 (0)