Skip to content

Commit ca39552

Browse files
committed
Merge branch 'master' into asset-management
2 parents 4dd843d + c8d2117 commit ca39552

29 files changed

+1932
-80
lines changed

comfy/ldm/wan/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1355,7 +1355,7 @@ def forward(self, x, context, transformer_options={}, **kwargs):
13551355

13561356
x = optimized_attention(q, k, v, heads=self.num_heads, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options)
13571357

1358-
x = x.transpose(1, 2).view(b, -1, n, d).flatten(2)
1358+
x = x.transpose(1, 2).reshape(b, -1, n * d)
13591359
x = self.o(x)
13601360
return x
13611361

@@ -1551,6 +1551,9 @@ def forward_orig(
15511551
context_img_len = None
15521552

15531553
if audio_embed is not None:
1554+
if reference_latent is not None:
1555+
zero_audio_pad = torch.zeros(audio_embed.shape[0], reference_latent.shape[-3], *audio_embed.shape[2:], device=audio_embed.device, dtype=audio_embed.dtype)
1556+
audio_embed = torch.cat([audio_embed, zero_audio_pad], dim=1)
15541557
audio = self.audio_proj(audio_embed).permute(0, 3, 1, 2).flatten(2).transpose(1, 2)
15551558
else:
15561559
audio = None

comfy/ldm/wan/model_animate.py

Lines changed: 548 additions & 0 deletions
Large diffs are not rendered by default.

comfy/model_base.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import comfy.ldm.cosmos.predict2
4040
import comfy.ldm.lumina.model
4141
import comfy.ldm.wan.model
42+
import comfy.ldm.wan.model_animate
4243
import comfy.ldm.hunyuan3d.model
4344
import comfy.ldm.hidream.model
4445
import comfy.ldm.chroma.model
@@ -1253,6 +1254,23 @@ def extra_conds(self, **kwargs):
12531254

12541255
return out
12551256

1257+
class WAN22_Animate(WAN21):
1258+
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
1259+
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel)
1260+
self.image_to_video = image_to_video
1261+
1262+
def extra_conds(self, **kwargs):
1263+
out = super().extra_conds(**kwargs)
1264+
1265+
face_video_pixels = kwargs.get("face_video_pixels", None)
1266+
if face_video_pixels is not None:
1267+
out['face_pixel_values'] = comfy.conds.CONDRegular(face_video_pixels)
1268+
1269+
pose_latents = kwargs.get("pose_video_latent", None)
1270+
if pose_latents is not None:
1271+
out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents))
1272+
return out
1273+
12561274
class WAN22_S2V(WAN21):
12571275
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
12581276
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)

comfy/model_detection.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
404404
dit_config["model_type"] = "s2v"
405405
elif '{}audio_proj.audio_proj_glob_1.layer.bias'.format(key_prefix) in state_dict_keys:
406406
dit_config["model_type"] = "humo"
407+
elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys:
408+
dit_config["model_type"] = "animate"
407409
else:
408410
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
409411
dit_config["model_type"] = "i2v"

comfy/model_management.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def amd_min_version(device=None, min_rdna_version=0):
348348
# if any((a in arch) for a in ["gfx1201"]):
349349
# ENABLE_PYTORCH_ATTENTION = True
350350
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
351-
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
351+
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx942", "gfx950"]): # TODO: more arches
352352
SUPPORT_FP8_OPS = True
353353

354354
except:
@@ -645,7 +645,9 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
645645
if loaded_model.model.is_clone(current_loaded_models[i].model):
646646
to_unload = [i] + to_unload
647647
for i in to_unload:
648-
current_loaded_models.pop(i).model.detach(unpatch_all=False)
648+
model_to_unload = current_loaded_models.pop(i)
649+
model_to_unload.model.detach(unpatch_all=False)
650+
model_to_unload.model_finalizer.detach()
649651

650652
total_memory_required = {}
651653
for loaded_model in models_to_load:

comfy/ops.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -365,12 +365,13 @@ def reset_parameters(self):
365365
return None
366366

367367
def forward_comfy_cast_weights(self, input):
368-
try:
369-
out = fp8_linear(self, input)
370-
if out is not None:
371-
return out
372-
except Exception as e:
373-
logging.info("Exception during fp8 op: {}".format(e))
368+
if not self.training:
369+
try:
370+
out = fp8_linear(self, input)
371+
if out is not None:
372+
return out
373+
except Exception as e:
374+
logging.info("Exception during fp8 op: {}".format(e))
374375

375376
weight, bias = cast_bias_weight(self, input)
376377
return torch.nn.functional.linear(input, weight, bias)

