Skip to content

Commit a57a352

Browse files
Denis Barakhtanov0xE0F
authored andcommitted
DAOS-16362 pydaos: ensure checkpoint path is created
Some use of Checkpoint assumes that path to the checkpoiunt file will be created with all missing parent directories. For instance, DLIO benchmark writes checkpoints as `/prefix/global_epochX_stepY/layer-Z.pt`. This commit adds `ensure_path` parameter to call `mkdirall` before writing checkpoint file. Features: pytorch Signed-off-by: Denis Barakhtanov <dbarahtanov@enakta.com>
1 parent f59cba3 commit a57a352

File tree

3 files changed

+92
-4
lines changed

3 files changed

+92
-4
lines changed

src/client/pydaos/torch/torch_api.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#
22
# (C) Copyright 2024-2025 Google LLC
3-
# (C) Copyright 2024-2025 Enakta Labs Ltd
3+
# (C) Copyright 2024-2026 Enakta Labs Ltd
44
#
55
# SPDX-License-Identifier: BSD-2-Clause-Patent
66
#
@@ -11,11 +11,13 @@
1111
In addition, it provides Checkpoint class to save and load PyTorch model checkpoints.
1212
"""
1313

14+
import errno
1415
import io
1516
import math
1617
import os
1718
import stat
1819
from multiprocessing import Process, Queue
20+
from pathlib import Path
1921

2022
from torch.utils.data import Dataset as TorchDataset
2123
from torch.utils.data import IterableDataset as TorchIterableDataset
@@ -619,13 +621,16 @@ def reader(self, file, stream=None):
619621
stream.seek(0)
620622
return stream
621623

622-
def writer(self, file):
624+
def writer(self, file, ensure_path=True):
623625
""" Returns write buffer to save the checkpoint file """
624626

625627
if file is None:
626628
raise ValueError("file is required")
627629

628630
path = os.path.join(self._prefix, file)
631+
if ensure_path:
632+
self._dfs.mkdirall(os.path.dirname(path))
633+
629634
return WriteBuffer(self._dfs, path, self._mode, self._oflags,
630635
self._class_name, self._file_chunk_size, self._transfer_chunk_size,
631636
self._chunks_limit, self._workers)
@@ -810,3 +815,18 @@ def get_file_size(self, path):
810815
if ret != 0:
811816
raise OSError(ret, os.strerror(ret), path)
812817
return size
818+
819+
def mkdirall(self, path, mode=0o755):
820+
""" Creates directory, making parent directories if needed """
821+
822+
path = os.path.normpath(path)
823+
dirs = list(Path(path).parts)
824+
if not dirs:
825+
raise ValueError(f"invalid path: {path}")
826+
827+
parent = dirs.pop(0)
828+
for name in dirs:
829+
parent = os.path.join(parent, name)
830+
ret = torch_shim.torch_mkdir(DAOS_MAGIC, self._dfs, parent, mode)
831+
if ret not in (0, errno.EEXIST):
832+
raise OSError(ret, os.strerror(ret), parent)

src/client/pydaos/torch/torch_shim.c

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/**
22
* (C) Copyright 2019-2024 Intel Corporation.
33
* (C) Copyright 2024-2025 Google LLC
4-
* (C) Copyright 2024-2025 Enakta Labs Ltd
4+
* (C) Copyright 2024-2026 Enakta Labs Ltd
55
*
66
* SPDX-License-Identifier: BSD-2-Clause-Patent
77
*/
@@ -1061,6 +1061,39 @@ __shim_handle__torch_get_fsize(PyObject *self, PyObject *args)
10611061
return Py_BuildValue("iK", rc, st.st_size);
10621062
}
10631063

1064+
static PyObject *
1065+
__shim_handle__torch_mkdir(PyObject *self, PyObject *args)
1066+
{
1067+
struct dfs_handle *hdl = NULL;
1068+
char *path = NULL;
1069+
char *dir = NULL;
1070+
char *name = NULL;
1071+
mode_t mode;
1072+
dfs_obj_t *parent = NULL;
1073+
1074+
RETURN_NULL_IF_FAILED_TO_PARSE(args, "LsI", &hdl, &path, &mode);
1075+
1076+
assert(hdl->dfs != NULL);
1077+
1078+
int rc = split_path(path, &dir, &name);
1079+
if (rc) {
1080+
return PyLong_FromLong(rc);
1081+
}
1082+
1083+
rc = lookup_or_insert_dir_obj(hdl, dir, &parent);
1084+
if (rc) {
1085+
D_ERROR("Could not lookup '%s': %s (rc=%d)", dir, strerror(rc), rc);
1086+
goto out;
1087+
}
1088+
1089+
rc = dfs_mkdir(hdl->dfs, parent, name, mode, 0);
1090+
1091+
out:
1092+
D_FREE(dir);
1093+
D_FREE(name);
1094+
return PyLong_FromLong(rc);
1095+
}
1096+
10641097
/**
10651098
* Python shim module
10661099
*/
@@ -1080,6 +1113,7 @@ static PyMethodDef torchMethods[] = {
10801113
EXPORT_PYTHON_METHOD(torch_recommended_dir_split),
10811114
EXPORT_PYTHON_METHOD(torch_list_with_anchor),
10821115
EXPORT_PYTHON_METHOD(torch_get_fsize),
1116+
EXPORT_PYTHON_METHOD(torch_mkdir),
10831117

10841118
EXPORT_PYTHON_METHOD(module_init),
10851119
EXPORT_PYTHON_METHOD(module_fini),

src/tests/ftest/pytorch/checkpoint.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""
22
(C) Copyright 2025 Google LLC
3-
(C) Copyright 2025 Enakta Labs Ltd
3+
(C) Copyright 2025-2026 Enakta Labs Ltd
44
55
SPDX-License-Identifier: BSD-2-Clause-Patent
66
"""
7+
import errno
78
import os
89
import uuid
910

