diff --git a/daos_pytorch.md b/daos_pytorch.md new file mode 100644 index 00000000..2cee2033 --- /dev/null +++ b/daos_pytorch.md @@ -0,0 +1,48 @@ +# Running DLIO benchmark with DAOS PyTorch DataLoader + +## Prerequisites + + - Python 3.10 or higher + - openmpi and openmpi-devel packages + - DAOS client libraries built and installed from [master branch](https://github.com/daos-stack/daos/tree/master/src/client/pydaos/torch) + - configured and working DAOS agent on the compute nodes + + +## Getting started + + +Since `DAOS PyTorch` client was not released outside the `master` branch you'd need to build and install the `pydaos3-package` on the compute node manually (`torch` integration comes with `pydaos` package): + + +```bash + +$: pip install $(DAOS_BUILD_OUTPUT)/install/lib/daos/python + +``` + + + +## Example of running benchmark with `DAOS` PyTorch client + + +```bash +# LD_LIBRARY_PATH is needed to load DAOS libraries from build directory +export LD_LIBRARY_PATH=/lus/flare/projects/DAOS_Testing/daos/install/lib64/:$LD_LIBRARY_PATH + +mpiexec --np ${NTOTRANKS} -ppn ${NRANKS} --cpu-bind depth -d ${NDEPTH} --no-vni \ + dlio_benchmark workload=daos_pytorch \ + ++workload.workflow.generate_data=True \ + ++workload.dataset.daos_pool=DAOS_Testing \ + ++workload.dataset.daos_cont=defaults \ + ++workload.workflow.checkpoint=True \ + ++workload.checkpoint.checkpoint_daos_pool=DAOS_Testing \ + ++workload.checkpoint.checkpoint_daos_cont=defaults \ + ++workload.checkpoint.checkpoint_folder=/checkpoints \ + ++workload.dataset.data_folder=/datasets/small-08 \ + ++workload.dataset.num_files_train=80000 \ + ++workload.dataset.num_files_eval=10000 \ + ++workload.reader.batch_size=32 \ + ++workload.reader.read_threads=4 \ + ++workload.dataset.record_length_bytes=1048576 \ + ++workload.train.epochs=5 +``` diff --git a/dlio_benchmark/checkpointing/pytorch_daos_checkpointing.py b/dlio_benchmark/checkpointing/pytorch_daos_checkpointing.py new file mode 100644 index 00000000..65d1b3da --- /dev/null +++ b/dlio_benchmark/checkpointing/pytorch_daos_checkpointing.py @@ -0,0 +1,80 @@ +""" + Copyright (c) 2026, Enakta Labs Ltd + All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import logging +import torch +from pydaos.torch import Checkpoint as DaosCheckpoint + +from dlio_benchmark.checkpointing.base_checkpointing import BaseCheckpointing +from dlio_benchmark.checkpointing.pytorch_checkpointing import PyTorchCheckpointing +from dlio_benchmark.utils.utility import Profile, dft_ai + +from dlio_benchmark.common.constants import MODULE_CHECKPOINT + +dlp = Profile(MODULE_CHECKPOINT) + + +class PyTorchDaosCheckpointing(PyTorchCheckpointing): + __instance = None + + @staticmethod + def get_instance(): + """ Static access method. """ + if PyTorchDaosCheckpointing.__instance is None: + logging.basicConfig(level=logging.INFO) + PyTorchDaosCheckpointing.__instance = PyTorchDaosCheckpointing() + return PyTorchDaosCheckpointing.__instance + + @dft_ai.checkpoint.init + def __init__(self): + BaseCheckpointing.__init__(self, "pt") + + prefix = self.args.checkpoint_folder + pool = self.args.checkpoint_daos_pool + cont = self.args.checkpoint_daos_cont + chunk_size = self.args.checkpoint_daos_chunk_size + chunks_limit = self.args.checkpoint_daos_chunks_limit + + logging.info(f"Checkpointing is set to DAOS pool: {pool}, container: {cont}, prefix: {prefix}, chunk_size: {chunk_size} and chunks_limit: {chunks_limit}") + self.ckpt = DaosCheckpoint(pool, cont, prefix, transfer_chunk_size=chunk_size, chunks_limit=chunks_limit) + + @dft_ai.checkpoint.capture + def save_state(self, suffix, state, fsync = False): + name = self.get_name(suffix) + with self.ckpt.writer(name) as f: + torch.save(state, f) + + @dft_ai.checkpoint.restart + def load_state(self, suffix, state): + name = self.get_name(suffix) + state = dict() + with self.ckpt.reader(name) as f: + state = torch.load(f) + self.logger.debug(f"checkpoint state loaded: {state}") + assert(len(state.keys())>0) + + @dft_ai.checkpoint.capture + def save_checkpoint(self, epoch, step_number): + super().save_checkpoint(epoch, step_number) + + @dlp.log + def load_checkpoint(self, epoch, step_number): + super().load_checkpoint(epoch, step_number) + + @dlp.log + def finalize(self): + super().finalize() diff --git a/dlio_benchmark/common/enumerations.py b/dlio_benchmark/common/enumerations.py index 2c61475d..b3bf0033 100644 --- a/dlio_benchmark/common/enumerations.py +++ b/dlio_benchmark/common/enumerations.py @@ -59,6 +59,7 @@ class StorageType(Enum): PARALLEL_FS = 'parallel_fs' S3 = 's3' AISTORE = 'aistore' + DAOS_PYTORCH = 'daos_pytorch' def __str__(self): return self.value @@ -174,6 +175,7 @@ class DataLoaderType(Enum): CUSTOM='custom' NONE='none' SYNTHETIC='synthetic' + DAOS_PYTORCH="daos_pytorch" def __str__(self): return self.value diff --git a/dlio_benchmark/configs/workload/daos_pytorch.yaml b/dlio_benchmark/configs/workload/daos_pytorch.yaml new file mode 100644 index 00000000..e8620c1a --- /dev/null +++ b/dlio_benchmark/configs/workload/daos_pytorch.yaml @@ -0,0 +1,37 @@ +model: + name: default + +framework: pytorch +storage: + storage_type: "daos_pytorch" + storage_root: "/" + +workflow: + generate_data: False + train: True + evaluation: True + profiling: False + +dataset: + data_folder: /data/default + format: npz + num_files_train: 1024 + num_files_eval: 64 + num_samples_per_file: 1 + record_length: 4096 + num_subfolders_train: 2 + num_subfolders_eval: 2 + daos_pool: default-pool + daos_cont: samples + + +reader: + data_loader: daos_pytorch + batch_size: 32 + read_threads: 4 + +checkpoint: + checkpoint_folder: /checkpoints + checkpoint_daos_pool: default-pool + checkpoint_daos_cont: checkpoints + checkpoint_mechanism_classname: dlio_benchmark.checkpointing.pytorch_daos_checkpointing.PyTorchDaosCheckpointing diff --git a/dlio_benchmark/data_loader/base_data_loader.py b/dlio_benchmark/data_loader/base_data_loader.py index 97f15e6a..39c4e490 100644 --- a/dlio_benchmark/data_loader/base_data_loader.py +++ b/dlio_benchmark/data_loader/base_data_loader.py @@ -35,6 +35,7 @@ def __init__(self, format_type, dataset_type, epoch_number, data_loader_type): self.data_loader_type = data_loader_type self.num_samples = self._args.total_samples_train if self.dataset_type is DatasetType.TRAIN else self._args.total_samples_eval self.batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval + self.read_threads = self._args.read_threads self.logger = self._args.logger @abstractmethod diff --git a/dlio_benchmark/data_loader/daos_torch_data_loader.py b/dlio_benchmark/data_loader/daos_torch_data_loader.py new file mode 100644 index 00000000..f96dfdc7 --- /dev/null +++ b/dlio_benchmark/data_loader/daos_torch_data_loader.py @@ -0,0 +1,114 @@ +""" + Copyright (c) 2026, Enakta Labs, LTD + All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import os +import io +import numpy as np +import math +import pickle +from torch.utils.data import Dataset + +from dlio_benchmark.common.constants import MODULE_DATA_LOADER +from dlio_benchmark.common.enumerations import DataLoaderType, FormatType +from dlio_benchmark.data_loader.torch_data_loader import BaseTorchDataLoader +from dlio_benchmark.utils.utility import utcnow, DLIOMPI, Profile, dft_ai +from dlio_benchmark.utils.config import ConfigArguments + +dlp = Profile(MODULE_DATA_LOADER) + + +def get_format_reader(format): + if format == FormatType.NPZ: + return lambda b: np.load(io.BytesIO(b), allow_pickle=True)["x"] + elif format == FormatType.NPY: + return lambda b: np.load(io.BytesIO(b), allow_pickle=True) + else: + raise ValueError(f"TorchDaosDataset does not support {format} format") + +class TorchDaosDataset(Dataset): + """ + Wrapper over DaosDataset to log calls for the profiler + """ + @dlp.log_init + def __init__(self, format_type, dataset_type, epoch, num_samples, num_workers, batch_size): + self.format_type = format_type + self.dataset_type = dataset_type + self.epoch_number = epoch + self.num_samples = num_samples + self.num_images_read = 0 + self.batch_size = batch_size + args = ConfigArguments.get_instance() + self.serial_args = pickle.dumps(args) + self.logger = args.logger + self.dlp_logger = None + + # to avoid loading pydoas.torch at the top level if not needed or not installed + from pydaos.torch import Dataset as DaosDataset + + prefix = os.path.join(args.data_folder, f"{self.dataset_type}") + self.dataset = DaosDataset(pool=args.daos_pool, + cont=args.daos_cont, + path=prefix, + transform_fn=get_format_reader(self.format_type)) + + # self.num_samples = len(self.dataset) + if num_workers == 0: + self.worker_init(-1) + + + @dlp.log + def worker_init(self, worker_id): + pickle.loads(self.serial_args) + _args = ConfigArguments.get_instance() + _args.configure_dlio_logging(is_child=True) + self.dlp_logger = _args.configure_dftracer(is_child=True, use_pid=True) + self.logger.debug(f"{utcnow()} worker initialized {worker_id} with format {self.format_type}") + self.dataset.worker_init(worker_id) + + def __del__(self): + if self.dlp_logger: + self.dlp_logger.finalize() + + @dlp.log + def __len__(self): + return self.num_samples + + @dlp.log + def __getitem__(self, image_idx): + self.num_images_read += 1 + step = int(math.ceil(self.num_images_read / self.batch_size)) + self.logger.debug(f"{utcnow()} Rank {DLIOMPI.get_instance().rank()} reading {image_idx} sample") + dlp.update(step=step) + dft_ai.update(step=step) + return self.dataset.__getitem__(image_idx) + + @dlp.log + def __getitems__(self, indices): + self.num_images_read += len(indices) + step = int(math.ceil(self.num_images_read / self.batch_size)) + self.logger.debug(f"{utcnow()} Rank {DLIOMPI.get_instance().rank()} reading {len(indices)} samples") + dlp.update(step=step) + dft_ai.update(step=step) + return self.dataset.__getitems__(indices) + +class DaosTorchDataLoader(BaseTorchDataLoader): + @dlp.log_init + def __init__(self, format_type, dataset_type, epoch_number): + super().__init__(format_type, dataset_type, epoch_number, DataLoaderType.DAOS_PYTORCH) + + def get_dataset(self) -> Dataset: + return TorchDaosDataset(self.format_type, self.dataset_type, self.epoch_number, self.num_samples, + self.read_threads, self.batch_size) diff --git a/dlio_benchmark/data_loader/data_loader_factory.py b/dlio_benchmark/data_loader/data_loader_factory.py index 087dda03..8dfd3323 100644 --- a/dlio_benchmark/data_loader/data_loader_factory.py +++ b/dlio_benchmark/data_loader/data_loader_factory.py @@ -40,6 +40,9 @@ def get_loader(type, format_type, dataset_type, epoch): elif type == DataLoaderType.PYTORCH: from dlio_benchmark.data_loader.torch_data_loader import TorchDataLoader return TorchDataLoader(format_type, dataset_type, epoch) + elif type == DataLoaderType.DAOS_PYTORCH: + from dlio_benchmark.data_loader.daos_torch_data_loader import DaosTorchDataLoader + return DaosTorchDataLoader(format_type, dataset_type, epoch) elif type == DataLoaderType.TENSORFLOW: from dlio_benchmark.data_loader.tf_data_loader import TFDataLoader return TFDataLoader(format_type, dataset_type, epoch) diff --git a/dlio_benchmark/data_loader/torch_data_loader.py b/dlio_benchmark/data_loader/torch_data_loader.py index 840858f9..99fa02c3 100644 --- a/dlio_benchmark/data_loader/torch_data_loader.py +++ b/dlio_benchmark/data_loader/torch_data_loader.py @@ -1,5 +1,6 @@ """ Copyright (c) 2025, UChicago Argonne, LLC + Copyright (c) 2026, Enakta Labs, LTD All Rights Reserved Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,14 +15,16 @@ See the License for the specific language governing permissions and limitations under the License. """ +import os import math import pickle import torch +from abc import abstractmethod from torch.utils.data import Dataset, DataLoader from torch.utils.data.sampler import Sampler from dlio_benchmark.common.constants import MODULE_DATA_LOADER -from dlio_benchmark.common.enumerations import DatasetType, DataLoaderType +from dlio_benchmark.common.enumerations import DatasetType, DataLoaderType, FormatType from dlio_benchmark.data_loader.base_data_loader import BaseDataLoader from dlio_benchmark.reader.reader_factory import ReaderFactory from dlio_benchmark.utils.utility import utcnow, DLIOMPI, Profile, dft_ai @@ -87,7 +90,7 @@ def __init__(self, rank, size, num_samples, epochs): self.rank = rank self.num_samples = num_samples self.epochs = epochs - samples_per_proc = int(math.ceil(num_samples/size)) + samples_per_proc = int(math.ceil(num_samples/size)) start_sample = self.rank * samples_per_proc end_sample = (self.rank + 1) * samples_per_proc - 1 if end_sample > num_samples - 1: @@ -103,15 +106,18 @@ def __iter__(self): yield sample -class TorchDataLoader(BaseDataLoader): +class BaseTorchDataLoader(BaseDataLoader): @dlp.log_init - def __init__(self, format_type, dataset_type, epoch_number): - super().__init__(format_type, dataset_type, epoch_number, DataLoaderType.PYTORCH) + def __init__(self, format_type, dataset_type, epoch_number, data_loader_type): + super().__init__(format_type, dataset_type, epoch_number, data_loader_type) + + @abstractmethod + def get_dataset(self) -> Dataset: + return None @dlp.log def read(self): - dataset = TorchDataset(self.format_type, self.dataset_type, self.epoch_number, self.num_samples, - self._args.read_threads, self.batch_size) + dataset = self.get_dataset() sampler = dlio_sampler(self._args.my_rank, self._args.comm_size, self.num_samples, self._args.epochs) if self._args.read_threads >= 1: prefetch_factor = math.ceil(self._args.prefetch_size / self._args.read_threads) @@ -132,7 +138,7 @@ def read(self): else: kwargs={'multiprocessing_context':self._args.multiprocessing_context, 'prefetch_factor': prefetch_factor} - if torch.__version__ != '1.3.1': + if torch.__version__ != '1.3.1': kwargs['persistent_workers'] = True if torch.__version__ == '1.3.1': if 'prefetch_factor' in kwargs: @@ -143,9 +149,9 @@ def read(self): num_workers=self._args.read_threads, pin_memory=self._args.pin_memory, drop_last=True, - worker_init_fn=dataset.worker_init, + worker_init_fn=dataset.worker_init, **kwargs) - else: + else: self._dataset = DataLoader(dataset, batch_size=self.batch_size, sampler=sampler, @@ -176,3 +182,12 @@ def next(self): @dlp.log def finalize(self): pass + +class TorchDataLoader(BaseTorchDataLoader): + @dlp.log_init + def __init__(self, format_type, dataset_type, epoch_number): + super().__init__(format_type, dataset_type, epoch_number, DataLoaderType.PYTORCH) + + def get_dataset(self) -> Dataset: + return TorchDataset(self.format_type, self.dataset_type, self.epoch_number, self.num_samples, + self.read_threads, self.batch_size) diff --git a/dlio_benchmark/storage/daos_torch_storage.py b/dlio_benchmark/storage/daos_torch_storage.py new file mode 100644 index 00000000..df9ee432 --- /dev/null +++ b/dlio_benchmark/storage/daos_torch_storage.py @@ -0,0 +1,167 @@ +""" + Copyright (c) 2026, Enakta Labs Ltd + All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import os + +from pydaos.torch import Dataset, Checkpoint +from dlio_benchmark.utils.utility import Profile +from dlio_benchmark.common.constants import MODULE_STORAGE +from dlio_benchmark.storage.storage_handler import DataStorage, Namespace +from dlio_benchmark.common.enumerations import NamespaceType, MetadataType +from dlio_benchmark.utils.config import ConfigArguments +from concurrent.futures import ThreadPoolExecutor, as_completed +from multiprocessing import cpu_count + +dlp = Profile(MODULE_STORAGE) + + +class DaosTorchStorage(DataStorage): + """ + Implementation of DataStorage interface for via DAOS Pytorch client integration. + There's no generic, POSIX like Python interface yet, so this implementation relies only on what Dataset provides: + list of file names. Which, then converted to list of files and directories so get_node and walk_node + operate on these two lists. + pydaos.torch.Checkpoint interface is used to implement put_data. + """ + + @dlp.log_init + def __init__(self, namespace="/", framework=None): + super().__init__(framework) + self.namespace = Namespace(namespace, NamespaceType.HIERARCHICAL) + + args = ConfigArguments.get_instance() + + self.pool = args.daos_pool + self.cont = args.daos_cont + self.prefix = args.data_folder + + self._dirs = None + self._files = None + self._checkpoint = None + + # Should initialize DAOS early in the parent process for the sake of atfork call in the module init + if args.generate_data or args.generate_only: + self._dirs = [] + self._files = [] + # setting transfer_chunk_size to zero forces Checkpoint interface to use sync write call + # otherwise it would spawn a queue with several threads, which is overkill for writing sample files + self._checkpoint = Checkpoint(pool=self.pool, cont=self.cont, transfer_chunk_size=0) + else: + self.ensure_cache() + + @dlp.log + def get_uri(self, id): + return os.path.join(self.namespace.name, self.prefix, id) + + @dlp.log + def create_namespace(self, exist_ok=False): + return True + + @dlp.log + def get_namespace(self): + return self.namespace.name + + @dlp.log + def create_node(self, id, exist_ok=False): + """ + This will only work for checkpoints: DAOS Checkpoint interface ensures that path exists + before writing the checkpoint file. + """ + return True + + @dlp.log + def get_node(self, id=""): + self.ensure_cache() + + path = self.get_uri(id) + path = os.path.normpath(path) + + for dir in self._dirs: + if dir.startswith(path): + return MetadataType.DIRECTORY + return MetadataType.FILE + + @dlp.log + def walk_node(self, id, use_pattern=False): + self.ensure_cache() + + path = self.get_uri(id) + + if use_pattern: + path = path[:path.find("*")] + return [f for f in self._files if f.startswith(path)] + + if not path.endswith(os.sep): + path += os.sep + + pref_len = len(path) + files = [f for f in self._files if f.startswith(path) and f.find(os.sep, pref_len) < 0] + dirs = [d for d in self._dirs if d.startswith(path) and d.find(os.sep, pref_len) < 0 and len(d) > pref_len] + + return files + dirs + + @dlp.log + def delete_node(self, id): + raise NotImplementedError + + @dlp.log + def put_data(self, id, data, offset=None, length=None): + # in case when a caller wants to list files after writing new ones + # cache needs to be invalidate to force it re-read directories later + self.invalidate_cache() + + with self._checkpoint.writer(id) as w: + w.write(data) + + @dlp.log + def get_data(self, id, data, offset=None, length=None): + raise NotImplementedError + + def get_basename(self, id): + return os.path.basename(id) + + def invalidate_cache(self): + self._dirs = None + self._files = None + + def ensure_cache(self): + if self._dirs: + return + + with Dataset(self.pool, self.cont, self.prefix) as dataset: + files = [name for (name, size) in dataset.objects] + + def get_dir(fname): + d = os.path.dirname(fname) + return os.path.normpath(d) + + def process_chunk(chunk): + return {get_dir(fname) for fname in chunk} + + workers = cpu_count() + chunk_size = len(files) // min(workers, len(files)) + + dirs = set() + with ThreadPoolExecutor(max_workers=workers) as ex: + futures = [ex.submit(process_chunk, files[i:i + chunk_size]) + for i in range(0, len(files), chunk_size)] + + for future in as_completed(futures): + dirs.update(future.result()) + + self._dirs = dirs + self._files = files diff --git a/dlio_benchmark/storage/storage_factory.py b/dlio_benchmark/storage/storage_factory.py index e346187c..33caa1de 100644 --- a/dlio_benchmark/storage/storage_factory.py +++ b/dlio_benchmark/storage/storage_factory.py @@ -48,5 +48,8 @@ def get_storage(storage_type, namespace, framework=None): from dlio_benchmark.storage.s3_torch_storage import S3PyTorchConnectorStorage return S3PyTorchConnectorStorage(namespace, framework) return S3Storage(namespace, framework) + elif storage_type == StorageType.DAOS_PYTORCH: + from dlio_benchmark.storage.daos_torch_storage import DaosTorchStorage + return DaosTorchStorage(namespace, framework) else: raise Exception(str(ErrorCodes.EC1001)) diff --git a/dlio_benchmark/utils/config.py b/dlio_benchmark/utils/config.py index 15a1071d..98dc39c1 100644 --- a/dlio_benchmark/utils/config.py +++ b/dlio_benchmark/utils/config.py @@ -105,6 +105,10 @@ class ConfigArguments: checkpoint_type: CheckpointLocationType = CheckpointLocationType.RANK_ZERO checkpoint_mechanism: CheckpointMechanismType = CheckpointMechanismType.NONE checkpoint_mode: CheckpointModeType = CheckpointModeType.DEFAULT + checkpoint_daos_pool: str = None + checkpoint_daos_cont: str = None + checkpoint_daos_chunk_size: int = 64*1024*1024 + checkpoint_daos_chunks_limit: int = 32 model_datatype: str = "fp16" optimizer_datatype: str = "fp32" checkpoint_fsync: bool = False @@ -145,6 +149,8 @@ class ConfigArguments: multiprocessing_context: str = "fork" pin_memory: bool = True odirect: bool = False + daos_pool: str = None + daos_cont: str = None # derived fields required_samples: int = 1 @@ -489,7 +495,7 @@ def derive_configurations(self, file_list_train=None, file_list_eval=None): if self.data_loader_sampler is None and self.data_loader_classname is None: if self.data_loader == DataLoaderType.TENSORFLOW: self.data_loader_sampler = DataLoaderSampler.ITERATIVE - elif self.data_loader in [DataLoaderType.PYTORCH, DataLoaderType.DALI]: + elif self.data_loader in [DataLoaderType.PYTORCH, DataLoaderType.DAOS_PYTORCH, DataLoaderType.DALI]: self.data_loader_sampler = DataLoaderSampler.INDEX if self.data_loader_classname is not None: from dlio_benchmark.data_loader.base_data_loader import BaseDataLoader @@ -711,6 +717,10 @@ def GetConfig(args, key): value = args.format elif keys[1] == "keep_files": value = args.keep_files + elif keys[1] == "daos_pool": + value = args.daos_pool + elif keys[1] == "daos_cont": + value = args.daos_cont # data reader reader = None @@ -806,8 +816,16 @@ def GetConfig(args, key): value = args.num_checkpoints_read elif keys[1] == "checkpoint_rank_sync": value = args.checkpoint_rank_sync - elif keys[1] == "recovery_rank_shift": + elif keys[1] == "recovery_rank_shift": value = args.checkpoint_recovery_rank_shift + elif keys[1] == "checkpoint_daos_pool": + value = args.checkpoint_daos_pool + elif keys[1] == "checkpoint_daos_cont": + value = args.checkpoint_daos_cont + elif keys[1] == "checkpoint_daos_chunk_size": + value = args.checkpoint_daos_chunk_size + elif keys[1] == "checkpoint_daos_chunks_limit": + value = args.checkpoint_daos_chunks_limit if len(keys) > 1 and keys[0] == "model": if keys[1] == "name": @@ -941,6 +959,10 @@ def LoadConfig(args, config): args.record_element_type = config['dataset']['record_element_type'] if 'record_dims' in config['dataset']: args.record_dims = list(config['dataset']['record_dims']) + if 'daos_pool' in config['dataset']: + args.daos_pool = config['dataset']['daos_pool'] + if 'daos_cont' in config['dataset']: + args.daos_cont = config['dataset']['daos_cont'] # hdf5 only config if 'hdf5' in config['dataset']: @@ -1102,6 +1124,14 @@ def LoadConfig(args, config): args.ksm_low_ram_exit = config['checkpoint']['ksm']['low_ram_exit'] if 'await_time' in config['checkpoint']['ksm']: args.ksm_await_time = config['checkpoint']['ksm']['await_time'] + if 'checkpoint_daos_pool' in config['checkpoint']: + args.checkpoint_daos_pool = config['checkpoint']['checkpoint_daos_pool'] + if 'checkpoint_daos_cont' in config['checkpoint']: + args.checkpoint_daos_cont = config['checkpoint']['checkpoint_daos_cont'] + if 'checkpoint_daos_chunk_size' in config['checkpoint']: + args.checkpoint_daos_chunk_size = config['checkpoint']['checkpoint_daos_chunk_size'] + if 'checkpoint_daos_chunks_limit' in config['checkpoint']: + args.checkpoint_daos_chunks_limit = config['checkpoint']['checkpoint_daos_chunks_limit'] if 'model' in config: if 'name' in config['model']: