Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions lib_layerdiffusion/attention_sharing.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def __init__(self, in_features: int, out_features: int, rank: int = 256, org=Non
def forward(self, h):
org_weight = self.org[0].weight.to(h)
org_bias = self.org[0].bias.to(h) if self.org[0].bias is not None else None
down_weight = self.down.weight
up_weight = self.up.weight
down_weight = self.down.weight.to(h.device)
up_weight = self.up.weight.to(h.device)
final_weight = org_weight + torch.mm(up_weight, down_weight)
return torch.nn.functional.linear(h, final_weight, org_bias)

Expand Down Expand Up @@ -143,6 +143,9 @@ def __init__(self, module, frames=2, use_control=True, rank=256):
in_features=hidden_size, out_features=hidden_size
)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.to(device)

self.control_convs = None

if use_control:
Expand All @@ -155,11 +158,23 @@ def __init__(self, module, frames=2, use_control=True, rank=256):
for _ in range(self.frames)
]
self.control_convs = torch.nn.ModuleList(self.control_convs)
self.control_convs.to(device)

self.control_signals = None

def forward(self, h, context=None, value=None):
transformer_options = self.transformer_options

device = h.device
self.temporal_i.to(device)
self.temporal_q.to(device)
self.temporal_k.to(device)
self.temporal_v.to(device)
self.temporal_o.to(device)
self.to_q_lora.to(device)
self.to_k_lora.to(device)
self.to_v_lora.to(device)
self.to_out_lora.to(device)

modified_hidden_states = einops.rearrange(
h, "(b f) d c -> f b d c", f=self.frames
Expand Down Expand Up @@ -227,6 +242,11 @@ def forward(self, h, context=None, value=None):
)

x = modified_hidden_states
self.temporal_n.to(device)
if self.temporal_n.weight is not None:
self.temporal_n.weight = self.temporal_n.weight.to(device)
if self.temporal_n.bias is not None:
self.temporal_n.bias = self.temporal_n.bias.to(device)
x = self.temporal_n(x)
x = self.temporal_i(x)
d = x.shape[1]
Expand Down Expand Up @@ -345,11 +365,17 @@ def __init__(self, unet: ModelPatcher, frames=2, use_control=True, rank=256):
self.kwargs_encoder = AdditionalAttentionCondsEncoder()
else:
self.kwargs_encoder = None

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.to(device)
if self.kwargs_encoder is not None:
self.kwargs_encoder.to(device)

self.dtype = torch.float32
if model_management.should_use_fp16(model_management.get_torch_device()):
self.dtype = torch.float16
self.hookers.half()
self.hookers.to(device)
return

def set_control(self, img):
Expand Down
76 changes: 53 additions & 23 deletions lib_layerdiffusion/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import cv2
import numpy as np
import logging

from tqdm import tqdm
from typing import Optional, Tuple
Expand All @@ -10,6 +11,8 @@
import importlib.metadata
from packaging.version import parse

DEBUG_ENABLED = False

diffusers_version = importlib.metadata.version('diffusers')

def check_diffusers_version(min_version="0.25.0"):
Expand Down Expand Up @@ -261,11 +264,17 @@ def __init__(self, sd, device, dtype):

@torch.no_grad()
def estimate_single_pass(self, pixel, latent):
"""Run a single forward pass through the UNet model."""
y = self.model(pixel, latent)
return y

@torch.no_grad()
def estimate_augmented(self, pixel, latent):
"""Apply augmentations (flips and rotations) and aggregate results.

Uses 8 hardcoded augmentations (4 rotations with/without horizontal flip).
Replaced torch.median with torch.mean to avoid empty tensor issues on DirectML.
"""
args = [
[False, 0],
[False, 1],
Expand All @@ -275,10 +284,9 @@ def estimate_augmented(self, pixel, latent):
[True, 1],
[True, 2],
[True, 3],
]
] # Hardcoded 8 augmentations as in original implementation

result = []

for flip, rok in tqdm(args):
feed_pixel = pixel.clone()
feed_latent = latent.clone()
Expand All @@ -296,35 +304,57 @@ def estimate_augmented(self, pixel, latent):
if flip:
eps = torch.flip(eps, dims=(3,))

result += [eps]

result = torch.stack(result, dim=0)
if self.load_device == torch.device("mps"):
'''
In case that apple silicon devices would crash when calling torch.median() on tensors
in gpu vram with dimensions higher than 4, we move it to cpu, call torch.median()
and then move the result back to gpu.
'''
median = torch.median(result.cpu(), dim=0).values
median = median.to(device=self.load_device, dtype=self.dtype)
else:
median = torch.median(result, dim=0).values
return median
result.append(eps)
if DEBUG_ENABLED:
logging.debug(f"estimate_augmented: single_pass eps shape={eps.shape}, dtype={eps.dtype}")

result = torch.stack(result, dim=0) # Shape: [8, B, C, H, W]
if DEBUG_ENABLED:
logging.debug(f"estimate_augmented: stacked result shape={result.shape}, dtype={result.dtype}")

# Check for NaN or inf values to catch data issues
if torch.isnan(result).any() or torch.isinf(result).any():
logging.error("estimate_augmented: stacked tensor contains NaN or inf values")
raise ValueError("Stacked tensor contains NaN or inf values")

# Use mean instead of median for stability, especially on DirectML
y = torch.mean(result, dim=0) # Shape: [B, C, H, W]
if DEBUG_ENABLED:
logging.debug(f"estimate_augmented: y shape={y.shape}, dtype={y.dtype}")

return y

@torch.no_grad()
def decode_pixel(
self, pixel: torch.TensorType, latent: torch.TensorType
) -> torch.TensorType:
# pixel.shape = [B, C=3, H, W]
assert pixel.shape[1] == 3
"""Decode pixel and latent tensors to produce an RGBA image.

Args:
pixel: Input RGB image tensor of shape [B, 3, H, W].
latent: Latent representation tensor of shape [B, 4, H/8, W/8].

Returns:
Tensor of shape [B, 4, H, W] containing RGBA channels.
"""
assert pixel.shape[1] == 3, f"Expected pixel.shape[1] == 3, got {pixel.shape[1]}"
pixel_device = pixel.device
pixel_dtype = pixel.dtype

if DEBUG_ENABLED:
logging.debug(f"decode_pixel: pixel shape={pixel.shape}, dtype={pixel.dtype}")
logging.debug(f"decode_pixel: latent shape={latent.shape}, dtype={latent.dtype}")

pixel = pixel.to(device=self.load_device, dtype=self.dtype)
latent = latent.to(device=self.load_device, dtype=self.dtype)
# y.shape = [B, C=4, H, W]
y = self.estimate_augmented(pixel, latent)
y = y.clip(0, 1)
assert y.shape[1] == 4
# Restore image to original device of input image.
return y.to(pixel_device, dtype=pixel_dtype)
if DEBUG_ENABLED:
logging.debug(f"decode_pixel: y shape={y.shape}, dtype={y.dtype}")

if len(y.shape) < 2:
logging.error(f"decode_pixel: y has insufficient dimensions, shape={y.shape}")
raise ValueError(f"Expected y to have at least 2 dimensions, got {y.shape}")

y = y.clip(0, 1) # Ensure output is in [0, 1] range
assert y.shape[1] == 4, f"Expected y.shape[1] == 4, got {y.shape[1]}"
return y.to(pixel_device, dtype=pixel_dtype)