Skip to content

Commit bb69e9a

Browse files
victoroliv2Victor Oliveira
authored andcommitted
ONNX: Fix FP8 quantization for the second MLP in LayernormMLP
1 parent 08dc786 commit bb69e9a

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

transformer_engine/pytorch/module/layernorm_mlp.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2243,14 +2243,23 @@ def onnx_forward(
22432243

22442244
assert not TEDebugState.debug_enabled, "Debug mode is not supported in ONNX export"
22452245
assert_warmed_up(self)
2246+
2247+
# Get quantizers
22462248
(
22472249
fc1_input_quantizer,
22482250
fc1_weight_quantizer,
2251+
fc1_output_quantizer,
2252+
_,
2253+
_,
2254+
_,
22492255
fc2_input_quantizer,
22502256
fc2_weight_quantizer,
2251-
output_quantizer,
2252-
*_,
2257+
fc2_output_quantizer,
2258+
_,
2259+
_,
2260+
_,
22532261
) = self._get_quantizers(False, is_grad_enabled)
2262+
22542263
inp_dtype = inp.dtype
22552264

22562265
fc1_weight, fc2_weight = self._get_weight_tensors()
@@ -2324,7 +2333,7 @@ def _clamped_swiglu(x, limit, alpha):
23242333

23252334
fc2_out = onnx_gemm(fc2_weight, act_out, fc2_bias)
23262335

2327-
if output_quantizer is not None:
2336+
if fc2_output_quantizer is not None:
23282337
raise NotImplementedError("ONNX export of quantized output is not supported")
23292338

23302339
if self.return_layernorm_output:

0 commit comments

Comments
 (0)