Skip to content

Commit b535b99

Browse files
committed
nvme support
1 parent 9e94656 commit b535b99

File tree

4 files changed

+104
-9
lines changed

4 files changed

+104
-9
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,16 @@
2020
import torch
2121

2222
from ..utils import get_logger, is_accelerate_available
23+
from ..utils.import_utils import is_deepspeed_available, is_deepspeed_version
2324
from .hooks import HookRegistry, ModelHook
2425

2526

2627
if is_accelerate_available():
2728
from accelerate.hooks import AlignDevicesHook, CpuOffload
2829
from accelerate.utils import send_to_device
2930

31+
if is_deepspeed_available() and is_deepspeed_version(">=", "0.16"):
32+
from ..utils.state_dict_utils import _fast_aio_save
3033

3134
logger = get_logger(__name__) # pylint: disable=invalid-name
3235

@@ -62,6 +65,7 @@ def __init__(
6265
low_cpu_mem_usage: bool = False,
6366
onload_self: bool = True,
6467
offload_to_disk_path: Optional[str] = None,
68+
_enable_deepnvme_disk_offloading: Optional[bool] = False
6569
) -> None:
6670
self.modules = modules
6771
self.offload_device = offload_device
@@ -80,7 +84,9 @@ def __init__(
8084
self._is_offloaded_to_disk = False
8185

8286
if self.offload_to_disk_path:
83-
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors")
87+
self._enable_deepnvme_disk_offloading = _enable_deepnvme_disk_offloading
88+
ext = ".pt" if _enable_deepnvme_disk_offloading else ".safetensors"
89+
self.param_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.{ext}")
8490

8591
all_tensors = []
8692
for module in self.modules:
@@ -153,8 +159,8 @@ def onload_(self):
153159

154160
with context:
155161
if self.stream is not None:
156-
# Load to CPU, pin, and async copy to device for overlapping transfer and compute
157-
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
162+
# Load to CPU from disk, pin, and async copy to device for overlapping transfer and compute
163+
loaded_cpu_tensors = safetensors.torch.load_file(self.param_file_path, device="cpu")
158164
for key, tensor_obj in self.key_to_tensor.items():
159165
pinned_tensor = loaded_cpu_tensors[key].pin_memory()
160166
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
@@ -165,7 +171,7 @@ def onload_(self):
165171
onload_device = (
166172
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
167173
)
168-
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
174+
loaded_tensors = safetensors.torch.load_file(self.param_file_path, device=onload_device)
169175
for key, tensor_obj in self.key_to_tensor.items():
170176
tensor_obj.data = loaded_tensors[key]
171177
return
@@ -218,15 +224,18 @@ def offload_(self):
218224
if self.offload_to_disk_path:
219225
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
220226
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
221-
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
227+
# overhead. Currently, we just check if the given `param_file_path` exists and if not
222228
# we perform a write.
223229
# Check if the file has been saved in this session or if it already exists on disk.
224-
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
225-
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
230+
if not self._is_offloaded_to_disk and not os.path.exists(self.param_file_path):
231+
os.makedirs(os.path.dirname(self.param_file_path), exist_ok=True)
226232
tensors_to_save = {
227233
key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()
228234
}
229-
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
235+
if not self._enable_deepnvme_disk_offloading:
236+
safetensors.torch.save_file(tensors_to_save, self.param_file_path)
237+
else:
238+
_fast_aio_save(tensors_to_save, self.param_file_path)
230239

231240
# The group is now considered offloaded to disk for the rest of the session.
232241
self._is_offloaded_to_disk = True
@@ -426,6 +435,7 @@ def apply_group_offloading(
426435
record_stream: bool = False,
427436
low_cpu_mem_usage: bool = False,
428437
offload_to_disk_path: Optional[str] = None,
438+
_enable_deepnvme_disk_offloading: Optional[bool] = False
429439
) -> None:
430440
r"""
431441
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -531,6 +541,7 @@ def apply_group_offloading(
531541
stream=stream,
532542
record_stream=record_stream,
533543
low_cpu_mem_usage=low_cpu_mem_usage,
544+
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading
534545
)
535546
elif offload_type == "leaf_level":
536547
_apply_group_offloading_leaf_level(
@@ -542,6 +553,7 @@ def apply_group_offloading(
542553
stream=stream,
543554
record_stream=record_stream,
544555
low_cpu_mem_usage=low_cpu_mem_usage,
556+
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading
545557
)
546558
else:
547559
raise ValueError(f"Unsupported offload_type: {offload_type}")
@@ -557,6 +569,7 @@ def _apply_group_offloading_block_level(
557569
record_stream: Optional[bool] = False,
558570
low_cpu_mem_usage: bool = False,
559571
offload_to_disk_path: Optional[str] = None,
572+
_enable_deepnvme_disk_offloading: Optional[bool] = False
560573
) -> None:
561574
r"""
562575
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
@@ -617,6 +630,7 @@ def _apply_group_offloading_block_level(
617630
record_stream=record_stream,
618631
low_cpu_mem_usage=low_cpu_mem_usage,
619632
onload_self=True,
633+
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading
620634
)
621635
matched_module_groups.append(group)
622636
for j in range(i, i + len(current_modules)):
@@ -651,6 +665,7 @@ def _apply_group_offloading_block_level(
651665
stream=None,
652666
record_stream=False,
653667
onload_self=True,
668+
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
654669
)
655670
if stream is None:
656671
_apply_group_offloading_hook(module, unmatched_group, None)
@@ -667,6 +682,7 @@ def _apply_group_offloading_leaf_level(
667682
record_stream: Optional[bool] = False,
668683
low_cpu_mem_usage: bool = False,
669684
offload_to_disk_path: Optional[str] = None,
685+
_enable_deepnvme_disk_offloading: Optional[bool] = False
670686
) -> None:
671687
r"""
672688
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
@@ -717,6 +733,7 @@ def _apply_group_offloading_leaf_level(
717733
record_stream=record_stream,
718734
low_cpu_mem_usage=low_cpu_mem_usage,
719735
onload_self=True,
736+
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading
720737
)
721738
_apply_group_offloading_hook(submodule, group, None)
722739
modules_with_group_offloading.add(name)
@@ -764,6 +781,7 @@ def _apply_group_offloading_leaf_level(
764781
record_stream=record_stream,
765782
low_cpu_mem_usage=low_cpu_mem_usage,
766783
onload_self=True,
784+
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading
767785
)
768786
_apply_group_offloading_hook(parent_module, group, None)
769787

@@ -785,6 +803,7 @@ def _apply_group_offloading_leaf_level(
785803
record_stream=False,
786804
low_cpu_mem_usage=low_cpu_mem_usage,
787805
onload_self=True,
806+
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading
788807
)
789808
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
790809

src/diffusers/models/modeling_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,7 @@ def enable_group_offload(
549549
record_stream: bool = False,
550550
low_cpu_mem_usage=False,
551551
offload_to_disk_path: Optional[str] = None,
552+
_enable_deepnvme_disk_offloading: Optional[bool] = False
552553
) -> None:
553554
r"""
554555
Activates group offloading for the current model.
@@ -599,6 +600,7 @@ def enable_group_offload(
599600
record_stream=record_stream,
600601
low_cpu_mem_usage=low_cpu_mem_usage,
601602
offload_to_disk_path=offload_to_disk_path,
603+
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading
602604
)
603605

604606
def save_pretrained(

src/diffusers/utils/import_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,11 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
220220
_better_profanity_available, _better_profanity_version = _is_package_available("better_profanity")
221221
_nltk_available, _nltk_version = _is_package_available("nltk")
222222
_cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail")
223+
_deepspeed_available, _deepspeed_version = _is_package_available("deepspeed")
224+
225+
226+
def is_deepspeed_available():
227+
return _deepspeed_available
223228

224229

225230
def is_torch_available():
@@ -655,6 +660,19 @@ def is_torch_version(operation: str, version: str):
655660
return compare_versions(parse(_torch_version), operation, version)
656661

657662

663+
def is_deepspeed_version(operation: str, version: str):
664+
"""
665+
Compares the current DeepSpeed version to a given reference with an operation.
666+
667+
Args:
668+
operation (`str`):
669+
A string representation of an operator, such as `">"` or `"<="`
670+
version (`str`):
671+
A string version of DeepSpeed
672+
"""
673+
return compare_versions(parse(_deepspeed_version), operation, version)
674+
675+
658676
def is_torch_xla_version(operation: str, version: str):
659677
"""
660678
Compares the current torch_xla version to a given reference with an operation.

src/diffusers/utils/state_dict_utils.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,19 @@
1818
import enum
1919
import json
2020

21-
from .import_utils import is_torch_available
21+
from .import_utils import is_deepspeed_available, is_deepspeed_version, is_torch_available
2222
from .logging import get_logger
2323

2424

2525
if is_torch_available():
2626
import torch
2727

28+
if is_deepspeed_available() and is_deepspeed_version(">", "0.16"):
29+
from deepspeed.io import FastFileWriter, FastFileWriterConfig
30+
from deepspeed.ops.op_builder import AsyncIOBuilder, GDSBuilder
31+
32+
from .deep_nvme_utils import save as _nvme_save
33+
2834

2935
logger = get_logger(__name__)
3036

@@ -364,3 +370,53 @@ def _load_sft_state_dict_metadata(model_file: str):
364370
return json.loads(raw) if raw else None
365371
else:
366372
return None
373+
374+
375+
# Utilities below are taken from
376+
# https://github.com/deepspeedai/DeepSpeedExamples/blob/28a984e77b8d096dadc6389b6d1440b823587e28/deepnvme/model_checkpoint/torch_save_utils.py#L16
377+
def _load_io_ops(args):
378+
if AsyncIOBuilder().is_compatible():
379+
AsyncIOBuilder().load(verbose=False)
380+
if args.gpu and GDSBuilder().is_compatible():
381+
GDSBuilder().load(verbose=False)
382+
383+
384+
def _get_aio_handle():
385+
AIO_QUEUE_DEPTH = 8
386+
AIO_BLOCK_SIZE = 8 * (1024**2)
387+
AIO_INTRA_OP_PARALLEL = 1
388+
AIO_SINGLE_SUBMIT = False
389+
390+
h = (
391+
AsyncIOBuilder()
392+
.load(verbose=False)
393+
.aio_handle(
394+
block_size=AIO_BLOCK_SIZE,
395+
queue_depth=AIO_QUEUE_DEPTH,
396+
single_submit=AIO_SINGLE_SUBMIT,
397+
overlap_events=AIO_SINGLE_SUBMIT,
398+
intra_op_parallelism=AIO_INTRA_OP_PARALLEL,
399+
)
400+
)
401+
return h
402+
403+
404+
def _get_aio_components():
405+
PINNED_BUFFER_MB = 64
406+
h = _get_aio_handle()
407+
pinned_memory = torch.zeros(PINNED_BUFFER_MB * (1024**2), dtype=torch.uint8, device="cpu").pin_memory()
408+
return h, pinned_memory
409+
410+
411+
def _fast_aio_save(buffer, file, single_io_buffer=False):
412+
h, pinned_memory = _get_aio_components()
413+
fast_writer_config = FastFileWriterConfig(
414+
dnvme_handle=h,
415+
pinned_tensor=pinned_memory,
416+
double_buffer=not single_io_buffer,
417+
num_parallel_writers=1,
418+
writer_rank=0,
419+
)
420+
421+
ds_fast_writer = FastFileWriter(file_path=file, config=fast_writer_config)
422+
_nvme_save(f=ds_fast_writer, obj=buffer, _use_new_zipfile_serialization=False)

0 commit comments

Comments
 (0)