diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 539c609396..7ef502ad25 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -239,14 +239,14 @@ jobs: cd python LIT_TEST_DIR="build/$(ls build | grep -i cmake)/test" if [ ! -d "${LIT_TEST_DIR}" ]; then - echo "Coult not find '${LIT_TEST_DIR}'" ; exit -1 + echo "Could not find '${LIT_TEST_DIR}'" ; exit -1 fi lit -v "${LIT_TEST_DIR}" - name: Run python tests on CUDA run: | INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/build/$(ls python/build | grep -i lib)/triton/instrumentation" if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then - echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1 + echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1 fi cd python/test/unit python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py @@ -268,14 +268,16 @@ jobs: language/test_random.py language/test_block_pointer.py language/test_subprocess.py language/test_line_info.py \ runtime/test_autotuner.py::test_kwargs[False]\ ../../tutorials/06-fused-attention.py::test_op --device cpu + - name: Run regression tests + run: | + cd python/test/regression + python3 -m pytest -s -n 8 . - name: Run C++ unittests run: | cd python cd "build/$(ls build | grep -i cmake)" ctest -j32 - name: Run Proton tests - env: - LD_LIBRARY_PATH: "/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" run: | cd third_party/proton python3 -m pytest -s test @@ -395,14 +397,14 @@ jobs: cd python LIT_TEST_DIR="build/$(ls build | grep -i cmake)/test" if [ ! -d "${LIT_TEST_DIR}" ]; then - echo "Coult not find '${LIT_TEST_DIR}'" ; exit -1 + echo "Could not find '${LIT_TEST_DIR}'" ; exit -1 fi lit -v "${LIT_TEST_DIR}" - name: Run python tests on HIP run: | INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/triton/instrumentation" if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then - echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1 + echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1 fi pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py cd python/test/unit @@ -416,10 +418,15 @@ jobs: # Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0 TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s -n 8 language/test_line_info.py + - name: Run regression tests + run: | + # Reenable test_functional_regression.py once it's fixed + cd python/test/regression + python3 -m pytest -s -n 8 ./test_cast_matmul.py - name: Run Proton tests run: | cd third_party/proton - python3 -m pytest test + python3 -m pytest -s test - name: Run C++ unittests run: | cd python diff --git a/.github/workflows/integration-tests.yml.in b/.github/workflows/integration-tests.yml.in index 8e80983ae9..d84ac6f334 100644 --- a/.github/workflows/integration-tests.yml.in +++ b/.github/workflows/integration-tests.yml.in @@ -272,7 +272,7 @@ jobs: cd python LIT_TEST_DIR="build/$(ls build | grep -i cmake)/test" if [ ! -d "${LIT_TEST_DIR}" ]; then - echo "Coult not find '${LIT_TEST_DIR}'" ; exit -1 + echo "Could not find '${LIT_TEST_DIR}'" ; exit -1 fi lit -v "${LIT_TEST_DIR}" @@ -280,7 +280,7 @@ jobs: run: | INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/build/$(ls python/build | grep -i lib)/triton/instrumentation" if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then - echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1 + echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1 fi cd python/test/unit python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py @@ -304,6 +304,11 @@ jobs: runtime/test_autotuner.py::test_kwargs[False]\ ../../tutorials/06-fused-attention.py::test_op --device cpu + - name: Run regression tests + run: | + cd python/test/regression + python3 -m pytest -s -n 8 . + - &run-cpp-unittests-step name: Run C++ unittests run: | @@ -311,9 +316,8 @@ jobs: cd "build/$(ls build | grep -i cmake)" ctest -j32 - - name: Run Proton tests - env: - LD_LIBRARY_PATH: "/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" + - &run-proton-tests-step + name: Run Proton tests run: | cd third_party/proton python3 -m pytest -s test @@ -398,7 +402,7 @@ jobs: run: | INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/triton/instrumentation" if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then - echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1 + echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1 fi pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py cd python/test/unit @@ -413,11 +417,13 @@ jobs: # Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0 TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s -n 8 language/test_line_info.py - - name: Run Proton tests + - name: Run regression tests run: | - cd third_party/proton - python3 -m pytest test + # Reenable test_functional_regression.py once it's fixed + cd python/test/regression + python3 -m pytest -s -n 8 ./test_cast_matmul.py + - *run-proton-tests-step - *run-cpp-unittests-step - *save-build-artifacts-step - *inspect-cache-directories-step diff --git a/bin/triton-lsp.cpp b/bin/triton-lsp.cpp index b185b03748..f95036dc6c 100644 --- a/bin/triton-lsp.cpp +++ b/bin/triton-lsp.cpp @@ -6,6 +6,5 @@ int main(int argc, char **argv) { mlir::DialectRegistry registry; registerTritonDialects(registry); - mlir::MLIRContext context(registry); return mlir::failed(mlir::MlirLspServerMain(argc, argv, registry)); } diff --git a/docs/meetups/dev_conference_2024.md b/docs/meetups/dev_conference_2024.md new file mode 100644 index 0000000000..6816b4c597 --- /dev/null +++ b/docs/meetups/dev_conference_2024.md @@ -0,0 +1,3 @@ +The conference slides are available [here](https://drive.google.com/drive/folders/1osK9hwcX_lC1EjdZGB-v4w5oKx23UnU2?usp=drive_link) + +The conference videos are available [here](https://www.youtube.com/playlist?list=PLc_vA1r0qoiTjlrINKUuFrI8Ptoopm8Vz). diff --git a/include/triton/Dialect/Triton/IR/Types.h b/include/triton/Dialect/Triton/IR/Types.h index 9313e26911..74fa4ba961 100644 --- a/include/triton/Dialect/Triton/IR/Types.h +++ b/include/triton/Dialect/Triton/IR/Types.h @@ -34,6 +34,8 @@ Type getI32SameShape(Type type); Type getPointerTypeSameShape(Type type); +Type getPointerTypeToElement(Type type); + } // namespace triton } // namespace mlir diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 3ebccfc801..0ccd97970a 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -116,9 +116,20 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { RankedTensorType dstTy = op.getType(); Attribute srcLayout = srcTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding(); + // FIXME [Dot LL] + // Do for all DotOperandEncodingAttr once we have LLs for all of them + auto isAmpereLargeKWidth = [](Attribute layout) { + if (auto dot = dyn_cast(layout)) { + if (auto mma = dyn_cast(dot.getParent())) { + return mma.isAmpere() && dot.getKWidth() == 8; + } + } + return false; + }; if (isa(srcLayout) && - isa( - dstLayout)) { + (isa( + dstLayout) || + isAmpereLargeKWidth(dstLayout))) { return lowerSharedToDistributed(op, adaptor, getTypeConverter(), rewriter); } @@ -170,6 +181,37 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { SmallVector outVals = loadSharedToDistributed( dstTy, srcTy, elemLlvmTy, smemObj, loc, rewriter, targetInfo); + // FIXME [Dot LL] + // Ampere case + // In this case, we need to pack the outputs into i32 + if (isa(dstTy.getEncoding())) { + if (elemLlvmTy.isInteger(8)) { + auto concat = [&](Value a1, Value a2, Value a3, Value a4) { + return or_(or_(zext(i32_ty, a1), shl(zext(i32_ty, a2), i32_val(8))), + or_(shl(zext(i32_ty, a3), i32_val(16)), + shl(zext(i32_ty, a4), i32_val(24)))); + }; + SmallVector outVals32(outVals.size() / 4); + for (int i = 0; i < outVals32.size(); ++i) { + outVals32[i] = concat(outVals[4 * i], outVals[4 * i + 1], + outVals[4 * i + 2], outVals[4 * i + 3]); + } + outVals = outVals32; + } else { + assert(elemLlvmTy.isBF16() && "Unexpected element type"); + auto concat = [&](Value a, Value b) { + return or_(zext(i32_ty, bitcast(a, i16_ty)), + shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16))); + }; + + SmallVector outVals32(outVals.size() / 2); + for (int i = 0; i < outVals32.size(); ++i) { + outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]); + } + outVals = outVals32; + } + } + Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy); rewriter.replaceOp(op, result); diff --git a/lib/Dialect/Triton/IR/Types.cpp b/lib/Dialect/Triton/IR/Types.cpp index cf7c4a0bca..6e41e70a8e 100644 --- a/lib/Dialect/Triton/IR/Types.cpp +++ b/lib/Dialect/Triton/IR/Types.cpp @@ -1,6 +1,7 @@ #include "triton/Dialect/Triton/IR/Types.h" #include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` +#include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LLVM.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` @@ -157,6 +158,12 @@ Type getPointerTypeSameShape(Type type) { } } +Type getPointerTypeToElement(Type type) { + Type elementType = getElementTypeOrSelf(type); + PointerType ptrType = PointerType::get(elementType, 1); + return ptrType; +} + // upstream Triton only uses address space 1 for Pointer Type Type getPointerType(Type type, int addressSpace) { return PointerType::get(type, addressSpace); diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index b6d855a053..cee1ae84ef 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -163,85 +163,6 @@ void LayoutRematerialization::cleanup() { op->erase(); } -// Look ahead to at the transitive uses and see if there is a convert to mma -// operations. -bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) { - SmallVector queue = {op->getResult(0)}; - SetVector forwardSlice; - llvm::SmallDenseSet seen; - while (!queue.empty()) { - Value currentValue = queue.back(); - queue.pop_back(); - getForwardSlice(currentValue, &forwardSlice); - for (Operation *op : forwardSlice) { - // HACK: Stop propagation if the ReduceOp is using mma layout but is - // producing tensor smaller than the layout we would like to propagate. - // This is to avoid stepping into the known bug. - if (isa(op)) { - auto tensorType = - dyn_cast(op->getOperand(0).getType()); - if (tensorType && - isa(tensorType.getEncoding())) { - auto mmaInstrShape = - cast(encoding).getInstrShape(); - if (tensorType.getShape()[tensorType.getRank() - 2] < - mmaInstrShape[0] || - tensorType.getShape()[tensorType.getRank() - 1] < - mmaInstrShape[1]) { - return false; - } - } - } - - if (auto convertOp = dyn_cast(op)) { - Attribute dstEncoding = convertOp.getType().getEncoding(); - if (auto mmaLayout = dyn_cast(dstEncoding)) - return (mmaLayout.getVersionMajor() > 1) ? true - : mmaLayout == encoding; - if (isa(dstEncoding)) - return true; - if (isa(dstEncoding)) { - if (auto mmaLayout = dyn_cast(encoding)) { - return mmaLayout.getVersionMajor() > 1; - } else { - assert((mlir::isa(encoding))); - return true; - } - } - } - bool isMMAV3 = - isa(encoding) && - cast(encoding).getVersionMajor() == 3; - if (isMMAV3 && (isa(op) || isa(op))) - return true; - auto yield = dyn_cast(op); - if (!yield) - continue; - if (auto ifOp = dyn_cast(yield->getParentOp())) { - for (OpOperand &operand : yield->getOpOperands()) { - Operation *def = operand.get().getDefiningOp(); - if (def && - (forwardSlice.count(def) || operand.get() == currentValue) && - (seen.insert(operand.get()).second == true)) - queue.push_back(ifOp.getResult(operand.getOperandNumber())); - } - } - auto forOp = dyn_cast(yield.getOperation()->getParentOp()); - if (!forOp) - continue; - for (OpOperand &operand : yield->getOpOperands()) { - Operation *def = operand.get().getDefiningOp(); - if (def && (forwardSlice.count(def) || operand.get() == currentValue) && - (seen.insert(operand.get()).second == true)) - queue.push_back(forOp.getRegionIterArg(operand.getOperandNumber())); - } - } - } - return false; -} - // Return true if the op is an op with a layout we don't want to change. We will // propagate the layout starting from anchor ops. bool isLayoutAnchor(Operation *op) { @@ -262,18 +183,8 @@ bool isLayoutAnchor(Operation *op) { } void LayoutPropagation::initAnchorLayout() { - auto maybeAddAnchor = [&](Value v) { + auto addAnchor = [&](Value v) { if (auto tensorType = dyn_cast(v.getType())) { - // Workaround, don't popagate MMA layout unless there is a convert - // back to mma further down to avoid generating reduction with MMA - // layout that may have lower performance. - // This can be improved with more aggressive backward propagation. - if (isa(tensorType.getEncoding()) && - v.getDefiningOp() && - !hasConvertToMMATransisitiveUse(v.getDefiningOp(), - tensorType.getEncoding())) { - return; - } layouts.insert({v, LayoutInfo(tensorType.getEncoding())}); } }; @@ -282,13 +193,13 @@ void LayoutPropagation::initAnchorLayout() { // you can pass a tensor with an encoding as an arg, instead of explicitly // calling tt.load. for (auto arg : funcOp.getArguments()) { - maybeAddAnchor(arg); + addAnchor(arg); } funcOp.walk([&](Operation *op) { if (isLayoutAnchor(op)) { for (auto result : op->getResults()) { - maybeAddAnchor(result); + addAnchor(result); } } }); diff --git a/python/test/regression/conftest.py b/python/test/regression/conftest.py index 7a02d322b4..d88687b45f 100644 --- a/python/test/regression/conftest.py +++ b/python/test/regression/conftest.py @@ -1,12 +1,22 @@ -# content of conftest.py - +import os import pytest +import tempfile def pytest_addoption(parser): - parser.addoption("--device", action="store", default='cuda') + parser.addoption("--device", action="store", default="cuda") @pytest.fixture def device(request): return request.config.getoption("--device") + + +@pytest.fixture +def fresh_triton_cache(): + with tempfile.TemporaryDirectory() as tmpdir: + try: + os.environ["TRITON_CACHE_DIR"] = tmpdir + yield tmpdir + finally: + os.environ.pop("TRITON_CACHE_DIR", None) diff --git a/python/test/regression/test_cast_matmul.py b/python/test/regression/test_cast_matmul.py index 0e4b4bb05b..3f3801b3f5 100644 --- a/python/test/regression/test_cast_matmul.py +++ b/python/test/regression/test_cast_matmul.py @@ -1,22 +1,69 @@ """ +Mixed precision tests for matmul (tl.dot) with cast (tl.to) + issue: https://github.com/triton-lang/triton/issues/2523 -fused type convert and matmul, base on triton matmul, the different with matmul: -1. force C's dtype=dot_out_dtype to ["float16", "float32"] -2. accept A and B with dtype=["float32", "float64"] +TODO: float8 types """ import warnings import pytest import torch +import triton import triton.runtime as tr import triton.language as tl -from triton import cdiv, jit -input_dtypes = ["float32", "float64"] +input_dtypes = ["float16", "float32", "float64"] out_dtypes = ["float16", "float32"] +@triton.jit +def matmul_kernel(A, B, C, M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + dot_out_dtype: tl.constexpr, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, # + BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr): + # matrix multiplication + pid = tl.program_id(0) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) + for k in range(0, tl.cdiv(K, BLOCK_K)): + k_remaining = K - k * BLOCK_K + _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) + a = a.to(C.dtype.element_ty) + b = b.to(C.dtype.element_ty) + acc += tl.dot(a, b, out_dtype=dot_out_dtype) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + acc = acc.to(C.dtype.element_ty) + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + tl.store(C, acc, mask=mask) + + @pytest.mark.parametrize("M, K, N, w_dtype, x_dtype, out_dtype", [(M, K, N, w, x, o) # for (M, K, N) in [(128, 128, 128), (1280, 768, 1024)] # @@ -25,7 +72,7 @@ for o in out_dtypes]) def test_cast_matmul(M, K, N, w_dtype, x_dtype, out_dtype, device): if x_dtype == w_dtype: - pytest.xfail("skip same dtype") + pytest.xfail("skip the same input dtype") if device == "xpu" and "float64" in (w_dtype, x_dtype) and not tr.driver.active.get_current_target().arch['has_fp64']: pytest.xfail("float64 not supported on current xpu hardware") @@ -40,53 +87,7 @@ def test_cast_matmul(M, K, N, w_dtype, x_dtype, out_dtype, device): # launch kernel BLOCK_M, BLOCK_N, BLOCK_K = 16, 16, 32 - grid = ((cdiv(M, BLOCK_M) * cdiv(N, BLOCK_N)), 1) - - @jit - def matmul_kernel(A, B, C, M, N, K, # - stride_am, stride_ak, # - stride_bk, stride_bn, # - stride_cm, stride_cn, # - dot_out_dtype: tl.constexpr, # - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, # - BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr): - # matrix multiplication - pid = tl.program_id(0) - grid_m = tl.cdiv(M, BLOCK_M) - grid_n = tl.cdiv(N, BLOCK_N) - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) - # do matrix multiplication - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - # pointers - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) - for k in range(0, tl.cdiv(K, BLOCK_K)): - k_remaining = K - k * BLOCK_K - _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) - a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) - b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) - a = a.to(C.dtype.element_ty) - b = b.to(C.dtype.element_ty) - acc += tl.dot(a, b, out_dtype=dot_out_dtype) - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - acc = acc.to(C.dtype.element_ty) - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) - mask = (rm < M)[:, None] & (rn < N)[None, :] - tl.store(C, acc, mask=mask) + grid = ((triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N)), 1) matmul_kernel[grid]( a, b, out_triton, M, N, K, # diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py deleted file mode 100644 index 8c50e5ad5b..0000000000 --- a/python/test/regression/test_performance.py +++ /dev/null @@ -1,267 +0,0 @@ -import pytest -import torch - -import triton -import triton.language as tl -import triton.ops -from triton.testing import get_dram_gbps, get_max_tensorcore_tflops, nvsmi - -DEVICE_NAME = {7: 'v100', 8: 'a100'}[torch.cuda.get_device_capability()[0]] - -####################### -# Utilities -####################### - - -def print_perf(cur_ms, cur_util, ref_util): - # print on the same line cur_ms, cur_util and ref_util with 3 decimal places - print(f'{cur_ms:.3f} ms \t cur: {cur_util:.3f} \t ref: {ref_util:.3f} \t dif={cur_util - ref_util:.3f}', end='\t') - - -####################### -# Matrix Multiplication -####################### - -sm_clocks = {'v100': 1350, 'a100': 1350} -mem_clocks = {'v100': 877, 'a100': 1215} - -matmul_data = { - 'a100': { - # square - (512, 512, 512): {'float16': 0.108, 'float32': 0.097, 'int8': 0.05}, - (1024, 1024, 1024): {'float16': 0.355, 'float32': 0.313, 'int8': 0.169}, - (2048, 2048, 2048): {'float16': 0.653, 'float32': 0.532, 'int8': 0.34}, - (8192, 8192, 8192): {'float16': 0.839, 'float32': 0.754, 'int8': 0.51}, - # tall-skinny - (16, 1024, 1024): {'float16': 0.015, 'float32': 0.009, 'int8': 0.005}, - (16, 4096, 4096): {'float16': 0.080, 'float32': 0.051, 'int8': 0.026}, - (16, 8192, 8192): {'float16': 0.083, 'float32': 0.077, 'int8': 0.043}, - (64, 1024, 1024): {'float16': 0.045, 'float32': 0.023, 'int8': 0.017}, - (64, 4096, 4096): {'float16': 0.170, 'float32': 0.000, 'int8': 0.097}, - (64, 8192, 8192): {'float16': 0.227, 'float32': 0.000, 'int8': 0.174}, - (1024, 64, 1024): {'float16': 0.040, 'float32': 0.046, 'int8': 0.017}, - (4096, 64, 4096): {'float16': 0.160, 'float32': 0.214, 'int8': 0.102}, - (8192, 64, 8192): {'float16': 0.272, 'float32': 0.000, 'int8': 0.177}, - # test EVEN_K==False - (8192, 8192, 8176): {'float16': 0.828, 'float32': 0.743, 'int8': 0.51}, - } -} - - -@pytest.mark.parametrize('M, N, K, dtype_str', [(M, N, K, dtype_str) - for M, N, K in matmul_data[DEVICE_NAME].keys() - for dtype_str in ['float16']]) -def test_matmul(M, N, K, dtype_str): - stream = torch.cuda.Stream() - torch.cuda.set_stream(stream) - if dtype_str in ['float32', 'int8'] and DEVICE_NAME != 'a100': - pytest.skip('Only test float32 & int8 on a100') - if (M, N, K) in [(64, 4096, 4096), (64, 8192, 8192), (8192, 64, 8192)] and dtype_str == 'float32': - pytest.skip('Out of shared memory in float32') - dtype = {'float16': torch.float16, 'float32': torch.float32, 'int8': torch.int8}[dtype_str] - torch.manual_seed(0) - ref_gpu_util = matmul_data[DEVICE_NAME][(M, N, K)][dtype_str] - cur_sm_clock = nvsmi(['clocks.current.sm'])[0] - max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3) - if dtype == torch.int8: - a = torch.randint(-128, 127, (M, K), dtype=dtype, device='cuda') - b = torch.randint(-128, 127, (N, K), dtype=dtype, device='cuda') - b = b.t() # only test row-col layout - else: - a = torch.randn((M, K), dtype=dtype, device='cuda') - b = torch.randn((K, N), dtype=dtype, device='cuda') - fn = lambda: triton.ops.matmul(a, b) - ms = triton.testing.do_bench_cudagraph(fn) - cur_gpu_perf = 2. * M * N * K / ms * 1e-9 - cur_gpu_util = cur_gpu_perf / max_gpu_perf - print_perf(ms, cur_gpu_util, ref_gpu_util) - triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.02, rtol=0.01) - - -####################### -# Element-Wise -####################### - - -@triton.jit -def _add(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - x = tl.load(x_ptr + offsets, mask=mask) - y = tl.load(y_ptr + offsets, mask=mask) - output = x + y - tl.store(output_ptr + offsets, output, mask=mask) - - -elementwise_data = { - 'a100': { - 1024 * 16: {'float16': 0.031, 'float32': 0.060}, - 1024 * 64: {'float16': 0.120, 'float32': 0.224}, - 1024 * 256: {'float16': 0.394, 'float32': 0.691}, - 1024 * 1024: {'float16': 1.06, 'float32': 1.453}, - 1024 * 16384: {'float16': 0.832, 'float32': 0.862}, - 1024 * 65536: {'float16': 0.873, 'float32': 0.882}, - # Non pow 2 - 1020 * 100: {'float16': 0.173, 'float32': 0.327}, - 10003 * 7007: {'float16': 0.522, 'float32': 0.873}, - } -} - - -@pytest.mark.parametrize('N', elementwise_data[DEVICE_NAME].keys()) -@pytest.mark.parametrize("dtype_str", ['float16', 'bfloat16', 'float32']) -def test_elementwise(N, dtype_str): - stream = torch.cuda.Stream() - torch.cuda.set_stream(stream) - torch.manual_seed(0) - if dtype_str in ['bfloat16'] and DEVICE_NAME != 'a100': - pytest.skip('Only test bfloat16 on a100') - dtype = {'float16': torch.float16, 'bfloat16': torch.bfloat16, 'float32': torch.float32}[dtype_str] - ref_dtype_str = 'float16' if dtype_str == 'bfloat16' else dtype_str - ref_gpu_util = elementwise_data[DEVICE_NAME][N][ref_dtype_str] - max_gpu_perf = get_dram_gbps() - z = torch.empty((N, ), dtype=dtype, device='cuda') - x = torch.randn_like(z) - y = torch.randn_like(z) - grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), ) - fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024) - ms = triton.testing.do_bench_cudagraph(fn) - cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6 - cur_gpu_util = cur_gpu_perf / max_gpu_perf - print_perf(ms, cur_gpu_util, ref_gpu_util) - triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.02, rtol=0.01) - - -####################### -# Flash-Attention -####################### - -flash_attention_data = { - "a100": { - (4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.542, - (4, 48, 4096, 64, True, True, 'forward', 'bfloat16'): 0.471, - (4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.155, - (4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.232, - (4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.231, - (4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.138, - (4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.306, - (4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.266, - (4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.098, - (4, 48, 4096, 64, True, False, 'backward', 'float16'): 0.157, - (4, 48, 4096, 64, True, False, 'backward', 'bfloat16'): 0.157, - (4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.092, - (4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.541, - (4, 48, 4096, 64, False, True, 'forward', 'bfloat16'): 0.471, - (4, 48, 1024, 16, False, True, 'forward', 'float32'): 0.150, - (4, 48, 4096, 64, False, True, 'backward', 'float16'): 0.291, - (4, 48, 4096, 64, False, True, 'backward', 'bfloat16'): 0.255, - (4, 48, 1024, 16, False, True, 'backward', 'float32'): 0.144, - (4, 48, 4096, 64, False, False, 'forward', 'float16'): 0.306, - (4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.266, - (4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.098, - (4, 48, 4096, 64, False, False, 'backward', 'float16'): 0.159, - (4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.159, - (4, 48, 1024, 16, False, False, 'backward', 'float32'): 0.088, - } -} - - -@pytest.mark.parametrize("dtype_str", ['float16', 'bfloat16', 'float32']) -@pytest.mark.parametrize("mode", ['forward', 'backward']) -@pytest.mark.parametrize("causal", [True, False]) -@pytest.mark.parametrize("seq_par", [True, False]) -@pytest.mark.parametrize("Z, H, N_CTX, D_HEAD", [[4, 48, 4096, 64]]) -def test_flash_attention(Z, H, N_CTX, D_HEAD, seq_par, causal, mode, dtype_str): - stream = torch.cuda.Stream() - torch.cuda.set_stream(stream) - is_backward = mode == 'backward' - capability = torch.cuda.get_device_capability() - if capability[0] < 8: - pytest.skip("Flash attention only supported for compute capability < 80") - torch.manual_seed(20) - dtype = {'float16': torch.float16, 'bfloat16': torch.bfloat16, 'float32': torch.float32}[dtype_str] - # init data - if dtype_str == 'float32': - N_CTX = 1024 - D_HEAD = 16 - q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_() - k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_() - v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_() - sm_scale = 0.2 - # benchmark - fn = lambda: triton.ops.attention(q, k, v, causal, sm_scale, seq_par) - if is_backward: - o = fn() - do = torch.randn_like(o) - fn = lambda: o.backward(do, retain_graph=True) - ms = triton.testing.do_bench_cudagraph(fn) - # compute flops - flops_per_matmul = 2. * Z * H * N_CTX * N_CTX * D_HEAD * 0.5 - total_flops = 2 * flops_per_matmul - if is_backward: - total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) - cur_gpu_perf = total_flops / ms * 1e-9 - # maximum flops - cur_sm_clock = nvsmi(['clocks.current.sm'])[0] - max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3) - cur_gpu_util = cur_gpu_perf / max_gpu_perf - ref_gpu_util = flash_attention_data[DEVICE_NAME][(Z, H, N_CTX, D_HEAD, seq_par, causal, mode, dtype_str)] - print_perf(ms, cur_gpu_util, ref_gpu_util) - triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.02, rtol=0.01) - - -####################### -# Reduction -####################### - - -@triton.jit -def _sum(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - x = tl.load(x_ptr + offsets, mask=mask) - y = tl.load(y_ptr + offsets, mask=mask) - # run in a loop to only to make it compute bound. - for i in range(100): - x = tl.sum(x, axis=0) + y - - tl.store(output_ptr + offsets, x, mask=mask) - - -reduction_data = { - 'a100': { - 1024 * 16384: {'float16': 0.016, 'float32': 0.031, 'int16': 0.022, 'int32': 0.048}, - 1024 * 65536: {'float16': 0.016, 'float32': 0.032, 'int16': 0.022, 'int32': 0.049}, - } -} - - -@pytest.mark.parametrize('N', reduction_data[DEVICE_NAME].keys()) -@pytest.mark.parametrize("dtype_str", ['float16', 'float32', 'int16', 'int32']) -def test_reductions(N, dtype_str): - stream = torch.cuda.Stream() - torch.cuda.set_stream(stream) - torch.manual_seed(0) - dtype = {'float16': torch.float16, 'float32': torch.float32, 'int16': torch.int16, 'int32': torch.int32}[dtype_str] - ref_gpu_util = reduction_data[DEVICE_NAME][N][dtype_str] - cur_sm_clock = nvsmi(['clocks.current.sm'])[0] - max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3) - z = torch.empty((N, ), dtype=dtype, device='cuda') - if dtype == torch.float16 or dtype == torch.float32: - x = torch.randn_like(z) - y = torch.randn_like(z) - else: - info = torch.iinfo(dtype) - x = torch.randint(info.min, info.max, (N, ), dtype=dtype, device='cuda') - y = torch.randint(info.min, info.max, (N, ), dtype=dtype, device='cuda') - grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), ) - fn = lambda: _sum[grid](x, y, z, N, BLOCK_SIZE=1024) - ms = triton.testing.do_bench_cudagraph(fn) - cur_gpu_perf = 100. * 2. * N / ms * 1e-9 - cur_gpu_util = cur_gpu_perf / max_gpu_perf - print_perf(ms, cur_gpu_util, ref_gpu_util) - triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.02, rtol=0.01) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index c5f07f065c..2a3606b581 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3267,21 +3267,6 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), y_tri, y_tri.stride(0), y_tri.stride(1), w_tri, w_tri.stride(0), w_tri.stride(1), z_tri, z_tri.stride(0), z_tri.stride(1), **kern_kwargs) - if epilogue == 'softmax' and (in_dtype != 'float32' or input_precision == "tf32"): - if not is_cuda(): - pass - else: - ptx = pgm.asm["ptx"] - start = ptx.find("shfl.sync.bfly") - end = ptx.find("cvt.rn.f16.f32") - red_code = ptx[start:end] - assert len(red_code) > 0 - - # skip this check on hopper because there are some functions whose name contain "shared" in ptx. - # TODO: we should eliminate these unused functions in ptx code. - if not (capability[0] >= 9): - assert "shared" not in red_code - assert "bar.sync" not in red_code # torch result if in_dtype == 'int8': z_ref = np.matmul(x.astype(np.float32), y.astype(np.float32())).astype(np.int32) @@ -3360,16 +3345,12 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3' in ptx -@pytest.mark.parametrize("M, N, K, col_a, col_b, type_a, type_b, num_warps", [ - (M, N, K, col_a, col_b, type_a, type_b, 4) - for M, N, K in itertools.product([32, 64, 128], [32, 64, 128], [64, 128]) - for col_a, col_b in itertools.product([True, False], repeat=2) - # We don't test e5m2 as it seems to overflow more easily - # Tested locally and it works fine other than for ~10 entries out of 10_000 - # which are of the size of 10**30 - for type_a in ["e4m3"] - for type_b in ["e4m3"] -]) +@pytest.mark.parametrize("M, N, K, col_a, col_b, type_a, type_b, num_warps", + [(M, N, K, col_a, col_b, type_a, type_b, 4) + for M, N, K in itertools.product([32, 64, 128], [32, 64, 128], [64, 128]) + for col_a, col_b in itertools.product([True, False], repeat=2) + for type_a in ["e2m1", "e4m3", "e5m2"] + for type_b in ["e4m3", "e5m2"]]) def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, device): if not is_cuda(): pytest.xfail("scaled_dot only supported on CUDA") @@ -3400,7 +3381,7 @@ def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, s a_scale = tl.load(scale_a_ptr) c = tl.dot_scaled(a, a_scale, type_a, b, None, type_b) out_ptr = out + tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] - tl.store(out_ptr, c) + tl.store(out_ptr, c.to(tl.bfloat16)) @triton.jit def mxfp_to_bf16_kernel( @@ -3477,7 +3458,6 @@ def dot_scale_ref(x, scale, y, type_x, type_y): # Need to implement fp4 -> fp8 cast to support 1 byte types comp_dtype = torch.bfloat16 - out_dtype = torch.float32 x = x.contiguous() x_upcast = x.new_empty(scale.shape[:-1] + (32 * scale.shape[-1], ), dtype=comp_dtype) @@ -3486,36 +3466,65 @@ def dot_scale_ref(x, scale, y, type_x, type_y): BLOCK_SIZE = 512 grid = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, ) mxfp_to_bf16_kernel[grid](x, scale, x_upcast, scale.numel(), e_bits, m_bits, BLOCK_SIZE, num_warps=num_warps) + assert x_upcast.isfinite().all() y_upcast = y.view(type_fp8_y).to(comp_dtype) - return torch.matmul(x_upcast.to(out_dtype), y_upcast.to(out_dtype)) + + class AccumulateInFp32: + + def __enter__(self): + self.prev_value = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False + + def __exit__(self, exc_type, exc_val, exc_tb): + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = self.prev_value + + with AccumulateInFp32(): + return torch.matmul(x_upcast.to(comp_dtype), y_upcast.to(comp_dtype)) torch.manual_seed(0) - def create_uint8(shape): - return torch.randint(0xff, shape, dtype=torch.uint8, device=device) + def create_uint8(shape, col_major=False, max_val=255): + if col_major: + shape = shape[:-2] + (shape[-1], shape[-2]) + ret = torch.randint(max_val + 1, shape, dtype=torch.uint8, device=device) + if col_major: + ret = ret.mT + return ret + + DIV_FACTOR = 2 if type_a == "e2m1" else 1 + x = create_uint8((M, K // DIV_FACTOR), col_major=col_a) + y = create_uint8((K, N), col_major=col_b) + + # sample scales that don't overflow as otherwise it's implementation defined (underflowing is alright) + # We substract a reasonably high number (64) so that the sum of all the mxfp elements does not overflow + m_bytes = int(type_a[1]) + bias_type_a = 1 << (m_bytes - 1) - 1 + max_exponent_type_a = (1 << m_bytes) - 1 - bias_type_a + scale_x = create_uint8((M, K // 32), max_val=255 - max_exponent_type_a - 64) + + def make_finite(x, dtype): + # e5m2 has too many non-finite values when sampled uniformly (1 / 32) and + # Fp8E5M2_to_Bf16 doesn't preserve NaNs (fixme) + if dtype not in ("e5m2", "e4m3"): + return x + mask = 0x7C if dtype == "e5m2" else 0x7F + finite = torch.arange(x.numel(), device=device, dtype=torch.uint8).reshape_as(x) % mask + x_finite = torch.where(x & mask == mask, finite | (0x80 & x), x) + x.copy_(x_finite) + return x - x = create_uint8((K, M)).T if col_a else create_uint8((M, K)) - y = create_uint8((N, K)).T if col_b else create_uint8((K, N)) - scale_x = create_uint8((M, K // 32)) + x = make_finite(x, type_a) + y = make_finite(y, type_b) - z = x.new_empty((M, N), dtype=torch.float32) + z = x.new_empty((M, N), dtype=torch.bfloat16) pgm = dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), z, M, N, K, type_a, type_b, num_warps=num_warps) z_ref = dot_scale_ref(x, scale_x, y, type_a, type_b) - # dot_scale_ref computes the result in higher precision - # so we equalise all the non-finite values - # This also fixes a bug in our upcasting from e5m2 to bf16 where inf is not preserved - non_finite_z = ~z.isfinite() - z_ref[non_finite_z] = z[non_finite_z] - non_finite_ref = ~z_ref.isfinite() - z[non_finite_ref] = z_ref[non_finite_ref] - - # generous rtol set because the ref is more precise than the fused - # (computes in higher dtype) and we are sampling the whole range of floats - torch.testing.assert_close(z, z_ref, equal_nan=True, atol=1e-5, rtol=1e-2) + # generous rtol as we are sampling the whole range of floats + torch.testing.assert_close(z, z_ref, atol=1e-5, rtol=1e-2) # make sure ld/st are vectorized ptx = pgm.asm['ptx'] diff --git a/python/test/unit/language/test_pipeliner.py b/python/test/unit/language/test_pipeliner.py index cc02cf0e33..fa5f34290b 100644 --- a/python/test/unit/language/test_pipeliner.py +++ b/python/test/unit/language/test_pipeliner.py @@ -11,7 +11,7 @@ def is_cuda(): return triton.runtime.driver.active.get_current_target().backend == "cuda" -def is_cuda_tma_available(): +def is_hopper(): return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 @@ -33,32 +33,54 @@ def check_capabilities(): @triton.jit def matmul_kernel( # - a_ptr, b_ptr, output_ptr, # - M, N, K, # + a_ptr, scale_ptr, b_ptr, output_ptr, # + M, N, K_MXFP, # K_MXFP is the number of mxfp vectors in a row of a. Otherwise it's just K stride_am, stride_ak, # + stride_sm, stride_sk, # stride_bk, stride_bn, # stride_cm, stride_cn, # BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # - NUM_STAGES: tl.constexpr): + NUM_STAGES: tl.constexpr, a_type: tl.constexpr, b_type: tl.constexpr): pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_M) pid_m = pid % num_pid_m pid_n = pid // num_pid_m offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N + IS_SCALED: tl.constexpr = a_type is not None and b_type is not None + DIV_FACTOR: tl.constexpr = 2 if IS_SCALED and a_type == "e2m1" else 1 + # We pass K_MXFP to make explicit that KB is multiple of 32 and KA is multiple of 16 or 32 + # for the pipeliner divisibility condition + KA = K_MXFP if not IS_SCALED else K_MXFP * (32 // DIV_FACTOR) + KB = K_MXFP if not IS_SCALED else K_MXFP * 32 + BLOCK_AK: tl.constexpr = BLOCK_K // DIV_FACTOR offs_k = tl.arange(0, BLOCK_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + offs_ak = tl.arange(0, BLOCK_AK) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + if IS_SCALED: + BLOCK_SK: tl.constexpr = BLOCK_K // 32 + offs_sk = tl.arange(0, BLOCK_SK) + scale_ptrs = scale_ptr + (offs_am[:, None] * stride_sm + offs_sk[None, :] * stride_sk) accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - for k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES): - mask_a = (offs_am[:, None] < M) & (offs_k[None, :] + k * BLOCK_K < K) - mask_b = ((offs_k[:, None] + k * BLOCK_K) < K) & (offs_bn[None, :] < N) + for k in tl.range(0, tl.cdiv(KB, BLOCK_K), num_stages=NUM_STAGES): + mask_a = (offs_am[:, None] < M) & (offs_ak[None, :] + k * BLOCK_AK < KA) + mask_b = ((offs_k[:, None] + k * BLOCK_K) < KB) & (offs_bn[None, :] < N) a = tl.load(a_ptrs, mask=mask_a, other=0) b = tl.load(b_ptrs, mask=mask_b, other=0) - accumulator = tl.dot(a, b, acc=accumulator) - a_ptrs += BLOCK_K * stride_ak + if IS_SCALED: + # Adapted scale indexing and dot_scaled operation + mask_scale = (offs_am[:, None] < M) & (offs_sk[None, :] + k * BLOCK_SK < K_MXFP) + a_scale = tl.load(scale_ptrs, mask=mask_scale, other=0) + accumulator = tl.dot_scaled(a, a_scale, a_type, b, None, b_type, acc=accumulator) + else: + accumulator = tl.dot(a, b, acc=accumulator) + a_ptrs += BLOCK_AK * stride_ak b_ptrs += BLOCK_K * stride_bk - accumulator = accumulator.to(tl.float16) + if IS_SCALED: + scale_ptrs += BLOCK_SK * stride_sk + OUT_DTYPE = tl.bfloat16 if IS_SCALED else tl.float16 + accumulator = accumulator.to(OUT_DTYPE) offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) mask_c = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) @@ -105,16 +127,142 @@ def vecadd_kernel(a_ptr, b_ptr, output_ptr, n_elements, num_blocks, BLOCK_SIZE: offsets += BLOCK_SIZE -def test_pipeline_matmul(device): +@triton.jit +def mxfp_to_bf16_kernel( + x_ptr, + scale_ptr, + mxfp_ptr, + N, + e_bits: tl.constexpr, + m_bits: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + # x.shape == (N, 32) for fp8 or (N, 16) for fp4 + # scale.shape == (N,) + # out.shape == (N, 32) + is_fp8: tl.constexpr = e_bits + m_bits == 7 + # fp8: BLOCK_SIZE -> BLOCK_SIZE // 32, 32 + # fp4: BLOCK_SIZE // 2 -> BLOCK_SIZE // 32 , 16 + PARALLEL_DIM: tl.constexpr = BLOCK_SIZE // 32 + LAST_DIM: tl.constexpr = 32 if is_fp8 else 16 + LOAD_SIZE: tl.constexpr = LAST_DIM * PARALLEL_DIM + + offsets = (tl.program_id(0) * LOAD_SIZE + tl.arange(0, PARALLEL_DIM)[:, None] * LAST_DIM + + tl.arange(0, LAST_DIM)[None, :]) + x = tl.load(x_ptr + offsets, mask=offsets < N * LAST_DIM) + + offsets = tl.program_id(0) * PARALLEL_DIM + tl.arange(0, PARALLEL_DIM)[:, None] + scale = tl.load(scale_ptr + offsets, mask=offsets < N) + tl.static_assert(scale.dtype == tl.uint8) + tl.static_assert(x.dtype == tl.uint8) + + scale_bf16 = (scale.to(tl.uint16) << 7).to(tl.bfloat16, bitcast=True) + if is_fp8: + if e_bits == 5 and m_bits == 2: + x_f8 = x.to(tl.float8e5, bitcast=True) + x_bf16 = x_f8.to(tl.bfloat16) + # Preserve infs and nans. FIXME Fp8E5M2_to_Bf16 doesn't preserve them! + non_finite_mask: tl.constexpr = ((1 << e_bits) - 1) << m_bits + non_finite_mask_bf16: tl.constexpr = ((1 << 8) - 1) << 7 + x_bf16 = tl.where( + x & non_finite_mask == non_finite_mask, + (x_bf16.to(tl.uint16, bitcast=True) | non_finite_mask_bf16).to(tl.bfloat16, bitcast=True), + x_bf16, + ) + else: + tl.static_assert(e_bits == 4 and m_bits == 3) + x_f8 = x.to(tl.float8e4nv, bitcast=True) + x_bf16 = x_f8.to(tl.bfloat16) + else: + # e2m1 + em0 = x & 0x70 + em1 = x & 0x7 + x0 = (em0.to(tl.uint16) << 2) | ((x & 0x80).to(tl.uint16) << 8) + x1 = (em1.to(tl.uint16) << (2 + 4)) | ((x & 0x8).to(tl.uint16) << (8 + 4)) + # Three cases: + # 1) x is normal and non-zero: Correct bias + x0 = tl.where((em0 & 0x60) != 0, x0 + ((127 - 1) << 7), x0) + x1 = tl.where((em1 & 0x6) != 0, x1 + ((127 - 1) << 7), x1) + # 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in bf16 + x0 = tl.where(em0 == 0x10, 16128 | (x0 & 0x8000), x0) + x1 = tl.where(em1 == 0x1, 16128 | (x1 & 0x8000), x1) + # 3) x is zero, do nothing + x_bf16 = tl.interleave(x0, x1).to(tl.bfloat16, bitcast=True) + # Multiplication preserves infs and NaNs in x_bf16 + mxfp = x_bf16 * scale_bf16 + # If scale is NaN, we encode it as an bf16 inf, so we need to correct for that + mxfp = tl.where(scale == 0xFF, float("nan"), mxfp) + + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + tl.store(mxfp_ptr + offsets, tl.ravel(mxfp), mask=offsets < N * 32) + + +def dot_scale_ref(x, scale, y, type_x, type_y): + e_bits, m_bits = {"e2m1": (2, 1), "e4m3": (4, 3), "e5m2": (5, 2)}[type_x] + type_fp8_y = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2}[type_y] + + comp_dtype = torch.float32 + out_dtype = torch.bfloat16 + + x = x.contiguous() + x_upcast = x.new_empty(scale.shape[:-1] + (32 * scale.shape[-1], ), dtype=comp_dtype) + + N = x_upcast.numel() + BLOCK_SIZE = 512 + grid = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, ) + mxfp_to_bf16_kernel[grid](x, scale, x_upcast, scale.numel(), e_bits, m_bits, BLOCK_SIZE, num_warps=4) + y_upcast = y.view(type_fp8_y) + + class AccumulateInFp32: + + def __enter__(self): + self.prev_value = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False + + def __exit__(self, exc_type, exc_val, exc_tb): + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = self.prev_value + + with AccumulateInFp32(): + return torch.matmul(x_upcast.to(out_dtype), y_upcast.to(out_dtype)) + + +@pytest.mark.parametrize("scale", [True, False]) +def test_pipeline_matmul(scale, device): check_capabilities() + if scale and not is_cuda(): + pytest.skip("NYI: scale_dot just implemented in CUDA") M, N, K = 512, 512, 128 BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 32 NUM_STAGES = 4 - a = torch.randn(M, K, device=device, dtype=torch.float16) - b = torch.randn(K, N, device=device, dtype=torch.float16) - output = torch.empty((M, N), dtype=torch.float16, device=device) + + if scale: + # TODO Use e5m2 for Ampere, as it does not support fp_to_fp conversions for fp8e4m3 + BLOCK_K = 64 # 32 NYI + K = BLOCK_K * NUM_STAGES + a_type = "e2m1" + DIV_FACTOR = 2 if a_type == "e2m1" else 1 + a = torch.randint(256, (M, K // DIV_FACTOR), device=device, dtype=torch.uint8) + # Sample small-ish scales to avoid overflow + scale_a = torch.randint(74, (M, K // 32), device=device, dtype=torch.uint8) + # Ampere does not support fp8e4m3 + b_type = "e4m3" if is_hopper() else "e5m2" + b = torch.randint(256, (K, N), device=device, dtype=torch.uint8) + # e5m2 has too many non-finite values when sampled uniformly (1 / 32) and + # Fp8E5M2_to_Bf16 doesn't preserve NaNs (fixme) + if b_type == "e5m2": + finite = torch.arange(K * N, device=device, dtype=torch.uint8).reshape(K, N) % 0x7C + b = torch.where(b & 0x7C == 0x7C, finite | (0x80 & b), b) + output = torch.empty((M, N), dtype=torch.bfloat16, device=device) + else: + a = torch.randn(M, K, device=device, dtype=torch.float16) + b = torch.randn(K, N, device=device, dtype=torch.float16) + scale_a = None + a_type, b_type = None, None + output = torch.empty((M, N), dtype=torch.float16, device=device) grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1) - if is_cuda_tma_available(): + use_tma = not scale and is_hopper() + + if use_tma: a_tma = triton.tools.experimental_descriptor.create_2d_tma_descriptor(a.data_ptr(), M, K, BLOCK_M, BLOCK_K, a.element_size()) b_tma = triton.tools.experimental_descriptor.create_2d_tma_descriptor(b.data_ptr(), K, N, BLOCK_K, BLOCK_N, @@ -124,19 +272,26 @@ def test_pipeline_matmul(device): handler = matmul_kernel_tma[grid](a_tma, b_tma, output_tma, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES=NUM_STAGES) else: - handler = matmul_kernel[grid](a, b, output, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), - output.stride(0), output.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, - NUM_STAGES=NUM_STAGES) - ref_out = torch.matmul(a, b) - atol = 1e-2 if is_hip_mi200() else None + # Pass K_MXFP to make explicit that KB is multiple of 32 and KA is multiple of 16 or 32ยบ + if scale: + K = scale_a.shape[-1] + stride_sm, stride_sk = scale_a.stride() if scale else (0, 0) + handler = matmul_kernel[grid](a, scale_a, b, output, M, N, K, a.stride(0), a.stride(1), stride_sm, stride_sk, + b.stride(0), b.stride(1), output.stride(0), output.stride(1), BLOCK_M, BLOCK_N, + BLOCK_K, NUM_STAGES=NUM_STAGES, a_type=a_type, b_type=b_type) + if scale: + ref_out = dot_scale_ref(a, scale_a, b, a_type, b_type) + else: + ref_out = torch.matmul(a, b) # Bigger tolerance for AMD MI200 devices. # MI200 devices use reduced precision fp16 and bf16 and flush input and # output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices - rtol = 1e-2 if is_hip_mi200() else None - torch.testing.assert_close(ref_out, output, atol=atol, rtol=rtol) + atol = 1e-2 if is_hip_mi200() or scale else None + rtol = 1e-2 if is_hip_mi200() or scale else None + torch.testing.assert_close(ref_out, output, atol=atol, rtol=rtol, equal_nan=scale) if is_cuda(): ttgir = handler.asm["ttgir"] - if is_cuda_tma_available(): + if use_tma: assert ttgir.count("triton_nvidia_gpu.async_tma_copy_global_to_local") != 0, "async tma copy not found" assert ttgir.count(f"num = {NUM_STAGES} : i32") == 0, "num_stages not match" # a_tma, b_tma, output_tma, barriar diff --git a/test/Conversion/amd/buffer_load_store.mlir b/test/Conversion/amd/buffer_load_store.mlir new file mode 100644 index 0000000000..209c7065d8 --- /dev/null +++ b/test/Conversion/amd/buffer_load_store.mlir @@ -0,0 +1,178 @@ +// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck %s + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: buffer_load + tt.func @buffer_load(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0>{tt.divisibility=16:i32}) { + // CHECK: %[[c_mask:.*]] = llvm.mlir.constant(true) : i1 + // CHECK: %[[offset:.*]] = llvm.select %[[c_mask]] + // CHECK: rocdl.raw.ptr.buffer.load {{.*}}, %[[offset]] + %ret = amdgpu.buffer_load %arg0[%offset] : tensor<128xf32, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: buffer_load_mask + tt.func @buffer_load_mask(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0> {tt.divisibility=16:i32}, %N : i32 {tt.divisibility = 16 : i32}) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<128xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<128xi32, #blocked0> + %5 = tt.splat %N: i32 -> tensor<128xi32, #blocked0> + %7 = arith.cmpi slt, %4, %5: tensor<128xi32, #blocked0> + // CHECK: %[[mask:.*]] = llvm.extractvalue %{{.*}} : !llvm.struct<(i1, i1, i1, i1)> + // CHECK: %[[offset:.*]] = llvm.select %[[mask]] + // CHECK: rocdl.raw.ptr.buffer.load {{.*}}, %[[offset]] + %ret = amdgpu.buffer_load %arg0[%offset], %7: tensor<128xf32, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: buffer_load_mask_other + tt.func @buffer_load_mask_other(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0> {tt.divisibility=16:i32}, %N : i32 {tt.divisibility = 16 : i32}) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<128xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<128xi32, #blocked0> + %5 = tt.splat %N: i32 -> tensor<128xi32, #blocked0> + %7 = arith.cmpi slt, %4, %5: tensor<128xi32, #blocked0> + %other = arith.constant dense<0.00e+00> : tensor<128xf32, #blocked0> + // CHECK: %[[mask:.*]] = llvm.extractvalue %{{.*}} : !llvm.struct<(i1, i1, i1, i1)> + // CHECK: %[[offset:.*]] = llvm.select %[[mask]] + // CHECK: rocdl.raw.ptr.buffer.load {{.*}}, %[[offset]] + // CHECK: llvm.select + %ret = amdgpu.buffer_load %arg0[%offset], %7, %other: tensor<128xf32, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: buffer_store + tt.func @buffer_store(%value : tensor<128xf32, #blocked0>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0>{tt.divisibility=16:i32}) { + // CHECK: %[[c_mask:.*]] = llvm.mlir.constant(true) : i1 + // CHECK: %[[offset:.*]] = llvm.select %[[c_mask]] + // CHECK: rocdl.raw.ptr.buffer.store {{.*}}, {{.*}}, %[[offset]] + amdgpu.buffer_store %value, %arg0[%offset] : tensor<128xf32, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: buffer_store_mask + tt.func @buffer_store_mask(%value : tensor<128xf32, #blocked0>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0> {tt.divisibility=16:i32}, %N : i32 {tt.divisibility = 16 : i32}) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<128xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<128xi32, #blocked0> + %5 = tt.splat %N: i32 -> tensor<128xi32, #blocked0> + %7 = arith.cmpi slt, %4, %5: tensor<128xi32, #blocked0> + // CHECK: %[[mask0:.*]] = llvm.extractvalue %{{.*}} : !llvm.struct<(i1, i1, i1, i1)> + // CHECK: %[[mask1:.*]] = llvm.and %[[mask0]], {{.*}} + // CHECK: %[[offset:.*]] = llvm.select %[[mask1]] + // CHECK: rocdl.raw.ptr.buffer.store {{.*}}, {{.*}}, %[[offset]] + amdgpu.buffer_store %value, %arg0[%offset], %7: tensor<128xf32, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: buffer_load_store_vec4 + tt.func @buffer_load_store_vec4(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<256xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<256xi32, #blocked0> + // Load 8 elements from A with two vectorized load instructions + // CHECK-COUNT-2: rocdl.raw.ptr.buffer.load {{.*}} : vector<4xf32> + %9 = amdgpu.buffer_load %arg0[%4] : tensor<256xf32, #blocked0> + // Load 8 elements from B with two vectorized load instructions + // CHECK-COUNT-2: rocdl.raw.ptr.buffer.load {{.*}} : vector<4xf32> + %10 = amdgpu.buffer_load %arg1[%4] : tensor<256xf32, #blocked0> + %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> + // Store 8 elements into C with two vectorized store instructions + // CHECK-COUNT-2: rocdl.raw.ptr.buffer.store {{.*}} : vector<4xf32> + amdgpu.buffer_store %11, %arg2[%4]: tensor<256xf32, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: buffer_load_store_vec1 + tt.func @buffer_load_store_vec1(%arg0: !tt.ptr , %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<256xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<256xi32, #blocked0> + %5 = tt.splat %arg3 : i32 -> tensor<256xi32, #blocked0> + %7 = arith.cmpi slt, %4, %5: tensor<256xi32, #blocked0> + // Load 8 elements from A with eight scalar load instructions + // CHECK-COUNT-8: rocdl.raw.ptr.buffer.load {{.*}} : f32 + %9 = amdgpu.buffer_load %arg0[%4], %7 : tensor<256xf32, #blocked0> + // Load 8 elements from B with two scalar load instructions + // CHECK-COUNT-8: rocdl.raw.ptr.buffer.load {{.*}} : f32 + %10 = amdgpu.buffer_load %arg1[%4], %7 : tensor<256xf32, #blocked0> + %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> + // Store 8 elements into C with two scalar store instructions + // CHECK-COUNT-8: rocdl.raw.ptr.buffer.store {{.*}} : f32 + amdgpu.buffer_store %11, %arg2[%4], %7 : tensor<256xf32, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: buffer_load_store_vec2 + tt.func @buffer_load_store_vec2(%arg0: !tt.ptr {tt.divisibility = 4 : i32}, %arg1: !tt.ptr{tt.divisibility = 4 : i32}, %arg2: !tt.ptr{tt.divisibility = 4: i32}, %arg3: i32{tt.divisibility = 4: i32}) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<256xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<256xi32, #blocked0> + %5 = tt.splat %arg3 : i32 -> tensor<256xi32, #blocked0> + %7 = arith.cmpi slt, %4, %5: tensor<256xi32, #blocked0> + // Load 8 fp16 elements from A with four i32 scalar load instructions + // CHECK-COUNT-4: rocdl.raw.ptr.buffer.load {{.*}} : i32 + %9 = amdgpu.buffer_load %arg0[%4], %7 : tensor<256xf16, #blocked0> + // Load 8 fp16 elements from B with four i32 scalar load instructions + // CHECK-COUNT-4: rocdl.raw.ptr.buffer.load {{.*}} : i32 + %10 = amdgpu.buffer_load %arg1[%4], %7 : tensor<256xf16, #blocked0> + %11 = arith.addf %9, %10 : tensor<256xf16, #blocked0> + // Store 8 fp16 elements into C with four i32 scalar store instructionss + // CHECK-COUNT-4: rocdl.raw.ptr.buffer.store {{.*}} : i32 + amdgpu.buffer_store %11, %arg2[%4], %7 : tensor<256xf16, #blocked0> + tt.return + } +} diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 78c6f68bf6..682c1cb301 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -2607,3 +2607,45 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return %outLHS : tensor<128x64xf32, #blocked1> } } + +// ----- + +#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#CL = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> +#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> + +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { + // CHECK-LABEL: matmul_add + tt.func @matmul_add(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %C : !tt.ptr) { + %a_ptr_init = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %b_ptr_init = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> + %c_ptr_init = tt.splat %C : !tt.ptr -> tensor<128x128x!tt.ptr, #CL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #CL> + %cst = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %100:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #CL>) { + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> + %c = tt.dot %a, %b, %cst : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> + %t = triton_gpu.convert_layout %c : tensor<128x128xf32, #C> -> tensor<128x128xf32, #CL> + // CHECK: %[[T0:.*]] = tt.dot + // CHECK: arith.addf %{{.*}}, %[[T0]] : tensor<128x128xf32, #mma> + %t2 = arith.addf %prev_c, %t : tensor<128x128xf32, #CL> + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + // CHECK: scf.yield + scf.yield %next_a_ptr, %next_b_ptr, %t2 : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #CL> + } + + // CHECK: triton_gpu.convert_layout {{.*}} : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked + tt.store %c_ptr_init, %100#2 : tensor<128x128x!tt.ptr, #CL> + tt.return + } +} diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h index f7f824e30b..a7395f86dc 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h @@ -29,6 +29,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/PatternMatch.h" +#include "triton/Dialect/Triton/IR/Traits.h" // clang-format off #include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h.inc" // clang-format on diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index 4721d14ecb..538e31378f 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -26,9 +26,12 @@ #define TRITON_AMDGPU_OPS include "mlir/IR/OpBase.td" +include "triton/Dialect/Triton/IR/TritonAttrDefs.td" include "mlir/IR/EnumAttr.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" include "TritonAMDGPUDialect.td" include "TritonAMDGPUAttrDefs.td" @@ -36,6 +39,11 @@ class TT_AMDGPU_Op traits = []> : Op { } +// +// Interfaces +// +def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; + def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> { let summary = "A placeholder op for instruction scheduling hints within a basic block"; let description = [{ @@ -52,4 +60,73 @@ def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> { let assemblyFormat = [{attr-dict}]; } +// +// AMD Buffer operations. +// +def BufferLoadOp : TT_AMDGPU_Op<"buffer_load", [ + SameLoadStoreOperandsAndResultEncoding, + AttrSizedOperandSegments, + MemoryEffects<[MemRead]>, + TypesMatchWith<"result element type matches the pointed type of ptr", "result", "ptr", "getPointerTypeToElement($_self)">, + TypesMatchWith<"result and offsets have the same shape", "result", "offsets", "getI32SameShape($_self)">, + TypesMatchWith<"result and mask have the same shape", "result", "mask", "getI1SameShape($_self)", + "($_op.getOperands().size() <= 2) || std::equal_to<>()">, + TypesMatchWith<"result and other have the same type", "result", "other", "$_self", + "($_op.getOperands().size() <= 3) || std::equal_to<>()">, +]>{ + let summary = "Load from a scalar base pointer and a tensor offset"; + let description = [{ + AMD Buffer load operation. Buffer store is similar to + a normal store but it accesses global memory via a scalar base pointer + and a tensor of offsets instead of a tensor of pointers. The other fields + are similar to a normal load, i.e., the `mask` is a boolean vector that + determines if a given element should be read from memory, and `other` is the + element that should be returned on lane `i` when `mask[i] == 0`. + }]; + let arguments = ( + ins + TT_Ptr:$ptr, + I32Tensor:$offsets, + Optional:$mask, + Optional:$other + ); + let results = (outs TT_Tensor:$result); + + let assemblyFormat = [{ + $ptr `[` $offsets `]` (`,` $mask^)? (`,` $other^)? + attr-dict `:` type($result) + }]; +} + +def BufferStoreOp : TT_AMDGPU_Op<"buffer_store", [ + SameLoadStoreOperandsEncoding, + MemoryEffects<[MemWrite]>, + TypesMatchWith<"value element type matches the pointed type of ptr", "value", "ptr", "getPointerTypeToElement($_self)">, + TypesMatchWith<"value and offsets have the same shape", "value", "offsets", "getI32SameShape($_self)">, + TypesMatchWith<"value and mask have the same shape", "value", "mask", "getI1SameShape($_self)", + "($_op.getOperands().size() <= 3) || std::equal_to<>()">, +]>{ + let summary = "Store into scalar base pointer and a tensor offset"; + let description = [{ + AMD Buffer store operation. Buffer store is similar to + normal store but it accesses global memory via a scalar base pointer + and a tensor of offsets instead of a tensor of pointers. The other fields + are similar to a normal store , i.e., the `mask` is a boolean vector that + determines if a given element should be written to memory, and `value` is the + tensor of elements that should be written on lane `i` when `mask[i] == 1`. + }]; + let arguments = ( + ins + TT_Tensor:$value, + TT_Ptr:$ptr, + I32Tensor:$offsets, + Optional:$mask + ); + + let assemblyFormat = [{ + $value `,` $ptr `[` $offsets `]` (`,` $mask^)? + attr-dict `:` type($value) + }]; +} + #endif diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index 5631d56b24..a82a77e9f5 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -21,6 +21,7 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +#include "triton/Dialect/Triton/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp new file mode 100644 index 0000000000..be009af4d1 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp @@ -0,0 +1,175 @@ +#include "PatternTritonGPUOpToLLVM.h" +#include "TargetInfo.h" +#include "Utility.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "BufferOpsEmitter.h" + +using mlir::triton::gpu::appendOrGetExternFuncOp; +using mlir::triton::gpu::getFunctionType; +using namespace triton::AMD; + +namespace { + +// Utility function to determine if a scalar/tensor value is zero +bool isZero(Value v) { + if (auto constantOp = v.getDefiningOp()) { + if (auto attr = dyn_cast(constantOp.getValue())) + return attr.getValue().isZero(); + if (auto attr = dyn_cast(constantOp.getValue())) + return attr.getValue().isZero(); + if (auto denseAttr = + dyn_cast(constantOp.getValueAttr())) + return denseAttr.isSplat() && denseAttr.getSplatValue().isZero(); + if (auto denseAttr = + dyn_cast(constantOp.getValueAttr())) + return denseAttr.isSplat() && denseAttr.getSplatValue().isZero(); + } + return false; +} +} // namespace + +namespace mlir::LLVM::AMD { +BufferEmitter::BufferEmitter(RewriterBase &rw, Location loc, TargetInfo ti) + : rewriter(rw), loc(loc), targetInfo(ti) {} + +Value BufferEmitter::createResourceDescriptor(Value basePtr) { + // 1. Create the resource descriptor + // bits 0-11: dst sel, ignored by these intrinsics + // bits 12-14: data format (ignored, must be nonzero, 7=float) + // bits 15-18: data format (ignored, must be nonzero, 4=32bit) + // bit 19: In nested heap (0 here) + // bit 20: Behavior on unmap (0 means "return 0 / ignore") + // bits 21-22: Index stride for swizzles (N/A) + // bit 23: Add thread ID (0) + // bit 24: Reserved to 1 (RDNA) or 0 (CDNA) + // bits 25-26: Reserved (0) + // bit 27: Buffer is non-volatile (CDNA only) + // bits 28-29: Out of bounds select (RDNA only) + // (0 = structured, + // 1 = check index, + // 2 = none, + // 3 = either swizzles or testing against offset field) + // bits 30-31: Type (must be 0) + uint32_t flags = (7 << 12) | (4 << 15); + if (targetInfo.getISAFamily() == ISAFamily::RDNA2 || + targetInfo.getISAFamily() == ISAFamily::RDNA3) { + flags |= (1 << 24); + uint32_t oob = 3; + flags |= (oob << 28); + } + Value stride = int_val(16, 0); + Value flagsConst = int_val(32, flags); + Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8); + Value numRecordsByte = int_val(32, std::numeric_limits::max() - 1); + + Value resource = rewriter.createOrFold( + loc, rsrcType, basePtr, stride, numRecordsByte, flagsConst); + return resource; +} + +Value BufferEmitter::emitLoad(Type type, Value rsrcDesc, Value offset, + Value pred, Value falseVal) { + SmallVector args; + fillCommonArgs(type, rsrcDesc, offset, pred, args); + Type bufferType = getBufferOpType(type); + Value data = rewriter.create( + loc, bufferType, args, ArrayRef()); + data = bitcast(data, type); + if (!isZero(falseVal)) + data = select(pred, data, falseVal); + return data; +} + +void BufferEmitter::emitStore(Value rsrcDesc, Value offset, Value data, + Value pred) { + VectorType vecTy = cast(data.getType()); + Type bufferType = getBufferOpType(vecTy); + if (vecTy != bufferType) + data = bitcast(data, bufferType); + SmallVector args{data}; + fillCommonArgs(vecTy, rsrcDesc, offset, pred, args); + rewriter.create(loc, TypeRange{}, args, + ArrayRef()); +} + +Type BufferEmitter::getBufferOpType(Type type) { + int64_t vecSize = 1; + Type elementType = type; + if (auto vecType = dyn_cast(type)) { + vecSize = vecType.getNumElements(); + elementType = vecType.getElementType(); + } + + const int valueElemNBits = std::max(8u, elementType.getIntOrFloatBitWidth()); + const size_t totalWidthBits = valueElemNBits * vecSize; + + // For bf16, always convert to i16 + Type bufferElementType = elementType; + if (elementType.isBF16()) + bufferElementType = rewriter.getI16Type(); + + // If we are dealing with a subword type (e.g., i8 or f16) but we + // still need multiple words, then pack the subwords into 32bit integers + // and update the vector length and the type + int64_t bufferVecSize = vecSize; + if (valueElemNBits < 32) { + if (totalWidthBits > 32) { + bufferElementType = rewriter.getI32Type(); + bufferVecSize = totalWidthBits / 32; + } else { + bufferElementType = rewriter.getIntegerType(totalWidthBits); + bufferVecSize = 1; + } + } + + // This is the buffer type that the buffer operation will use. It + // will be bitcast-able to the original type. So if the types + // ended up different, we simply have to emit a `bitcastOp` to convert + Type bufferType = type; + if (bufferVecSize != vecSize) + bufferType = VectorType::get(bufferVecSize, bufferElementType); + if (bufferVecSize == 1) + bufferType = getElementTypeOrSelf(bufferType); + + return bufferType; +} + +void BufferEmitter::fillCommonArgs(Type type, Value rsrcDesc, + Value vOffsetElems, Value pred, + SmallVector &args) { + + // 1. Create the (masked) offset + Type elementType = getElementTypeOrSelf(type); + const int valueElemNBits = std::max(8u, elementType.getIntOrFloatBitWidth()); + const int elementByteWidth = valueElemNBits / 8; + // Please note: the index passed is not in bytes, but in number of elements + // In order to pass the index to the buffer operation, we need to convert in + // bytes (i.e., we need to multiply by `elementByteWidth`) + Value vOffsetOutOfBunds = int_val( + 32, static_cast(std::numeric_limits::max() + int64_t(1))); + Value vOffsetBytes = mul(int_val(32, elementByteWidth), vOffsetElems); + Value maskedOffsetBytes = select(pred, vOffsetBytes, vOffsetOutOfBunds); + + // 2. Set the sgprOffset to 0 + Value sgprOffset = int_val(32, 0); + + // 3. Create the cache modifiers word + // bit 0: GLC = 0 (atomics drop value, less coherency) + // bits 1-2: SLC, DLC = 0 (similarly) + // bit 3: swizzled (0 for raw) + Value cacheModifiers = int_val(32, 0); + + // 5. Add the arguments + args.push_back(rsrcDesc); + args.push_back(maskedOffsetBytes); + args.push_back(sgprOffset); + args.push_back(cacheModifiers); +} +} // namespace mlir::LLVM::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.h b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.h new file mode 100644 index 0000000000..ad6d46ff78 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.h @@ -0,0 +1,93 @@ +#ifndef TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_BUFFER_OPS_EMITTER_H +#define TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_BUFFER_OPS_EMITTER_H + +#include "TargetInfo.h" +#include "TritonAMDGPUToLLVM/GCNAsmFormat.h" + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace mlir::LLVM::AMD { +// Utility class to take care of buffer operation emission. We may add more +// emitters into this as needed. Buffer operations accept a memory descriptor +// and an offset. +// +// The memory descriptor is stored in s_gprs and hence needs to +// be uniform across the wave. It contains two fields (among many others): +// +// - `base_pointer`: represents the (scalar) pointer to the memory area +// - `num_records`: represents the size of the memory region. This is a +// 32 bit unsigned integer +// +// The offset can be non-uniform across the wave (and hence stored in vgprs). +// +// The high level behaviour of a buffer operation can be described as: +// ``` +// def buffer_op(mem_desc, offset): +// address = splat(mem_desc.base_pointer) +// address += offset +// return buffer_op(address) +// ``` +// This means we don't need to store the addresses in vgprs and we need less +// VALU operations to compute the final address. +// +// Also note that buffer operations support out-of-boundary memory access. +// I.e., if offset[i] > mem_desc.num_records the operation is a nop for the i-th +// thread. +// +// This can be exploited to support masked operations, like in the following +// snippet: +// ``` +// def masked_op(base_ptr, offset, pred) +// mem_desc.base_ptr = base_ptr +// mem_desc.num_records = max_int_32 +// oob_offset = max_int_32+1 +// masked_offset = (pred ? offset : oob_offset) +// buffer_op(mem_desc, masked_offset) +// ``` +// To use buffer operations three main requirements need to be met: +// +// 1. The buffer pointer needs to be a scalar, it cannot be non-uniform across +// threads of the given wave +// 2. The offset needs to be expressed in 32 bits +// 3. The offset needs to be non-negative +// +// Failure to meet 1) will result in a scalarized loop (very poor performance). +// Failure to meet 2) and 3) will result in incorrect memory access. +struct BufferEmitter { + BufferEmitter(RewriterBase &rw, Location loc, + mlir::triton::AMD::TargetInfo ti); + + // Create a resource descriptor that points to the area of memory we want to + // load from + Value createResourceDescriptor(Value basePtr); + + // Emit a predicated rocdl.raw.ptr.buffer.load + Value emitLoad(Type type, Value rsrcDesc, Value offset, Value pred, + Value falseVal); + + // Emit a predicated rocdl.raw.ptr.buffer.store + void emitStore(Value rsrcDesc, Value offset, Value data, Value pred); + +private: + // Fill common buffer operation arguments. + void fillCommonArgs(Type type, Value rsrcDesc, Value vOffsetElems, Value pred, + SmallVector &args); + + // Given a type, the buffer type can be either the same type + // or a packed version. E.g., a vector of 8xfp16 can be bitcasted to + // a vector of 4xi32. This usually makes the life of the backend easier + Type getBufferOpType(Type type); + + // Rewriter utilities + RewriterBase &rewriter; + Location loc; + mlir::triton::AMD::TargetInfo targetInfo; +}; + +} // namespace mlir::LLVM::AMD + +#endif // TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_BUFFER_OPS_EMITTER_H diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt index dc05155527..b6a514f450 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt @@ -1,4 +1,5 @@ add_triton_library(TritonAMDGPUToLLVM + BufferOpsEmitter.cpp ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 6009156cfc..f7dc8755fa 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1,11 +1,17 @@ +#include "BufferOpsEmitter.h" +#include "Dialect/TritonAMDGPU/IR/Dialect.h" #include "PatternTritonGPUOpToLLVM.h" #include "TargetInfo.h" #include "Utility.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Transforms/DialectConversion.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" using namespace mlir; @@ -98,6 +104,44 @@ struct LoadStoreConversionBase { ModuleAxisInfoAnalysis &axisAnalysisPass) : targetInfo(targetInfo), axisAnalysisPass(axisAnalysisPass) {} + // Createa a LLVM vector of type `vecTy` containing all zeros + Value createZeroVector(OpBuilder &builder, Location loc, + VectorType vecTy) const { + mlir::Attribute zeroAttr = builder.getZeroAttr(vecTy.getElementType()); + auto denseValue = + DenseElementsAttr::get(cast(vecTy), zeroAttr); + Value zeroVal = builder.create(loc, vecTy, denseValue); + return zeroVal; + } + + // Given a vector of values `elems` and a starting point `start`, create a + // LLVM vector of length `vec` whose elements are `elems[start, ..., + // elems+vec-1]` + Value packElementRangeIntoVector(ConversionPatternRewriter &rewriter, + const LLVMTypeConverter *typeConverter, + Location loc, VectorType vecTy, + ArrayRef elems, int64_t start) const { + int64_t vec = vecTy.getNumElements(); + // If we need to mask the loaded value with other elements + Value v = undef(vecTy); + for (size_t s = 0; s < vec; ++s) { + Value otherElem = elems[start + s]; + Value indexVal = + LLVM::createIndexConstant(rewriter, loc, typeConverter, s); + v = insert_element(vecTy, v, otherElem, indexVal); + } + return v; + } + + // Return a tensor of pointers with the same type of `basePtr` and the same + // shape of `offset` + Type getPointerTypeWithShape(Value basePtr, Value offset) const { + Type basePtrType = basePtr.getType(); + auto offsetType = cast(offset.getType()); + return offsetType.cloneWith(std::nullopt, basePtrType); + } + + // Get contiguity for a tensor pointer `ptr` unsigned getContiguity(Value ptr) const { auto tensorTy = dyn_cast(ptr.getType()); if (!tensorTy) @@ -105,16 +149,63 @@ struct LoadStoreConversionBase { return axisAnalysisPass.getPtrContiguity(ptr); } + // Get contiguity for a scalar pointer `ptr` and a tensor `offset` + unsigned getContiguity(Value ptr, Value offset) const { + // Get contiguity from the offset + Type type = getPointerTypeWithShape(ptr, offset); + RankedTensorType tensorTy = cast(type); + auto layout = tensorTy.getEncoding(); + auto order = triton::gpu::getOrder(layout); + auto uniqueContigPerThread = + triton::gpu::getUniqueContigPerThread(layout, tensorTy.getShape()); + assert(order[0] < uniqueContigPerThread.size() && + "Unexpected uniqueContigPerThread size"); + unsigned contiguity = uniqueContigPerThread[order[0]]; + + // Get alignment from the pointer. Since this is a scalar pointer + // we should not take the pointer contiguity to consider alignment + auto *axisInfo = axisAnalysisPass.getAxisInfo(ptr); + auto maxMultipleBytes = axisInfo->getDivisibility(order[0]); + auto elemNumBits = triton::getPointeeBitWidth(tensorTy); + auto elemNumBytes = std::max(elemNumBits / 8, 1); + auto align = std::max(maxMultipleBytes / elemNumBytes, 1); + + // Final contiguity is a min of the offset contiguity and pointer alignment + contiguity = std::min(align, contiguity); + return contiguity; + } + + // Determine the vector size of a tensor of pointers unsigned getVectorSize(Value ptr) const { auto tensorTy = dyn_cast(ptr.getType()); if (!tensorTy) return 1; auto contiguity = getContiguity(ptr); auto pointeeBitWidth = triton::getPointeeBitWidth(tensorTy); - // The maximum vector size is 128 bits on NVIDIA GPUs. return std::min(128 / pointeeBitWidth, contiguity); } + // Given a scalar pointer and a tensor of offsets, determine the vector size + unsigned getVectorSize(Value ptr, Value offset) const { + auto contiguity = getContiguity(ptr, offset); + auto pointeeBitWidth = triton::getPointeeBitWidth(ptr.getType()); + return std::min(128 / pointeeBitWidth, contiguity); + } + + // Unpack the elements contained in a `llvmStruct` into a `SmallVector` of + // `Value`s. While you do that, check also the alignment of the mask and + // update the vector length `vec` accordingly + SmallVector + getMaskElemsAndUpdateVeclen(ConversionPatternRewriter &rewriter, Location loc, + Value llMask, Value mask, unsigned &vec) const { + SmallVector maskElems; + if (llMask) { + vec = std::min(vec, getMaskAlignment(mask)); + maskElems = unpackLLElements(loc, llMask, rewriter); + } + return maskElems; + } + unsigned getMaskAlignment(Value mask) const { return axisAnalysisPass.getMaskAlignment(mask); } @@ -163,36 +254,18 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, typeConverter->convertType(getElementTypeOrSelf(valueTy)); unsigned vec = getVectorSize(ptr); unsigned numElems = getTotalElemsPerThread(ptr.getType()); - if (llMask) - vec = std::min(vec, getMaskAlignment(mask)); // Get the LLVM values for pointers auto ptrElems = unpackLLElements(loc, llPtr, rewriter); assert(ptrElems.size() == numElems); // Get the LLVM values for mask - SmallVector maskElems; - if (llMask) { - maskElems = unpackLLElements(loc, llMask, rewriter); - assert(maskElems.size() == numElems); - } + SmallVector maskElems = + getMaskElemsAndUpdateVeclen(rewriter, loc, llMask, mask, vec); - // Get the LLVM values for `other` - // TODO: (goostavz) handle when other is const but not splat, which - // should be rarely seen - bool otherIsSplatConstInt = false; - DenseElementsAttr constAttr; - int64_t splatVal = 0; - if (other && isa(valueElemTy) && - matchPattern(other, m_Constant(&constAttr)) && constAttr.isSplat() && - isa(constAttr.getElementType())) { - otherIsSplatConstInt = true; - splatVal = constAttr.getSplatValue().getSExtValue(); - } SmallVector otherElems; - if (other) { + if (other) otherElems = unpackLLElements(loc, llOther, rewriter); - } // vectorized iteration through all the pointer/mask/other elements const int valueElemNBits = @@ -204,8 +277,6 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, auto cacheMod = op.getCache(); SmallVector loadedVals; for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { - size_t in_off = 0; - const size_t maxWordWidth = std::max(32, valueElemNBits); const size_t totalWidth = valueElemNBits * vec; const size_t width = std::min(totalWidth, maxWordWidth); @@ -218,29 +289,100 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, auto vecTy = LLVM::getFixedVectorType(valueElemTy, vec); Value ptr = addrspacecast(ptr_ty(getContext()), ptrElems[vecStart]); - mlir::Attribute zeroAttr = rewriter.getZeroAttr(valueElemTy); - auto denseValue = - DenseElementsAttr::get(cast(vecTy), zeroAttr); - Value zeroVal = rewriter.create(loc, vecTy, denseValue); - - Value falseVal = zeroVal; + Value falseVal = createZeroVector(rewriter, loc, cast(vecTy)); // If we need to mask the loaded value with other elements - if (otherElems.size() != 0) { - Value v = undef(vecTy); - for (size_t s = 0; s < vec; ++s) { - Value otherElem = otherElems[vecStart + s]; - Value indexVal = LLVM::createIndexConstant( - rewriter, loc, this->getTypeConverter(), s); - v = insert_element(vecTy, v, otherElem, indexVal); - } - falseVal = v; - } + if (otherElems.size() != 0) + falseVal = packElementRangeIntoVector( + rewriter, this->getTypeConverter(), loc, cast(vecTy), + otherElems, vecStart); Value loadVal = llLoad(rewriter, loc, ptr, vecTy, pred, falseVal, ptrAlignmentBytes, cacheMod); for (size_t ii = 0; ii < vec; ++ii) { Value vecIdx = createIndexAttrConstant( - rewriter, loc, this->getTypeConverter()->getIndexType(), ii % vec); + rewriter, loc, this->getTypeConverter()->getIndexType(), ii); + Value loaded = extract_element(valueElemTy, loadVal, vecIdx); + loadedVals.push_back(loaded); + } + } // end vec + + Type llvmResultStructTy = getTypeConverter()->convertType(valueTy); + Value resultStruct = packLLElements(loc, getTypeConverter(), loadedVals, + rewriter, llvmResultStructTy); + rewriter.replaceOp(op, {resultStruct}); + return success(); + } +}; + +struct BufferLoadOpConversion + : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + using ConvertOpToLLVMPattern< + triton::amdgpu::BufferLoadOp>::ConvertOpToLLVMPattern; + + BufferLoadOpConversion(LLVMTypeConverter &converter, + const AMD::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, + benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::amdgpu::BufferLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + LLVM::AMD::BufferEmitter bufferEmitter(rewriter, loc, targetInfo); + + // original values + Value ptr = op.getPtr(); + Value offset = op.getOffsets(); + Value mask = op.getMask(); + Value other = op.getOther(); + + // Converted values + Value llPtr = adaptor.getPtr(); + Value llOffset = adaptor.getOffsets(); + Value llMask = adaptor.getMask(); + Value llOther = adaptor.getOther(); + + // Determine the vectorization size + Type valueTy = op.getType(); + Type valueElemTy = + typeConverter->convertType(getElementTypeOrSelf(valueTy)); + Type ptrType = getPointerTypeWithShape(ptr, offset); + unsigned numElems = getTotalElemsPerThread(ptrType); + unsigned vec = getVectorSize(ptr, offset); + + // Get the offset + SmallVector offsetElems = unpackLLElements(loc, llOffset, rewriter); + assert(offsetElems.size() == numElems); + + // Get the mask + SmallVector maskElems = + getMaskElemsAndUpdateVeclen(rewriter, loc, llMask, mask, vec); + + // Get the `other` value (if any) + SmallVector otherElems; + if (llOther) + otherElems = unpackLLElements(loc, llOther, rewriter); + + // Create the resource descriptor and then emit the buffer_load intrinsic(s) + Value rsrcDesc = bufferEmitter.createResourceDescriptor(llPtr); + SmallVector loadedVals; + Type vecTy = LLVM::getFixedVectorType(valueElemTy, vec); + for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { + Value pred = mask ? maskElems[vecStart] : int_val(1, 1); + Value falseVal = createZeroVector(rewriter, loc, cast(vecTy)); + if (otherElems.size() != 0) + falseVal = packElementRangeIntoVector( + rewriter, this->getTypeConverter(), loc, cast(vecTy), + otherElems, vecStart); + Value loadVal = bufferEmitter.emitLoad( + vecTy, rsrcDesc, offsetElems[vecStart], pred, falseVal); + for (size_t ii = 0; ii < vec; ++ii) { + Value vecIdx = createIndexAttrConstant( + rewriter, loc, this->getTypeConverter()->getIndexType(), ii); Value loaded = extract_element(valueElemTy, loadVal, vecIdx); loadedVals.push_back(loaded); } @@ -283,6 +425,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, Type valueElemTy = typeConverter->convertType(getElementTypeOrSelf(valueTy)); + // Determine the vectorization size unsigned vec = getVectorSize(ptr); unsigned elemsPerThread = getTotalElemsPerThread(ptr.getType()); @@ -290,15 +433,8 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, auto valueElems = unpackLLElements(loc, llValue, rewriter); assert(ptrElems.size() == valueElems.size()); - // Determine the vectorization size - SmallVector maskElems; - if (llMask) { - maskElems = unpackLLElements(loc, llMask, rewriter); - assert(valueElems.size() == maskElems.size()); - - unsigned maskAlign = getMaskAlignment(mask); - vec = std::min(vec, maskAlign); - } + SmallVector maskElems = + getMaskElemsAndUpdateVeclen(rewriter, loc, llMask, mask, vec); const size_t valueElemNBits = std::max(8, valueElemTy.getIntOrFloatBitWidth()); @@ -309,7 +445,6 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, const int numVecs = elemsPerThread / vec; Value rDataMask = redundantDataMask(valueTy, rewriter, loc, targetInfo); for (size_t vecStart = 0; vecStart < elemsPerThread; vecStart += vec) { - size_t in_off = 0; Value pred = mask ? and_(maskElems[vecStart], rDataMask) : rDataMask; auto vecTy = LLVM::getFixedVectorType(valueElemTy, vec); @@ -320,21 +455,14 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, const size_t wordNElems = width / valueElemNBits; assert(wordNElems * nWords * numVecs == elemsPerThread); - Type valArgTy = IntegerType::get(ctx, width); - auto wordTy = vec_ty(valueElemTy, wordNElems); - SmallVector> asmArgs; Value elem = valueElems[vecStart]; Value ptr = addrspacecast(ptr_ty(getContext()), ptrElems[vecStart]); // Create the store val - Value storeVal = undef(vecTy); - for (size_t s = 0; s < vec; ++s) { - Value otherElem = valueElems[vecStart + s]; - Value indexVal = createIndexAttrConstant( - rewriter, loc, this->getTypeConverter()->getIndexType(), s); - storeVal = insert_element(vecTy, storeVal, otherElem, indexVal); - } + Value storeVal = packElementRangeIntoVector( + rewriter, this->getTypeConverter(), loc, cast(vecTy), + valueElems, vecStart); llStore(rewriter, loc, ptr, storeVal, pred, ptrAlignmentBytes, cacheMod); } // end vec rewriter.eraseOp(op); @@ -342,6 +470,71 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, } }; +struct BufferStoreOpConversion + : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + using ConvertOpToLLVMPattern< + triton::amdgpu::BufferStoreOp>::ConvertOpToLLVMPattern; + + BufferStoreOpConversion(LLVMTypeConverter &converter, + const AMD::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, + benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::amdgpu::BufferStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + LLVM::AMD::BufferEmitter bufferEmitter(rewriter, loc, targetInfo); + + // original values + Value ptr = op.getPtr(); + Value offset = op.getOffsets(); + Value mask = op.getMask(); + Value data = op.getValue(); + + Value llPtr = adaptor.getPtr(); + Value llOffset = adaptor.getOffsets(); + Value llMask = adaptor.getMask(); + Value llData = adaptor.getValue(); + + // Determine the vectorization size + Type valueTy = data.getType(); + Type valueElemTy = + typeConverter->convertType(getElementTypeOrSelf(valueTy)); + Type ptrType = getPointerTypeWithShape(ptr, offset); + + unsigned numElems = getTotalElemsPerThread(ptrType); + unsigned vec = getVectorSize(ptr, offset); + + // Get the offsets and value + SmallVector offsetElems = unpackLLElements(loc, llOffset, rewriter); + SmallVector valueElems = unpackLLElements(loc, llData, rewriter); + + // Get the mask + SmallVector maskElems = + getMaskElemsAndUpdateVeclen(rewriter, loc, llMask, mask, vec); + + Value rsrcDesc = bufferEmitter.createResourceDescriptor(llPtr); + Value rDataMask = redundantDataMask(valueTy, rewriter, loc, targetInfo); + for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { + Type vecTy = LLVM::getFixedVectorType(valueElemTy, vec); + Value pred = mask ? and_(maskElems[vecStart], rDataMask) : rDataMask; + // Create the store val + Value storeVal = packElementRangeIntoVector( + rewriter, this->getTypeConverter(), loc, cast(vecTy), + valueElems, vecStart); + bufferEmitter.emitStore(rsrcDesc, offsetElems[vecStart], storeVal, pred); + } // end vec + + rewriter.eraseOp(op); + return success(); + } +}; + static LLVM::AtomicOrdering getMemoryOrdering(MemSemantic memOrdering) { switch (memOrdering) { case MemSemantic::RELAXED: @@ -679,8 +872,9 @@ void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) { - patterns.add(typeConverter, targetInfo, axisInfoAnalysis, - benefit); + patterns + .add( + typeConverter, targetInfo, axisInfoAnalysis, benefit); } } // namespace mlir::triton::AMD