@@ -1926,14 +1926,16 @@ def forward(
19261926 ctx .chunk_size = chunk_size
19271927 ctx .fp8_dyn = fp8_dyn
19281928 ctx .clamp_acc_to_dl16 = clamp_acc_to_dl16
1929+ ctx .fp8_e4m3_max = torch .finfo (torch .float8_e4m3fn ).max
1930+ ctx .fp8_e5m2_max = torch .finfo (torch .float8_e5m2 ).max
19291931 ctx .dl8_min = 0.0087890625
19301932
1933+ x_scale = torch .tensor (1.0 , device = x .device , dtype = org_dtype )
1934+ w_scale = x_scale .clone ()
19311935 if fp8_dyn :
19321936 # use Q/dQ simulation for now, meaning still compute in fp16/bf16
19331937 # if choose per_token for input, use per_channel for W
19341938 # (W saved as [out, in], reduce inCh-dim, => reduce_dim=1)
1935- ctx .fp8_e4m3_max = torch .finfo (torch .float8_e4m3fn ).max
1936- ctx .fp8_e5m2_max = torch .finfo (torch .float8_e5m2 ).max
19371939 reduce_dim = None if fp8_dyn == "per_tensor" else 1
19381940 x_scale = (
19391941 x .abs ().amax (dim = reduce_dim , keepdim = True ) / ctx .fp8_e4m3_max
@@ -1942,22 +1944,30 @@ def forward(
19421944 weight .abs ().amax (dim = reduce_dim , keepdim = True ) / ctx .fp8_e4m3_max
19431945 ).clamp (min = 1e-5 )
19441946
1945- x = (x / x_scale ).to (torch .float8_e4m3fn ).to (org_dtype ) * x_scale
1946- weight = (weight / w_scale ).to (torch .float8_e4m3fn ).to (org_dtype ) * w_scale
1947+ x = (x / x_scale ).to (torch .float8_e4m3fn ).to (torch . float32 )
1948+ weight = (weight / w_scale ).to (torch .float8_e4m3fn ).to (torch . float32 )
19471949 if clamp_acc_to_dl16 :
1948- # NOTE For DL8@DL8 acc in DL16, as DL8 doesn't support subnorm numbers like PyTorch
1949- # (whose real min for e4m3fn is 2^-9) , need to flush subnorm numbers to 0
1950- x .masked_fill_ (x < ctx .dl8_min , 0 )
1951- weight .masked_fill_ (weight < ctx .dl8_min , 0 )
1950+ # at this point, x and W are clamped to PT's FP8 range (2^-9 to 448). But since DL8
1951+ # doesn't support subnorm like PyTorch , need to flush subnorms to 0 BEFORE descaling
1952+ x .masked_fill_ (x . abs () < ctx .dl8_min , 0 )
1953+ weight .masked_fill_ (weight . abs () < ctx .dl8_min , 0 )
19521954
19531955 # triton kernel assumes 2D inputs and cast the return to input.dtype
1954- output = tl_matmul (
1955- x ,
1956- weight .t ().to (org_dtype ),
1957- chunk_trun_bits = trun_bits ,
1958- chunk_size = chunk_size ,
1959- clamp_acc_to_dl16 = clamp_acc_to_dl16 ,
1960- ).reshape (target_shape_output )
1956+ output = (
1957+ (
1958+ tl_matmul (
1959+ x ,
1960+ weight .t (),
1961+ chunk_trun_bits = trun_bits ,
1962+ chunk_size = chunk_size ,
1963+ clamp_acc_to_dl16 = clamp_acc_to_dl16 ,
1964+ )
1965+ * x_scale
1966+ * w_scale .t ()
1967+ )
1968+ .to (org_dtype )
1969+ .reshape (target_shape_output )
1970+ )
19611971
19621972 if bias is not None :
19631973 output = output + bias .to (org_dtype )
@@ -1977,44 +1987,54 @@ def backward(ctx, grad_output):
19771987 target_shape_grad_input = grad_output .shape [:- 1 ] + (in_dim ,)
19781988 grad_output_2D = grad_output .reshape (- 1 , out_dim ).to (dtype_input )
19791989
1990+ x_scale = torch .tensor (1.0 , device = x .device , dtype = dtype_input )
1991+ w_scale = x_scale .clone ()
19801992 if ctx .fp8_dyn :
19811993 reduce_dim = None if ctx .fp8_dyn == "per_tensor" else 1
19821994 x_scale = x .abs ().amax (dim = reduce_dim ) / ctx .fp8_e5m2_max
19831995 w_scale = weight .abs ().amax (dim = reduce_dim ) / ctx .fp8_e5m2_max
19841996 # always assume perT in this case
19851997 grad_out_scale = grad_output_2D .abs ().amax (dim = None ) / ctx .fp8_e5m2_max
19861998
1987- x = (x / x_scale ).to (torch .float8_e5m2 ).to (dtype_input ) * x_scale
1988- weight = (weight / w_scale ).to (torch .float8_e5m2 ).to (weight . dtype ) * w_scale
1989- grad_output_2D = (grad_output_2D / grad_out_scale ). to ( torch . float8_e5m2 ). to (
1990- grad_output . dtype
1991- ) * grad_out_scale
1999+ x = (x / x_scale ).to (torch .float8_e5m2 ).to (torch . float )
2000+ weight = (weight / w_scale ).to (torch .float8_e5m2 ).to (torch . float )
2001+ grad_output_2D = (
2002+ ( grad_output_2D / grad_out_scale ). to ( torch . float8_e5m2 ). to ( torch . float )
2003+ )
19922004 if ctx .clamp_acc_to_dl16 :
19932005 # flush subnorm numbers to 0 as DL8 doesn't support it
1994- x .masked_fill_ (x < ctx .dl8_min , 0 )
1995- weight .masked_fill_ (weight < ctx .dl8_min , 0 )
1996- grad_output_2D .masked_fill_ (grad_output_2D < ctx .dl8_min , 0 )
2006+ x .masked_fill_ (x . abs () < ctx .dl8_min , 0 )
2007+ weight .masked_fill_ (weight . abs () < ctx .dl8_min , 0 )
2008+ grad_output_2D .masked_fill_ (grad_output_2D . abs () < ctx .dl8_min , 0 )
19972009
19982010 # Compute grad_weight, shape = [out, in]
19992011 # NOTE: this triton kernel requires A matrix to be contiguous
2000- grad_weight = tl_matmul (
2001- grad_output_2D .transpose (0 , 1 ).contiguous (),
2002- x ,
2003- chunk_trun_bits = trun_bits ,
2004- chunk_size = chunk_size ,
2005- clamp_acc_to_dl16 = ctx .clamp_acc_to_dl16 ,
2006- ).to (weight .dtype )
2007- # Compute grad_input in 2D then reshape to target shape, could be 3D or 2D
2008- grad_input = (
2012+ grad_weight = (
20092013 tl_matmul (
2010- grad_output_2D ,
2011- weight . to ( dtype_input ) ,
2014+ grad_output_2D . transpose ( 0 , 1 ). contiguous () ,
2015+ x ,
20122016 chunk_trun_bits = trun_bits ,
20132017 chunk_size = chunk_size ,
20142018 clamp_acc_to_dl16 = ctx .clamp_acc_to_dl16 ,
20152019 )
2016- .reshape (target_shape_grad_input )
2020+ * grad_out_scale .t ()
2021+ * x_scale
2022+ ).to (weight .dtype )
2023+ # Compute grad_input in 2D then reshape to target shape, could be 3D or 2D
2024+ grad_input = (
2025+ (
2026+ tl_matmul (
2027+ grad_output_2D ,
2028+ weight ,
2029+ chunk_trun_bits = trun_bits ,
2030+ chunk_size = chunk_size ,
2031+ clamp_acc_to_dl16 = ctx .clamp_acc_to_dl16 ,
2032+ )
2033+ * grad_out_scale
2034+ * w_scale
2035+ )
20172036 .to (dtype_input )
2037+ .reshape (target_shape_grad_input )
20182038 )
20192039
20202040 if not ctx .has_bias :
0 commit comments