Skip to content

Commit 1359348

Browse files
committed
update
1 parent ace698a commit 1359348

File tree

1 file changed

+21
-23
lines changed

1 file changed

+21
-23
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,7 @@ def _pinned_memory_tensors(self):
135135
finally:
136136
pinned_dict = None
137137

138-
def _transfer_tensor_to_device(self, tensor, source_tensor=None, current_stream=None):
139-
if source_tensor is None:
140-
source_tensor = tensor
138+
def _transfer_tensor_to_device(self, tensor, source_tensor, current_stream=None):
141139
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
142140
if self.record_stream and current_stream is not None:
143141
tensor.data.record_stream(current_stream)
@@ -159,26 +157,6 @@ def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None)
159157
source = pinned_memory[buffer] if pinned_memory else buffer.data
160158
self._transfer_tensor_to_device(buffer, source, current_stream)
161159

162-
@torch.compiler.disable()
163-
def onload_(self):
164-
torch_accelerator_module = (
165-
getattr(torch, torch.accelerator.current_accelerator().type)
166-
if hasattr(torch, "accelerator")
167-
else torch.cuda
168-
)
169-
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
170-
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
171-
172-
if self.stream is not None:
173-
# Wait for previous Host->Device transfer to complete
174-
self.stream.synchronize()
175-
176-
with context:
177-
if self.offload_to_disk_path:
178-
self._onload_from_disk(current_stream)
179-
else:
180-
self._onload_from_memory(current_stream)
181-
182160
def _onload_from_disk(self, current_stream):
183161
if self.stream is not None:
184162
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
@@ -207,6 +185,26 @@ def _onload_from_memory(self, current_stream):
207185
else:
208186
self._process_tensors_from_modules(None, current_stream)
209187

188+
@torch.compiler.disable()
189+
def onload_(self):
190+
torch_accelerator_module = (
191+
getattr(torch, torch.accelerator.current_accelerator().type)
192+
if hasattr(torch, "accelerator")
193+
else torch.cuda
194+
)
195+
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
196+
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
197+
198+
if self.stream is not None:
199+
# Wait for previous Host->Device transfer to complete
200+
self.stream.synchronize()
201+
202+
with context:
203+
if self.offload_to_disk_path:
204+
self._onload_from_disk(current_stream)
205+
else:
206+
self._onload_from_memory(current_stream)
207+
210208
@torch.compiler.disable()
211209
def _offload_to_disk(self):
212210
# TODO: we can potentially optimize this code path by checking if the _all_ the desired

0 commit comments

Comments
 (0)