Skip to content

Commit 080cc30

Browse files
authored
Merge pull request #64 from jkoelker/jk/refactor
refactor(pipelines): generate() templates
2 parents ab38df4 + f0a16f0 commit 080cc30

File tree

11 files changed

+1275
-247
lines changed

11 files changed

+1275
-247
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,4 @@ uv.lock
9595

9696
.beads
9797
.gitattributes
98+
.beads/artifacts/

src/oneiro/pipelines/base.py

Lines changed: 136 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def load(self, model_config: dict[str, Any], full_config: dict[str, Any] | None
4444
full_config: Full configuration dict (for accessing global sections like embeddings)
4545
"""
4646

47-
@abstractmethod
4847
def generate(
4948
self,
5049
prompt: str,
@@ -56,7 +55,142 @@ def generate(
5655
guidance_scale: float = 0.0,
5756
**kwargs: Any,
5857
) -> GenerationResult:
59-
"""Generate an image."""
58+
"""Generate an image using Template Method pattern.
59+
60+
Subclasses should override hooks (validate_pipeline, pre_generate,
61+
build_generation_kwargs, run_inference, build_result, post_generate),
62+
not this method.
63+
"""
64+
self.validate_pipeline()
65+
self.pre_generate(**kwargs)
66+
try:
67+
actual_seed, generator = self._prepare_seed(seed)
68+
# Pop init_image and strength from kwargs to avoid passing twice
69+
init_image = self._load_init_image(kwargs.pop("init_image", None))
70+
strength = kwargs.pop("strength", 0.75)
71+
72+
gen_kwargs = self.build_generation_kwargs(
73+
prompt=prompt,
74+
negative_prompt=negative_prompt,
75+
width=width,
76+
height=height,
77+
steps=steps,
78+
guidance_scale=guidance_scale,
79+
generator=generator,
80+
init_image=init_image,
81+
strength=strength,
82+
**kwargs,
83+
)
84+
is_img2img = init_image is not None
85+
result = self.run_inference(gen_kwargs, is_img2img)
86+
87+
DevicePolicy.clear_cache()
88+
return self.build_result(
89+
result=result,
90+
seed=actual_seed,
91+
prompt=prompt,
92+
negative_prompt=negative_prompt,
93+
steps=steps,
94+
guidance_scale=guidance_scale,
95+
)
96+
finally:
97+
self.post_generate(**kwargs)
98+
99+
def validate_pipeline(self) -> None:
100+
"""Validate pipeline is ready for generation.
101+
102+
This is called before pre_generate() and before any kwargs are consumed.
103+
Override for additional validation checks (e.g., config state validation).
104+
"""
105+
if self.pipe is None:
106+
raise RuntimeError("Pipeline not loaded")
107+
108+
def pre_generate(self, **kwargs: Any) -> None: # noqa: B027
109+
"""Pre-generation hook called before building kwargs.
110+
111+
Override for scheduler/LoRA setup or other pre-processing.
112+
This is an optional hook with a no-op default; it is intentionally
113+
not abstract so subclasses can choose whether to implement it.
114+
115+
Note: This method may pop keys from kwargs to consume them before
116+
build_generation_kwargs() is called. The modified kwargs are then
117+
passed through to build_generation_kwargs() and post_generate().
118+
"""
119+
pass
120+
121+
@abstractmethod
122+
def build_generation_kwargs(
123+
self,
124+
prompt: str,
125+
negative_prompt: str | None,
126+
width: int,
127+
height: int,
128+
steps: int,
129+
guidance_scale: float,
130+
generator: torch.Generator,
131+
init_image: Image.Image | None,
132+
strength: float,
133+
**kwargs: Any,
134+
) -> dict[str, Any]:
135+
"""Build pipeline-specific generation kwargs.
136+
137+
This is the REQUIRED hook that each subclass must implement.
138+
Return a dict to be passed to self.pipe().
139+
"""
140+
141+
def run_inference(self, gen_kwargs: dict[str, Any], is_img2img: bool) -> Any:
142+
"""Run the diffusion pipeline.
143+
144+
Args:
145+
gen_kwargs: Keyword arguments to pass to the underlying pipeline.
146+
is_img2img: Whether this is an image-to-image generation. This flag
147+
is not used by the base implementation but is provided for
148+
subclasses that need to branch on img2img vs txt2img behavior.
149+
150+
Returns:
151+
Pipeline output (typically has .images attribute).
152+
153+
Override if the pipeline call signature or behavior differs.
154+
"""
155+
return self.pipe(**gen_kwargs)
156+
157+
def build_result(
158+
self,
159+
result: Any,
160+
seed: int,
161+
prompt: str,
162+
negative_prompt: str | None,
163+
steps: int,
164+
guidance_scale: float,
165+
) -> GenerationResult:
166+
"""Build GenerationResult from pipeline output.
167+
168+
Override if result format differs.
169+
"""
170+
output_image = result.images[0]
171+
return GenerationResult(
172+
image=output_image,
173+
seed=seed,
174+
prompt=prompt,
175+
negative_prompt=negative_prompt,
176+
width=output_image.width,
177+
height=output_image.height,
178+
steps=steps,
179+
guidance_scale=guidance_scale,
180+
)
181+
182+
def post_generate(self, **kwargs: Any) -> None: # noqa: B027
183+
"""Post-generation cleanup hook called after generation completes.
184+
185+
Override for LoRA restore or other cleanup. This is an optional hook
186+
with a no-op default; it is intentionally not abstract so subclasses
187+
can choose whether to implement it.
188+
189+
Note: The kwargs passed here have already had 'init_image' and 'strength'
190+
removed by generate(). If a subclass needs access to these values,
191+
it should save them in pre_generate() before they are consumed.
192+
"""
193+
pass
60194

61195
def unload(self) -> None:
62196
"""Free GPU memory."""

src/oneiro/pipelines/civitai_checkpoint.py

Lines changed: 55 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from pathlib import Path
1111
from typing import TYPE_CHECKING, Any
1212

13+
import torch
14+
from PIL import Image
15+
1316
from oneiro.device import DevicePolicy, OffloadMode
1417
from oneiro.pipelines.base import BasePipeline, GenerationResult
1518
from oneiro.pipelines.embedding import EmbeddingLoaderMixin, parse_embeddings_from_config
@@ -519,6 +522,7 @@ def __init__(self) -> None:
519522
self._current_scheduler: str | None = None
520523
self._static_lora_configs: list[LoraConfig] = []
521524
self._cpu_offload: bool = False
525+
self._has_dynamic_loras: bool = False
522526

523527
def load(self, model_config: dict[str, Any], full_config: dict[str, Any] | None = None) -> None:
524528
"""Load checkpoint from config (synchronous, requires checkpoint_path).
@@ -692,6 +696,12 @@ def configure_scheduler(self, scheduler_name: str | None) -> None:
692696
self._current_scheduler = scheduler_name
693697
print(f" Scheduler: {scheduler_name}")
694698

699+
def validate_pipeline(self) -> None:
700+
"""Validate pipeline and config are ready for generation."""
701+
super().validate_pipeline()
702+
if self._pipeline_config is None:
703+
raise RuntimeError("Pipeline config not initialized")
704+
695705
def generate(
696706
self,
697707
prompt: str,
@@ -762,84 +772,70 @@ def generate(
762772
as the actual seed used, prompts, final image size, number of
763773
steps, and guidance scale.
764774
"""
765-
if self.pipe is None:
766-
raise RuntimeError("Pipeline not loaded")
775+
# Apply defaults from pipeline config (validation happens in super().generate())
776+
# Note: We need to check _pipeline_config here before applying defaults,
777+
# but full validation happens in validate_pipeline() called by super()
778+
if self._pipeline_config is not None:
779+
width = width if width is not None else self._pipeline_config.default_width
780+
height = height if height is not None else self._pipeline_config.default_height
781+
steps = steps if steps is not None else self._pipeline_config.default_steps
782+
guidance_scale = (
783+
guidance_scale
784+
if guidance_scale is not None
785+
else self._pipeline_config.default_guidance_scale
786+
)
767787

768-
if self._pipeline_config is None:
769-
raise RuntimeError("Pipeline config not initialized")
788+
return super().generate(
789+
prompt=prompt,
790+
negative_prompt=negative_prompt,
791+
width=width or 1024,
792+
height=height or 1024,
793+
seed=seed,
794+
steps=steps or 20,
795+
guidance_scale=guidance_scale if guidance_scale is not None else 7.0,
796+
**kwargs,
797+
)
770798

799+
def pre_generate(self, **kwargs: Any) -> None:
800+
"""Pre-generation setup: scheduler override and dynamic LoRA loading."""
771801
scheduler_override = kwargs.pop("scheduler", None)
772802
if scheduler_override:
773803
self.configure_scheduler(scheduler_override)
774804

775805
dynamic_loras = kwargs.pop("loras", None)
776-
has_dynamic_loras = False
806+
self._has_dynamic_loras = False
777807
if dynamic_loras:
778-
# Mark that we are entering a dynamic LoRA context before loading,
779-
# so that failures during loading can be properly rolled back.
780-
has_dynamic_loras = True
808+
self._has_dynamic_loras = True
781809
try:
782810
self._load_dynamic_loras(dynamic_loras)
783811
except Exception:
784-
# If loading dynamic LoRAs fails after modifying the pipeline
785-
# state (for example, after unloading static LoRAs), attempt
786-
# to restore the original static LoRAs before propagating
787-
# the error.
788812
self._restore_static_loras()
789-
has_dynamic_loras = False
813+
self._has_dynamic_loras = False
790814
raise
791815

792-
try:
793-
return self._run_generation(
794-
prompt=prompt,
795-
negative_prompt=negative_prompt,
796-
width=width,
797-
height=height,
798-
seed=seed,
799-
steps=steps,
800-
guidance_scale=guidance_scale,
801-
**kwargs,
802-
)
803-
finally:
804-
if has_dynamic_loras:
805-
self._restore_static_loras()
806-
807-
def _run_generation(
816+
def build_generation_kwargs(
808817
self,
809818
prompt: str,
810819
negative_prompt: str | None,
811-
width: int | None,
812-
height: int | None,
813-
seed: int,
814-
steps: int | None,
815-
guidance_scale: float | None,
820+
width: int,
821+
height: int,
822+
steps: int,
823+
guidance_scale: float,
824+
generator: "torch.Generator",
825+
init_image: "Image.Image | None",
826+
strength: float,
816827
**kwargs: Any,
817-
) -> GenerationResult:
828+
) -> dict[str, Any]:
829+
"""Build generation kwargs with embedding support."""
818830
assert self._pipeline_config is not None
819-
width = width or self._pipeline_config.default_width
820-
height = height or self._pipeline_config.default_height
821-
steps = steps or self._pipeline_config.default_steps
822-
guidance_scale = (
823-
guidance_scale
824-
if guidance_scale is not None
825-
else self._pipeline_config.default_guidance_scale
826-
)
827-
828-
actual_seed, generator = self._prepare_seed(seed)
829831

830-
# Handle img2img
831-
init_image = self._load_init_image(kwargs.get("init_image"))
832-
strength = kwargs.get("strength", 0.75)
833-
834-
# Build generation kwargs
835832
gen_kwargs: dict[str, Any] = {
836833
"num_inference_steps": steps,
837834
"guidance_scale": guidance_scale,
838835
"generator": generator,
839836
}
840837

841838
# Use embedding-based prompt handling for pipelines that support it
842-
# (SD 1.x, SD 2.x, SDXL, Flux, SD3) - enables weight syntax like (word:1.5)
843839
if self._supports_prompt_embeddings():
844840
gen_kwargs.update(self._encode_prompts_to_embeddings(prompt, negative_prompt))
845841
else:
@@ -849,29 +845,21 @@ def _run_generation(
849845
gen_kwargs["negative_prompt"] = negative_prompt
850846

851847
if init_image:
852-
print(f"CivitAI img2img: '{prompt[:50]}...' seed={actual_seed} strength={strength}")
848+
print(f"CivitAI img2img: '{prompt[:50]}...' strength={strength}")
853849
gen_kwargs["image"] = init_image
854850
gen_kwargs["strength"] = strength
855851
else:
856-
print(f"CivitAI generating: '{prompt[:50]}...' seed={actual_seed}")
852+
print(f"CivitAI generating: '{prompt[:50]}...'")
857853
gen_kwargs["height"] = height
858854
gen_kwargs["width"] = width
859855

860-
result = self.pipe(**gen_kwargs)
861-
862-
DevicePolicy.clear_cache()
856+
return gen_kwargs
863857

864-
output_image = result.images[0]
865-
return GenerationResult(
866-
image=output_image,
867-
seed=actual_seed,
868-
prompt=prompt,
869-
negative_prompt=negative_prompt,
870-
width=output_image.width,
871-
height=output_image.height,
872-
steps=steps,
873-
guidance_scale=guidance_scale,
874-
)
858+
def post_generate(self, **kwargs: Any) -> None:
859+
"""Post-generation cleanup: restore static LoRAs if dynamic were used."""
860+
if self._has_dynamic_loras:
861+
self._restore_static_loras()
862+
self._has_dynamic_loras = False
875863

876864
def _load_dynamic_loras(self, loras: list[LoraConfig]) -> None:
877865
if self.pipe is None or not loras:

0 commit comments

Comments
 (0)