Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def __init__(
onload_self: bool = True,
offload_to_disk_path: Optional[str] = None,
) -> None:
if stream is None and record_stream:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove this completely, no? It's not supposed to be user facing and now apply_group_offloading handles the error case too

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can but since ModuleGroup is a public class I thought of keeping it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really, it's not a public class - no user is expected to import it or that we document it that way. We don't necessarily need to mark every internal data structure private to imply it shouldn't be used by users

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

raise ValueError("`record_stream` cannot be True when `stream` is None.")

self.modules = modules
self.offload_device = offload_device
self.onload_device = onload_device
Expand Down Expand Up @@ -96,9 +99,6 @@ def __init__(
else:
self.cpu_param_dict = self._init_cpu_param_dict()

if self.stream is None and self.record_stream:
raise ValueError("`record_stream` cannot be True when `stream` is None.")

def _init_cpu_param_dict(self):
cpu_param_dict = {}
if self.stream is None:
Expand Down Expand Up @@ -513,6 +513,9 @@ def apply_group_offloading(
else:
raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.")

if not use_stream and record_stream:
raise ValueError("`record_stream` cannot be True when `use_stream=False`.")

_raise_error_if_accelerate_model_or_sequential_hook_present(module)

if offload_type == "block_level":
Expand Down