22import torch .nn as nn
33import torch .nn .functional as F
44from contextlib import contextmanager
5+ from diffsynth_engine .utils .platform import DTYPE_FP8
56
67
78def enable_fp8_autocast (module : nn .Module , compute_dtype : torch .dtype = torch .bfloat16 , use_fp8_linear : bool = False ):
@@ -51,7 +52,7 @@ def enable_fp8_linear(module: nn.Module):
5152def _enable_fp8_linear (module : nn .Module ):
5253 if isinstance (module , nn .Linear ) and torch .is_floating_point (module .weight .data ):
5354 # avoid conversion for int weights like GGUF
54- module .weight .data = module .weight .data .to (torch . float8_e4m3fn )
55+ module .weight .data = module .weight .data .to (DTYPE_FP8 )
5556 for submodule in module .children ():
5657 _enable_fp8_linear (submodule )
5758
@@ -71,16 +72,24 @@ def fp8_linear(
7172 ) -> torch .Tensor :
7273 device = input .device
7374 origin_dtype = input .dtype
74- input = input .to (torch .float8_e4m3fn )
75- weight = weight .to (torch .float8_e4m3fn )
75+ scale_a = 1.0
76+ # For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
77+ # To avoid overflow and ensure numerical compatibility during FP8 computation,
78+ # we scale down the input by 2.0 in advance.
79+ # This scaling will be compensated later during the final result scaling.
80+ if DTYPE_FP8 == torch .float8_e4m3fnuz :
81+ scale_a = 2.0
82+ input = input / scale_a
83+ input = input .to (DTYPE_FP8 )
84+ weight = weight .to (DTYPE_FP8 )
7685
7786 if len (input .shape ) > 2 :
7887 origin_shape = input .shape
7988 input = input .reshape (- 1 , origin_shape [- 1 ])
8089 result = torch ._scaled_mm (
8190 input ,
8291 weight .T ,
83- scale_a = torch .tensor (1.0 ).to (device = device ),
92+ scale_a = torch .tensor (scale_a ).to (device = device ),
8493 scale_b = torch .tensor (1.0 ).to (device = device ),
8594 bias = bias ,
8695 out_dtype = origin_dtype ,
@@ -91,7 +100,7 @@ def fp8_linear(
91100 result = torch ._scaled_mm (
92101 input ,
93102 weight .T ,
94- scale_a = torch .tensor (1.0 ).to (device = device ),
103+ scale_a = torch .tensor (scale_a ).to (device = device ),
95104 scale_b = torch .tensor (1.0 ).to (device = device ),
96105 bias = bias ,
97106 out_dtype = origin_dtype ,
0 commit comments