Skip to content

Commit a6cf0d8

Browse files
sudhu2ksudhu2k
andauthored
13679 TE2.4 keep_fp8_transpose_cache refactor (#328)
* Initial commit * Removed rocm_utils * Added comment and bug fixes * Grouped IS_HIP_EXTENSION with the property assignment * Reverted transpose.cpp, removed keep_fp8_transpose_cache flag from grouped_linear, removed manual clearing of tensors in modules * Aligning grouped_linear module with upstream * Reverted tests to use _test_granular_accuracy_with_fp8 multiple times as needed * Added comments back * Moved comment to the test --------- Co-authored-by: sudhu2k <[email protected]>
1 parent 9a2257b commit a6cf0d8

File tree

7 files changed

+51
-97
lines changed

7 files changed

+51
-97
lines changed

tests/pytorch/test_numerics.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,11 +1330,20 @@ def test_fp8_linear_without_transpose_cache_accuracy(dtype, bs, model, fp8_model
13301330
keep_fp8_weight_transpose_cache=True # defaults to True
13311331
).eval()
13321332

1333-
outputs = _test_granular_accuracy_with_fp8(layer, bs, dtype, config)
1334-
ref_outputs = _test_granular_accuracy_with_fp8(ref_layer, bs, dtype, config)
1333+
# The keep_fp8_transpose_cache flag will be evaluated over two iterations.
1334+
# Given that the transpose operation's cache is invalidated during the backward pass,
1335+
# the objective of this test is to observe the subsequent forward pass behavior.
1336+
num_iterations = 2
1337+
all_outputs = []
1338+
all_ref_outputs = []
1339+
for _ in range(num_iterations):
1340+
outputs = _test_granular_accuracy_with_fp8(layer, bs, dtype, config)
1341+
ref_outputs = _test_granular_accuracy_with_fp8(ref_layer, bs, dtype, config)
1342+
all_outputs.append(outputs)
1343+
all_ref_outputs.append(ref_outputs)
13351344

13361345
# Check output.
1337-
for te_output_no_cache, te_output_cache in zip(outputs, ref_outputs):
1346+
for te_output_no_cache, te_output_cache in zip(all_outputs, all_ref_outputs):
13381347
assert_allclose(te_output_no_cache, te_output_cache, atol=0, rtol=0)
13391348

13401349
@pytest.mark.parametrize("dtype", param_types)

transformer_engine/pytorch/module/base.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,7 +1150,7 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
11501150
quantizer is not None
11511151
) # to use primary fp8 weight one needs to use FP8 autocast with specific recipe.
11521152
quantizer.internal = False
1153-
if not self.keep_fp8_weight_transpose_cache:
1153+
if IS_HIP_EXTENSION and not self.keep_fp8_weight_transpose_cache:
11541154
quantizer.columnwise_usage=False
11551155
param = quantizer(param)
11561156

