Skip to content

Commit 39005cf

Browse files
authored
Merge pull request #62 from jkoelker/jk/refactor
feat(device): centralize device management with DevicePolicy class
2 parents 950d854 + faaaf2c commit 39005cf

File tree

13 files changed

+584
-262
lines changed

13 files changed

+584
-262
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,6 @@ Thumbs.db
9292

9393
# uv
9494
uv.lock
95+
96+
.beads
97+
.gitattributes

AGENTS.md

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ Discord bot for image generation with Huggingface Diffusers.
66

77
```bash
88
uv pip install -e ".[dev]" # Install dev dependencies
9-
uv run --extra dev pytest -v # Run all tests
10-
uv run --extra dev pytest tests/test_config.py -v # Single file
9+
uv run --extra dev pytest # Run all tests
10+
uv run --extra dev pytest tests/test_config.py # Single file
1111
ruff check src/ --fix # Lint + auto-fix
1212
ruff format src/ # Format
1313
```
@@ -137,3 +137,20 @@ select = ["E", "W", "F", "I", "B", "C4", "UP"]
137137
[tool.ruff.lint.isort]
138138
known-first-party = ["oneiro"]
139139
```
140+
141+
## Git Workflow
142+
143+
- **Commits allowed** on atomic work units (single logical change)
144+
- **Never push** - leave pushing to the user
145+
- **Never `git add .`** - only stage specific files needed for the commit
146+
147+
```bash
148+
# CORRECT: Stage specific files
149+
git add src/oneiro/config.py tests/test_config.py
150+
git commit -m "Add config hot reload support"
151+
152+
# WRONG: Never do this
153+
git add .
154+
git add -A
155+
git push
156+
```

src/oneiro/device.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""Device management for pipeline placement."""
2+
3+
from dataclasses import dataclass
4+
from enum import Enum
5+
6+
import torch
7+
8+
9+
class OffloadMode(str, Enum):
10+
"""CPU offload behavior for CUDA pipelines."""
11+
12+
AUTO = "auto" # Offload if CUDA available (default)
13+
ALWAYS = "always" # Require offload (error if not CUDA)
14+
NEVER = "never" # Never offload, use .to(device)
15+
16+
17+
@dataclass(frozen=True)
18+
class DevicePolicy:
19+
"""Immutable device configuration for pipeline placement.
20+
21+
Attributes:
22+
device: Target device ("cuda", "mps", "cpu")
23+
dtype: Torch dtype for model weights
24+
offload: CPU offload behavior for large models
25+
"""
26+
27+
device: str
28+
dtype: torch.dtype
29+
offload: OffloadMode = OffloadMode.AUTO
30+
31+
@classmethod
32+
def auto_detect(cls, cpu_offload: bool = True) -> "DevicePolicy":
33+
"""Create policy with auto-detected device and dtype.
34+
35+
Args:
36+
cpu_offload: Enable CPU offloading when available (default: True)
37+
38+
Returns:
39+
DevicePolicy configured for the best available device
40+
"""
41+
if torch.cuda.is_available():
42+
device = "cuda"
43+
# Use bfloat16 only if supported, else float16
44+
if torch.cuda.is_bf16_supported():
45+
dtype = torch.bfloat16
46+
else:
47+
dtype = torch.float16
48+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
49+
device = "mps"
50+
dtype = torch.float32 # MPS works best with float32
51+
else:
52+
device = "cpu"
53+
dtype = torch.float32
54+
55+
offload = OffloadMode.AUTO if cpu_offload else OffloadMode.NEVER
56+
return cls(device=device, dtype=dtype, offload=offload)
57+
58+
def apply_to_pipeline(self, pipe) -> None:
59+
"""Apply device policy to a diffusers pipeline.
60+
61+
Args:
62+
pipe: A diffusers pipeline instance
63+
64+
Raises:
65+
ValueError: If offload=ALWAYS but device is not CUDA
66+
"""
67+
should_offload = self.offload == OffloadMode.ALWAYS or (
68+
self.offload == OffloadMode.AUTO and self.device == "cuda"
69+
)
70+
71+
if should_offload:
72+
if self.device != "cuda":
73+
raise ValueError(
74+
f"CPU offload requires CUDA device, got '{self.device}'. "
75+
f"Set cpu_offload=false in config or use a CUDA-enabled system."
76+
)
77+
pipe.enable_model_cpu_offload()
78+
elif self.device != "cpu":
79+
pipe.to(self.device)
80+
# CPU: no action needed, pipeline stays on CPU
81+
82+
@staticmethod
83+
def clear_cache() -> None:
84+
"""Clear device memory cache if applicable."""
85+
if torch.cuda.is_available():
86+
torch.cuda.synchronize()
87+
torch.cuda.empty_cache()
88+
elif (
89+
hasattr(torch.backends, "mps")
90+
and torch.backends.mps.is_available()
91+
and hasattr(torch.mps, "empty_cache")
92+
):
93+
torch.mps.empty_cache()

src/oneiro/pipelines/base.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import torch
1212
from PIL import Image
1313

14+
from oneiro.device import DevicePolicy
15+
1416

1517
@dataclass
1618
class GenerationResult:
@@ -29,10 +31,9 @@ class GenerationResult:
2931
class BasePipeline(ABC):
3032
"""Base class for all pipeline types."""
3133

32-
def __init__(self):
34+
def __init__(self) -> None:
3335
self.pipe: Any = None
34-
self._device = "cuda" if torch.cuda.is_available() else "cpu"
35-
self._dtype = torch.bfloat16 if self._device == "cuda" else torch.float32
36+
self.policy: DevicePolicy = DevicePolicy.auto_detect()
3637

3738
@abstractmethod
3839
def load(self, model_config: dict[str, Any], full_config: dict[str, Any] | None = None) -> None:
@@ -68,11 +69,8 @@ def unload(self) -> None:
6869
del self.pipe
6970
self.pipe = None
7071

71-
# Aggressive cleanup
7272
gc.collect()
73-
if torch.cuda.is_available():
74-
torch.cuda.empty_cache()
75-
torch.cuda.synchronize()
73+
DevicePolicy.clear_cache()
7674

7775
def _prepare_seed(self, seed: int) -> tuple[int, torch.Generator]:
7876
"""Prepare seed and generator for generation."""

src/oneiro/pipelines/civitai_checkpoint.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
from pathlib import Path
1111
from typing import TYPE_CHECKING, Any
1212

13-
import torch
14-
13+
from oneiro.device import DevicePolicy, OffloadMode
1514
from oneiro.pipelines.base import BasePipeline, GenerationResult
1615
from oneiro.pipelines.embedding import EmbeddingLoaderMixin, parse_embeddings_from_config
1716
from oneiro.pipelines.long_prompt import (
@@ -619,25 +618,26 @@ def _load_from_path(self, checkpoint_path: Path, model_config: dict[str, Any]) -
619618
print(f" Base model: {base_model or 'unknown'}")
620619
print(f" Pipeline: {self._pipeline_config.pipeline_class}")
621620

621+
cpu_offload = model_config.get("cpu_offload", True)
622+
self.policy = DevicePolicy.auto_detect(cpu_offload=cpu_offload)
623+
622624
# Get the pipeline class
623625
pipeline_class = get_diffusers_pipeline_class(self._pipeline_config.pipeline_class)
624626

625627
# Load from single file
626628
self.pipe = pipeline_class.from_single_file(
627629
str(checkpoint_path),
628-
torch_dtype=self._dtype,
630+
torch_dtype=self.policy.dtype,
629631
)
630632

631633
scheduler_override = model_config.get("scheduler")
632634
self.configure_scheduler(scheduler_override)
633635

634-
# Apply optimizations
635-
cpu_offload = model_config.get("cpu_offload", True)
636-
self._cpu_offload = cpu_offload and self._device == "cuda"
637-
if self._cpu_offload:
638-
self.pipe.enable_model_cpu_offload()
639-
elif self._device == "cuda":
640-
self.pipe.to("cuda")
636+
self.policy.apply_to_pipeline(self.pipe)
637+
# Track whether offload was applied (for dynamic LoRA handling)
638+
self._cpu_offload = (
639+
self.policy.offload != OffloadMode.NEVER and self.policy.device == "cuda"
640+
)
641641

642642
# Enable memory optimizations for VAE if available
643643
if hasattr(self.pipe, "vae"):
@@ -859,8 +859,7 @@ def _run_generation(
859859

860860
result = self.pipe(**gen_kwargs)
861861

862-
if torch.cuda.is_available():
863-
torch.cuda.empty_cache()
862+
DevicePolicy.clear_cache()
864863

865864
output_image = result.images[0]
866865
return GenerationResult(
@@ -887,7 +886,7 @@ def _load_dynamic_loras(self, loras: list[LoraConfig]) -> None:
887886
# Only move pipeline to device manually when CPU offload is not enabled.
888887
# With CPU offload, diffusers manages device placement automatically.
889888
if not self._cpu_offload:
890-
self.pipe.to(self._device)
889+
self.pipe.to(self.policy.device)
891890

892891
loaded_names: list[str] = []
893892
loaded_weights: list[float] = []
@@ -911,7 +910,7 @@ def _restore_static_loras(self) -> None:
911910
return
912911

913912
if not self._cpu_offload:
914-
self.pipe.to(self._device)
913+
self.pipe.to(self.policy.device)
915914
self.load_loras_sync(self._static_lora_configs)
916915
print(f"Restored {len(self._static_lora_configs)} static LoRA(s)")
917916

src/oneiro/pipelines/flux1.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
from typing import Any
44

5-
import torch
6-
5+
from oneiro.device import DevicePolicy
76
from oneiro.pipelines.base import BasePipeline, GenerationResult
87

98

@@ -49,14 +48,15 @@ def load(self, model_config: dict[str, Any]) -> None:
4948
# Configure CPU threading for text encoder
5049
self._configure_cpu_threads(cpu_utilization)
5150

51+
self.policy = DevicePolicy.auto_detect(cpu_offload=cpu_offload)
52+
5253
print(" Creating pipeline...")
5354
self.pipe = FluxPipeline.from_pretrained(
5455
repo,
55-
torch_dtype=self._dtype,
56+
torch_dtype=self.policy.dtype,
5657
)
5758

58-
if cpu_offload:
59-
self.pipe.enable_model_cpu_offload()
59+
self.policy.apply_to_pipeline(self.pipe)
6060

6161
# Memory optimization for large T5 encoder and high-res VAE decoding
6262
self.pipe.vae.enable_tiling()
@@ -130,8 +130,7 @@ def generate(
130130
max_sequence_length=512,
131131
)
132132

133-
if torch.cuda.is_available():
134-
torch.cuda.empty_cache()
133+
DevicePolicy.clear_cache()
135134

136135
output_image = result.images[0]
137136
return GenerationResult(

src/oneiro/pipelines/flux2.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
from typing import Any
44

5-
import torch
6-
5+
from oneiro.device import DevicePolicy
76
from oneiro.pipelines.base import BasePipeline, GenerationResult
87
from oneiro.pipelines.embedding import EmbeddingLoaderMixin, parse_embeddings_from_config
98
from oneiro.pipelines.lora import LoraLoaderMixin, parse_loras_from_model_config
@@ -31,20 +30,22 @@ def load(self, model_config: dict[str, Any], full_config: dict[str, Any] | None
3130
# Configure CPU threading for text encoder
3231
self._configure_cpu_threads(cpu_utilization)
3332

33+
self.policy = DevicePolicy.auto_detect(cpu_offload=cpu_offload)
34+
3435
# Load transformer and text encoder on CPU first
3536
print(" Loading transformer on CPU...")
3637
transformer = Flux2Transformer2DModel.from_pretrained(
3738
repo,
3839
subfolder="transformer",
39-
torch_dtype=self._dtype,
40+
torch_dtype=self.policy.dtype,
4041
device_map="cpu",
4142
)
4243

4344
print(" Loading text encoder on CPU...")
4445
text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
4546
repo,
4647
subfolder="text_encoder",
47-
torch_dtype=self._dtype,
48+
torch_dtype=self.policy.dtype,
4849
device_map="cpu",
4950
)
5051

@@ -53,11 +54,10 @@ def load(self, model_config: dict[str, Any], full_config: dict[str, Any] | None
5354
repo,
5455
transformer=transformer,
5556
text_encoder=text_encoder,
56-
torch_dtype=self._dtype,
57+
torch_dtype=self.policy.dtype,
5758
)
5859

59-
if cpu_offload:
60-
self.pipe.enable_model_cpu_offload()
60+
self.policy.apply_to_pipeline(self.pipe)
6161

6262
loras = parse_loras_from_model_config(model_config)
6363
if loras:
@@ -115,8 +115,7 @@ def generate(
115115
generator=generator,
116116
)
117117

118-
if torch.cuda.is_available():
119-
torch.cuda.empty_cache()
118+
DevicePolicy.clear_cache()
120119

121120
output_image = result.images[0]
122121
return GenerationResult(

0 commit comments

Comments
 (0)