8787 from ..triton_kernels .layernorm import te_layernorm_bwd_triton
8888 from ..triton_kernels .rmsnorm import te_rmsnorm_bwd_triton
8989
90- from ..rocm_utils import create_fp8_weight_transpose_cache , clear_fp8_weight_transpose_cache
91-
9290__all__ = ["LayerNormMLP" ]
9391
9492
@@ -347,8 +345,8 @@ def forward(
347345 # which handles weight caching etc.
348346 # FP8 cast to workspace buffer
349347 update_workspace = is_first_microbatch is None or is_first_microbatch
350- fc1_weight_quantizer .set_usage (rowwise = True , columnwise = True )
351- fc2_weight_quantizer .set_usage (rowwise = True , columnwise = True )
348+ fc1_weight_quantizer .set_usage (rowwise = True , columnwise = keep_fp8_weight_transpose_cache )
349+ fc2_weight_quantizer .set_usage (rowwise = True , columnwise = keep_fp8_weight_transpose_cache )
352350 fc1_weight_final = module .get_weight_workspace (
353351 tensor = fc1_weight ,
354352 quantizer = fc1_weight_quantizer ,
@@ -357,7 +355,6 @@ def forward(
357355 skip_update_flag = skip_fp8_weight_update ,
358356 fsdp_group = fsdp_group ,
359357 workspace_dtype = activation_dtype ,
360- create_transpose_cache = keep_fp8_weight_transpose_cache ,
361358 )
362359 fc2_weight_final = module .get_weight_workspace (
363360 tensor = fc2_weight ,
@@ -367,7 +364,6 @@ def forward(
367364 skip_update_flag = skip_fp8_weight_update ,
368365 fsdp_group = fsdp_group ,
369366 workspace_dtype = activation_dtype ,
370- create_transpose_cache = keep_fp8_weight_transpose_cache ,
371367 )
372368 fc1_weight_final .update_usage (rowwise_usage = True )
373369 fc2_weight_final .update_usage (rowwise_usage = True )
@@ -412,6 +408,10 @@ def forward(
412408 gemm_gelu_fusion = False
413409 if debug :
414410 gemm_gelu_fusion = False
411+
412+ if IS_HIP_EXTENSION and fp8 and not keep_fp8_weight_transpose_cache :
413+ assert fc1_weight_final ._transpose is None or fc1_weight_final ._transpose .numel () == 0 , "Expected _transpose to be None or an empty tensor when transpose cache is disabled."
414+
415415 fc1_outputs = general_gemm (
416416 fc1_weight_final ,
417417 ln_out_total ,
@@ -482,6 +482,9 @@ def forward(
482482 # ------------------------------------------------------
483483 # FC2 GEMM
484484 # ------------------------------------------------------
485+ if IS_HIP_EXTENSION and fp8 and not keep_fp8_weight_transpose_cache :
486+ assert fc2_weight_final ._transpose is None or fc2_weight_final ._transpose .numel () == 0 , "Expected _transpose to be None or an empty tensor when transpose cache is disabled."
487+
485488 gemm_out , * _ , reduce_scatter_out = general_gemm (
486489 fc2_weight_final ,
487490 act_out ,
@@ -817,12 +820,9 @@ def backward(
817820 if isinstance (grad_output , QuantizedTensorBase ):
818821 grad_output .update_usage (rowwise_usage = True )
819822 if ctx .fc2_weight_quantizer is not None and isinstance (
820- ctx . fc2_weight , QuantizedTensorBase
823+ fc2_weight , QuantizedTensorBase
821824 ):
822- ctx .fc2_weight .update_usage (columnwise_usage = True )
823-
824- if ctx .fp8 and not ctx .keep_fp8_weight_transpose_cache :
825- create_fp8_weight_transpose_cache (fc2_weight )
825+ fc2_weight .update_usage (columnwise_usage = True )
826826
827827 # Perform GEMM
828828 gemm_output , * _ = general_gemm (
@@ -853,7 +853,7 @@ def backward(
853853 fc2_dgrad = gemm_output
854854
855855 if ctx .fp8 and not ctx .keep_fp8_weight_transpose_cache :
856- clear_fp8_weight_transpose_cache ( fc2_weight )
856+ fc2_weight . update_usage ( columnwise_usage = False )
857857
858858 # --------------------------------------------------
859859 # Finished FC2 DGRAD...
@@ -1041,18 +1041,16 @@ def fc2_wgrad_gemm(
10411041 ub_obj_fc1_wgrad = get_ub ("fc1_wgrad" )
10421042 ub_type_fc1_wgrad = tex .CommOverlapType .RS
10431043
1044- if ctx .fp8 and not ctx .keep_fp8_weight_transpose_cache :
1045- create_fp8_weight_transpose_cache (fc1_weight )
10461044
10471045 # --------------------------------------------------
10481046 # FC1 DGRAD
10491047 # --------------------------------------------------
10501048
10511049 # Make sure required data is available
10521050 if ctx .fc1_weight_quantizer is not None and isinstance (
1053- ctx . fc1_weight_quantizer , QuantizedTensorBase
1051+ fc1_weight , QuantizedTensorBase
10541052 ):
1055- ctx . fc1_weight .update_usage (columnwise_usage = True )
1053+ fc1_weight .update_usage (columnwise_usage = True )
10561054
10571055 # Output buffers for Userbuffers reduce-scatter
10581056 gemm_out = None
@@ -1082,7 +1080,7 @@ def fc2_wgrad_gemm(
10821080 )
10831081
10841082 if ctx .fp8 and not ctx .keep_fp8_weight_transpose_cache :
1085- clear_fp8_weight_transpose_cache ( fc1_weight )
1083+ fc1_weight . update_usage ( columnwise_usage = False )
10861084
10871085 # Prepare grad input tensor
10881086 # Note: Perform tensor-parallel communication
@@ -1552,7 +1550,7 @@ def __init__(
15521550 self .set_parallel_mode = set_parallel_mode
15531551 self .zero_centered_gamma = zero_centered_gamma
15541552 self .symmetric_ar_type = symmetric_ar_type
1555- self .keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache
1553+ self .keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache if IS_HIP_EXTENSION else True
15561554
15571555 # GEMM-GELU fusion is currently only supported with split GEMM-AG overlap
15581556 self .gemm_gelu_fusion = (
@@ -1918,6 +1916,8 @@ def _get_quantizers(self, fp8_output):
19181916 fc1_input_quantizer .internal = True
19191917 fc1_weight_quantizer = self .quantizers ["scaling_fwd" ][tex .FP8FwdTensors .GEMM1_WEIGHT ]
19201918 fc1_weight_quantizer .internal = True
1919+ if IS_HIP_EXTENSION :
1920+ fc1_weight_quantizer .set_usage (columnwise = self .keep_fp8_weight_transpose_cache )
19211921 fc2_input_quantizer = self .quantizers ["scaling_fwd" ][tex .FP8FwdTensors .GEMM2_INPUT ]
19221922 fc2_input_quantizer .set_usage (
19231923 rowwise = True ,
@@ -1926,6 +1926,8 @@ def _get_quantizers(self, fp8_output):
19261926 fc1_input_quantizer .internal = True
19271927 fc2_weight_quantizer = self .quantizers ["scaling_fwd" ][tex .FP8FwdTensors .GEMM2_WEIGHT ]
19281928 fc2_weight_quantizer .internal = True
1929+ if IS_HIP_EXTENSION :
1930+ fc2_weight_quantizer .set_usage (columnwise = self .keep_fp8_weight_transpose_cache )
19291931 if fp8_output :
19301932 fc2_output_quantizer = self .quantizers ["scaling_fwd" ][
19311933 tex .FP8FwdTensors .GEMM2_OUTPUT
0 commit comments