@@ -1201,7 +1201,6 @@ def get_weight_workspace(
12011201
skip_update_flag: Optional[torch.Tensor] = None,
12021202
fsdp_group: Optional[dist_group_type] = None,
12031203
workspace_dtype: Optional[torch.dtype] = None,
1204-
create_transpose_cache: bool = True,
12051204
) -> QuantizedTensor:
12061205
"""Get FP8 workspace buffer and maybe update its values
12071206
@@ -1224,8 +1223,6 @@ def get_weight_workspace(
12241223
over `update_workspace` if provided.
12251224
fsdp_group: bool, default = None
12261225
FSDP process group that the weights are distributed over.
1227-
create_transpose_cache: bool, default = True
1228-
Create transpose buffer from `tensor`.
12291226
workspace_dtype: torch.dtype, default = None
12301227
If weight workspace contains high-precision tensor - for example
12311228
for debug quantization, this is dtype of the tensor.
@@ -1269,19 +1266,6 @@ def get_weight_workspace(
12691266
):
12701267
_fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out)
12711268

1272-
if not is_non_tn_fp8_gemm_supported() and not create_transpose_cache:
1273-
current_quantizer = None
1274-
if out is None:
1275-
current_quantizer = quantizer
1276-
else:
1277-
if hasattr(out, "quantize_"):
1278-
current_quantizer = out._get_quantizer()
1279-
else:
1280-
current_quantizer = quantizer
1281-
1282-
# NOTE: Not create transpose buffer internally.
1283-
current_quantizer.columnwise_usage = False
1284-
12851269
# Construct workspace if needed
12861270
if out is None:
12871271
if tensor is None or quantizer is None:

transformer_engine/pytorch/module/grouped_linear.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,6 @@ def __init__(
501501
ub_overlap_ag: bool = False,
502502
ub_name: Optional[str] = None,
503503
delay_wgrad_compute: bool = False,
504-
keep_fp8_weight_transpose_cache: bool = True,
505504
) -> None:
506505
super().__init__()
507506

@@ -516,8 +515,6 @@ def __init__(
516515
self.ub_overlap_rs = ub_overlap_rs
517516
self.ub_overlap_ag = ub_overlap_ag
518517
self.ub_name = ub_name
519-
self.keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache
520-
521518
assert (
522519
not ub_overlap_rs and not ub_overlap_ag
523520
), "GroupedLinear doesn't support Userbuffer overlap."

transformer_engine/pytorch/module/layernorm_linear.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,6 @@
8181
from ..triton_kernels.layernorm import te_layernorm_bwd_triton
8282
from ..triton_kernels.rmsnorm import te_rmsnorm_bwd_triton
8383

84-
from ..rocm_utils import create_fp8_weight_transpose_cache, clear_fp8_weight_transpose_cache
85-
8684

8785
__all__ = ["LayerNormLinear"]
8886

@@ -291,7 +289,7 @@ def forward(
291289

292290
# Configure quantizer
293291
if weight_quantizer is not None:
294-
weight_quantizer.set_usage(rowwise=True, columnwise=True)
292+
weight_quantizer.set_usage(rowwise=True, columnwise=keep_fp8_weight_transpose_cache)
295293

296294
# Get quantized weight
297295
update_workspace = is_first_microbatch is None or is_first_microbatch
@@ -303,7 +301,6 @@ def forward(
303301
skip_update_flag=skip_fp8_weight_update,
304302
fsdp_group=fsdp_group,
305303
workspace_dtype=activation_dtype,
306-
create_transpose_cache=keep_fp8_weight_transpose_cache,
307304
)
308305
weightmat.update_usage(rowwise_usage=True)
309306

@@ -350,6 +347,8 @@ def forward(
350347
# Forward GEMM
351348
# Note: y = x * w^T
352349
# ------------------------------------------------------
350+
if IS_HIP_EXTENSION and fp8 and not keep_fp8_weight_transpose_cache:
351+
assert weightmat._transpose is None or weightmat._transpose.numel() == 0, "Expected _transpose to be None or an empty tensor when transpose cache is disabled."
353352
nvtx_range_push(f"{nvtx_label}.gemm")
354353
gemm_out, *_, reduce_scatter_out = general_gemm(
355354
weightmat,
@@ -701,8 +700,6 @@ def backward(
701700
if ctx.grad_input_quantizer is not None:
702701
ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)
703702

704-
if ctx.fp8 and not ctx.keep_fp8_weight_transpose_cache:
705-
create_fp8_weight_transpose_cache(weight)
706703

707704
# Output buffers for Userbuffers reduce-scatter
708705
gemm_out = None
@@ -735,7 +732,7 @@ def backward(
735732
nvtx_range_pop(f"{nvtx_label}.dgrad_gemm")
736733

737734
if ctx.fp8 and not ctx.keep_fp8_weight_transpose_cache:
738-
clear_fp8_weight_transpose_cache(weight)
735+
weight.update_usage(columnwise_usage=False)
739736

740737
# Prepare grad input tensor
741738
# Note: Perform tensor-parallel communication
@@ -1195,7 +1192,7 @@ def __init__(
11951192
self.name = name
11961193
if TEDebugState.debug_enabled:
11971194
self._turn_off_unsupported_features_in_debug() # turn off userbuffers
1198-
self.keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache
1195+
self.keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache if IS_HIP_EXTENSION else True
11991196

12001197
if tp_group is None:
12011198
self.tp_size = tp_size
@@ -1638,6 +1635,8 @@ def _get_quantizers(self, fp8_output, fp8_grad):
16381635
input_quantizer.internal = True
16391636
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
16401637
weight_quantizer.internal = True
1638+
if IS_HIP_EXTENSION:
1639+
weight_quantizer.set_usage(columnwise = self.keep_fp8_weight_transpose_cache)
16411640
if fp8_output:
16421641
output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
16431642
if torch.is_grad_enabled():

transformer_engine/pytorch/module/layernorm_mlp.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,6 @@
8787
from ..triton_kernels.layernorm import te_layernorm_bwd_triton
8888
from ..triton_kernels.rmsnorm import te_rmsnorm_bwd_triton
8989

90-
from ..rocm_utils import create_fp8_weight_transpose_cache, clear_fp8_weight_transpose_cache
91-
9290
__all__ = ["LayerNormMLP"]
9391

9492

@@ -347,8 +345,8 @@ def forward(
347345
# which handles weight caching etc.
348346
# FP8 cast to workspace buffer
349347
update_workspace = is_first_microbatch is None or is_first_microbatch
350-
fc1_weight_quantizer.set_usage(rowwise=True, columnwise=True)
351-
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True)
348+
fc1_weight_quantizer.set_usage(rowwise=True, columnwise=keep_fp8_weight_transpose_cache)
349+
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=keep_fp8_weight_transpose_cache)
352350
fc1_weight_final = module.get_weight_workspace(
353351
tensor=fc1_weight,
354352
quantizer=fc1_weight_quantizer,
@@ -357,7 +355,6 @@ def forward(
357355
skip_update_flag=skip_fp8_weight_update,
358356
fsdp_group=fsdp_group,
359357
workspace_dtype=activation_dtype,
360-
create_transpose_cache=keep_fp8_weight_transpose_cache,
361358
)
362359
fc2_weight_final = module.get_weight_workspace(
363360
tensor=fc2_weight,
@@ -367,7 +364,6 @@ def forward(
367364
skip_update_flag=skip_fp8_weight_update,
368365
fsdp_group=fsdp_group,
369366
workspace_dtype=activation_dtype,
370-
create_transpose_cache=keep_fp8_weight_transpose_cache,
371367
)
372368
fc1_weight_final.update_usage(rowwise_usage=True)
373369
fc2_weight_final.update_usage(rowwise_usage=True)
@@ -412,6 +408,10 @@ def forward(
412408
gemm_gelu_fusion = False
413409
if debug:
414410
gemm_gelu_fusion = False
411+
412+
if IS_HIP_EXTENSION and fp8 and not keep_fp8_weight_transpose_cache:
413+
assert fc1_weight_final._transpose is None or fc1_weight_final._transpose.numel() == 0, "Expected _transpose to be None or an empty tensor when transpose cache is disabled."
414+
415415
fc1_outputs = general_gemm(
416416
fc1_weight_final,
417417
ln_out_total,
@@ -482,6 +482,9 @@ def forward(
482482
# ------------------------------------------------------
483483
# FC2 GEMM
484484
# ------------------------------------------------------
485+
if IS_HIP_EXTENSION and fp8 and not keep_fp8_weight_transpose_cache:
486+
assert fc2_weight_final._transpose is None or fc2_weight_final._transpose.numel() == 0, "Expected _transpose to be None or an empty tensor when transpose cache is disabled."
487+
485488
gemm_out, *_, reduce_scatter_out = general_gemm(
486489
fc2_weight_final,
487490
act_out,
@@ -817,12 +820,9 @@ def backward(
817820
if isinstance(grad_output, QuantizedTensorBase):
818821
grad_output.update_usage(rowwise_usage=True)
819822
if ctx.fc2_weight_quantizer is not None and isinstance(
820-
ctx.fc2_weight, QuantizedTensorBase
823+
fc2_weight, QuantizedTensorBase
821824
):
822-
ctx.fc2_weight.update_usage(columnwise_usage=True)
823-
824-
if ctx.fp8 and not ctx.keep_fp8_weight_transpose_cache:
825-
create_fp8_weight_transpose_cache(fc2_weight)
825+
fc2_weight.update_usage(columnwise_usage=True)
826826

827827
# Perform GEMM
828828
gemm_output, *_ = general_gemm(
@@ -853,7 +853,7 @@ def backward(
853853
fc2_dgrad = gemm_output
854854

855855
if ctx.fp8 and not ctx.keep_fp8_weight_transpose_cache:
856-
clear_fp8_weight_transpose_cache(fc2_weight)
856+
fc2_weight.update_usage(columnwise_usage=False)
857857

858858
# --------------------------------------------------
859859
# Finished FC2 DGRAD...
@@ -1041,18 +1041,16 @@ def fc2_wgrad_gemm(
10411041
ub_obj_fc1_wgrad = get_ub("fc1_wgrad")
10421042
ub_type_fc1_wgrad = tex.CommOverlapType.RS
10431043

1044-
if ctx.fp8 and not ctx.keep_fp8_weight_transpose_cache:
1045-
create_fp8_weight_transpose_cache(fc1_weight)
10461044

10471045
# --------------------------------------------------
10481046
# FC1 DGRAD
10491047
# --------------------------------------------------
10501048

10511049
# Make sure required data is available
10521050
if ctx.fc1_weight_quantizer is not None and isinstance(
1053-
ctx.fc1_weight_quantizer, QuantizedTensorBase
1051+
fc1_weight, QuantizedTensorBase
10541052
):
1055-
ctx.fc1_weight.update_usage(columnwise_usage=True)
1053+
fc1_weight.update_usage(columnwise_usage=True)
10561054

10571055
# Output buffers for Userbuffers reduce-scatter
10581056
gemm_out = None
@@ -1082,7 +1080,7 @@ def fc2_wgrad_gemm(
10821080
)
10831081

10841082
if ctx.fp8 and not ctx.keep_fp8_weight_transpose_cache:
1085-
clear_fp8_weight_transpose_cache(fc1_weight)
1083+
fc1_weight.update_usage(columnwise_usage=False)
10861084

10871085
# Prepare grad input tensor
10881086
# Note: Perform tensor-parallel communication
@@ -1552,7 +1550,7 @@ def __init__(
15521550
self.set_parallel_mode = set_parallel_mode
15531551
self.zero_centered_gamma = zero_centered_gamma
15541552
self.symmetric_ar_type = symmetric_ar_type
1555-
self.keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache
1553+
self.keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache if IS_HIP_EXTENSION else True
15561554

15571555
# GEMM-GELU fusion is currently only supported with split GEMM-AG overlap
15581556
self.gemm_gelu_fusion = (
@@ -1918,6 +1916,8 @@ def _get_quantizers(self, fp8_output):
19181916
fc1_input_quantizer.internal = True
19191917
fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
19201918
fc1_weight_quantizer.internal = True
1919+
if IS_HIP_EXTENSION:
1920+
fc1_weight_quantizer.set_usage(columnwise = self.keep_fp8_weight_transpose_cache)
19211921
fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
19221922
fc2_input_quantizer.set_usage(
19231923
rowwise=True,
@@ -1926,6 +1926,8 @@ def _get_quantizers(self, fp8_output):
19261926
fc1_input_quantizer.internal = True
19271927
fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT]
19281928
fc2_weight_quantizer.internal = True
1929+
if IS_HIP_EXTENSION:
1930+
fc2_weight_quantizer.set_usage(columnwise = self.keep_fp8_weight_transpose_cache)
19291931
if fp8_output:
19301932
fc2_output_quantizer = self.quantizers["scaling_fwd"][
19311933
tex.FP8FwdTensors.GEMM2_OUTPUT

transformer_engine/pytorch/module/linear.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,11 @@
6868
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
6969
from ..tensor.mxfp8_tensor import MXFP8Quantizer
7070
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
71-
from ..rocm_utils import create_fp8_weight_transpose_cache, clear_fp8_weight_transpose_cache
7271
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
7372
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
7473
from ...debug.pytorch.debug_state import TEDebugState
7574
from ...debug.pytorch.utils import any_feature_enabled
75+
from torch.utils.cpp_extension import IS_HIP_EXTENSION
7676

7777
__all__ = ["Linear"]
7878

@@ -228,8 +228,8 @@ def forward(
228228
if fp8 or debug:
229229
# Configure quantizer
230230
if weight_quantizer is not None:
231-
columnwise_usage = is_grad_enabled and inp.requires_grad
232-
if not columnwise_usage:
231+
columnwise_usage = is_grad_enabled and inp.requires_grad and keep_fp8_weight_transpose_cache
232+
if not columnwise_usage and keep_fp8_weight_transpose_cache:
233233
columnwise_usage = (
234234
is_fp8_activation_recompute_enabled()
235235
and not in_fp8_activation_recompute_phase()
@@ -246,7 +246,6 @@ def forward(
246246
skip_update_flag=skip_fp8_weight_update,
247247
fsdp_group=fsdp_group,
248248
workspace_dtype=activation_dtype,
249-
create_transpose_cache=keep_fp8_weight_transpose_cache,
250249
)
251250
weightmat.update_usage(rowwise_usage=True)
252251

@@ -293,6 +292,9 @@ def forward(
293292
# Forward GEMM
294293
# Note: y = x * w^T
295294
# ------------------------------------------------------
295+
if IS_HIP_EXTENSION and fp8 and not keep_fp8_weight_transpose_cache:
296+
assert weightmat._transpose is None or weightmat._transpose.numel() == 0, "Expected _transpose to be None or an empty tensor when transpose cache is disabled."
297+
296298
nvtx_range_push(f"{nvtx_label}.gemm")
297299
gemm_out, *_, reduce_scatter_out = general_gemm(
298300
weightmat,
@@ -618,9 +620,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
618620
if ctx.grad_input_quantizer is not None:
619621
ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)
620622

621-
if ctx.fp8 and not ctx.keep_fp8_weight_transpose_cache:
622-
create_fp8_weight_transpose_cache(weight_fp8)
623-
624623
# Output buffers for Userbuffers reduce-scatter
625624
gemm_out = None
626625
reduce_scatter_out = None
@@ -652,7 +651,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
652651
nvtx_range_pop(f"{nvtx_label}.dgrad_gemm")
653652

654653
if ctx.fp8 and not ctx.keep_fp8_weight_transpose_cache:
655-
clear_fp8_weight_transpose_cache(weight_fp8)
654+
weight_fp8.update_usage(columnwise_usage=False)
656655

657656
# Prepare grad input tensor
658657
# Note: Perform tensor-parallel communication
@@ -1044,7 +1043,7 @@ def __init__(
10441043
self._turn_off_unsupported_features_in_debug() # turn off userbuffers
10451044

10461045
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
1047-
self.keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache
1046+
self.keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache if IS_HIP_EXTENSION else True
10481047

10491048
if device == "meta":
10501049
assert parameters_split is None, "Cannot split module parameters on 'meta' device."
@@ -1431,6 +1430,9 @@ def _get_quantizers(self, fp8_output, fp8_grad):
14311430
input_quantizer.internal = True
14321431
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
14331432
weight_quantizer.internal = True
1433+
if IS_HIP_EXTENSION:
1434+
weight_quantizer.set_usage(columnwise = self.keep_fp8_weight_transpose_cache)
1435+
14341436
if fp8_output:
14351437
output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
14361438
if torch.is_grad_enabled():

0 commit comments

Comments
 (0)