Skip to content

Commit 211fcba

Browse files
Hack to make zimage work in fp16.
1 parent 33d6aec commit 211fcba

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

comfy/ldm/lumina/model.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ def modulate(x, scale):
2222
# Core NextDiT Model #
2323
#############################################################################
2424

25+
def clamp_fp16(x):
26+
if x.dtype == torch.float16:
27+
return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
28+
return x
2529

2630
class JointAttention(nn.Module):
2731
"""Multi-head attention module."""
@@ -169,7 +173,7 @@ def __init__(
169173

170174
# @torch.compile
171175
def _forward_silu_gating(self, x1, x3):
172-
return F.silu(x1) * x3
176+
return clamp_fp16(F.silu(x1) * x3)
173177

174178
def forward(self, x):
175179
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
@@ -273,27 +277,27 @@ def forward(
273277
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
274278

275279
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
276-
self.attention(
280+
clamp_fp16(self.attention(
277281
modulate(self.attention_norm1(x), scale_msa),
278282
x_mask,
279283
freqs_cis,
280284
transformer_options=transformer_options,
281-
)
285+
))
282286
)
283287
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
284-
self.feed_forward(
288+
clamp_fp16(self.feed_forward(
285289
modulate(self.ffn_norm1(x), scale_mlp),
286-
)
290+
))
287291
)
288292
else:
289293
assert adaln_input is None
290294
x = x + self.attention_norm2(
291-
self.attention(
295+
clamp_fp16(self.attention(
292296
self.attention_norm1(x),
293297
x_mask,
294298
freqs_cis,
295299
transformer_options=transformer_options,
296-
)
300+
))
297301
)
298302
x = x + self.ffn_norm2(
299303
self.feed_forward(

comfy/supported_models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,6 +1027,8 @@ class ZImage(Lumina2):
10271027

10281028
memory_usage_factor = 1.7
10291029

1030+
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
1031+
10301032
def clip_target(self, state_dict={}):
10311033
pref = self.text_encoder_key_prefix[0]
10321034
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))

0 commit comments

Comments
 (0)