@@ -22,6 +22,10 @@ def modulate(x, scale):
2222# Core NextDiT Model #
2323#############################################################################
2424
25+ def clamp_fp16 (x ):
26+ if x .dtype == torch .float16 :
27+ return torch .nan_to_num (x , nan = 0.0 , posinf = 65504 , neginf = - 65504 )
28+ return x
2529
2630class JointAttention (nn .Module ):
2731 """Multi-head attention module."""
@@ -169,7 +173,7 @@ def __init__(
169173
170174 # @torch.compile
171175 def _forward_silu_gating (self , x1 , x3 ):
172- return F .silu (x1 ) * x3
176+ return clamp_fp16 ( F .silu (x1 ) * x3 )
173177
174178 def forward (self , x ):
175179 return self .w2 (self ._forward_silu_gating (self .w1 (x ), self .w3 (x )))
@@ -273,27 +277,27 @@ def forward(
273277 scale_msa , gate_msa , scale_mlp , gate_mlp = self .adaLN_modulation (adaln_input ).chunk (4 , dim = 1 )
274278
275279 x = x + gate_msa .unsqueeze (1 ).tanh () * self .attention_norm2 (
276- self .attention (
280+ clamp_fp16 ( self .attention (
277281 modulate (self .attention_norm1 (x ), scale_msa ),
278282 x_mask ,
279283 freqs_cis ,
280284 transformer_options = transformer_options ,
281- )
285+ ))
282286 )
283287 x = x + gate_mlp .unsqueeze (1 ).tanh () * self .ffn_norm2 (
284- self .feed_forward (
288+ clamp_fp16 ( self .feed_forward (
285289 modulate (self .ffn_norm1 (x ), scale_mlp ),
286- )
290+ ))
287291 )
288292 else :
289293 assert adaln_input is None
290294 x = x + self .attention_norm2 (
291- self .attention (
295+ clamp_fp16 ( self .attention (
292296 self .attention_norm1 (x ),
293297 x_mask ,
294298 freqs_cis ,
295299 transformer_options = transformer_options ,
296- )
300+ ))
297301 )
298302 x = x + self .ffn_norm2 (
299303 self .feed_forward (
0 commit comments