@@ -73,6 +74,39 @@ def test_checkpoint_chunking(self):
7374
chunk_size=chunk_size, chunks_limit=chunks_limit,
7475
workers=worker)
7576

77+
def test_checkpoint_nested_directories(self):
78+
""" Test Pytorch Checkpoint interface with nested directories
79+
Test Description: Ensure that parent directories are created for the checkpoint path
80+
81+
:avocado: tags=all,full_regression
82+
:avocado: tags=vm
83+
:avocado: tags=pytorch
84+
:avocado: tags=PytorchCheckpointTest,test_checkpoint_nested_directories
85+
"""
86+
87+
pool = self.get_pool()
88+
container = self.get_container(pool)
89+
data = os.urandom(4096)
90+
91+
files = ["/file.pt", "/one/file.pt", "/one/two/file.pt"]
92+
93+
with Checkpoint(pool, container) as pt:
94+
# By default parent should be created
95+
for name in files:
96+
with pt.writer(name) as w:
97+
w.write(data)
98+
99+
try:
100+
fname = f"/{str(uuid.uuid4())}/file.pt"
101+
with pt.writer(fname, ensure_path=False) as w:
102+
w.write(data)
103+
raise RuntimeError("expected OSError with errno.ENOENT")
104+
except OSError as e:
105+
if e.errno != errno.ENOENT:
106+
raise RuntimeError(f"expected errno.ENOENT, got {os.strerror(e.errno)}") from e
107+
except Exception as e:
108+
raise RuntimeError(f"unexpected error: {e}") from e
109+
76110
def _test_checkpoint(self, pool, cont, writes, chunk_size=0, chunks_limit=0, workers=0):
77111
"""Creates a checkpoint with the given parameters, writes the given data to it,
78112
then reads written data back from it and compares it with the expected writes.

0 commit comments

Comments
 (0)