diff --git a/lib_layerdiffusion/attention_sharing.py b/lib_layerdiffusion/attention_sharing.py index cc67986..6a81285 100644 --- a/lib_layerdiffusion/attention_sharing.py +++ b/lib_layerdiffusion/attention_sharing.py @@ -66,8 +66,8 @@ def __init__(self, in_features: int, out_features: int, rank: int = 256, org=Non def forward(self, h): org_weight = self.org[0].weight.to(h) org_bias = self.org[0].bias.to(h) if self.org[0].bias is not None else None - down_weight = self.down.weight - up_weight = self.up.weight + down_weight = self.down.weight.to(h.device) + up_weight = self.up.weight.to(h.device) final_weight = org_weight + torch.mm(up_weight, down_weight) return torch.nn.functional.linear(h, final_weight, org_bias) @@ -143,6 +143,9 @@ def __init__(self, module, frames=2, use_control=True, rank=256): in_features=hidden_size, out_features=hidden_size ) + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + self.to(device) + self.control_convs = None if use_control: @@ -155,11 +158,23 @@ def __init__(self, module, frames=2, use_control=True, rank=256): for _ in range(self.frames) ] self.control_convs = torch.nn.ModuleList(self.control_convs) + self.control_convs.to(device) self.control_signals = None def forward(self, h, context=None, value=None): transformer_options = self.transformer_options + + device = h.device + self.temporal_i.to(device) + self.temporal_q.to(device) + self.temporal_k.to(device) + self.temporal_v.to(device) + self.temporal_o.to(device) + self.to_q_lora.to(device) + self.to_k_lora.to(device) + self.to_v_lora.to(device) + self.to_out_lora.to(device) modified_hidden_states = einops.rearrange( h, "(b f) d c -> f b d c", f=self.frames @@ -227,6 +242,11 @@ def forward(self, h, context=None, value=None): ) x = modified_hidden_states + self.temporal_n.to(device) + if self.temporal_n.weight is not None: + self.temporal_n.weight = self.temporal_n.weight.to(device) + if self.temporal_n.bias is not None: + self.temporal_n.bias = self.temporal_n.bias.to(device) x = self.temporal_n(x) x = self.temporal_i(x) d = x.shape[1] @@ -345,11 +365,17 @@ def __init__(self, unet: ModelPatcher, frames=2, use_control=True, rank=256): self.kwargs_encoder = AdditionalAttentionCondsEncoder() else: self.kwargs_encoder = None + + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + self.to(device) + if self.kwargs_encoder is not None: + self.kwargs_encoder.to(device) self.dtype = torch.float32 if model_management.should_use_fp16(model_management.get_torch_device()): self.dtype = torch.float16 self.hookers.half() + self.hookers.to(device) return def set_control(self, img): diff --git a/lib_layerdiffusion/models.py b/lib_layerdiffusion/models.py index 19396f7..5c9779a 100644 --- a/lib_layerdiffusion/models.py +++ b/lib_layerdiffusion/models.py @@ -2,6 +2,7 @@ import torch import cv2 import numpy as np +import logging from tqdm import tqdm from typing import Optional, Tuple @@ -10,6 +11,8 @@ import importlib.metadata from packaging.version import parse +DEBUG_ENABLED = False + diffusers_version = importlib.metadata.version('diffusers') def check_diffusers_version(min_version="0.25.0"): @@ -261,11 +264,17 @@ def __init__(self, sd, device, dtype): @torch.no_grad() def estimate_single_pass(self, pixel, latent): + """Run a single forward pass through the UNet model.""" y = self.model(pixel, latent) return y @torch.no_grad() def estimate_augmented(self, pixel, latent): + """Apply augmentations (flips and rotations) and aggregate results. + + Uses 8 hardcoded augmentations (4 rotations with/without horizontal flip). + Replaced torch.median with torch.mean to avoid empty tensor issues on DirectML. + """ args = [ [False, 0], [False, 1], @@ -275,10 +284,9 @@ def estimate_augmented(self, pixel, latent): [True, 1], [True, 2], [True, 3], - ] + ] # Hardcoded 8 augmentations as in original implementation result = [] - for flip, rok in tqdm(args): feed_pixel = pixel.clone() feed_latent = latent.clone() @@ -296,35 +304,57 @@ def estimate_augmented(self, pixel, latent): if flip: eps = torch.flip(eps, dims=(3,)) - result += [eps] - - result = torch.stack(result, dim=0) - if self.load_device == torch.device("mps"): - ''' - In case that apple silicon devices would crash when calling torch.median() on tensors - in gpu vram with dimensions higher than 4, we move it to cpu, call torch.median() - and then move the result back to gpu. - ''' - median = torch.median(result.cpu(), dim=0).values - median = median.to(device=self.load_device, dtype=self.dtype) - else: - median = torch.median(result, dim=0).values - return median + result.append(eps) + if DEBUG_ENABLED: + logging.debug(f"estimate_augmented: single_pass eps shape={eps.shape}, dtype={eps.dtype}") + + result = torch.stack(result, dim=0) # Shape: [8, B, C, H, W] + if DEBUG_ENABLED: + logging.debug(f"estimate_augmented: stacked result shape={result.shape}, dtype={result.dtype}") + + # Check for NaN or inf values to catch data issues + if torch.isnan(result).any() or torch.isinf(result).any(): + logging.error("estimate_augmented: stacked tensor contains NaN or inf values") + raise ValueError("Stacked tensor contains NaN or inf values") + + # Use mean instead of median for stability, especially on DirectML + y = torch.mean(result, dim=0) # Shape: [B, C, H, W] + if DEBUG_ENABLED: + logging.debug(f"estimate_augmented: y shape={y.shape}, dtype={y.dtype}") + + return y @torch.no_grad() def decode_pixel( self, pixel: torch.TensorType, latent: torch.TensorType ) -> torch.TensorType: - # pixel.shape = [B, C=3, H, W] - assert pixel.shape[1] == 3 + """Decode pixel and latent tensors to produce an RGBA image. + + Args: + pixel: Input RGB image tensor of shape [B, 3, H, W]. + latent: Latent representation tensor of shape [B, 4, H/8, W/8]. + + Returns: + Tensor of shape [B, 4, H, W] containing RGBA channels. + """ + assert pixel.shape[1] == 3, f"Expected pixel.shape[1] == 3, got {pixel.shape[1]}" pixel_device = pixel.device pixel_dtype = pixel.dtype + + if DEBUG_ENABLED: + logging.debug(f"decode_pixel: pixel shape={pixel.shape}, dtype={pixel.dtype}") + logging.debug(f"decode_pixel: latent shape={latent.shape}, dtype={latent.dtype}") pixel = pixel.to(device=self.load_device, dtype=self.dtype) latent = latent.to(device=self.load_device, dtype=self.dtype) - # y.shape = [B, C=4, H, W] y = self.estimate_augmented(pixel, latent) - y = y.clip(0, 1) - assert y.shape[1] == 4 - # Restore image to original device of input image. - return y.to(pixel_device, dtype=pixel_dtype) + if DEBUG_ENABLED: + logging.debug(f"decode_pixel: y shape={y.shape}, dtype={y.dtype}") + + if len(y.shape) < 2: + logging.error(f"decode_pixel: y has insufficient dimensions, shape={y.shape}") + raise ValueError(f"Expected y to have at least 2 dimensions, got {y.shape}") + + y = y.clip(0, 1) # Ensure output is in [0, 1] range + assert y.shape[1] == 4, f"Expected y.shape[1] == 4, got {y.shape[1]}" + return y.to(pixel_device, dtype=pixel_dtype) \ No newline at end of file