Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions invokeai/app/invocations/compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:

with (
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (model_state_dict, text_encoder),
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora_text_encoder(
text_encoder,
loras=_lora_loader(),
model_state_dict=model_state_dict,
cached_weights=cached_weights,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
Expand Down Expand Up @@ -175,13 +175,13 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:

with (
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (state_dict, text_encoder),
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora(
text_encoder,
loras=_lora_loader(),
prefix=lora_prefix,
model_state_dict=state_dict,
cached_weights=cached_weights,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
Expand Down
16 changes: 14 additions & 2 deletions invokeai/app/invocations/denoise_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
Expand Down Expand Up @@ -833,6 +834,17 @@ def step_callback(state: PipelineIntermediateState) -> None:
if self.unet.freeu_config:
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))

### lora
if self.unet.loras:
for lora_field in self.unet.loras:
ext_manager.add_extension(
LoRAExt(
node_context=context,
model_id=lora_field.lora,
weight=lora_field.weight,
)
)

# context for loading additional models
with ExitStack() as exit_stack:
# later should be smth like:
Expand Down Expand Up @@ -913,14 +925,14 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
ExitStack() as exit_stack,
unet_info.model_on_device() as (model_state_dict, unet),
unet_info.model_on_device() as (cached_weights, unet),
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
set_seamless(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(
unet,
loras=_lora_loader(),
model_state_dict=model_state_dict,
cached_weights=cached_weights,
),
):
assert isinstance(unet, UNet2DConditionModel)
Expand Down
132 changes: 82 additions & 50 deletions invokeai/backend/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@

import bisect
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Set, Tuple, Union

import torch
from safetensors.torch import load_file
from typing_extensions import Self

import invokeai.backend.util.logging as logger
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.raw_model import RawModel

Expand Down Expand Up @@ -46,9 +47,19 @@ def __init__(
self.rank = None # set in layer implementation
self.layer_key = layer_key

def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()

def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
return self.bias

def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
params = {"weight": self.get_weight(orig_module.weight)}
bias = self.get_bias(orig_module.bias)
if bias is not None:
params["bias"] = bias
return params

def calc_size(self) -> int:
model_size = 0
for val in [self.bias]:
Expand All @@ -60,6 +71,17 @@ def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype]
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)

def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
"""Log a warning if values contains unhandled keys."""
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
unknown_keys = set(values.keys()) - all_known_keys
if unknown_keys:
logger.warning(
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
)


# TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase):
Expand All @@ -76,14 +98,19 @@ def __init__(

self.up = values["lora_up.weight"]
self.down = values["lora_down.weight"]
if "lora_mid.weight" in values:
self.mid: Optional[torch.Tensor] = values["lora_mid.weight"]
else:
self.mid = None
self.mid = values.get("lora_mid.weight", None)

self.rank = self.down.shape[0]
self.check_keys(
values,
{
"lora_up.weight",
"lora_down.weight",
"lora_mid.weight",
},
)

def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
Expand Down Expand Up @@ -125,20 +152,23 @@ def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
self.w1_b = values["hada_w1_b"]
self.w2_a = values["hada_w2_a"]
self.w2_b = values["hada_w2_b"]

if "hada_t1" in values:
self.t1: Optional[torch.Tensor] = values["hada_t1"]
else:
self.t1 = None

if "hada_t2" in values:
self.t2: Optional[torch.Tensor] = values["hada_t2"]
else:
self.t2 = None
self.t1 = values.get("hada_t1", None)
self.t2 = values.get("hada_t2", None)

self.rank = self.w1_b.shape[0]
self.check_keys(
values,
{
"hada_w1_a",
"hada_w1_b",
"hada_w2_a",
"hada_w2_b",
"hada_t1",
"hada_t2",
},
)

def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.t1 is None:
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)

Expand Down Expand Up @@ -186,37 +216,39 @@ def __init__(
):
super().__init__(layer_key, values)

if "lokr_w1" in values:
self.w1: Optional[torch.Tensor] = values["lokr_w1"]
self.w1_a = None
self.w1_b = None
else:
self.w1 = None
self.w1 = values.get("lokr_w1", None)
if self.w1 is None:
self.w1_a = values["lokr_w1_a"]
self.w1_b = values["lokr_w1_b"]

if "lokr_w2" in values:
self.w2: Optional[torch.Tensor] = values["lokr_w2"]
self.w2_a = None
self.w2_b = None
else:
self.w2 = None
self.w2 = values.get("lokr_w2", None)
if self.w2 is None:
self.w2_a = values["lokr_w2_a"]
self.w2_b = values["lokr_w2_b"]

if "lokr_t2" in values:
self.t2: Optional[torch.Tensor] = values["lokr_t2"]
else:
self.t2 = None
self.t2 = values.get("lokr_t2", None)

if "lokr_w1_b" in values:
self.rank = values["lokr_w1_b"].shape[0]
elif "lokr_w2_b" in values:
self.rank = values["lokr_w2_b"].shape[0]
if self.w1_b is not None:
self.rank = self.w1_b.shape[0]
elif self.w2_b is not None:
self.rank = self.w2_b.shape[0]
else:
self.rank = None # unscaled

def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
self.check_keys(
values,
{
"lokr_w1",
"lokr_w1_a",
"lokr_w1_b",
"lokr_w2",
"lokr_w2_a",
"lokr_w2_b",
"lokr_t2",
},
)

def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
w1: Optional[torch.Tensor] = self.w1
if w1 is None:
assert self.w1_a is not None
Expand Down Expand Up @@ -272,7 +304,9 @@ def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype]


class FullLayer(LoRALayerBase):
# bias handled in LoRALayerBase(calc_size, to)
# weight: torch.Tensor
# bias: Optional[torch.Tensor]

def __init__(
self,
Expand All @@ -282,15 +316,12 @@ def __init__(
super().__init__(layer_key, values)

self.weight = values["diff"]

if len(values.keys()) > 1:
_keys = list(values.keys())
_keys.remove("diff")
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
self.bias = values.get("diff_b", None)

self.rank = None # unscaled
self.check_keys(values, {"diff", "diff_b"})

def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight

def calc_size(self) -> int:
Expand Down Expand Up @@ -319,8 +350,9 @@ def __init__(
self.on_input = values["on_input"]

self.rank = None # unscaled
self.check_keys(values, {"weight", "on_input"})

def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
weight = self.weight
if not self.on_input:
weight = weight.reshape(-1, 1)
Expand Down Expand Up @@ -459,23 +491,23 @@ def from_checkpoint(

for layer_key, values in state_dict.items():
# lora and locon
if "lora_down.weight" in values:
if "lora_up.weight" in values:
layer: AnyLoRALayer = LoRALayer(layer_key, values)

# loha
elif "hada_w1_b" in values:
elif "hada_w1_a" in values:
layer = LoHALayer(layer_key, values)

# lokr
elif "lokr_w1_b" in values or "lokr_w1" in values:
elif "lokr_w1" in values or "lokr_w1_a" in values:
layer = LoKRLayer(layer_key, values)

# diff
elif "diff" in values:
layer = FullLayer(layer_key, values)

# ia3
elif "weight" in values and "on_input" in values:
elif "on_input" in values:
layer = IA3Layer(layer_key, values)

else:
Expand Down
Loading