diff --git a/.gitignore b/.gitignore index 4ece24d518..68bba92666 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ *.nsys-rep *.ncu-rep *.sqlite +*.onnx .eggs build/ *.so diff --git a/qa/L0_unittest/test.sh b/qa/L0_unittest/test.sh index 121555a3d2..7e2020d4c4 100644 --- a/qa/L0_unittest/test.sh +++ b/qa/L0_unittest/test.sh @@ -6,5 +6,6 @@ set -e : ${TE_PATH:=/opt/transformerengine} -pip install pytest==6.2.5 +pip install pytest==6.2.5 onnxruntime pytest -v -s $TE_PATH/tests/test_transformerengine.py +pytest -v -s $TE_PATH/tests/test_onnx_export.py diff --git a/setup.py b/setup.py index 2145025789..5f824c63e2 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ from setuptools.command.build_ext import build_ext from distutils.version import LooseVersion from distutils.file_util import copy_file -from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME +from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME path = os.path.dirname(os.path.realpath(__file__)) @@ -85,6 +85,7 @@ def make_abs_path(l): pytorch_sources = [ "transformer_engine/pytorch/csrc/extensions.cu", "transformer_engine/pytorch/csrc/common.cu", + "transformer_engine/pytorch/csrc/ts_fp8_op.cpp", ] pytorch_sources = make_abs_path(pytorch_sources) diff --git a/tests/libcustom_ort_fp8_qdq_ops.so b/tests/libcustom_ort_fp8_qdq_ops.so new file mode 100755 index 0000000000..ba6ad6d85e Binary files /dev/null and b/tests/libcustom_ort_fp8_qdq_ops.so differ diff --git a/tests/test_onnx_export.py b/tests/test_onnx_export.py new file mode 100644 index 0000000000..789eccf09f --- /dev/null +++ b/tests/test_onnx_export.py @@ -0,0 +1,1003 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +This file contains tests for exporting TransformerEngine models to ONNX. +""" + +import os +import pytest +import warnings +import numpy as np +import math +import onnxruntime as ort +import torch +from torch import nn as nn +from typing import Union, Tuple, List +import transformer_engine.pytorch as te +from transformer_engine.common import recipe +import transformer_engine_extensions as tex +from transformer_engine.pytorch.cpp_extensions import * +from transformer_engine.pytorch.module import get_workspace +import transformer_engine.pytorch.cpp_extensions as texcpp +import transformer_engine.pytorch.softmax as softmax_defs +from transformer_engine.pytorch.utils import get_default_init_method + + +# Directory where generated ONNX test models are stored. +TESTS_DIR = os.path.dirname(os.path.abspath(__file__)) +ONNX_FILES_DIR = os.path.join(TESTS_DIR, "./gen_onnx_models") + +# Shared library implementing custom FP8 Q/DQ operators for ONNX Runtime (ORT). +ORT_CUSTOM_OPS_LIB = "./tests/libcustom_ort_fp8_qdq_ops.so" + +# ScaledUpperTriangMaskedSoftmax is exported via ONNX::Trilu which was introduced in opset 14. +TRILU_OPSET = 14 +# Opset used in the ONNX files generated by the tests. +OPSET = 15 +assert OPSET >= TRILU_OPSET + + +def create_fp8_recipe(): + return recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3) + + +def do_export( + model: torch.nn.Module, + inp: torch.Tensor, + fname: str, + use_fp8: bool=True, + opset: int=OPSET, + input_names: list=["input"], + output_names: list=["output"], +): + """Export to ONNX""" + fp8_recipe = create_fp8_recipe() + + with torch.inference_mode(), te.fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), warnings.catch_warnings(): + warnings.filterwarnings( + action='ignore', + category=torch.jit.TracerWarning, + module=r'.*' + ) + + model.cuda().eval() + os.makedirs(ONNX_FILES_DIR, exist_ok=True) + fname = os.path.join(ONNX_FILES_DIR, fname) + torch.onnx.export(model, + inp if isinstance(inp, list) or isinstance(inp, tuple) else (inp,), + fname, + verbose=False, + opset_version=opset, + input_names=input_names, + output_names=output_names, + do_constant_folding=True, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH) + + +def to_numpy(tensor): + return tensor.cpu().numpy() + + +def set_layer_scale(module: torch.nn.Module, scales: List[float]): + module.fp8_init() + num_fp8_tensors = len(scales) + scale = torch.ones(num_fp8_tensors, dtype=torch.float32, device="cuda") + scale_inv = torch.ones(num_fp8_tensors, dtype=torch.float32, device="cuda") + amax_history_len = module.fp8_meta["recipe"].amax_history_len + amax_history = torch.zeros(amax_history_len, num_fp8_tensors, dtype=torch.float32, device="cuda") + for i, s in enumerate(scales): + scale[i] *= s + scale_inv[i] /= s + module.fp8_meta["scaling_fwd"].scale = scale + module.fp8_meta["scaling_fwd"].scale_inv = scale_inv + module.fp8_meta["scaling_fwd"].amax_history = amax_history + + +def te_infer(model: torch.nn.Module, inps: Union[Tuple[torch.tensor], torch.tensor], is_fp8: bool): + """Transformer Engine forward prpoagtation. + + Return results after copying to the CPU and converting to numpy. + """ + fp8_recipe = create_fp8_recipe() + with torch.inference_mode(), te.fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe), warnings.catch_warnings(): + te_outputs = model(*inps if isinstance(inps, tuple) else (inps,)) + if not isinstance(te_outputs, tuple): + te_outputs = (te_outputs,) + te_outputs_np = [to_numpy(te_output) for te_output in te_outputs] + return te_outputs_np + + +def validate_result( + fname: str, + inps: Union[Tuple[torch.Tensor], torch.Tensor], + model: torch.nn.Module, + atol: float=1.e-8, # np.isclose default atol + rtol: float=1.e-5, # np.isclose default rtol + max_errors_printed: int=10, + is_fp8: bool=False, +): + """Validate the outputs of an ONNX model vs. ONNX Runtime.""" + + def create_ort_session(fname: str, is_fp8: bool): + def load_custom_ops(session_opts: ort.SessionOptions): + """For FP8 validation with ORT we need to load our custom FP8 Q/DQ extension.""" + if not os.path.exists(ORT_CUSTOM_OPS_LIB): + raise FileNotFoundError(f"Unable to find {ORT_CUSTOM_OPS_LIB}") + session_opts.register_custom_ops_library(ORT_CUSTOM_OPS_LIB) + print("registered custom FP8 Q/DQ ops!") + + """Create an ONNX Runtime session for validation.""" + if is_fp8: + sess_options = ort.SessionOptions() + load_custom_ops(sess_options) + # Model loading successfully indicates that the custom op node could be resolved successfully + s = ort.InferenceSession(fname, sess_options=sess_options) + else: + s = ort.InferenceSession(fname) + return s + + def create_ort_input_dict(session, inps): + inp_dict = {} + if isinstance(inps, tuple) or isinstance(inps, list): + nonetype_inputs = 0 + for idx, inp in enumerate(inps): + if inp is None: + nonetype_inputs += 1 + continue + inp_dict[session.get_inputs()[idx - nonetype_inputs].name] = to_numpy(inp) + else: + inp_dict[session.get_inputs()[0].name] = to_numpy(inps) + return inp_dict + + # Run ORT session and TE model. + fname = os.path.join(ONNX_FILES_DIR, fname) + ort_s = create_ort_session(fname, is_fp8) + onnx_outputs = ort_s.run(None, input_feed=create_ort_input_dict(ort_s, inps)) + te_outputs = te_infer(model, inps, is_fp8) + + # Compare ORT and TE outputs. + assert len(onnx_outputs) == len(te_outputs) + for onnx_output, te_output in zip(onnx_outputs, te_outputs): + + # Compare ORT and PyTorch outputs. + # np.isclose: abs(a - b) <= (atol + rtol * abs(b)) + ac = ~np.isclose(onnx_output, te_output, atol=atol, rtol=rtol) + + mismatches = ac.nonzero() + mismatched_ids = [loc for loc in zip(*mismatches)] + if mismatched_ids: + # Log some information in case of error. + print("*" * 100) + print(onnx_output.shape) + nb_vals = min(len(mismatched_ids), max_errors_printed) + print(f"Detected {len(mismatched_ids)} diverging values.\nShowing first {nb_vals} errors (ONNX -- TE):") + abs_err = abs(onnx_output - te_output) + for loc in mismatched_ids[:nb_vals]: + ref = te_output[loc] + print(f"{onnx_output[loc]} -- {te_output[loc]} err={abs_err[loc]} > {atol + rtol * abs(ref)}") + raise ValueError(f"Output validation of {fname} failed with {len(mismatched_ids)} errors") + + +def create_meta(scale_factor: float, size: int=1): + meta = tex.FP8TensorMeta() + meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda") + meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor + meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor + return meta + + +def dtype2str(dtype: torch.dtype): + return { + torch.float32: "_fp32", + torch.float16: "_fp16", + torch.bfloat16: "_bf16", + }[dtype] + + +def as_te_type(dtype: torch.dtype): + return { + torch.float32: tex.DType.kFloat32, + torch.float16: tex.DType.kFloat16, + torch.bfloat16: tex.DType.kBFloat16, + }[dtype] + + +def get_attn_mask_str(use_mask, attn_mask_type): + # See FusedScaleMaskSoftmax::forward_fused_softmax for logic behind names. + if attn_mask_type is None: + return "_mask" if use_mask else "_no-mask" + attn_mask_str = "_padding-no-mask" + attn_mask_str = "_causal-mask" if attn_mask_type == "causal" else attn_mask_str + attn_mask_str = "_padding-mask" if use_mask and attn_mask_type == "padding" else attn_mask_str + return attn_mask_str + + +@pytest.mark.parametrize("scale_factor, atol", [ + (1, 1e-7), + (224, 1e-7) +]) +@pytest.mark.parametrize("precision", [torch.float32, torch.float16]) +def test_export_cast_ops(scale_factor: float, atol: float, precision: torch.dtype): + class TestFP8_QDQ(nn.Module): + def __init__(self): + super().__init__() + self.fp8_tensor = 0 + self.meta = create_meta(scale_factor) + self.highprec_type = as_te_type(precision) + self.fp8_type = tex.DType.kFloat8E4M3 + + def forward(self, inp): + ret = cast_to_fp8( + inp, + self.meta, + self.fp8_tensor, + self.fp8_type) + + ret = cast_from_fp8( + ret, + self.meta, + self.fp8_tensor, + self.fp8_type, + self.highprec_type) + return ret + + # Set dimensions (these are arbitrary). + in_features = 64 + hidden_size = 256 + inp = torch.randn(hidden_size, in_features, device="cuda", dtype=precision) + high_prec_str = dtype2str(precision) + fname = f"te.cast_fp8_{scale_factor}{high_prec_str}.onnx" + model = TestFP8_QDQ() + do_export(model, inp, fname) + validate_result(fname, inp, model, atol=atol, is_fp8=True) + + +@pytest.mark.parametrize("scale_factor", [448]) +@pytest.mark.parametrize( + "precision, atol", [ + [torch.float32, 1e-7], + [torch.float16, 2e-3] +]) +def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: float): + class TestFP8_Gelu(nn.Module): + def __init__(self): + super().__init__() + self.fp8_tensor = 0 + self.meta = create_meta(scale_factor) + self.highprec_type = as_te_type(precision) + self.fp8_type = tex.DType.kFloat8E4M3 + + def forward(self, inp): + ret = fp8_gelu( + inp, + self.meta, + self.fp8_tensor, + self.fp8_type) + ret = cast_from_fp8( + ret, + self.meta, + self.fp8_tensor, + self.fp8_type, + self.highprec_type) + return ret + + # Set dimensions (these are arbitrary). + in_features = 64 + hidden_size = 256 + inp = torch.randn(hidden_size, in_features, device="cuda", dtype=precision) + high_prec_str = dtype2str(precision) + fname = f"te.gelu_fp8_{scale_factor}{high_prec_str}.onnx" + model = TestFP8_Gelu() + do_export(model, inp, fname) + validate_result(fname, inp, model, rtol=1e-1, atol=atol, is_fp8=True) + + +@pytest.mark.parametrize("scale_factors", + [(224, 224,), +]) +@pytest.mark.parametrize( + "precision, use_fp8, use_bias, use_gelu", [ + (torch.float32, False, False, False), + (torch.float16, False, False, False), + (torch.float32, False, True, False), + (torch.float16, False, True, False), + (torch.float32, False, True, True), + (torch.float16, False, True, True), + + # For FP8 GEMM GeLU is not used. + (torch.float32, True, False, False), + (torch.float16, True, False, False), + # When enabling bias we must use float16 or bfloat16 (because of kernel limitations) + (torch.float16, True, True, False), + (torch.bfloat16, True, True, False), +]) +def test_export_gemm( + precision, # Precision of inputs, weights, output and bias + use_fp8, + use_bias, + use_gelu, + scale_factors +): + class TestFP8_GEMM(nn.Module): + def __init__(self, precision, use_bias, gelu, scale_factors): + super().__init__() + self.use_bias = use_bias + self.gelu = gelu + self.precision = precision + + self.fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT + self.fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT + nb_inp_scales, nb_weight_scales = 1, out_features + act_scale_factor, weight_scale_factor = scale_factors + self.meta_inp = create_meta(act_scale_factor, nb_inp_scales) + self.meta_weight = create_meta(weight_scale_factor, nb_weight_scales) + + bias_size = nb_weight_scales + self.bias = torch.randn(bias_size, dtype=precision, device="cuda") + self.gelu_input = torch.randn(hidden_size, out_features, dtype=precision, device="cuda") + + self.inp_type = tex.DType.kFloat8E4M3 + self.weights_type = tex.DType.kFloat8E4M3 + self.outp_type = precision + + def forward(self, inp, weight): + inp_fp8 = cast_to_fp8( + inp, + self.meta_inp, + self.fp8_tensor_inp, + self.inp_type) + + weight_fp8 = cast_to_fp8( + weight, + self.meta_weight, + self.fp8_tensor_weight, + self.weights_type) + + ret = fp8_gemm( + weight_fp8, + self.meta_weight.scale_inv, + self.fp8_tensor_weight, + self.inp_type, + inp_fp8, + self.meta_inp.scale_inv, + self.fp8_tensor_inp, + self.weights_type, + self.outp_type, + get_workspace(), + bias=self.bias, + use_bias=self.use_bias, + fp32_output=(self.precision==torch.float32), + use_split_accumulator=False) + return ret + + class Test_GEMM(nn.Module): + def __init__(self, precision, use_bias=False, gelu=False): + super().__init__() + self.use_bias = use_bias + self.gelu = gelu + self.precision = precision + bias_size = out_features + self.bias = torch.randn(bias_size, dtype=precision, device="cuda") + self.gelu_input = torch.randn(hidden_size, out_features, dtype=precision, device="cuda") + + def forward(self, inp, weight): + outp_type = self.precision + + # note: due to logic in lines 104:116 and L129 in cpp_extensions.py + # it appears either bias OR gelu can be activated, not both + ret, _, _ = gemm( + weight, + inp, + outp_type, + get_workspace(), + + # test bias + bias=self.bias, + use_bias=self.use_bias, + + # test gelu + gelu=self.gelu, + gelu_input=self.gelu_input, + grad=False # only True for backward pass + ) + return ret + + # If gelu is applied then bias must be added, as defined by TE kernel. + if use_gelu: assert use_bias + # Set dimensions (these are arbitrary). + out_features = 128 + hidden_size = 256 + in_features = 64 + inp = torch.randn(hidden_size, in_features, dtype=precision, device="cuda") + weight = torch.randn(out_features, in_features, dtype=precision, device="cuda") + fp8_str = "_fp8" if use_fp8 else "" + bias_str = "_bias" if use_bias else "" + gelu_str = "_gelu" if use_gelu else "" + high_prec_str = dtype2str(precision) + fname = f"te.gemm{fp8_str}{bias_str}{gelu_str}{high_prec_str}.onnx" + if use_fp8: + model = TestFP8_GEMM(precision, use_bias, use_gelu, scale_factors) + do_export(model, (inp, weight), fname, use_fp8) + if precision not in (torch.bfloat16, torch.float16): + validate_result(fname, (inp, weight), model, rtol=1e-2, atol=1e-2, is_fp8=True) + else: + model = Test_GEMM(precision, use_bias, use_gelu) + do_export(model, (inp, weight), fname, use_fp8) + validate_result(fname, (inp, weight), model, rtol=1e-2, atol=2e-2) + + +@pytest.mark.parametrize("use_fp8", [False, True]) +@pytest.mark.parametrize("scale_factor", [448, 112]) +@pytest.mark.parametrize("precision", [torch.float32, torch.float16]) +def test_export_layernorm( + use_fp8: bool, + scale_factor: float, + precision: torch.dtype +): + # Set dimensions (these are arbitrary). + inp_shape = [64, 32] + + class Test_Layernorm(nn.Module): + def __init__(self) -> None: + super().__init__() + normalized_shape = torch.Size(inp.shape[1:]) + self.weight = torch.randn(*normalized_shape, dtype=precision, device="cuda") + self.bias = torch.zeros(*normalized_shape, dtype=precision, device="cuda") + self.eps = 1e-6 # An arbitrary small value + + def forward(self, inp): + ret = texcpp.layernorm_fwd_inf( + inp, + self.weight, + self.bias, + self.eps) + return ret + + class TestFP8_Layernorm(nn.Module): + def __init__(self) -> None: + super().__init__() + normalized_shape = torch.Size(inp.shape[1:]) + self.weight = torch.randn(*normalized_shape, dtype=precision, device="cuda") + self.bias = torch.zeros(*normalized_shape, dtype=precision, device="cuda") + self.eps = 1e-6 # An arbitrary small value + + self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT + self.meta = create_meta(scale_factor) + self.fp8_type = tex.DType.kFloat8E4M3 + + def forward(self, inp): + ret = texcpp.layernorm_fwd_fp8_inf( + inp, + self.weight, + self.bias, + self.eps, + self.meta, + self.fp8_tensor, + self.fp8_type) + + ret = cast_from_fp8( + ret, + self.meta, + self.fp8_tensor, + self.fp8_type, + tex.DType.kFloat32 if precision == torch.float32 else tex.DType.kFloat16) + return ret + + inp = torch.randn(*inp_shape, device="cuda", dtype=precision) + model = TestFP8_Layernorm() if use_fp8 else Test_Layernorm() + high_prec_str = dtype2str(precision) + fp8_str = f"_fp8-{scale_factor}" if use_fp8 else "" + fname = f"te.layernorm{fp8_str}{high_prec_str}.onnx" + do_export(model, inp, fname) + if precision not in (torch.bfloat16, ): + # TODO: FP32 has a small threshold (1e-5) + validate_result(fname, inp, model, atol=1e-3, is_fp8=use_fp8) + + +@pytest.mark.parametrize("softmax_def", [ + softmax_defs.ScaledUpperTriangMaskedSoftmax, + softmax_defs.ScaledMaskedSoftmax, + softmax_defs.ScaledSoftmax, +]) +# Softmax kernel only supports FP16 or BF16! +@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16]) +def test_export_softmax(softmax_def, precision): + class Test_Softmax(nn.Module): + def __init__(self, softmax_function, mask_inp=False): + super().__init__() + self.softmax_fn = softmax_function + self.mask_inp = mask_inp + + def forward(self, inp, mask): + scale_factor = 8 # arbitrary value + if self.mask_inp: + ret = self.softmax_fn.apply(inp, mask, scale_factor) + else: + ret = self.softmax_fn.apply(inp, scale_factor) + return ret + + # Set dimensions (these are arbitrary). + in_features = 64 + hidden_size = 256 + mask = None + input_names = ["input"] + inp_shape = [hidden_size, in_features, in_features, in_features] + if softmax_def == softmax_defs.ScaledUpperTriangMaskedSoftmax: + inp_shape = [hidden_size, in_features, in_features] + kernel_str = "ScaledUpperTriangMaskedSoftmax" + model = Test_Softmax(softmax_def) + elif softmax_def == softmax_defs.ScaledMaskedSoftmax: + # Generate a random mask with 50% probability for 0 or 1. + probs = 0.5 * torch.ones(hidden_size, 1, in_features, in_features, device="cuda", dtype=precision) + mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) + input_names.append("mask") + kernel_str = "ScaledMaskedSoftmax" + model = Test_Softmax(softmax_def, mask_inp=True) + elif softmax_def == softmax_defs.ScaledSoftmax: + kernel_str = "ScaledSoftmax" + model = Test_Softmax(softmax_def) + input_tensor = torch.randn(*inp_shape, device="cuda") + input_tensor = input_tensor.to(torch.bfloat16) if precision == torch.bfloat16 else input_tensor.half() + high_prec_str = dtype2str(precision) + fname = f"{kernel_str}{high_prec_str}.onnx" + inp = (input_tensor, mask) + do_export(model, inp, fname, input_names=input_names) + if precision != torch.bfloat16: + validate_result(fname, inp, model, atol=1e-3) + + +@pytest.mark.parametrize("scale_factors", [[448, 448]]) +@pytest.mark.parametrize("use_fp8", [False, True]) +# Returning the bias is a TE fusion optimization we don't care about. +@pytest.mark.parametrize("return_bias", [False]) +@pytest.mark.parametrize( + "precision, use_bias",[ + (torch.float32, False), + (torch.float32, True), + (torch.float16, False), + (torch.float16, True), + # Todo: cannot configure BF16 when bias is disabled (ORT issue?) + (torch.bfloat16, False), + # Todo: cannot configure BF16 when bias is enabled (ORT issue?) + # (torch.bfloat16, True), +]) +def test_export_linear( + scale_factors: List[float], + use_fp8: bool, + use_bias: bool, + return_bias: bool, + precision: torch.dtype +): + # Set dimensions (these are arbitrary). + in_features = 64 + out_features = 256 + hidden_size = 256 + + class Test_Linear(nn.Module): + def __init__(self, + in_features, + out_features, + use_bias, + return_bias, + precision + ): + super().__init__() + self.linear = te.Linear( + in_features, + out_features, + bias=use_bias, + return_bias=return_bias, + params_dtype=precision + ) + + def forward(self, inp): + ret = self.linear(inp) + return ret + + + inp = torch.randn(hidden_size, in_features, device="cuda", dtype=precision) + fp8_str = "_fp8" if use_fp8 else "" + bias_str = "_bias" if use_bias else "" + high_prec_str = dtype2str(precision) + fname = f"te.linear{fp8_str}{bias_str}{high_prec_str}.onnx" + with te.fp8_autocast(enabled=use_fp8, fp8_recipe=create_fp8_recipe()): + model = Test_Linear( + in_features, + out_features, + use_bias, + return_bias, + precision + ).to(device='cuda') + if use_fp8: + set_layer_scale(model.linear, scale_factors) + do_export(model, inp, fname, use_fp8) + + if precision in (torch.bfloat16, ): + return + if not use_fp8: + validate_result(fname, inp, model, atol=5e-4) + else: + validate_result(fname, inp, model, atol=5e-4, is_fp8=use_fp8) + + +@pytest.mark.parametrize("scale_factors", [[448, 448]]) +@pytest.mark.parametrize("use_fp8", [False, True]) +# Returning the bias is a TE fusion optimization we don't care about. +@pytest.mark.parametrize("return_bias", [False]) +@pytest.mark.parametrize("return_layernorm_output", [False]) +@pytest.mark.parametrize( + "precision, use_bias",[ + (torch.float32, False), + (torch.float32, True), + (torch.float16, True), + (torch.float16, False), +]) +def test_export_layernorm_linear( + scale_factors: List[float], + use_fp8: bool, + use_bias: bool, + return_bias: bool, + return_layernorm_output: bool, + precision: torch.dtype +): + # Set dimensions (these are arbitrary). + in_features = 64 + out_features = 256 + hidden_size = 256 + + inp = torch.randn(in_features, out_features, device="cuda", dtype=precision) + fp8_str = "_fp8" if use_fp8 else "" + bias_str = "_bias" if use_bias else "" + high_prec_str = dtype2str(precision) + fname = f"te.layernorm_linear{fp8_str}{bias_str}{high_prec_str}.onnx" + with te.fp8_autocast(enabled=use_fp8, fp8_recipe=create_fp8_recipe()): + model = te.LayerNormLinear( + hidden_size, + 3 * hidden_size, + bias=use_bias, + return_bias=return_bias, + return_layernorm_output=return_layernorm_output, + params_dtype=precision, + ).to(device='cuda') + if use_fp8: + set_layer_scale(model, scale_factors) + do_export(model, inp, fname, use_fp8) + if not use_fp8: + validate_result(fname, inp, model, atol=1e-3) + elif precision not in (torch.bfloat16,): + validate_result(fname, inp, model, atol=1e-3, is_fp8=use_fp8) + + +@pytest.mark.parametrize("scale_factors", [[224, 224, 448, 448]]) +@pytest.mark.parametrize("use_fp8", [False, True]) +# Returning the bias is a TE fusion optimization we don't care about. +@pytest.mark.parametrize("return_bias", [False]) +@pytest.mark.parametrize("return_layernorm_output", [False]) +@pytest.mark.parametrize( + "precision, use_bias",[ + (torch.float32, False), + (torch.float32, True), + (torch.float16, True), + (torch.float16, False), +]) +def test_export_layernorm_mlp( + scale_factors: List[float], + use_fp8: bool, + use_bias: bool, + return_bias: bool, + return_layernorm_output: bool, + precision: torch.dtype +): + # Set dimensions (these are arbitrary). + in_features = 64 + out_features = 256 + hidden_size = 256 + ffn_hidden_size = 256 + + inp = torch.randn(in_features, out_features, device="cuda", dtype=precision) + fp8_str = "_fp8" if use_fp8 else "" + bias_str = "_bias" if use_bias else "" + high_prec_str = dtype2str(precision) + fname = f"te.layernorm_mlp{fp8_str}{bias_str}{high_prec_str}.onnx" + with te.fp8_autocast(enabled=use_fp8, fp8_recipe=create_fp8_recipe()): + model = te.LayerNormMLP( + hidden_size, + ffn_hidden_size, + bias=use_bias, + return_bias=return_bias, + return_layernorm_output=return_layernorm_output, + params_dtype=precision, + ).to(device='cuda') + if use_fp8: + set_layer_scale(model, scale_factors) + do_export(model, inp, fname, use_fp8) + if not use_fp8: + validate_result(fname, inp, model, atol=5e-4) + else: + validate_result(fname, inp, model, atol=7e-3, is_fp8=use_fp8) + + +@pytest.mark.parametrize( + "precision, use_mask, attn_mask_type", [ + (torch.float32, False, None), # calls forward_torch_softmax + (torch.float32, True, None), # calls forward_torch_softmax + (torch.float16, False, "causal"), # calls ScaledUpperTriangMaskedSoftmax + (torch.float16, True, "padding"), # calls ScaledMaskedSoftmax + (torch.float16, False, "padding"), # calls ScaledSoftmax +]) +@pytest.mark.parametrize("attention_softmax_in_fp32", + [True, False]) +@pytest.mark.parametrize("apply_query_key_layer_scaling", + [True, False]) +def test_export_core_attention( + precision: torch.dtype, + use_mask: bool, + attn_mask_type: str, + attention_softmax_in_fp32: bool, + apply_query_key_layer_scaling: bool, +): + # Set dimensions (these are arbitrary). + kv_channels = 64 + num_attention_heads = 1 + qkv_size = (2048, 4, num_attention_heads, kv_channels) + + query_layer = torch.randn(qkv_size, dtype=precision, device="cuda") + key_layer = torch.randn(qkv_size, dtype=precision, device="cuda") + value_layer = torch.randn(qkv_size, dtype=precision, device="cuda") + input_names = ["query", "key", "value"] + attention_mask = None + if use_mask: + # Generate a random mask with 50% probability for 0 or 1. + probs = 0.5 * torch.ones(qkv_size[1], qkv_size[2], qkv_size[0], qkv_size[0], device="cuda", dtype=precision) + attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) + input_names.append("attention_mask") + inp = (query_layer, key_layer, value_layer, attention_mask) + + sm_prec_str = "_sm-fp32" if attention_softmax_in_fp32 else "_sm-fp16" + qk_scaling_str = "_qk-scaling" if apply_query_key_layer_scaling else "" + mask_str = get_attn_mask_str(use_mask, attn_mask_type) + high_prec_str = dtype2str(precision) + fname = f"te.core_attention{mask_str}{qk_scaling_str}{sm_prec_str}{high_prec_str}.onnx" + + if attn_mask_type is None: + attn_mask_type = 'causal' + model = te.transformer.CoreAttention( + num_attention_heads=num_attention_heads, + kv_channels=kv_channels, + attention_dropout=0.5, + attn_mask_type=attn_mask_type, + attention_softmax_in_fp32=attention_softmax_in_fp32, + apply_query_key_layer_scaling=apply_query_key_layer_scaling, + ).to(device='cuda') + do_export(model, + inp, + fname, + input_names=input_names, + use_fp8=True) + validate_result(fname, inp, model, atol=1e-2) + + +def set_mha_scales(module, + scale_factor_qkv: List[float]=[448, 448], + scale_factor_query: List[float]=[112, 112], + scale_factor_kv: List[float]=[224, 224], + scale_factor_proj: List[float]=[448, 448] +): + if module.attention_type == "self": + if module.input_layernorm: + # LayernormLinear layer scale init + set_layer_scale(module.layernorm_qkv, scale_factor_qkv) + else: + # Linear layer scale init + set_layer_scale(module.qkv, scale_factor_qkv) + else: + if module.input_layernorm: + # LayernormLinear layer scale init + set_layer_scale(module.layernorm_query, scale_factor_query) + else: + # Linear layer scale init + set_layer_scale(module.query_layer, scale_factor_query) + + # Linear layer scale init + set_layer_scale(module.key_value, scale_factor_kv) + + # Linear layer scale init + set_layer_scale(module.proj, scale_factor_proj) + +test_configs_multihead_attention = [ + #"use_mask, attn_mask_type" + (False, "causal"), # calls ScaledUpperTriangMaskedSoftmax + (True, "padding"), # calls ScaledMaskedSoftmax + (False, "padding"), # calls ScaledSoftmax +] +test_configs_attention_type = [ + #"input_layernorm, attention_type, fuse_qkv_params" + (True, "self", True), + (False, "self", True), + (True, "self", False), + (False, "self", False), + # disabled because query_bias (reqd for cross attention) is defined when fuse_qkv_params is False + # (True, "cross", True), + # (False, "cross", True), + (True, "cross", False), + # disabled because TypeError: cannot assign 'transformer_engine.pytorch.module.Linear' + # as parameter 'query' (torch.nn.Parameter or None expected) + # (False, "cross", False), +] +@pytest.mark.parametrize("use_fp8", [False, True]) +@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention) +@pytest.mark.parametrize("precision", [torch.float32, torch.float16]) +@pytest.mark.parametrize("return_layernorm_output", [False]) +@pytest.mark.parametrize("input_layernorm, attention_type, fuse_qkv_params", test_configs_attention_type) +@pytest.mark.parametrize("scale_factor_qkv", [[448, 448]]) +@pytest.mark.parametrize("scale_factor_query", [[112, 112]]) +@pytest.mark.parametrize("scale_factor_kv", [[224, 224]]) +@pytest.mark.parametrize("scale_factor_proj", [[448, 448]]) +def test_export_multihead_attention( + use_fp8: bool, + use_mask: bool, + attn_mask_type: str, + precision: torch.dtype, + return_layernorm_output: bool, + input_layernorm: bool, + attention_type: str, + fuse_qkv_params: bool, + scale_factor_qkv: List[float], + scale_factor_query: List[float], + scale_factor_kv: List[float], + scale_factor_proj: List[float], +): + hidden_size = 256 + sequence_length = 128 + batch_size = 4 + num_attention_heads = 32 + kv_channels = 8 + attention_dropout = 0.1 + layernorm_epsilon = 1e-5 + init_method = output_layer_init_method = get_default_init_method() + attention_args = ( + hidden_size, + num_attention_heads, + kv_channels, + attention_dropout, + layernorm_epsilon, + init_method, + output_layer_init_method, + ) + hidden_states = torch.randn(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") + + attention_mask = None + if use_mask and attn_mask_type != "causal": + # Generate a random mask with 50% probability for 0 or 1. + probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision) + attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) + + encoder_output = None + if attention_type == "cross": + encoder_output = torch.randn(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") + inp = (hidden_states, attention_mask, encoder_output) + input_names = ["hidden_states", "attention_mask", "encoder_output"] + + fp8_str = "_fp8" if use_fp8 else "" + dtype_str = dtype2str(precision) + attn_type_str = "_self-attention" if attention_type == "self" else "_cross-attention" + fuse_qkv_str = "_fused-qkv" if fuse_qkv_params else "" + attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type) + input_ln_str = "_input-ln" if input_layernorm else "" + fname = f"te.multihead_attention{fp8_str}{attn_mask_str}{attn_type_str}{input_ln_str}{fuse_qkv_str}{dtype_str}.onnx" + + with te.fp8_autocast(enabled=use_fp8, fp8_recipe=create_fp8_recipe()): + model = te.transformer.MultiHeadAttention( + *attention_args, + attn_mask_type=attn_mask_type, + params_dtype=precision, + return_layernorm_output=return_layernorm_output, + input_layernorm=input_layernorm, + attention_type=attention_type, + fuse_qkv_params=fuse_qkv_params, + ).to(device='cuda') + if use_fp8: + set_mha_scales(model, + scale_factor_qkv, + scale_factor_query, + scale_factor_kv, + scale_factor_proj) + + do_export(model, inp, fname, use_fp8, input_names=input_names) + if not use_fp8: + validate_result(fname, inp, model, atol=1e-3) + elif precision != torch.float16: + validate_result(fname, inp, model, atol=5e-3, is_fp8=use_fp8) + +def set_transformer_layer_scales(module, + scales_self_attn: list, + scales_inter_attn: list, + scales_layernorm_mlp: list=[224, 224, 448, 448]): + # set mha scales + set_mha_scales(module.self_attention, *scales_self_attn) + if module.layer_type == "decoder": + set_mha_scales(module.inter_attention, *scales_inter_attn) + # set layernorm mlp scales + set_layer_scale(module.layernorm_mlp, scales_layernorm_mlp) + +@pytest.mark.parametrize("use_fp8", [False, True]) +@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention) +@pytest.mark.parametrize("output_layernorm", [ + #True, # TO DO: handle this + False +]) +@pytest.mark.parametrize("precision", [torch.float32, torch.float16]) +@pytest.mark.parametrize("fuse_qkv_params", [False, True]) +@pytest.mark.parametrize("apply_query_key_layer_scaling", [True, False]) +@pytest.mark.parametrize("scale_factor_qkv", [[448, 448]]) +@pytest.mark.parametrize("scale_factor_query", [[112, 112]]) +@pytest.mark.parametrize("scale_factor_kv", [[224, 224]]) +@pytest.mark.parametrize("scale_factor_proj", [[448, 448]]) +@pytest.mark.parametrize("scale_factor_layernorm_mlp", [[224, 224, 448, 448]]) +def test_export_transformer_layer( + use_fp8: bool, + use_mask: bool, + attn_mask_type: str, + output_layernorm: bool, + precision: torch.dtype, + fuse_qkv_params: bool, + apply_query_key_layer_scaling: bool, + scale_factor_qkv: List[float], + scale_factor_query: List[float], + scale_factor_kv: List[float], + scale_factor_proj: List[float], + scale_factor_layernorm_mlp: List[float], +): + # Layer configuration + hidden_size = 64 + sequence_length = 128 + batch_size = 1 + ffn_hidden_size = 256 + num_attention_heads = 4 + + input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda") + input_names = ["input"] + attention_mask = None + if use_mask and attn_mask_type != "causal": + # Generate a random mask with 50% probability for 0 or 1. + probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision) + attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) + input_names.append("attention_mask") + inp = (input_tensor, attention_mask) + + fp8_str = "_fp8" if use_fp8 else "" + fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else "" + qk_scaling_str = "_qk-scaling" if apply_query_key_layer_scaling else "" + high_prec_str = dtype2str(precision) + attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type) + fname = f"te.transformer_layer{fp8_str}{attn_mask_str}{fuse_qkv_params_str}{qk_scaling_str}{high_prec_str}.onnx" + + with te.fp8_autocast(enabled=use_fp8, fp8_recipe=create_fp8_recipe()): + model = te.TransformerLayer( + hidden_size, + ffn_hidden_size, + num_attention_heads, + self_attn_mask_type=attn_mask_type, + output_layernorm=output_layernorm, + params_dtype=precision, + fuse_qkv_params=fuse_qkv_params, + apply_query_key_layer_scaling=apply_query_key_layer_scaling).to(device='cuda') + if use_fp8: + mha_scales = [ + scale_factor_qkv, + scale_factor_query, + scale_factor_kv, + scale_factor_proj + ] + set_transformer_layer_scales(model, + scales_self_attn=mha_scales, + scales_inter_attn=mha_scales, + scales_layernorm_mlp=scale_factor_layernorm_mlp) + + do_export(model, inp, fname, use_fp8) + if not use_fp8: + validate_result(fname, inp, model, atol=1e-3) + elif precision != torch.float16: + validate_result(fname, inp, model, atol=1e-2, is_fp8=use_fp8) diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 1c5ddd5c09..b941896d49 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -10,3 +10,5 @@ from .transformer import TransformerLayer from .fp8 import fp8_autocast from .distributed import checkpoint +# Register custom op symbolic ONNX functions +from .te_onnx_extensions import * diff --git a/transformer_engine/pytorch/cpp_extensions.py b/transformer_engine/pytorch/cpp_extensions.py index 0db38d25ad..babd1ec7f7 100644 --- a/transformer_engine/pytorch/cpp_extensions.py +++ b/transformer_engine/pytorch/cpp_extensions.py @@ -12,9 +12,11 @@ def fp8_gemm( A: torch.Tensor, A_scale_inv: torch.Tensor, + A_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], A_dtype: tex.DType, B: torch.Tensor, B_scale_inv: torch.Tensor, + B_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], B_dtype: tex.DType, out_dtype: torch.dtype, workspace: torch.Tensor, @@ -41,19 +43,21 @@ def fp8_gemm( out_dtype = tex.DType.kFloat32 if fp32_output else TE_DType[out_dtype] - tex.te_gemm( + _ = torch.ops.tex_ts.te_gemm_ts( A, A_scale_inv, + A_fp8_tensor, A_dtype, True, # transa B, B_scale_inv, + B_fp8_tensor, B_dtype, False, # transb out, out_dtype, bias if use_bias else empty_tensor, - empty_tensor, + empty_tensor, # this is pre_gelu_out False, # grad workspace, workspace.shape[0], @@ -87,6 +91,7 @@ def gemm( transa = layout[0] == "T" transb = layout[1] == "T" empty_tensor = torch.Tensor() + fp8_index = -1 # dummy index input_dtype = TE_DType[dtype] output_dtype = tex.DType.kFloat32 if fp32_output else input_dtype @@ -115,13 +120,15 @@ def gemm( bias = bias if use_bias else empty_tensor - tex.te_gemm( + _ = torch.ops.tex_ts.te_gemm_ts( A, empty_tensor, + fp8_index, input_dtype, transa, B, empty_tensor, + fp8_index, input_dtype, transb, out, @@ -214,11 +221,12 @@ def fp8_gelu( otype: tex.DType, ) -> torch.Tensor: """GeLU with FP8 output""" - return tex.fp8_gelu( + return torch.ops.tex_ts.fp8_gelu_ts( inp, - fp8_meta_tensor.scale[fp8_tensor], - fp8_meta_tensor.amax_history[0][fp8_tensor], - fp8_meta_tensor.scale_inv[fp8_tensor], + fp8_meta_tensor.scale, + fp8_meta_tensor.amax_history, + fp8_meta_tensor.scale_inv, + fp8_tensor, otype, ) @@ -245,6 +253,48 @@ def layernorm_fwd_fp8( ) +def layernorm_fwd_fp8_inf( + inp: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + fp8_meta_tensor: tex.FP8TensorMeta, + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + otype: tex.DType, +) -> torch.Tensor: + """LayerNorm with FP8 output. + + This version of layernorm_fwd_fp8 is specialized for inference, and returns + only the normalized output. + """ + ret = torch.ops.tex_ts.layernorm_fwd_fp8_inf_ts( + inp, + weight, + bias, + eps, + fp8_meta_tensor.scale, + fp8_meta_tensor.amax_history, + fp8_meta_tensor.scale_inv, + fp8_tensor, + otype) + return ret + + +def layernorm_fwd_inf( + inp: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, +) -> torch.Tensor: + """LayerNorm with FP8 output""" + return torch.ops.tex_ts.layernorm_fwd_inf_ts( + inp, + weight, + bias, + eps, + ) + + def cast_to_fp8( inp: torch.Tensor, fp8_meta_tensor: tex.FP8TensorMeta, @@ -252,11 +302,12 @@ def cast_to_fp8( otype: tex.DType, ) -> torch.Tensor: """Cast input to FP8""" - return tex.cast_to_fp8( + return torch.ops.tex_ts.cast_to_fp8_ts( inp, - fp8_meta_tensor.scale[fp8_tensor], - fp8_meta_tensor.amax_history[0][fp8_tensor], - fp8_meta_tensor.scale_inv[fp8_tensor], + fp8_meta_tensor.scale, + fp8_meta_tensor.amax_history, + fp8_meta_tensor.scale_inv, + fp8_tensor, otype, ) @@ -269,9 +320,10 @@ def cast_from_fp8( otype: tex.DType, ) -> torch.Tensor: """Cast input from FP8""" - return tex.cast_from_fp8( + return torch.ops.tex_ts.cast_from_fp8_ts( inp, - fp8_meta_tensor.scale_inv[fp8_tensor], + fp8_meta_tensor.scale_inv, + fp8_tensor, itype, otype, ) diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 43389294dd..a4e8fb8a7a 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -94,6 +94,8 @@ inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) { return transformer_engine::DType::kFloat32; case at::kBFloat16: return transformer_engine::DType::kBFloat16; + case at::kBool: + return transformer_engine::DType::kByte; default: NVTE_ERROR("Invalid type"); } diff --git a/transformer_engine/pytorch/csrc/extensions.cu b/transformer_engine/pytorch/csrc/extensions.cu index 4dc18d3fae..ebc5c6d7fa 100644 --- a/transformer_engine/pytorch/csrc/extensions.cu +++ b/transformer_engine/pytorch/csrc/extensions.cu @@ -397,6 +397,23 @@ std::vector layernorm_fwd_fp8(const at::Tensor &input, } +at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + float eps, + at::Tensor scale, + at::Tensor amax, + at::Tensor scale_inv, + transformer_engine::DType otype +) { + // This is a specialized version of layernorm_fwd_fp8, optimized for inference, + // which only returns the normalized output. + std::vector out = layernorm_fwd_fp8( + input, weight, bias, eps, scale, amax, scale_inv, otype); + return out[0]; +} + + std::vector layernorm_fwd(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, @@ -428,6 +445,16 @@ std::vector layernorm_fwd(const at::Tensor &input, return {ln_out, mu, rsigma}; } +at::Tensor layernorm_fwd_inf(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + float eps +) { + // This is a specialized version of layernorm_fwd, optimized for inference, + // which only returns the normalized output. + std::vector out = layernorm_fwd(input, weight, bias, eps); + return out[0]; +} at::Tensor cast_to_fp8(const at::Tensor &input, const at::Tensor &scale, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index e2717203d5..434eacb8eb 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -95,6 +95,15 @@ std::vector layernorm_fwd_fp8(const at::Tensor &input, transformer_engine::DType otype ); +at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + float eps, + at::Tensor scale, + at::Tensor amax, + at::Tensor scale_inv, + transformer_engine::DType otype +); std::vector layernorm_fwd(const at::Tensor &input, const at::Tensor &weight, @@ -102,6 +111,11 @@ std::vector layernorm_fwd(const at::Tensor &input, float eps ); +at::Tensor layernorm_fwd_inf(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + float eps +); at::Tensor cast_to_fp8(const at::Tensor &input, const at::Tensor &scale, diff --git a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp new file mode 100755 index 0000000000..b4e6b6900a --- /dev/null +++ b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp @@ -0,0 +1,168 @@ +/************************************************************************* + * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include "extensions.h" + +namespace { + transformer_engine::DType reverse_map_dtype(int64_t dtype) { + if (dtype >= 0 && dtype < static_cast(transformer_engine::DType::kNumTypes)) { + return static_cast(dtype); + } else { + NVTE_ERROR("Type not supported."); + } + } +} //namespace + + +at::Tensor cast_to_fp8_ts(const at::Tensor &input, + const at::Tensor &scale, + const at::Tensor &amax, + const at::Tensor &scale_inv, + int64_t fp8_tensor, + int64_t otype) { + transformer_engine::DType otype_arg = reverse_map_dtype(otype); + at::Tensor output = cast_to_fp8(input, + scale[fp8_tensor], + amax[0][fp8_tensor], + scale_inv[fp8_tensor], + otype_arg); + return output.clone(); +} + +at::Tensor cast_from_fp8_ts(const at::Tensor &input, + const at::Tensor &scale_inv, + int64_t fp8_tensor, + int64_t itype, + int64_t otype) { + transformer_engine::DType itype_arg = reverse_map_dtype(itype); + transformer_engine::DType otype_arg = reverse_map_dtype(otype); + at::Tensor output = cast_from_fp8(input, + scale_inv[fp8_tensor], + itype_arg, + otype_arg); + return output.clone(); +} + +at::Tensor fp8_gelu_ts(at::Tensor input, + at::Tensor scale, + at::Tensor amax, + at::Tensor scale_inv, + int64_t fp8_tensor, + int64_t otype) { + transformer_engine::DType otype_arg = reverse_map_dtype(otype); + at::Tensor output = fp8_gelu(input, + scale[fp8_tensor], + amax[0][fp8_tensor], + scale_inv[fp8_tensor], + otype_arg); + return output.clone(); +} + +at::Tensor te_gemm_ts(at::Tensor A, + at::Tensor A_scale_inverse, + int64_t A_fp8_tensor, + int64_t A_type, + int64_t transa, + at::Tensor B, + at::Tensor B_scale_inverse, + int64_t B_fp8_tensor, + int64_t B_type, + int64_t transb, + at::Tensor D, + int64_t D_type, + at::Tensor bias, + at::Tensor pre_gelu_out, + int64_t grad, + at::Tensor workspace, + int64_t workspaceSize, + int64_t accumulate, + int64_t use_split_accumulator) { + // cast inputs to types accepted by te_gemm + transformer_engine::DType A_type_arg = reverse_map_dtype(A_type); + bool transa_arg = static_cast(transa); + transformer_engine::DType B_type_arg = reverse_map_dtype(B_type); + bool transb_arg = static_cast(transb); + transformer_engine::DType D_type_arg = reverse_map_dtype(D_type); + bool grad_arg = static_cast(grad); + size_t workspaceSize_arg = static_cast(workspaceSize); + bool accumulate_arg = static_cast(accumulate); + bool use_split_accumulator_arg = static_cast(use_split_accumulator); + + at::Tensor A_scale_inverse_arg = A_scale_inverse.clone(); + if (A_scale_inverse.numel()) + A_scale_inverse_arg = A_scale_inverse[A_fp8_tensor]; + + at::Tensor B_scale_inverse_arg = B_scale_inverse.clone(); + if (B_scale_inverse.numel()) + B_scale_inverse_arg = B_scale_inverse[B_fp8_tensor]; + + te_gemm(A, + A_scale_inverse_arg, + A_type_arg, + transa_arg, + B, + B_scale_inverse_arg, + B_type_arg, + transb_arg, + D, + D_type_arg, + bias, + pre_gelu_out, + grad_arg, + workspace, + workspaceSize_arg, + accumulate_arg, + use_split_accumulator_arg); + return D; +} + +at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + double eps, + at::Tensor scale, + at::Tensor amax, + at::Tensor scale_inv, + int64_t fp8_tensor, + int64_t otype) { + transformer_engine::DType otype_arg = reverse_map_dtype(otype); + float eps_float = static_cast(eps); + + at::Tensor output = layernorm_fwd_fp8_inf(input, + weight, + bias, + eps_float, + scale, + amax, + scale_inv, + otype_arg); + + return output.clone(); +} + +at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + double eps) { + float eps_float = static_cast(eps); + + at::Tensor output = layernorm_fwd_inf(input, + weight, + bias, + eps_float); + + return output.clone(); +} + +TORCH_LIBRARY(tex_ts, m) { + m.def("cast_to_fp8_ts", &cast_to_fp8_ts); + m.def("cast_from_fp8_ts", &cast_from_fp8_ts); + m.def("fp8_gelu_ts", &fp8_gelu_ts); + m.def("te_gemm_ts", &te_gemm_ts); + m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts); + m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts); +} diff --git a/transformer_engine/pytorch/module.py b/transformer_engine/pytorch/module.py index dfbfadbb36..96b89bad11 100644 --- a/transformer_engine/pytorch/module.py +++ b/transformer_engine/pytorch/module.py @@ -4,11 +4,12 @@ """Top level Transformer Engine PyTorch modules""" import os +import pickle import warnings from abc import ABC, abstractmethod -from typing import Union, Optional, Callable, Tuple, Dict, List, Any, Mapping +from typing import Union, Optional, Callable, Tuple, Dict, Any, Mapping from functools import partial - +import numpy as np import torch from torch.nn.parameter import Parameter from torch.nn import init @@ -69,6 +70,8 @@ fp8_gelu, fp8_cast_transpose_bgrad_dgelu_fused, layernorm_fwd_fp8, + layernorm_fwd_fp8_inf, + layernorm_fwd_inf, cast_to_fp8, cast_from_fp8, ) @@ -144,8 +147,9 @@ def init_fp8_meta_tensors(self) -> None: self.set_meta_tensor(True) self.set_meta_tensor(False) - def get_extra_state(self) -> Union[List[Any], None]: + def get_extra_state(self) -> torch.Tensor: """Save before checkpointing.""" + state = None if self.fp8: state = {} state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale @@ -162,10 +166,12 @@ def get_extra_state(self) -> Union[List[Any], None]: extra[k] = v state["extra_fp8_variables"] = extra - return state - return None + state_serialized = pickle.dumps(state) + state_tensor = torch.tensor(np.frombuffer(state_serialized, dtype=np.uint8)) + + return state_tensor - def set_extra_state(self, state: Union[List[Any], None]) -> None: + def set_extra_state(self, state: torch.Tensor) -> None: """Load previous state.""" if state is None: return @@ -204,6 +210,11 @@ def set_extra_state(self, state: Union[List[Any], None]) -> None: self.fp8_meta["autocast_id_bwd"] = state[9] return + if isinstance(state, torch.Tensor): + state = pickle.loads(state.detach().numpy().tobytes()) + if state is None: + return + # Restore global FP8 buffer states. set_global_fp8_buffer(state["global_fp8_buffer"]) set_global_fp8_recompute_buffer(state["global_fp8_recompute_buffer"]) @@ -521,13 +532,13 @@ def grad_output_preprocess( fp8_dtype_backward, ) else: + grad_output_t = None grad_output_c = cast_to_fp8( grad_output_mat, ctx.fp8_meta["scaling_bwd"], tex.FP8BwdTensors.GRAD_OUTPUT1, fp8_dtype_backward, ) - grad_output_t = None grad_bias = None return grad_output_mat, grad_output_c, grad_output_t, grad_bias @@ -537,6 +548,7 @@ def forward(self): """Needs override.""" + class _LayerNormLinear(torch.autograd.Function): """LayerNormLinear semi-top level module Calls custom cuda extensions. @@ -564,6 +576,7 @@ def forward( activation_dtype: torch.dtype, parallel_mode: Union[str, None], return_layernorm_output: bool, + is_training: bool ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible in_features = ln_weight.numel() @@ -584,19 +597,37 @@ def forward( fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) if not return_layernorm_output: - ln_out, mu, rsigma = layernorm_fwd_fp8( - inputmat, - ln_weight, - ln_bias, - eps, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - ) + if is_training: + ln_out, mu, rsigma = layernorm_fwd_fp8( + inputmat, + ln_weight, + ln_bias, + eps, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + ) + else: + mu = rsigma = None + ln_out = layernorm_fwd_fp8_inf( + inputmat, + ln_weight, + ln_bias, + eps, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + ) else: - ln_out_return, mu, rsigma = tex.layernorm_fwd( - inputmat, ln_weight, ln_bias, eps - ) + if is_training: + ln_out_return, mu, rsigma = tex.layernorm_fwd( + inputmat, ln_weight, ln_bias, eps + ) + else: + ln_out_return, mu, rsigma = layernorm_fwd_inf( + inputmat, ln_weight, ln_bias, eps + ), None, None + ln_out = cast_to_fp8( ln_out_return, fp8_meta["scaling_fwd"], @@ -604,7 +635,12 @@ def forward( fp8_dtype_forward, ) else: - ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight, ln_bias, eps) + if is_training: + ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight, ln_bias, eps) + else: + ln_out, mu, rsigma = layernorm_fwd_inf( + inputmat, ln_weight, ln_bias, eps + ), None, None ln_out_return = ln_out # Column Parallel Linear @@ -622,21 +658,31 @@ def forward( bias = cast_if_needed(bias, bias_dtype) if use_bias else bias if update_fp8_weights: - fp8_cast_transpose_fused( - weight, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_WEIGHT, - fp8_dtype_forward, - cast_out=weight_fp8, - transpose_out=weight_t_fp8, - ) + if is_training: + fp8_cast_transpose_fused( + weight, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype_forward, + cast_out=weight_fp8, + transpose_out=weight_t_fp8, + ) + else: + weight_t_fp8 = None + weight_fp8 = cast_to_fp8( + weight, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype_forward) out = fp8_gemm( weight_fp8, - fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_WEIGHT], + fp8_meta["scaling_fwd"].scale_inv, + tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, ln_out_total, - fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT], + fp8_meta["scaling_fwd"].scale_inv, + tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, activation_dtype, get_workspace(), @@ -658,29 +704,30 @@ def forward( use_bias=use_bias, ) - ctx.save_for_backward( - inputmat, - ln_weight, - mu, - rsigma, - weight, - weight_t_fp8, - ln_out, - fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, - ) + if is_training: + ctx.save_for_backward( + inputmat, + ln_weight, + mu, + rsigma, + weight, + weight_t_fp8, + ln_out, + fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, + ) - ctx.activation_dtype = activation_dtype - ctx.fp8 = fp8 - ctx.fp8_meta = fp8_meta - ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation - ctx.is_first_microbatch = is_first_microbatch - ctx.use_bias = use_bias - ctx.sequence_parallel = sequence_parallel - ctx.tensor_parallel = tensor_parallel - ctx.inp_shape = inp.shape - ctx.parallel_mode = parallel_mode - ctx.tp_group = tp_group - ctx.return_layernorm_output = return_layernorm_output + ctx.activation_dtype = activation_dtype + ctx.fp8 = fp8 + ctx.fp8_meta = fp8_meta + ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation + ctx.is_first_microbatch = is_first_microbatch + ctx.use_bias = use_bias + ctx.sequence_parallel = sequence_parallel + ctx.tensor_parallel = tensor_parallel + ctx.inp_shape = inp.shape + ctx.parallel_mode = parallel_mode + ctx.tp_group = tp_group + ctx.return_layernorm_output = return_layernorm_output # Row Parallel Linear if parallel_mode == "row" and sequence_parallel: @@ -695,6 +742,7 @@ def forward( return out, ln_out_return.view_as(inp) return out + @staticmethod def backward( ctx, *grad_outputs: Tuple[torch.Tensor, ...] @@ -748,10 +796,12 @@ def backward( # DGRAD: Evaluated unconditionally to feed into Linear backward dgrad = fp8_gemm( weight_t_fp8, - fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT], + fwd_scale_inverses, + tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, grad_output_c, - ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1], + ctx.fp8_meta["scaling_bwd"].scale_inv, + tex.FP8BwdTensors.GRAD_OUTPUT1, fp8_dtype_backward, ctx.activation_dtype, get_workspace(), @@ -784,12 +834,12 @@ def backward( ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) wgrad = fp8_gemm( ln_out_total_t, - fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_INPUT], + fwd_scale_inverses, + tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, grad_output_t, - ctx.fp8_meta["scaling_bwd"].scale_inv[ - tex.FP8BwdTensors.GRAD_OUTPUT1 - ], + ctx.fp8_meta["scaling_bwd"].scale_inv, + tex.FP8BwdTensors.GRAD_OUTPUT1, fp8_dtype_backward, ctx.activation_dtype, get_workspace(), @@ -874,6 +924,7 @@ def backward( None, None, None, + None, ) @@ -1035,7 +1086,7 @@ def __init__( if self.parallel_mode == "column": set_tensor_model_parallel_attributes(self.bias, True, 0, 1) else: - self.register_buffer("bias", torch.Tensor(), persistent=False) + self.register_buffer("bias", torch.Tensor().type(params_dtype), persistent=False) with torch.no_grad(): self.bias.zero_() @@ -1094,8 +1145,13 @@ def forward( inp = self.pre_forward(inp) bias_tensor = bias if bias is not None else self.bias - - out = _LayerNormLinear.apply( + if self.training: + fwd_fn = _LayerNormLinear.apply + args = [] + else: + fwd_fn = _LayerNormLinear.forward + args = [None] + args += ( inp, self.layer_norm_weight, self.layer_norm_bias, @@ -1115,7 +1171,9 @@ def forward( self.activation_dtype, self.parallel_mode, self.return_layernorm_output, + self.training, ) + out = fwd_fn(*args) self.post_forward() @@ -1133,7 +1191,6 @@ def forward( return out, ln_out return out - class _Linear(torch.autograd.Function): """Linear semi-top level module Calls custom cuda extensions. @@ -1157,6 +1214,7 @@ def forward( tensor_parallel: bool, activation_dtype: torch.dtype, parallel_mode: Union[str, None], + is_training: bool, ) -> torch.Tensor: # Make sure input dimensions are compatible in_features = weight.shape[-1] @@ -1173,19 +1231,27 @@ def forward( fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) if not fp8_meta["recipe"].override_linear_precision.wgrad: - inputmat, inputmat_t = fp8_cast_transpose_fused( - inputmat, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - ) + if is_training: + inputmat, inputmat_t = fp8_cast_transpose_fused( + inputmat, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + ) + else: + inputmat = cast_to_fp8( + inputmat, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + ) else: - inputmat = cast_to_fp8( + inputmat, inputmat_t = cast_to_fp8( inputmat, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, - ) + ), None # Column Parallel Linear if parallel_mode == "column" and sequence_parallel: @@ -1202,21 +1268,32 @@ def forward( bias = cast_if_needed(bias, bias_dtype) if use_bias else bias if update_fp8_weights: - fp8_cast_transpose_fused( - weight, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_WEIGHT, - fp8_dtype_forward, - cast_out=weight_fp8, - transpose_out=weight_t_fp8, - ) + if is_training: + fp8_cast_transpose_fused( + weight, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype_forward, + cast_out=weight_fp8, + transpose_out=weight_t_fp8, + ) + else: + weight_t_fp8 = None + weight_fp8 = cast_to_fp8( + weight, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype_forward, + ) out = fp8_gemm( weight_fp8, - fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_WEIGHT], + fp8_meta["scaling_fwd"].scale_inv, + tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, inputmat, - fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT], + fp8_meta["scaling_fwd"].scale_inv, + tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, activation_dtype, get_workspace(), @@ -1238,28 +1315,29 @@ def forward( use_bias=use_bias, ) - ctx.save_for_backward( - inputmat_no_fp8 - if not fp8 or fp8_meta["recipe"].override_linear_precision.wgrad - else None, - inputmat_t - if fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad - else None, - weight, - weight_t_fp8, - fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, - ) - ctx.activation_dtype = activation_dtype - ctx.fp8 = fp8 - ctx.fp8_meta = fp8_meta - ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation - ctx.is_first_microbatch = is_first_microbatch - ctx.use_bias = use_bias - ctx.sequence_parallel = sequence_parallel - ctx.tensor_parallel = tensor_parallel - ctx.inp_shape = inp.shape - ctx.parallel_mode = parallel_mode - ctx.tp_group = tp_group + if is_training: + ctx.save_for_backward( + inputmat_no_fp8 + if not fp8 or fp8_meta["recipe"].override_linear_precision.wgrad + else None, + inputmat_t + if fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad + else None, + weight, + weight_t_fp8, + fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, + ) + ctx.activation_dtype = activation_dtype + ctx.fp8 = fp8 + ctx.fp8_meta = fp8_meta + ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation + ctx.is_first_microbatch = is_first_microbatch + ctx.use_bias = use_bias + ctx.sequence_parallel = sequence_parallel + ctx.tensor_parallel = tensor_parallel + ctx.inp_shape = inp.shape + ctx.parallel_mode = parallel_mode + ctx.tp_group = tp_group # Row Parallel Linear if parallel_mode == "row" and sequence_parallel: @@ -1270,6 +1348,7 @@ def forward( # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp.shape[1:-1], out.shape[-1]) + @staticmethod def backward( ctx, grad_output: torch.Tensor @@ -1326,10 +1405,12 @@ def backward( # DGRAD dgrad = fp8_gemm( weight_t_fp8, - fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT], + fwd_scale_inverses, + tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, grad_output_c, - ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1], + ctx.fp8_meta["scaling_bwd"].scale_inv, + tex.FP8BwdTensors.GRAD_OUTPUT1, fp8_dtype_backward, ctx.activation_dtype, get_workspace(), @@ -1361,12 +1442,12 @@ def backward( if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: wgrad = fp8_gemm( inputmat_t_total, - fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_INPUT], + fwd_scale_inverses, + tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, grad_output_t, - ctx.fp8_meta["scaling_bwd"].scale_inv[ - tex.FP8BwdTensors.GRAD_OUTPUT1 - ], + ctx.fp8_meta["scaling_bwd"].scale_inv, + tex.FP8BwdTensors.GRAD_OUTPUT1, fp8_dtype_backward, ctx.activation_dtype, get_workspace(), @@ -1429,6 +1510,7 @@ def backward( None, None, None, + None, ) @@ -1566,7 +1648,7 @@ def __init__( if self.parallel_mode == "column": set_tensor_model_parallel_attributes(self.bias, True, 0, 1) else: - self.register_buffer("bias", torch.Tensor(), persistent=False) + self.register_buffer("bias", torch.Tensor().type(params_dtype), persistent=False) with torch.no_grad(): self.bias.zero_() @@ -1620,8 +1702,13 @@ def forward( inp = self.pre_forward(inp) bias_tensor = bias if bias is not None else self.bias - - out = _Linear.apply( + if self.training: + linear_fn = _Linear.apply + args = [] + else: + linear_fn = _Linear.forward + args = [None] + args += ( weight if weight is not None else self.weight, self.weight1_fp8 if self.fp8 else None, self.weight1_t_fp8 if self.fp8 else None, @@ -1637,8 +1724,9 @@ def forward( self.tp_size > 1, self.activation_dtype, self.parallel_mode, + self.training, ) - + out = linear_fn(*args) self.post_forward() if self.gemm_bias_unfused_add: @@ -1681,6 +1769,7 @@ def forward( return_layernorm_output: bool, bias_gelu_nvfusion: bool, set_parallel_mode: bool, + is_training: bool ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible in_features = ln_weight.numel() @@ -1700,15 +1789,26 @@ def forward( if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) if not return_layernorm_output: - ln_out, mu, rsigma = layernorm_fwd_fp8( - inputmat, - ln_weight, - ln_bias, - eps, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - ) + if is_training: + ln_out, mu, rsigma = layernorm_fwd_fp8( + inputmat, + ln_weight, + ln_bias, + eps, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + ) + else: + ln_out = layernorm_fwd_fp8_inf( + inputmat, + ln_weight, + ln_bias, + eps, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + ) else: ln_out_return, mu, rsigma = tex.layernorm_fwd( inputmat, ln_weight, ln_bias, eps @@ -1720,9 +1820,14 @@ def forward( fp8_dtype_forward, ) else: - ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight, ln_bias, eps) - ln_out_return = ln_out + if is_training: + ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight, ln_bias, eps) + else: + ln_out, mu, rsigma = layernorm_fwd_inf( + inputmat, ln_weight, ln_bias, eps + ), None, None + ln_out_return = ln_out # Column Parallel Linear if set_parallel_mode and sequence_parallel: ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) @@ -1739,30 +1844,48 @@ def forward( fc2_bias = cast_if_needed(fc2_bias, bias_dtype) if use_bias else fc2_bias if update_fp8_weights: - fp8_cast_transpose_fused( - fc1_weight, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_WEIGHT, - fp8_dtype_forward, - cast_out=fc1_weight_fp8, - transpose_out=fc1_weight_t_fp8, - ) + if is_training: + fp8_cast_transpose_fused( + fc1_weight, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype_forward, + cast_out=fc1_weight_fp8, + transpose_out=fc1_weight_t_fp8, + ) - fp8_cast_transpose_fused( - fc2_weight, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM2_WEIGHT, - fp8_dtype_forward, - cast_out=fc2_weight_fp8, - transpose_out=fc2_weight_t_fp8, - ) + fp8_cast_transpose_fused( + fc2_weight, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM2_WEIGHT, + fp8_dtype_forward, + cast_out=fc2_weight_fp8, + transpose_out=fc2_weight_t_fp8, + ) + else: + fc1_weight_t_fp8 = None + fc1_weight_fp8 = cast_to_fp8( + fc1_weight, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype_forward, + ) + fc2_weight_t_fp8 = None + fc2_weight_fp8 = cast_to_fp8( + fc2_weight, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM2_WEIGHT, + fp8_dtype_forward, + ) fc1_out = fp8_gemm( fc1_weight_fp8, - fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_WEIGHT], + fp8_meta["scaling_fwd"].scale_inv, + tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, ln_out_total, - fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT], + fp8_meta["scaling_fwd"].scale_inv, + tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, activation_dtype, get_workspace(), @@ -1780,10 +1903,12 @@ def forward( fc2_out = fp8_gemm( fc2_weight_fp8, - fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM2_WEIGHT], + fp8_meta["scaling_fwd"].scale_inv, + tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype_forward, gelu_out, - fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM2_INPUT], + fp8_meta["scaling_fwd"].scale_inv, + tex.FP8FwdTensors.GEMM2_INPUT, fp8_dtype_forward, activation_dtype, get_workspace(), @@ -1810,7 +1935,7 @@ def forward( gelu=not bias_gelu_nvfusion, ) - if bias_gelu_nvfusion: + if bias_gelu_nvfusion and is_training: fc1_out, _, _ = fc1_outputs gelu_out = bias_gelu_fused(fc1_out, fc1_bias) else: @@ -1824,35 +1949,35 @@ def forward( bias=fc2_bias, use_bias=use_bias, ) - - ctx.save_for_backward( - inputmat, - ln_weight, - mu, - rsigma, - ln_out, - fc1_out, - gelu_out, - fc1_weight, - fc1_weight_t_fp8, - fc2_weight, - fc2_weight_t_fp8, - fc1_bias, - fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, - ) - ctx.activation_dtype = activation_dtype - ctx.fp8 = fp8 - ctx.fp8_meta = fp8_meta - ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation - ctx.is_first_microbatch = is_first_microbatch - ctx.use_bias = use_bias - ctx.sequence_parallel = sequence_parallel - ctx.tensor_parallel = tensor_parallel - ctx.inp_shape = inp.shape - ctx.tp_group = tp_group - ctx.bias_gelu_nvfusion = bias_gelu_nvfusion - ctx.return_layernorm_output = return_layernorm_output - ctx.set_parallel_mode = set_parallel_mode + if is_training: + ctx.save_for_backward( + inputmat, + ln_weight, + mu, + rsigma, + ln_out, + fc1_out, + gelu_out, + fc1_weight, + fc1_weight_t_fp8, + fc2_weight, + fc2_weight_t_fp8, + fc1_bias, + fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, + ) + ctx.activation_dtype = activation_dtype + ctx.fp8 = fp8 + ctx.fp8_meta = fp8_meta + ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation + ctx.is_first_microbatch = is_first_microbatch + ctx.use_bias = use_bias + ctx.sequence_parallel = sequence_parallel + ctx.tensor_parallel = tensor_parallel + ctx.inp_shape = inp.shape + ctx.tp_group = tp_group + ctx.bias_gelu_nvfusion = bias_gelu_nvfusion + ctx.return_layernorm_output = return_layernorm_output + ctx.set_parallel_mode = set_parallel_mode # Row Parallel Linear if set_parallel_mode and sequence_parallel: @@ -1867,6 +1992,7 @@ def forward( return fc2_out, ln_out_return.view_as(inp) return fc2_out + @staticmethod def backward( ctx, *grad_outputs: Tuple[torch.Tensor, ...] @@ -1925,10 +2051,12 @@ def backward( # FC2 DGRAD; Unconditional fc2_dgrad = fp8_gemm( fc2_weight_t_fp8, - fwd_scale_inverses[tex.FP8FwdTensors.GEMM2_WEIGHT], + fwd_scale_inverses, + tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype_forward, grad_output_c, - ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1], + ctx.fp8_meta["scaling_bwd"].scale_inv, + tex.FP8BwdTensors.GRAD_OUTPUT1, fp8_dtype_backward, ctx.activation_dtype, get_workspace(), @@ -1941,12 +2069,12 @@ def backward( gelu_out_t = tex.fp8_transpose(gelu_out, fp8_dtype_forward) fc2_wgrad = fp8_gemm( gelu_out_t, - fwd_scale_inverses[tex.FP8FwdTensors.GEMM2_INPUT], + fwd_scale_inverses, + tex.FP8FwdTensors.GEMM2_INPUT, fp8_dtype_forward, grad_output_t, - ctx.fp8_meta["scaling_bwd"].scale_inv[ - tex.FP8BwdTensors.GRAD_OUTPUT1 - ], + ctx.fp8_meta["scaling_bwd"].scale_inv, + tex.FP8BwdTensors.GRAD_OUTPUT1, fp8_dtype_backward, ctx.activation_dtype, get_workspace(), @@ -2004,10 +2132,12 @@ def backward( # FC1 DGRAD: Unconditional fc1_dgrad = fp8_gemm( fc1_weight_t_fp8, - fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT], + fwd_scale_inverses, + tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, dgelu, - ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT2], + ctx.fp8_meta["scaling_bwd"].scale_inv, + tex.FP8BwdTensors.GRAD_OUTPUT2, fp8_dtype_backward, ctx.activation_dtype, get_workspace(), @@ -2072,12 +2202,12 @@ def backward( ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) fc1_wgrad = fp8_gemm( ln_out_total_t, - fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_INPUT], + fwd_scale_inverses, + tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, dgelu_t, - ctx.fp8_meta["scaling_bwd"].scale_inv[ - tex.FP8BwdTensors.GRAD_OUTPUT2 - ], + ctx.fp8_meta["scaling_bwd"].scale_inv, + tex.FP8BwdTensors.GRAD_OUTPUT2, fp8_dtype_backward, ctx.activation_dtype, get_workspace(), @@ -2176,6 +2306,7 @@ def backward( None, None, None, + None, ) @@ -2370,7 +2501,7 @@ def __init__( ) ) else: - self.register_buffer("fc2_bias", torch.Tensor(), persistent=False) + self.register_buffer("fc2_bias", torch.Tensor().type(params_dtype), persistent=False) # For RPL, bias has to be added after TP collectives # So it cannot be fused with the GEMM @@ -2422,7 +2553,13 @@ def forward( inp = self.pre_forward(inp, num_gemms=2) - out = _LayerNormMLP.apply( + if self.training: + fwd_fn = _LayerNormMLP.apply + args = [] + else: + fwd_fn = _LayerNormMLP.forward + args = [None] + args += ( inp, self.layer_norm_weight, self.layer_norm_bias, @@ -2447,7 +2584,9 @@ def forward( self.return_layernorm_output, self.bias_gelu_nvfusion, self.set_parallel_mode, + self.training, ) + out = fwd_fn(*args) self.post_forward() diff --git a/transformer_engine/pytorch/softmax.py b/transformer_engine/pytorch/softmax.py index 8a34615d7f..2e9cf61388 100644 --- a/transformer_engine/pytorch/softmax.py +++ b/transformer_engine/pytorch/softmax.py @@ -5,9 +5,10 @@ """Fused scaled masked softmax functions""" import os from typing import Callable, Tuple, Union - import torch from torch import nn +import torch._C._onnx as _C_onnx +from torch.onnx import _type_utils import transformer_engine_extensions as tex @@ -46,6 +47,36 @@ def backward( return input_grads, None + @staticmethod + def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value: + """ScaledUpperTriangMaskedSoftmax symbolic method""" + def triangular_mask(): + dtype = _type_utils.JitScalarType.INT64 + ones = torch.onnx.symbolic_opset9.ones_like(g, inputs, dtype) + k = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) + mask = g.op("Trilu", ones, k, upper_i=1) + mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL) + return mask + + # Captures the logic of function scaled_upper_triang_masked_softmax_warp_forward + if inputs.type().scalarType() == "BFloat16": + inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT16) + mask = triangular_mask() + one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) + inv_mask = g.op("Sub", one, mask) + + neg_tenK = g.op("Constant", value_t=torch.tensor(-10000., dtype=torch.float16)) + softmax_mask = g.op("Mul", mask, neg_tenK) + + scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16)) + scaled = g.op("Mul", inputs, scale_input) + masked_scaled = g.op("Mul", inv_mask, scaled) + masked = g.op("Add", masked_scaled, softmax_mask) + out = g.op("Softmax", masked) + if inputs.type().scalarType() == "BFloat16": + out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT16) + return out + class ScaledMaskedSoftmax(torch.autograd.Function): """ @@ -78,6 +109,35 @@ def backward( ) return input_grads, None, None + @staticmethod + def symbolic( + g: torch.Graph, + inputs: torch._C.Value, + mask: torch._C.Value, + scale: float) -> torch._C.Value: + """ScaledMaskedSoftmax symbolic method""" + # Captures the logic of function scaled_masked_softmax_warp_forward. + # output = softmax(mask(input*scale) + # Computed as: + # masked_scaled = (1 - mask)*(input*scale) + # softmax_mask = mask * -10000 + # output = softmax(masked_scaled + softmax_mask) + if inputs.type().scalarType() == "BFloat16": + inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT16) + scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16)) + scaled = g.op("Mul", inputs, scale_input) + one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) + inv_mask = g.op("Sub", one, mask) + # Note: type is hard coded because softmax uses FP16 or BF16 + neg_tenK = g.op("Constant", value_t=torch.tensor(-10000., dtype=torch.float16)) + softmax_mask = g.op("Mul", mask, neg_tenK) + masked_scaled = g.op("Mul", inv_mask, scaled) + masked = g.op("Add", masked_scaled, softmax_mask) + out = g.op("Softmax", masked) + if inputs.type().scalarType() == "BFloat16": + out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT16) + return out + class ScaledSoftmax(torch.autograd.Function): """ @@ -107,6 +167,19 @@ def backward( ) return input_grads, None, None + @staticmethod + def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value: + """ScaledSoftmax symbolic method""" + if inputs.type().scalarType() == "BFloat16": + inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT16) + scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16)) + scaled = g.op("Mul", inputs, scale_input) + out = g.op("Softmax", scaled) + if inputs.type().scalarType() == "BFloat16": + out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT16) + return out + + class FusedScaleMaskSoftmax(nn.Module): """ @@ -163,7 +236,7 @@ def is_kernel_available(self, b: int, np: int, sq: int, sk: int) -> bool: and attn_batches % 4 == 0 # np * b must be divisor of 4 ): if 0 <= sk <= 4096: - batch_per_block = self.get_batch_per_block(sk) + batch_per_block = self.get_batch_per_block(int(sk)) if self.attn_mask_type == "causal": if attn_batches % batch_per_block == 0: diff --git a/transformer_engine/pytorch/te_onnx_extensions.py b/transformer_engine/pytorch/te_onnx_extensions.py new file mode 100755 index 0000000000..765e30e7fc --- /dev/null +++ b/transformer_engine/pytorch/te_onnx_extensions.py @@ -0,0 +1,194 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +ONNX symbolic functions for Transformer Engine + +Warnings of the type pasted below are a known Pytorch issue +(https://github.com/pytorch/pytorch/issues/81693): + +tests/test_onnx_export.py::test_export_cast_ops[112] + /opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py:649: + UserWarning: The shape inference of trt::TRT_FP8DequantizeLinear type is missing, + so it may result in wrong shape inference for the exported graph. + Please consider adding it in symbolic function. (Triggered internally at + /opt/pytorch/pytorch/torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1880.) + _C._jit_pass_onnx_graph_shape_type_inference( + + +Scale tensors are treated as lists ("fs") instead of tensors ("v") because we need to access +specific entries using the index passes as `fp8_tensor`. If you fail to do this you will get +the following error when accessing a sepcific scale element (e.g. `scale_inv[fp8_tensor]`): + TypeError: 'torch._C.Value' object is not subscriptable +""" + +import torch +from torch.onnx import symbolic_helper, register_custom_op_symbolic +import torch._C._onnx as _C_onnx +import transformer_engine_extensions as tex + +# This file registers custom op symbolic ONNX functions and does not export any symbols. +__all__ = [] + + +# Custom ops spec version +VER = 1 + +UNSPECIFIED_TYPE = -1 + + +def make_op_name(op_name: str) -> str: + """custom op name""" + return "trt::" + op_name + + +def quantize(g, inputs, scale_inv, fp8_tensor): + """Helper Function for Quantization""" + output_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs) + + # Q inputs are currently constrained to FP32 due to a similar limitation in ORT + # custom ops, so cast the input if needed. + if inputs.type().scalarType() == "Half": + inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT) + + scale = g.op("Constant", value_t=torch.tensor(scale_inv[fp8_tensor])) + q_op = g.op( + make_op_name("TRT_FP8QuantizeLinear"), inputs, scale).setType( + inputs.type().with_dtype(torch.uint8).with_sizes(output_shape)) + return q_op + + +def dequantize(g, inputs, scale_inv, fp8_tensor, otype): + """Helper Function for Dequantization""" + output_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs) + + scale = g.op("Constant", value_t=torch.tensor(scale_inv[fp8_tensor])) + out = g.op(make_op_name("TRT_FP8DequantizeLinear"), inputs, scale).setType( + inputs.type().with_dtype(torch.float32).with_sizes(output_shape)) + + # DQ outputs are currently constrained to FP32 due to a similar limitation in ORT + # custom ops, so cast the output if needed. + if otype == int(tex.DType.kFloat16): + out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT16) + return out + + +@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i") +def onnx_cast_to_fp8(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): + """ONNX graph for cast_to_fp8""" + # pylint: disable=unused-argument + return quantize(g, inputs, scale_inv, fp8_tensor) + + +@symbolic_helper.parse_args("v", "fs", "i", "i", "i") +def onnx_cast_from_fp8(g, inputs, scale_inv, fp8_tensor, itype, otype): + """ONNX graph for cast_from_fp8""" + # pylint: disable=unused-argument + return dequantize(g, inputs, scale_inv, fp8_tensor, otype) + + +@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i") +def onnx_fp8_gelu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): + """ONNX graph for fp8_gelu""" + # pylint: disable=unused-argument + gelu = torch.onnx.symbolic_opset9.gelu(g, inputs, "tanh") + out = quantize(g, gelu, scale_inv, fp8_tensor) + return out + + +@symbolic_helper.parse_args("v", "fs", "i", "i", "i", + "v", "fs", "i", "i", "i", + "v", "i", "v", "v", "i", + "v", "i", "i", "i") +def onnx_te_gemm( + g, + weight, + weight_scale_inverse, + weight_fp8_tensor, + weight_type, + trans_weight, + inputs, + input_scale_inverse, + input_fp8_tensor, + input_type, + trans_input, + out, + out_type, + bias, + pre_gelu_out, + grad, + workspace, + workspaceSize, + accumulate, + use_split_accumulator): + """ONNX graph for te_gemm""" + # pylint: disable=unused-argument + is_fp16 = bias.type().scalarType() == "Half" + if input_type == int(tex.DType.kFloat8E4M3): + inputs = dequantize(g, inputs, input_scale_inverse, input_fp8_tensor, UNSPECIFIED_TYPE) + + if weight_type == int(tex.DType.kFloat8E4M3): + weight = dequantize(g, weight, weight_scale_inverse, weight_fp8_tensor, UNSPECIFIED_TYPE) + + output = g.op("Gemm", inputs, weight, transA_i=trans_input, transB_i=trans_weight) + + empty_tensor_size = [0] + bias_empty = torch.onnx.symbolic_helper._get_tensor_sizes(bias) == empty_tensor_size + pre_gelu_out_empty = torch.onnx.symbolic_helper._get_tensor_sizes(pre_gelu_out) \ + == empty_tensor_size + if not bias_empty: + if pre_gelu_out_empty: + if is_fp16: + output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16) + output = g.op('Add', output, bias) + else: + if is_fp16: + output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16) + output = g.op('Add', output, bias) + output = torch.onnx.symbolic_opset9.gelu(g, output) + else: + if is_fp16: + output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16) + return output + + +@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v", "fs", "i", "i") +def onnx_layernorm_fwd_fp8(g, inputs, weight, bias, eps, scale, amax, scale_inv, fp8_tensor, otype): + """ONNX graph for layernorm_fwd_fp8""" + # pylint: disable=unused-argument + ln = onnx_layernorm_fwd(g, inputs, weight, bias, eps) + fp8_ln = quantize(g, ln, scale_inv, fp8_tensor) + return fp8_ln + + +@symbolic_helper.parse_args("v", "v", "v", "f") +def onnx_layernorm_fwd(g, inputs, weight, bias, eps): + """ONNX graph for layernorm_fwd""" + # pylint: disable=unused-argument + normalized_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs) + if normalized_shape is None: + ndim = torch.onnx.symbolic_helper._get_tensor_rank(inputs) + assert ndim is not None + normalized_shape = list(range(0, ndim)) + # Normalization axis = 0, so normalized_shape uses all dims except dim = 0 + normalized_shape = normalized_shape[1:] + + ln = torch.onnx.symbolic_opset9.layer_norm( + g, + inputs, + normalized_shape, + weight, + bias, + eps, + False # cudnn_enable (not relevant) + ) + return ln + + +register_custom_op_symbolic('tex_ts::cast_to_fp8_ts', onnx_cast_to_fp8, VER) +register_custom_op_symbolic('tex_ts::cast_from_fp8_ts', onnx_cast_from_fp8, VER) +register_custom_op_symbolic('tex_ts::fp8_gelu_ts', onnx_fp8_gelu, VER) +register_custom_op_symbolic('tex_ts::te_gemm_ts', onnx_te_gemm, VER) +register_custom_op_symbolic('tex_ts::layernorm_fwd_fp8_inf_ts', onnx_layernorm_fwd_fp8, VER) +register_custom_op_symbolic('tex_ts::layernorm_fwd_inf_ts', onnx_layernorm_fwd, VER)