Skip to content

Commit 62aa064

Browse files
committed
Handle seamless in modular denoise
1 parent 7c975f0 commit 62aa064

File tree

2 files changed

+80
-0
lines changed

2 files changed

+80
-0
lines changed

invokeai/app/invocations/denoise_latents.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
6363
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
6464
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
65+
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
6566
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
6667
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
6768
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
@@ -833,6 +834,10 @@ def step_callback(state: PipelineIntermediateState) -> None:
833834
if self.unet.freeu_config:
834835
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
835836

837+
### seamless
838+
if self.unet.seamless_axes:
839+
ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes))
840+
836841
# context for loading additional models
837842
with ExitStack() as exit_stack:
838843
# later should be smth like:
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from __future__ import annotations
2+
3+
from contextlib import contextmanager
4+
from typing import Callable, Dict, List, Optional, Tuple
5+
6+
import torch
7+
import torch.nn as nn
8+
from diffusers import UNet2DConditionModel
9+
from diffusers.models.lora import LoRACompatibleConv
10+
11+
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
12+
13+
14+
class SeamlessExt(ExtensionBase):
15+
def __init__(
16+
self,
17+
seamless_axes: List[str],
18+
):
19+
super().__init__()
20+
self._seamless_axes = seamless_axes
21+
22+
@contextmanager
23+
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
24+
with self.static_patch_model(
25+
model=unet,
26+
seamless_axes=self._seamless_axes,
27+
):
28+
yield
29+
30+
@staticmethod
31+
@contextmanager
32+
def static_patch_model(
33+
model: torch.nn.Module,
34+
seamless_axes: List[str],
35+
):
36+
if not seamless_axes:
37+
yield
38+
return
39+
40+
# override conv_forward
41+
# https://github.com/huggingface/diffusers/issues/556#issuecomment-1993287019
42+
def _conv_forward_asymmetric(
43+
self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
44+
):
45+
self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
46+
self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
47+
working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
48+
working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
49+
return torch.nn.functional.conv2d(
50+
working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups
51+
)
52+
53+
original_layers: List[Tuple[nn.Conv2d, Callable]] = []
54+
55+
try:
56+
x_mode = "circular" if "x" in seamless_axes else "constant"
57+
y_mode = "circular" if "y" in seamless_axes else "constant"
58+
59+
conv_layers: List[torch.nn.Conv2d] = []
60+
61+
for module in model.modules():
62+
if isinstance(module, torch.nn.Conv2d):
63+
conv_layers.append(module)
64+
65+
for layer in conv_layers:
66+
if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
67+
layer.lora_layer = lambda *x: 0
68+
original_layers.append((layer, layer._conv_forward))
69+
layer._conv_forward = _conv_forward_asymmetric.__get__(layer, torch.nn.Conv2d)
70+
71+
yield
72+
73+
finally:
74+
for layer, orig_conv_forward in original_layers:
75+
layer._conv_forward = orig_conv_forward

0 commit comments

Comments
 (0)