Skip to content

Commit af9b398

Browse files
pin memory bug fix (#141)
1 parent aeeeb5b commit af9b398

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

diffsynth_engine/utils/offload.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import torch.nn as nn
33
from typing import Dict
4-
4+
import platform
55

66
def enable_sequential_cpu_offload(module: nn.Module, device: str = "cuda"):
77
module = module.to("cpu")
@@ -26,13 +26,14 @@ def _forward_pre_hook(module: nn.Module, input_):
2626
for name, buffer in module.named_buffers(recurse=recurse):
2727
buffer.data = buffer.data.to(device=device)
2828
return tuple(x.to(device=device) if isinstance(x, torch.Tensor) else x for x in input_)
29-
30-
for name, param in module.named_parameters(recurse=recurse):
31-
param.data = param.data.pin_memory()
29+
for name, param in module.named_parameters(recurse=recurse):
30+
if platform.system() == 'Linux':
31+
param.data = param.data.pin_memory()
3232
offload_param_dict[name] = param.data
3333
param.data = param.data.to(device=device)
3434
for name, buffer in module.named_buffers(recurse=recurse):
35-
buffer.data = buffer.data.pin_memory()
35+
if platform.system() == 'Linux':
36+
buffer.data = buffer.data.pin_memory()
3637
offload_param_dict[name] = buffer.data
3738
buffer.data = buffer.data.to(device=device)
3839
setattr(module, "_offload_param_dict", offload_param_dict)
@@ -58,10 +59,12 @@ def offload_model_to_dict(module: nn.Module) -> Dict[str, torch.Tensor]:
5859
module = module.to("cpu")
5960
offload_param_dict = {}
6061
for name, param in module.named_parameters(recurse=True):
61-
param.data = param.data.pin_memory()
62+
if platform.system() == 'Linux':
63+
param.data = param.data.pin_memory()
6264
offload_param_dict[name] = param.data
6365
for name, buffer in module.named_buffers(recurse=True):
64-
buffer.data = buffer.data.pin_memory()
66+
if platform.system() == 'Linux':
67+
buffer.data = buffer.data.pin_memory()
6568
offload_param_dict[name] = buffer.data
6669
return offload_param_dict
6770

0 commit comments

Comments
 (0)