Skip to content

Commit 2ee7879

Browse files
Fix lowvram issues with hunyuan3d 2.1 (#9735)
1 parent 3493b9c commit 2ee7879

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

comfy/ldm/hunyuan3dv2_1/hunyuandit.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch.nn as nn
44
import torch.nn.functional as F
55
from comfy.ldm.modules.attention import optimized_attention
6+
import comfy.model_management
67

78
class GELU(nn.Module):
89

@@ -88,7 +89,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
8889
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
8990

9091
# get logits and pass it to softmax
91-
logits = F.linear(hidden_states, self.weight, bias = None)
92+
logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), bias = None)
9293
scores = logits.softmax(dim = -1)
9394

9495
topk_weight, topk_idx = torch.topk(scores, k = self.top_k, dim = -1, sorted = False)
@@ -255,7 +256,7 @@ def forward(self, timesteps, condition):
255256
cond_embed = self.cond_proj(condition)
256257
timestep_embed = timestep_embed + cond_embed
257258

258-
time_conditioned = self.mlp(timestep_embed.to(self.mlp[0].weight.device))
259+
time_conditioned = self.mlp(timestep_embed)
259260

260261
# for broadcasting with image tokens
261262
return time_conditioned.unsqueeze(1)

0 commit comments

Comments
 (0)