@@ -1797,7 +1797,7 @@ class LinearFuncFPxFwdBwd(torch.autograd.Function):
17971797 """
17981798
17991799 @staticmethod
1800- def forward (ctx , x , weight , bias = None , trun_bits = 0 , chunk_size = 16 ):
1800+ def forward (ctx , x , weight , bias = None , trun_bits = 0 , chunk_size = 16 , fp8_dyn = False ):
18011801 assert x .dtype in [torch .float , torch .bfloat16 , torch .float16 ]
18021802 # input can be 2D or 3D, need to reshape before tl_matmul
18031803 org_dtype = x .dtype
@@ -1813,6 +1813,20 @@ def forward(ctx, x, weight, bias=None, trun_bits=0, chunk_size=16):
18131813 ctx .save_for_backward (x , weight ) # x, W are saved in their original dtype
18141814 ctx .trun_bits = trun_bits
18151815 ctx .chunk_size = chunk_size
1816+ ctx .fp8_dyn = fp8_dyn
1817+
1818+ if fp8_dyn :
1819+ # use Q/dQ simulation for now, meaning still compute in fp16/bf16
1820+ # if choose per_token for input, use per_channel for W
1821+ # (W saved as [out, in], reduce inCh-dim, => reduce_dim=1)
1822+ ctx .fp8_e4m3_max = torch .finfo (torch .float8_e4m3fn ).max
1823+ ctx .fp8_e5m2_max = torch .finfo (torch .float8_e5m2 ).max
1824+ reduce_dim = None if fp8_dyn == "per_tensor" else 1
1825+ x_scale = x .abs ().amax (dim = reduce_dim ) / ctx .fp8_e4m3_max
1826+ w_scale = weight .abs ().amax (dim = reduce_dim ) / ctx .fp8_e4m3_max
1827+
1828+ x = (x / x_scale ).to (torch .float8_e4m3fn ).to (org_dtype )* x_scale
1829+ weight = (weight / w_scale ).to (torch .float8_e4m3fn ).to (org_dtype )* w_scale
18161830
18171831 # triton kernel assumes 2D inputs and cast the return to input.dtype
18181832 output = tl_matmul (
@@ -1840,6 +1854,18 @@ def backward(ctx, grad_output):
18401854 target_shape_grad_input = grad_output .shape [:- 1 ] + (in_dim ,)
18411855 grad_output_2D = grad_output .reshape (- 1 , out_dim ).to (dtype_input )
18421856
1857+ if ctx .fp8_dyn :
1858+ reduce_dim = None if ctx .fp8_dyn == "per_tensor" else 1
1859+ x_scale = x .abs ().amax (dim = reduce_dim ) / ctx .fp8_e5m2_max
1860+ w_scale = weight .abs ().amax (dim = reduce_dim ) / ctx .fp8_e5m2_max
1861+ grad_out_scale = grad_output_2D .abs ().amax (dim = None ) / ctx .fp8_e5m2_max # always perT
1862+
1863+ x = (x / x_scale ).to (torch .float8_e5m2 ).to (dtype_input )* x_scale
1864+ weight = (weight / w_scale ).to (torch .float8_e5m2 ).to (weight .dtype )* w_scale
1865+ grad_output_2D = (grad_output_2D / grad_out_scale ).to (torch .float8_e5m2
1866+ ).to (grad_output .dtype
1867+ )* grad_out_scale
1868+
18431869 # Compute grad_weight, shape = [out, in]
18441870 # NOTE: this triton kernel requires A matrix to be contiguous
18451871 grad_weight = tl_matmul (
@@ -1865,7 +1891,7 @@ def backward(ctx, grad_output):
18651891 else :
18661892 grad_bias = grad_output_2D .sum (0 ).to (ctx .bias_dtype )
18671893
1868- return grad_input , grad_weight , grad_bias , None
1894+ return grad_input , grad_weight , grad_bias , None , None , None
18691895
18701896
18711897class LinearFPxAcc (torch .nn .Linear ):
@@ -1906,20 +1932,23 @@ def from_nn(cls, nnlin, trun_bits=0, **kwargs):
19061932
19071933 lin24acc .weight = nnlin .weight
19081934 lin24acc .trun_bits = trun_bits
1935+ lin24acc .chunk_size = kwargs .get ("chunk_size" , False )
1936+ lin24acc .fp8_dyn = kwargs .get ("dynamic_fp8" , False ) #["per_tensor", "per_token"]
19091937
19101938 if nnlin .bias is not None :
19111939 lin24acc .bias = nnlin .bias
19121940 return lin24acc .to (target_device )
19131941
19141942 def forward (self , inputs ):
19151943 # This Linear Class will cast to BF16 before matmul and return FP32
1916- return LinearFuncFPxFwdBwd .apply (inputs , self .weight , self .bias , self .trun_bits )
1944+ return LinearFuncFPxFwdBwd .apply (inputs , self .weight , self .bias , self .trun_bits ,
1945+ self .chunk_size , self .fp8_dyn )
19171946
19181947 def extra_repr (self ) -> str :
19191948 """
19201949 Returns an alternative string representation of the object.
19211950 """
19221951 return (
19231952 f"in={ self .in_features } , out={ self .out_features } , bias={ self .bias is not None } , "
1924- f"trun_bits={ self .trun_bits } "
1953+ f"trun_bits={ self .trun_bits } ,fp8_dyn= { self . fp8_dyn } ,chunk_size= { self . chunk_size } "
19251954 )
0 commit comments