comfy/supported_models.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -995,7 +995,7 @@ class WAN21_T2V(supported_models_base.BASE):
995995
unet_extra_config = {}
996996
latent_format = latent_formats.Wan21
997997

998-
memory_usage_factor = 1.0
998+
memory_usage_factor = 0.9
999999

10001000
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
10011001

@@ -1004,7 +1004,7 @@ class WAN21_T2V(supported_models_base.BASE):
10041004

10051005
def __init__(self, unet_config):
10061006
super().__init__(unet_config)
1007-
self.memory_usage_factor = self.unet_config.get("dim", 2000) / 2000
1007+
self.memory_usage_factor = self.unet_config.get("dim", 2000) / 2222
10081008

10091009
def get_model(self, state_dict, prefix="", device=None):
10101010
out = model_base.WAN21(self, device=device)
@@ -1096,6 +1096,19 @@ def get_model(self, state_dict, prefix="", device=None):
10961096
out = model_base.WAN22_S2V(self, device=device)
10971097
return out
10981098

1099+
class WAN22_Animate(WAN21_T2V):
1100+
unet_config = {
1101+
"image_model": "wan2.1",
1102+
"model_type": "animate",
1103+
}
1104+
1105+
def __init__(self, unet_config):
1106+
super().__init__(unet_config)
1107+
1108+
def get_model(self, state_dict, prefix="", device=None):
1109+
out = model_base.WAN22_Animate(self, device=device)
1110+
return out
1111+
10991112
class WAN22_T2V(WAN21_T2V):
11001113
unet_config = {
11011114
"image_model": "wan2.1",
@@ -1361,6 +1374,6 @@ def get_model(self, state_dict, prefix="", device=None):
13611374
out = model_base.HunyuanImage21Refiner(self, device=device)
13621375
return out
13631376

1364-
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage]
1377+
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage]
13651378

13661379
models += [SVD_img2vid]

comfy/text_encoders/llama.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -400,21 +400,25 @@ def preprocess_embed(self, embed, device):
400400

401401
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
402402
grid = None
403+
position_ids = None
404+
offset = 0
403405
for e in embeds_info:
404406
if e.get("type") == "image":
405407
grid = e.get("extra", None)
406-
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
407408
start = e.get("index")
408-
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
409+
if position_ids is None:
410+
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
411+
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
409412
end = e.get("size") + start
410413
len_max = int(grid.max()) // 2
411414
start_next = len_max + start
412-
position_ids[:, end:] = torch.arange(start_next, start_next + (embeds.shape[1] - end), device=embeds.device)
413-
position_ids[0, start:end] = start
415+
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
416+
position_ids[0, start:end] = start + offset
414417
max_d = int(grid[0][1]) // 2
415-
position_ids[1, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
418+
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
416419
max_d = int(grid[0][2]) // 2
417-
position_ids[2, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
420+
position_ids[2, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
421+
offset += len_max - (end - start)
418422

419423
if grid is None:
420424
position_ids = None

comfy/weight_adapter/loha.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,12 @@ def __init__(self, loaded_keys, weights):
130130
def create_train(cls, weight, rank=1, alpha=1.0):
131131
out_dim = weight.shape[0]
132132
in_dim = weight.shape[1:].numel()
133-
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
134-
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
133+
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32)
134+
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
135135
torch.nn.init.normal_(mat1, 0.1)
136136
torch.nn.init.constant_(mat2, 0.0)
137-
mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
138-
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
137+
mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32)
138+
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
139139
torch.nn.init.normal_(mat3, 0.1)
140140
torch.nn.init.normal_(mat4, 0.01)
141141
return LohaDiff(

comfy/weight_adapter/lokr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ def create_train(cls, weight, rank=1, alpha=1.0):
8989
in_dim = weight.shape[1:].numel()
9090
out1, out2 = factorization(out_dim, rank)
9191
in1, in2 = factorization(in_dim, rank)
92-
mat1 = torch.empty(out1, in1, device=weight.device, dtype=weight.dtype)
93-
mat2 = torch.empty(out2, in2, device=weight.device, dtype=weight.dtype)
92+
mat1 = torch.empty(out1, in1, device=weight.device, dtype=torch.float32)
93+
mat2 = torch.empty(out2, in2, device=weight.device, dtype=torch.float32)
9494
torch.nn.init.kaiming_uniform_(mat2, a=5**0.5)
9595
torch.nn.init.constant_(mat1, 0.0)
9696
return LokrDiff(

0 commit comments

Comments
 (0)