Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
225a196
Feat: Add basic LoRA training support
yoland68 Mar 2, 2025
2cd3c8a
Fix ruff errors
yoland68 Mar 2, 2025
f03ece1
Add remaining patch
yoland68 Mar 2, 2025
bfc2f17
Refactor import statements in nodes_train.py
yoland68 Mar 2, 2025
0edc48a
Remove empty spaces
yoland68 Mar 2, 2025
b87f55e
Move allow batch execution logic to different PR
yoland68 Mar 10, 2025
d58ad2d
Expand supported image file extensions in LoadImageSetNode
yoland68 Mar 10, 2025
6fb4cc0
Weight Adapter Scheme
KohakuBlueleaf Apr 2, 2025
4774c32
Initial impl
KohakuBlueleaf Apr 2, 2025
c40686e
Utilize new weight adapter in lora.py
KohakuBlueleaf Apr 2, 2025
8431747
lint
KohakuBlueleaf Apr 2, 2025
88d9168
Sync (#1)
KohakuBlueleaf Apr 8, 2025
c792fad
Merge branch 'comfyanonymous:master' into kbl-new-lora
KohakuBlueleaf Apr 8, 2025
726fdfc
Fix import error
KohakuBlueleaf Apr 8, 2025
a220e5c
Fix typing syntax error
KohakuBlueleaf Apr 8, 2025
ff05027
Use correct v list
KohakuBlueleaf Apr 8, 2025
889f947
Remove unused import
KohakuBlueleaf Apr 8, 2025
e8f3bc5
Finalize the modularized weight adapter impl
KohakuBlueleaf Apr 9, 2025
14c2085
Add scheme of TrainBase class
KohakuBlueleaf Apr 14, 2025
bffbed8
Basic train base impl of lora
KohakuBlueleaf Apr 14, 2025
68c9e79
linting
KohakuBlueleaf Apr 14, 2025
aadc6c2
Merge branch 'yo-lora-trainer' into weight-adapter-train
KohakuBlueleaf Apr 15, 2025
5098e94
Utilize weight adapter scheme in basic training node
KohakuBlueleaf Apr 22, 2025
4616faa
Merge branch 'master' into weight-adapter-train
KohakuBlueleaf Apr 22, 2025
dc74839
Fix missed import in merging
KohakuBlueleaf Apr 22, 2025
c8bd95a
linting
KohakuBlueleaf Apr 22, 2025
3d3d14d
Merge branch 'master' into weight-adapter-train
KohakuBlueleaf May 31, 2025
9c0cf36
weight adapter fixes for training node
KohakuBlueleaf Jun 1, 2025
5e43ec9
Updates of training logic
KohakuBlueleaf Jun 6, 2025
1870402
Add lora model loader for onfly usage
KohakuBlueleaf Jun 6, 2025
c246a1d
Add gradient checkpointing
KohakuBlueleaf Jun 6, 2025
b3b36e5
Use tqdm for training loop
KohakuBlueleaf Jun 6, 2025
1baa1bd
check if need to disable pbar
KohakuBlueleaf Jun 6, 2025
8cf3b53
Merge branch 'comfyanonymous:master' into weight-adapter-train
KohakuBlueleaf Jun 10, 2025
b8757c5
Update lora.py
KohakuBlueleaf Jun 10, 2025
4dcd698
Update nodes_train.py
KohakuBlueleaf Jun 10, 2025
23523f5
Fix typo
KohakuBlueleaf Jun 13, 2025
218c3e3
Use encoded latents as input
KohakuBlueleaf Jun 13, 2025
31c8cc9
Correct dtype handling and better default arg
KohakuBlueleaf Jun 13, 2025
bbcc65e
Remove grad ckpt from model impl
KohakuBlueleaf Jun 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions comfy/comfy_types/node_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class IO(StrEnum):
CONTROL_NET = "CONTROL_NET"
VAE = "VAE"
MODEL = "MODEL"
LORA_MODEL = "LORA_MODEL"
LOSS_MAP = "LOSS_MAP"
CLIP_VISION = "CLIP_VISION"
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
STYLE_MODEL = "STYLE_MODEL"
Expand Down
6 changes: 3 additions & 3 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def forward(self, x, context=None, transformer_options={}):
for p in patch:
n = p(n, extra_options)

x += n
x = n + x
if "middle_patch" in transformer_patches:
patch = transformer_patches["middle_patch"]
for p in patch:
Expand Down Expand Up @@ -793,12 +793,12 @@ def forward(self, x, context=None, transformer_options={}):
for p in patch:
n = p(n, extra_options)

x += n
x = n + x
if self.is_res:
x_skip = x
x = self.ff(self.norm3(x))
if self.is_res:
x += x_skip
x = x_skip + x

return x

Expand Down
19 changes: 11 additions & 8 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,26 @@
"""

from __future__ import annotations
from typing import Optional, Callable
import torch

import collections
import copy
import inspect
import logging
import uuid
import collections
import math
import uuid
from typing import Callable, Optional

import torch

import comfy.utils
import comfy.float
import comfy.model_management
import comfy.lora
import comfy.hooks
import comfy.lora
import comfy.model_management
import comfy.patcher_extension
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
import comfy.utils
from comfy.comfy_types import UnetWrapperFunction
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP


def string_to_seed(data):
crc = 0xFFFFFFFF
Expand Down
23 changes: 22 additions & 1 deletion comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,7 +1081,28 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return (model_patcher, clip, vae, clipvision)


def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format
def load_diffusion_model_state_dict(sd, model_options={}):
"""
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.

Args:
sd (dict): State dictionary containing model weights and configuration
model_options (dict, optional): Additional options for model loading. Supports:
- dtype: Override model data type
- custom_operations: Custom model operations
- fp8_optimizations: Enable FP8 optimizations

Returns:
ModelPatcher: A wrapped model instance that handles device management and weight loading.
Returns None if the model configuration cannot be detected.

The function:
1. Detects and handles different model formats (regular, diffusers, mmdit)
2. Configures model dtype based on parameters and device capabilities
3. Handles weight conversion and device placement
4. Manages model optimization settings
5. Loads weights and returns a device-managed model instance
"""
dtype = model_options.get("dtype", None)

#Allow loading unets from checkpoint files
Expand Down
8 changes: 7 additions & 1 deletion comfy/weight_adapter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .base import WeightAdapterBase
from .base import WeightAdapterBase, WeightAdapterTrainBase
from .lora import LoRAAdapter
from .loha import LoHaAdapter
from .lokr import LoKrAdapter
Expand All @@ -15,3 +15,9 @@
OFTAdapter,
BOFTAdapter,
]

__all__ = [
"WeightAdapterBase",
"WeightAdapterTrainBase",
"adapters"
] + [a.__name__ for a in adapters]
35 changes: 33 additions & 2 deletions comfy/weight_adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,20 @@ class WeightAdapterBase:
weights: list[torch.Tensor]

@classmethod
def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]:
def load(cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor) -> Optional["WeightAdapterBase"]:
raise NotImplementedError

def to_train(self) -> "WeightAdapterTrainBase":
raise NotImplementedError

@classmethod
def create_train(cls, weight, *args) -> "WeightAdapterTrainBase":
"""
weight: The original weight tensor to be modified.
*args: Additional arguments for configuration, such as rank, alpha etc.
"""
raise NotImplementedError

def calculate_weight(
self,
weight,
Expand All @@ -33,10 +41,22 @@ def calculate_weight(


class WeightAdapterTrainBase(nn.Module):
# We follow the scheme of PR #7032
def __init__(self):
super().__init__()

# [TODO] Collaborate with LoRA training PR #7032
def __call__(self, w):
"""
w: The original weight tensor to be modified.
"""
raise NotImplementedError

def passive_memory_usage(self):
raise NotImplementedError("passive_memory_usage is not implemented")

def move_to(self, device):
self.to(device)
return self.passive_memory_usage()


def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
Expand Down Expand Up @@ -102,3 +122,14 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
padded_tensor[new_slices] = tensor[orig_slices]

return padded_tensor


def tucker_weight_from_conv(up, down, mid):
up = up.reshape(up.size(0), up.size(1))
down = down.reshape(down.size(0), down.size(1))
return torch.einsum("m n ..., i m, n j -> i j ...", mid, up, down)


def tucker_weight(wa, wb, t):
temp = torch.einsum("i j ..., j r -> i r ...", t, wb)
return torch.einsum("i j ..., i r -> r j ...", temp, wa)
66 changes: 65 additions & 1 deletion comfy/weight_adapter/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,56 @@

import torch
import comfy.model_management
from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape
from .base import (
WeightAdapterBase,
WeightAdapterTrainBase,
weight_decompose,
pad_tensor_to_shape,
tucker_weight_from_conv,
)


class LoraDiff(WeightAdapterTrainBase):
def __init__(self, weights):
super().__init__()
mat1, mat2, alpha, mid, dora_scale, reshape = weights
out_dim, rank = mat1.shape[0], mat1.shape[1]
rank, in_dim = mat2.shape[0], mat2.shape[1]
if mid is not None:
convdim = mid.ndim - 2
layer = (
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d
)[convdim]
else:
layer = torch.nn.Linear
self.lora_up = layer(rank, out_dim, bias=False)
self.lora_down = layer(in_dim, rank, bias=False)
self.lora_up.weight.data.copy_(mat1)
self.lora_down.weight.data.copy_(mat2)
if mid is not None:
self.lora_mid = layer(mid, rank, bias=False)
self.lora_mid.weight.data.copy_(mid)
else:
self.lora_mid = None
self.rank = rank
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)

def __call__(self, w):
org_dtype = w.dtype
if self.lora_mid is None:
diff = self.lora_up.weight @ self.lora_down.weight
else:
diff = tucker_weight_from_conv(
self.lora_up.weight, self.lora_down.weight, self.lora_mid.weight
)
scale = self.alpha / self.rank
weight = w + scale * diff.reshape(w.shape)
return weight.to(org_dtype)

def passive_memory_usage(self):
return sum(param.numel() * param.element_size() for param in self.parameters())


class LoRAAdapter(WeightAdapterBase):
Expand All @@ -13,6 +62,21 @@ def __init__(self, loaded_keys, weights):
self.loaded_keys = loaded_keys
self.weights = weights

@classmethod
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
in_dim = weight.shape[1:].numel()
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
torch.nn.init.kaiming_uniform_(mat1, a=5**0.5)
torch.nn.init.constant_(mat2, 0.0)
return LoraDiff(
(mat1, mat2, alpha, None, None, None)
)

def to_train(self):
return LoraDiff(self.weights)

@classmethod
def load(
cls,
Expand Down
Loading