Skip to content

Commit c6d61fa

Browse files
committed
update
1 parent 15f98db commit c6d61fa

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -159,27 +159,27 @@ def _pinned_memory_tensors(self):
159159
finally:
160160
pinned_dict = None
161161

162-
def _transfer_tensor_to_device(self, tensor, source_tensor, current_stream=None):
162+
def _transfer_tensor_to_device(self, tensor, source_tensor):
163163
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
164-
if self.record_stream and current_stream is not None:
165-
tensor.data.record_stream(current_stream)
164+
if self.record_stream:
165+
tensor.data.record_stream(self._torch_accelerator_module.current_stream())
166166

167-
def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None):
167+
def _process_tensors_from_modules(self, pinned_memory=None):
168168
for group_module in self.modules:
169169
for param in group_module.parameters():
170170
source = pinned_memory[param] if pinned_memory else param.data
171-
self._transfer_tensor_to_device(param, source, current_stream)
171+
self._transfer_tensor_to_device(param, source)
172172
for buffer in group_module.buffers():
173173
source = pinned_memory[buffer] if pinned_memory else buffer.data
174-
self._transfer_tensor_to_device(buffer, source, current_stream)
174+
self._transfer_tensor_to_device(buffer, source)
175175

176176
for param in self.parameters:
177177
source = pinned_memory[param] if pinned_memory else param.data
178-
self._transfer_tensor_to_device(param, source, current_stream)
178+
self._transfer_tensor_to_device(param, source)
179179

180180
for buffer in self.buffers:
181181
source = pinned_memory[buffer] if pinned_memory else buffer.data
182-
self._transfer_tensor_to_device(buffer, source, current_stream)
182+
self._transfer_tensor_to_device(buffer, source)
183183

184184
def _onload_from_disk(self):
185185
if self.stream is not None:
@@ -214,14 +214,12 @@ def _onload_from_memory(self):
214214
self.stream.synchronize()
215215

216216
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
217-
current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None
218-
219217
with context:
220218
if self.stream is not None:
221219
with self._pinned_memory_tensors() as pinned_memory:
222-
self._process_tensors_from_modules(pinned_memory, current_stream)
220+
self._process_tensors_from_modules(pinned_memory)
223221
else:
224-
self._process_tensors_from_modules(None, current_stream)
222+
self._process_tensors_from_modules(None)
225223

226224
def _offload_to_disk(self):
227225
# TODO: we can potentially optimize this code path by checking if the _all_ the desired

0 commit comments

Comments
 (0)