diff --git a/benchmarks/triton_kernels_benchmark/benchmark_testing.py b/benchmarks/triton_kernels_benchmark/benchmark_testing.py index 03d1f1891b..9d1020b95d 100644 --- a/benchmarks/triton_kernels_benchmark/benchmark_testing.py +++ b/benchmarks/triton_kernels_benchmark/benchmark_testing.py @@ -153,8 +153,7 @@ def do_bench_elapsed_time(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quan warmup_time = n_warmup * estimate_ms rep_time = n_repeat * estimate_ms - times = triton_do_bench(fn, warmup=warmup_time, rep=rep_time, grad_to_none=grad_to_none, return_mode="all", - device_type=device) + times = triton_do_bench(fn, warmup=warmup_time, rep=rep_time, grad_to_none=grad_to_none, return_mode="all") times = torch.tensor(times, dtype=torch.float) return _summarize_statistics(times, quantiles, return_mode) diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index 3d911f7cb9..457cbeaaad 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -35,5 +35,4 @@ add_triton_library(TritonGPUToLLVM TritonGPUTransforms TritonIntelGPUTransforms TritonNvidiaGPUTransforms - NVGPUIR ) diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 8762942c31..8ee1668669 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -41,36 +41,60 @@ SmallVector reorderValues(const SmallVector &values, Type inType, if (inBitWidth == ouBitWidth) return values; if (inBitWidth == 16 && ouBitWidth == 32) { + // Register layout conversion: + // + // [0, 1], [4, 5] ⟶ [0], [1], [4], [5] + // [2, 3], [6, 7] [2], [3], [6], [7] + // + // Original access order: + // + // [0, 1], [2, 3], [4, 5], [6, 7] + // + // Transformed access order: + // + // [0], [2], [1], [3], [4], [6], [5], [7] SmallVector ret; for (unsigned i = 0; i < values.size(); i += 8) { ret.push_back(values[i]); - ret.push_back(values[i + 1]); - ret.push_back(values[i + 4]); - ret.push_back(values[i + 5]); ret.push_back(values[i + 2]); + ret.push_back(values[i + 1]); ret.push_back(values[i + 3]); + ret.push_back(values[i + 4]); ret.push_back(values[i + 6]); + ret.push_back(values[i + 5]); ret.push_back(values[i + 7]); } return ret; } if (inBitWidth == 8 && ouBitWidth == 16) { + // Register layout conversion: + // + // [0, 1, 2, 3], [8, 9, 10, 11] ⟶ [0, 1], [2, 3], [8, 9], [10, 11] + // [4, 5, 6, 7], [12, 13, 14, 15] [4, 5], [6, 7], [12, 13], [14, 15] + // + // Original access order: + // + // [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15] + // + // Transformed access order: + // + // [0, 1], [4, 5], [2, 3], [6, 7], [8, 9], [12, 13], [10, 11], [14, 15] SmallVector ret; for (unsigned i = 0; i < values.size(); i += 16) { - ret.push_back(values[i + 0]); + ret.push_back(values[i]); ret.push_back(values[i + 1]); - ret.push_back(values[i + 2]); - ret.push_back(values[i + 3]); - ret.push_back(values[i + 8]); - ret.push_back(values[i + 9]); - ret.push_back(values[i + 10]); - ret.push_back(values[i + 11]); ret.push_back(values[i + 4]); ret.push_back(values[i + 5]); + ret.push_back(values[i + 2]); + ret.push_back(values[i + 3]); ret.push_back(values[i + 6]); ret.push_back(values[i + 7]); + ret.push_back(values[i + 8]); + ret.push_back(values[i + 9]); ret.push_back(values[i + 12]); ret.push_back(values[i + 13]); + ret.push_back(values[i + 10]); + ret.push_back(values[i + 11]); ret.push_back(values[i + 14]); ret.push_back(values[i + 15]); } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index e0a2a40ad1..3574329c28 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -238,6 +238,11 @@ static SmallVector eraseOrder(ArrayRef order, } SmallVector getWarpOrder(Attribute layout) { + if (auto dotLayout = dyn_cast(layout)) { + if (isa(dotLayout.getParent())) { + return getWarpOrder(dotLayout.getParent()); + } + } auto order = getOrder(layout); if (auto mmaLayout = dyn_cast(layout)) { if (mmaLayout.isHopper()) { diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 7e6327e3c5..2839b36680 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -473,10 +473,6 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef shape) const { int nIndex = 1 + hasBatchDim; (void)mIndex, (void)nIndex; - assert(((shape[mIndex] == 1 || shape[mIndex] >= getMDim()) && - (shape[nIndex] == 1 || shape[nIndex] >= getNDim())) && - "Unsupported tensor shape for given mfma layout"); - assert(((getMDim() == 32 && getNDim() == 32) || (getMDim() == 16 && getNDim() == 16)) && "Unsupported mfma type"); @@ -580,55 +576,76 @@ dotOperandMfmaToLinearLayout(DotOperandEncodingAttr dotMfmaLayout, // 4) warpsPerCTA[mDim] == 1: This guarantees that every B tensor element is // held by exactly one thread, maintaining the same number of global loads // as in a blocked layout. + // + // Other use of Linear layout is a support of rare corner cases, + // for example one instruction tile is larger than tensor auto mfmaLayout = llvm::cast(dotMfmaLayout.getParent()); - if (dotMfmaLayout.getOpIdx() == 0) { - return std::nullopt; - } auto rank = shape.size(); bool hasBatchDim = rank == 3; int mIndex = 0 + hasBatchDim; - auto kWidth = dotMfmaLayout.getKWidth(); + int32_t kWidth = dotMfmaLayout.getKWidth(); + auto kDim = dotMfmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2; + int32_t kSize = shape[kDim]; auto warpsPerCTA = mfmaLayout.getWarpsPerCTA(); - if (kWidth != 8 || warpsPerCTA[mIndex] != 1) { - return std::nullopt; - } - MLIRContext *ctx = dotMfmaLayout.getContext(); SmallVector outDimNames = standardOutDimNames(ctx, rank); StringAttr kRegister = S("register"); StringAttr kLane = S("lane"); + StringAttr kWarp = S("warp"); + // register order + // operand A: [1, 0] / [2, 1, 0] + // operand B: [0, 1] / [1, 2, 0] + // for both cases it is [k, nonk]/[k, nonk, batch] SmallVector order = triton::gpu::getOrder(dotMfmaLayout); - auto tileLayout = LinearLayout::empty(); + // warp order + // common for both operand A and B: [0, 1] / [0, 1, 2] + // in both cases it is [M dim, N dim]/[batch, M dim, N dim] + SmallVector warpOrder = triton::gpu::getWarpOrder(dotMfmaLayout); + + // Lane holds kWidth consecutive elements along k dimension, so + // base register vectors for one tile are initialized in following way: + // {1, 0}, {2, 0} ... {kWidth/2, 0} + std::vector> registerBase; + for (int32_t elem = 1; elem < kWidth; elem *= 2) + registerBase.emplace_back(std::vector{elem, 0}); + + std::vector> laneBase; + int32_t kTileSize = -1; if (mfmaLayout.getMDim() == 32) { - // Based on canonical MFMA linear layout, which handles 4 consecutive - // elements along the register dimension, kWidth=8 means we have 8 - // consecutive elements, so we have an additional {4, 0} base vector here. - // For lane dim, since the MFMA thread arrangement is {K, N} = {2, 32}, this - // means that mapping of first 5 base (up to thread 16) vectors will be an - // identity along N dim. Thread 32 will be mapped to element 8 in K - // dimension, because kWidth == 8. - tileLayout = LinearLayout( - {{kRegister, {{1, 0}, {2, 0}, {4, 0}}}, - {kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}}}, - {outDimNames[order[0]], outDimNames[order[1]]}); + // Canonical MFMA linear layout handles 4 consecutive elements along + // the register dimension. Dot operand handles varaible kWidth consecutive + // elements. For lane dim, since the MFMA thread arrangement is {K, N} = {2, + // 32}, this means that mapping of first 5 base (up to thread 16) vectors + // will be an identity along N dim. Thread 32 will be mapped to element + // kWidth in K dimension. + laneBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {kWidth, 0}}; + kTileSize = kWidth * 2; } else { assert(mfmaLayout.getMDim() == 16); // For lane dim, since the MFMA thread arrangement is {K, N} = {4, 16}, this // means that mapping of first 4 base (up to thread 16) vectors will be an - // identity along N dim. Thread 16 will be mapped to element 8 in K - // dimension, because kWidth == 8. Thread 32 is mapped to element 16 as that - // is 2*kWidth in K dim. - tileLayout = LinearLayout( - {{kRegister, {{1, 0}, {2, 0}, {4, 0}}}, - {kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}, {16, 0}}}}, - {outDimNames[order[0]], outDimNames[order[1]]}); + // identity along N dim. Thread 16 will be mapped to element kWisth in K + // dimension. Thread 32 is mapped to element 2*kWidth in K dim. + laneBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {kWidth, 0}, {kWidth * 2, 0}}; + kTileSize = kWidth * 4; } + assert(kTileSize != -1); + // Add repeats of registers along K dimension to register base vectors + for (int32_t elem = kTileSize; elem < kSize; elem *= 2) + registerBase.emplace_back(std::vector{elem, 0}); + + // Base vectors above are defined in a fixed order [non-k-dim, k-dim]. + // To assign them to actual matrix dimensions `order` array is used. + // For operand A: non-k-dim -> dim0, k-dim -> dim1 + // For operand B: non-k-dim -> dim1, k-dim -> dim0 + LinearLayout tileLayout({{kRegister, registerBase}, {kLane, laneBase}}, + {outDimNames[order[0]], outDimNames[order[1]]}); if (hasBatchDim) { assert(order[2] == 0); @@ -639,8 +656,10 @@ dotOperandMfmaToLinearLayout(DotOperandEncodingAttr dotMfmaLayout, } LinearLayout warpLayout = - identityND(S("warp"), warpsPerCTA, order, outDimNames); - LinearLayout ctaLayout = tileLayout * warpLayout; + identityND(kWarp, warpsPerCTA, warpOrder, outDimNames); + + LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) * + warpLayout.transposeOuts(outDimNames); return combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), shape); } @@ -1001,6 +1020,8 @@ bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef repShape, mlir::dyn_cast(tensorTy.getEncoding()); if (!mmaLayout || !mmaLayout.isHopper()) return false; + if (isa(tensorTy.getElementType())) + return false; if (tensorTy.getElementType().getIntOrFloatBitWidth() != 16) return false; if (order[0] != 1) diff --git a/python/triton/testing.py b/python/triton/testing.py index 73fc865fa9..1dd079ab98 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -139,7 +139,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod return _summarize_statistics(torch.tensor(ret), quantiles, return_mode) -def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean", device_type="xpu"): +def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean"): """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with the 20-th and 80-th performance percentile. @@ -164,11 +164,7 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m fn() di.synchronize() - # We maintain a buffer of 256 MB that we clear - # before each kernel call to make sure that the L2 cache - # doesn't contain any input data before the run - cache_size = 256 * 1024 * 1024 - cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device_type) + cache = runtime.driver.active.get_empty_cache_for_benchmark() # Estimate the runtime of the function start_event = Event(enable_timing=True) diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index e2f43f4ba6..e1a2ec68bd 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1034,7 +1034,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} { // CHECK-LABEL: atomic_add_f32 tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) { // CHECK: llvm.inline_asm @@ -1048,7 +1048,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} { // CHECK-LABEL: atomic_add_f32_scalar tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr, %arg1 : i1, %arg2 : f32) { // CHECK: llvm.icmp "eq" @@ -1062,7 +1062,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} { // CHECK-LABEL: atomic_add_f32 tt.func @atomic_add_f32_sys_scope(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) { // CHECK: llvm.inline_asm @@ -1076,6 +1076,34 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f16_nomask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f16_nomask + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2 + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data : (tensor<256x!tt.ptr, #blocked>, tensor<256xf16, #blocked>) -> tensor<256xf16, #blocked> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f16_withmask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>, %mask: tensor<256xi1, #blocked>) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f16_withmask + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16 + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16 + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16 + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr, #blocked>, tensor<256xf16, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf16, #blocked> + tt.return + } +} + +// ----- + #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: store_f32 diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index d445299662..83653d57b6 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -241,3 +241,41 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : tt.return } } + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f32_nomask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f32_nomask + // CHECK: atom.global.gpu.acq_rel.add.v4.f32 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data : (tensor<256x!tt.ptr, #blocked>, tensor<256xf32, #blocked>) -> tensor<256xf32, #blocked> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f32_withmask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 2 : i32}) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f32_withmask + // CHECK: atom.global.gpu.acq_rel.add.v2.f32 + // CHECK: atom.global.gpu.acq_rel.add.v2.f32 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr, #blocked>, tensor<256xf32, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf32, #blocked> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f16_withmask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 4 : i32}) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f16_withmask + // CHECK: atom.global.gpu.acq_rel.add.noftz.v4.f16 + // CHECK: atom.global.gpu.acq_rel.add.noftz.v4.f16 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr, #blocked>, tensor<256xf16, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf16, #blocked> + tt.return + } +} diff --git a/third_party/amd/backend/driver.py b/third_party/amd/backend/driver.py index 86c9dd4339..6e1a368bf8 100644 --- a/third_party/amd/backend/driver.py +++ b/third_party/amd/backend/driver.py @@ -503,3 +503,10 @@ def get_current_target(self): def get_benchmarker(self): from triton.testing import do_bench return do_bench + + def get_empty_cache_for_benchmark(self): + import torch + + # It's the same as the Nvidia backend. + cache_size = 256 * 1024 * 1024 + return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda') diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index 21b74ecf99..9296962983 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -5,6 +5,8 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include @@ -36,16 +38,15 @@ int getWmmaVersion(StringRef archGen) { return 0; } -SmallVector warpsPerTile(tt::DotOp dotOp, - const ArrayRef shape, - int numWarps, - SmallVector shapePerWarp) { +SmallVector +warpsPerTile(Operation *dotOp, ArrayRef shape, int numWarps, + std::pair shapePerWarp) { auto rank = shape.size(); // Early exit for batched matmul if (rank == 3) return {(unsigned)numWarps, 1, 1}; - auto filter = [&dotOp](Operation *op) { + auto filter = [dotOp](Operation *op) { return op->getParentRegion() == dotOp->getParentRegion(); }; ForwardSliceOptions fwdOpt; @@ -55,7 +56,7 @@ SmallVector warpsPerTile(tt::DotOp dotOp, bwdOpt.filter = filter; auto slices = getSlice(dotOp, bwdOpt, fwdOpt); for (Operation *op : slices) - if (isa(op) && (op != dotOp)) + if (op->hasTrait() && (op != dotOp)) return {(unsigned)numWarps, 1}; SmallVector tensorShape = {shape[0], shape[1]}; @@ -63,9 +64,9 @@ SmallVector warpsPerTile(tt::DotOp dotOp, do { if (ret[0] * ret[1] >= numWarps) break; - if (tensorShape[0] / (shapePerWarp[0] * 2) / ret[0] >= - tensorShape[1] / shapePerWarp[1] / ret[1]) { - if (ret[0] < tensorShape[0] / shapePerWarp[0]) { + if (tensorShape[0] / (shapePerWarp.first * 2) / ret[0] >= + tensorShape[1] / shapePerWarp.second / ret[1]) { + if (ret[0] < tensorShape[0] / shapePerWarp.first) { ret[0] *= 2; } else ret[1] *= 2; @@ -74,24 +75,89 @@ SmallVector warpsPerTile(tt::DotOp dotOp, } } while (true); - if (ret[1] * shapePerWarp[1] > tensorShape[1]) { + if (ret[1] * shapePerWarp.second > tensorShape[1]) { return {ret[1], ret[0]}; } return ret; } -SmallVector -warpsPerTileMFMA(tt::DotOp dotOp, const ArrayRef shape, int numWarps, - SmallVector shapePerWarp) { +SmallVector +warpsPerTileMFMA(Operation *dotOp, ArrayRef shape, int numWarps, + std::pair shapePerWarp) { return warpsPerTile(dotOp, shape, numWarps, shapePerWarp); } -SmallVector -warpsPerTileWMMA(tt::DotOp dotOp, const ArrayRef shape, int numWarps) { - return warpsPerTile(dotOp, shape, numWarps, - {ttg::AMDWmmaEncodingAttr::getMNKDimPerInstr()[0], - ttg::AMDWmmaEncodingAttr::getMNKDimPerInstr()[1]}); +SmallVector +warpsPerTileWMMA(Operation *dotOp, ArrayRef shape, int numWarps) { + auto mnk = ttg::AMDWmmaEncodingAttr::getMNKDimPerInstr(); + return warpsPerTile(dotOp, shape, numWarps, {mnk[0], mnk[1]}); +} + +// Chooses a proper MFMA instruction that can used to compute the given dot op. +// If enforcedNonKDim is not zero, it will be used to overwrite the default +// logic to chose a MFMA with matching M/N dim. +FailureOr chooseMfmaInstruction(RankedTensorType cType, + Type aElemType, Type bElemType, + int inputKSize, int mfmaVersion, + int enforcedNonKDim) { + // number of matrix elements along k dim per one MFMA intruction + unsigned kDim = 0; + + auto resShape = cType.getShape(); + auto rank = resShape.size(); + auto M = resShape[rank - 2]; + auto N = resShape[rank - 1]; + + unsigned mDim = 0; + unsigned nDim = 0; + if (enforcedNonKDim != 0) { + mDim = nDim = enforcedNonKDim; + } else { + int minSize = std::min(M, N); + if (minSize >= 32) { + mDim = 32; + nDim = 32; + } + if (minSize >= 16 && minSize < 32) { + mDim = 16; + nDim = 16; + } + if (minSize < 16) { + if (M < 16 && N >= 64) { + mDim = 4; + nDim = 64; + } else if (M >= 64 && N < 16) { + mDim = 64; + nDim = 4; + } else { + assert(inputKSize >= 64 && + "k should be at least 64 to use this layout"); + mDim = 4; + nDim = 4; + } + } + } + assert(mDim != 0 && nDim != 0); + + auto maybeMfmaInsn = + MfmaInsn::selectMfma(mDim, nDim, aElemType, bElemType, mfmaVersion); + if (failed(maybeMfmaInsn)) + llvm::report_fatal_error("No match found in MFMA database\n"); + + kDim = maybeMfmaInsn->getKDim(); + assert(kDim != 0); + assert(M % mDim == 0 && N % nDim == 0); + assert(inputKSize % kDim == 0); + return maybeMfmaInsn; +} + +FailureOr chooseMfmaInstruction(tt::DotOp dot, int mfmaVersion, + int nonKDim) { + RankedTensorType aType = dot.getA().getType(); + return chooseMfmaInstruction(dot.getC().getType(), aType.getElementType(), + dot.getB().getType().getElementType(), + aType.getShape().back(), mfmaVersion, nonKDim); } using OperandTypesVector = SmallVector; @@ -259,15 +325,16 @@ Value convertAndCastTensor(PatternRewriter &rewriter, Value value, return castedTensor; } -class BlockedToMFMA : public RewritePattern { +class BlockedToMFMA : public OpRewritePattern { int mfmaVersion; - int enforcedNonKDim; + int nonKDim; int kPack; public: - BlockedToMFMA(MLIRContext *context, int mfmaVersion, int nonKDim, int kPack) - : RewritePattern(tt::DotOp::getOperationName(), 2, context), - mfmaVersion(mfmaVersion), enforcedNonKDim(nonKDim), kPack(kPack) {} + BlockedToMFMA(MLIRContext *context, int mfmaVersion, int nonKDim, int kPack, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), mfmaVersion(mfmaVersion), + nonKDim(nonKDim), kPack(kPack) {} bool isSecondDot(tt::DotOp &dotOp) const { auto filter = [&dotOp](Operation *op) { @@ -285,75 +352,15 @@ class BlockedToMFMA : public RewritePattern { return false; } - /// @brief Choose MFMA instruction parameters - /// @param dot target dot operation - /// @return MfmaInsn or failure - FailureOr chooseMfmaInstruction(tt::DotOp dot) const { - // number of matrix elements along k dim per one MFMA intruction - unsigned kDim = 0; - auto opType = cast(dot.getA().getType()); - auto dataTypeA = opType.getElementType(); - auto dataTypeB = - cast(dot.getB().getType()).getElementType(); - - auto resType = cast(dot.getD().getType()); - auto resShape = resType.getShape(); - auto rank = resShape.size(); - auto M = resShape[rank - 2]; - auto N = resShape[rank - 1]; - - unsigned mDim = 0; - unsigned nDim = 0; - if (enforcedNonKDim != 0) { - mDim = enforcedNonKDim; - nDim = enforcedNonKDim; - } else { - int minSize = std::min(M, N); - if (minSize >= 32) { - mDim = 32; - nDim = 32; - } - if (minSize >= 16 && minSize < 32) { - mDim = 16; - nDim = 16; - } - if (minSize < 16) { - if (M < 16 && N >= 64) { - mDim = 4; - nDim = 64; - } else if (M >= 64 && N < 16) { - mDim = 64; - nDim = 4; - } else { - assert(opType.getShape()[rank - 1] >= 64 && - "k should be at least 64 to use this layout"); - mDim = 4; - nDim = 4; - } - } - } - assert(mDim != 0 && nDim != 0); - - auto maybeMfmaInsn = - MfmaInsn::selectMfma(mDim, nDim, dataTypeA, dataTypeB, mfmaVersion); - if (failed(maybeMfmaInsn)) - llvm::report_fatal_error("No match found in MFMA database\n"); - - kDim = maybeMfmaInsn->getKDim(); - assert(kDim != 0); - assert(M % mDim == 0 && N % nDim == 0); - assert(opType.getShape()[rank - 1] % kDim == 0); - return maybeMfmaInsn; - } - - LogicalResult matchAndRewrite(Operation *op, + LogicalResult matchAndRewrite(tt::DotOp dotOp, PatternRewriter &rewriter) const override { - auto dotOp = cast(op); - RankedTensorType oldRetType = dotOp.getType(); if (!oldRetType.getEncoding() || !isa(oldRetType.getEncoding())) return failure(); + if (!isa_and_nonnull(dotOp.getType().getEncoding())) + return rewriter.notifyMatchFailure( + dotOp, "expected blocked encoding result tensor"); if (!supportMFMA(dotOp)) return failure(); @@ -362,7 +369,7 @@ class BlockedToMFMA : public RewritePattern { // get MFMA encoding for the given number of warps auto retShape = oldRetType.getShape(); - auto mod = op->getParentOfType(); + auto mod = dotOp->getParentOfType(); int numWarps = ttg::TritonGPUDialect::getNumWarps(mod); // operands @@ -374,7 +381,7 @@ class BlockedToMFMA : public RewritePattern { ttg::AMDMfmaEncodingAttr mfmaEnc; - auto mfmaInstr = chooseMfmaInstruction(dotOp); + auto mfmaInstr = chooseMfmaInstruction(dotOp, mfmaVersion, nonKDim); auto mDim = mfmaInstr.value().getMDim(); auto nDim = mfmaInstr.value().getNDim(); auto kDim = mfmaInstr.value().getKDim(); @@ -397,7 +404,7 @@ class BlockedToMFMA : public RewritePattern { mfmaAccType = rewriter.getF32Type(); // convert accumulator - auto oldAcc = dotOp.getOperand(2); + auto oldAcc = dotOp.getC(); auto newAcc = convertAndCastTensor(rewriter, oldAcc, mfmaEnc, mfmaAccType); // Here is a brief explanation of kWidth, kBase, and kDim @@ -456,11 +463,12 @@ class BlockedToMFMA : public RewritePattern { convertAndCastTensor(rewriter, newDot, oldRetType.getEncoding(), oldRetType.getElementType()); - rewriter.replaceOp(op, dotOutput); + rewriter.replaceOp(dotOp, dotOutput); return success(); } }; + static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, Type promotedType) { Type tensorPromotedType = cast(operand.getType()) @@ -566,18 +574,17 @@ static void decomposeMixedModeDotOp(ModuleOp mod) { }); } -class BlockedToWMMA : public RewritePattern { +class BlockedToWMMA : public OpRewritePattern { int wmmaVersion; public: - BlockedToWMMA(MLIRContext *context, int wmmaVersion) - : RewritePattern(tt::DotOp::getOperationName(), 2, context), - wmmaVersion(wmmaVersion) {} + BlockedToWMMA(MLIRContext *context, int wmmaVersion, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), wmmaVersion(wmmaVersion) {} - LogicalResult matchAndRewrite(Operation *op, + LogicalResult matchAndRewrite(tt::DotOp dotOp, PatternRewriter &rewriter) const override { - auto ctx = op->getContext(); - auto dotOp = cast(op); + auto ctx = dotOp->getContext(); Value a = dotOp.getA(); Value b = dotOp.getB(); @@ -603,7 +610,7 @@ class BlockedToWMMA : public RewritePattern { if (wmmaVersion == 2 && llvm::isa(oldAType) && oldAType.getIntOrFloatBitWidth() == 8) { - return rewriter.notifyMatchFailure(op, "not supported yet"); + return rewriter.notifyMatchFailure(dotOp, "not supported yet"); } // get operand types @@ -612,7 +619,7 @@ class BlockedToWMMA : public RewritePattern { return failure(); // get WMMA encoding for the given number of warps - auto mod = op->getParentOfType(); + auto mod = dotOp->getParentOfType(); int numWarps = ttg::TritonGPUDialect::getNumWarps(mod); ttg::AMDWmmaEncodingAttr wmmaEnc; @@ -626,7 +633,7 @@ class BlockedToWMMA : public RewritePattern { auto newRetType = RankedTensorType::get(retShape, operandTypes[3], wmmaEnc); // convert accumulator - auto oldAcc = dotOp.getOperand(2); + auto oldAcc = dotOp.getC(); auto newAcc = convertAndCastTensor(rewriter, oldAcc, wmmaEnc, operandTypes[2]); @@ -649,7 +656,7 @@ class BlockedToWMMA : public RewritePattern { Value dotOutput = convertAndCastTensor(rewriter, newDot, oldRetEncoding, oldRetType.getElementType()); - rewriter.replaceOp(op, dotOutput); + rewriter.replaceOp(dotOp, dotOutput); return success(); } }; diff --git a/third_party/intel/backend/driver.py b/third_party/intel/backend/driver.py index 485e1dcb91..1db9b2c202 100644 --- a/third_party/intel/backend/driver.py +++ b/third_party/intel/backend/driver.py @@ -551,3 +551,12 @@ def is_active(): def get_benchmarker(self): from triton.testing import do_bench return do_bench + + def get_empty_cache_for_benchmark(self): + import torch + + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 cache + # doesn't contain any input data before the run + cache_size = 256 * 1024 * 1024 + return torch.empty(int(cache_size // 4), dtype=torch.int, device='xpu') diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index 286f8cb52a..38ce62b0c2 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -452,3 +452,12 @@ def is_active(): def get_benchmarker(self): from triton.testing import do_bench return do_bench + + def get_empty_cache_for_benchmark(self): + import torch + + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 cache + # doesn't contain any input data before the run + cache_size = 256 * 1024 * 1024 + return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda') diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 54371d063f..71fd3c0cd4 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -641,7 +641,6 @@ struct ConvertLayoutOpConversion // for the destination type, we need to pack values together // so they can be consumed by tensor core operations SmallVector vecVals; - SmallVector types; // For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer // instructions to pack & unpack sub-word integers. A workaround is to // store the results of ldmatrix in i32 @@ -655,37 +654,20 @@ struct ConvertLayoutOpConversion shl(i32_ty, zext(i32_ty, vals[i + j]), i32_val(elemSize * j)); val = or_(i32_ty, val, ext); } - vecVals.push_back(val); + vecVals.push_back(bitcast(val, i32_ty)); } - elems = elems / (32 / elemSize); - types = SmallVector(elems, i32_ty); } else { unsigned vecSize = std::max(32 / elemSize, 1); Type vecTy = vec_ty(elemTy, vecSize); - types = SmallVector(elems / vecSize, vecTy); for (unsigned i = 0; i < elems; i += vecSize) { Value packed = rewriter.create(loc, vecTy); for (unsigned j = 0; j < vecSize; j++) packed = insert_element(vecTy, packed, vals[i + j], i32_val(j)); - vecVals.push_back(packed); + vecVals.push_back(bitcast(packed, i32_ty)); } } - - // This needs to be ordered the same way that - // ldmatrix.x4 would order it - // TODO: this needs to be refactor so we don't - // implicitly depends on how emitOffsetsForMMAV2 - // is implemented - SmallVector reorderedVals; - for (unsigned i = 0; i < vecVals.size(); i += 4) { - reorderedVals.push_back(bitcast(vecVals[i], i32_ty)); - reorderedVals.push_back(bitcast(vecVals[i + 2], i32_ty)); - reorderedVals.push_back(bitcast(vecVals[i + 1], i32_ty)); - reorderedVals.push_back(bitcast(vecVals[i + 3], i32_ty)); - } - - Value view = packLLElements(loc, getTypeConverter(), reorderedVals, - rewriter, dstTy); + Value view = + packLLElements(loc, getTypeConverter(), vecVals, rewriter, dstTy); rewriter.replaceOp(op, view); return success(); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp index bf033bdd53..1abb0c5216 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp @@ -513,8 +513,8 @@ Value composeValuesToDotOperandLayoutStruct( for (int m = 0; m < n0; ++m) for (int k = 0; k < n1; ++k) { elems.push_back(vals.at({b, 2 * m, 2 * k})); - elems.push_back(vals.at({b, 2 * m, 2 * k + 1})); elems.push_back(vals.at({b, 2 * m + 1, 2 * k})); + elems.push_back(vals.at({b, 2 * m, 2 * k + 1})); elems.push_back(vals.at({b, 2 * m + 1, 2 * k + 1})); } assert(!elems.empty()); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index af897ef546..258b8fc261 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -62,12 +62,86 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct( auto elems = unpackLLElements(loc, value, rewriter); int offset{}; ValueTableV2 vals; + + // FIXME [Dot LL] + // [ez] Generalize the logic below for kWidth * elemBitWidth > 32 + auto dot = cast(type.getEncoding()); + auto largeK = dot.getKWidth() == 8 && + cast(dot.getParent()).isAmpere(); + if (largeK) { + llvm::SmallVector si; + + // For kWidth = 8, split the mma into 4 mmas with "stride 4" along K + if (dot.getOpIdx() == 0) { + // Original register layout: + // + // [0, 1, 2, 3], [8, 9, 10, 11] + // [4, 5, 6, 7], [12, 13, 14, 15] + // + // Each element in the layout consists of two bf16 values. + // For example, the row [0, 1, 2, 3] expands to: + // + // [[0/0, 0/1], [1/0, 1/1], [2/0, 2/1], [3/0, 3/1]] + // + // Here, 0/0 refers to the first half of element 0, and 0/1 refers to the + // second half, matching kWidth = 8. + // + // To derive four independent MMA operations, a stride of 4 is applied to + // the original register layout: + // + // 1st MMA: [0, 4, 8, 12] + // 2nd MMA: [1, 5, 9, 13] + // 3rd MMA: [2, 6, 10, 14] + // 4th MMA: [3, 7, 11, 15] + si = llvm::SmallVector{0, 4, 8, 12, 1, 5, 9, 13, + 2, 6, 10, 14, 3, 7, 11, 15}; + } else { + // Original register layout: + // + // [0, 1, 2, 3]^T, [4, 5, 6, 7]^T + // + // A stride of 4 is applied to derive four independent MMA operations: + // + // 1st MMA: [0, 4] + // 2nd MMA: [1, 5] + // 3rd MMA: [2, 6] + // 4th MMA: [3, 7] + si = llvm::SmallVector{0, 4, 1, 5, 2, 6, 3, 7}; + } + + auto step = si.size(); + SmallVector perm(step); + for (auto i = 0; i < elems.size() / step; ++i) { + for (auto j = 0; j < step; ++j) { + perm[j] = elems[i * step + si[j]]; + } + std::copy(perm.begin(), perm.end(), elems.begin() + i * step); + } + + if (dot.getOpIdx() == 1) { + // there are kWidth * 2 elems packed as bf16x2 + int elemsInTile = dot.getKWidth(); + // n0 and n1 are unrolled in the legacy path + // Unrolling n1 makes some sense, but unrolling n0 makes absolutely no + // sense IMO + n0 *= 2; + n1 *= 2; + for (auto b = 0; b < batch; ++b) + for (auto j = 0; j < n1 / elemsInTile; ++j) + for (auto i = 0; i < n0; ++i) + for (auto k = 0; k < elemsInTile; ++k) { + vals[{b, i, elemsInTile * j + k}] = elems[offset++]; + } + return vals; + } + } + for (auto b = 0; b < batch; ++b) for (auto i = 0; i < n0; ++i) { for (auto j = 0; j < n1; j++) { vals[{b, 2 * i, 2 * j}] = elems[offset++]; - vals[{b, 2 * i, 2 * j + 1}] = elems[offset++]; vals[{b, 2 * i + 1, 2 * j}] = elems[offset++]; + vals[{b, 2 * i, 2 * j + 1}] = elems[offset++]; vals[{b, 2 * i + 1, 2 * j + 1}] = elems[offset++]; } } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index b19f3ac88e..760ba75d98 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -98,6 +98,23 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, return mask; } +std::string getRegisterSizeCode(int size, bool is_float) { + switch (size) { + case 1: + return "b"; + case 16: + return "h"; + case 32: + return is_float ? "f" : "r"; + case 64: + return is_float ? "d" : "l"; + case 128: + return "q"; + default: + llvm_unreachable("Unsupported register size"); + } +} + // Contains some helper functions for both Load and Store conversions. struct LoadStoreConversionBase { explicit LoadStoreConversionBase(const NVIDIA::TargetInfo &targetInfo, @@ -632,6 +649,20 @@ struct AtomicRMWOpConversion : ConvertOpToLLVMPattern(converter, benefit), LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + bool supportsVectorized(Operation *moduleOp, RMWOp opType, + Type elementType) const { + // vectorized atomics are only supported on hopper, + // and only for specific atomic ops (add, min, max). + // Note that "packed types" like f16x2 are supported sm60+. + auto computeCapability = getNVIDIAComputeCapability(moduleOp); + if (computeCapability < 90) { + return false; + } + + return opType == RMWOp::FADD && + (elementType.isF16() || elementType.isBF16() || elementType.isF32()); + } + LogicalResult matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -664,45 +695,82 @@ struct AtomicRMWOpConversion : valueTy; const size_t valueElemNBits = valueElemTy.getIntOrFloatBitWidth(); auto elemsPerThread = getTotalElemsPerThread(val.getType()); - // vec = 1, numElements = 1 for scalar - auto vec = getVectorSize(ptr); - auto vecOrig = vec; - int numElems = 1; - // tensor + // packed: e.g. packed=2 for f16x2 + // vec: e.g. .v2, .v4, .v8 version of atom instruction. + unsigned vec, vecOrig; + int numElems, packed; if (tensorTy) { + vec = getVectorSize(ptr); + if (llMask) { + vec = std::min(vec, getMaskAlignment(op.getMask())); + } + vecOrig = vec; + packed = 1; auto valTy = cast(val.getType()); - vec = std::min(vec, valTy.getElementType().isF16() ? 2 : 1); - // mask + if (!supportsVectorized(moduleOp, atomicRmwAttr, + valTy.getElementType())) { + packed = + std::min(vecOrig, valTy.getElementType().isF16() ? 2 : 1); + vec = 1; + } numElems = tensorTy.getNumElements(); + } else { + // scalar + vec = 1; + vecOrig = 1; + numElems = 1; + packed = 1; } + assert((packed == 1 || vec == 1) && "packed or vec must be 1"); - if (vec == 1 && numElems > 1) + if (vec * packed == 1 && numElems > 1) op->emitRemark() << "Warning: vectorization fails vec = " << vec - << " origin vec = " << vecOrig + << " packed = " << packed << " origin vec = " << vecOrig << " numElems = " << numElems; Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); - auto vecTy = vec_ty(valueElemTy, vec); + auto packedTy = vec_ty(valueElemTy, packed); SmallVector resultVals(elemsPerThread); - for (size_t i = 0; i < elemsPerThread; i += vec) { - Value rmwVal = undef(vecTy); - for (int ii = 0; ii < vec; ++ii) { - Value iiVal = createIndexAttrConstant( - rewriter, loc, getTypeConverter()->getIndexType(), ii); - rmwVal = insert_element(vecTy, rmwVal, valElements[i + ii], iiVal); - } - + for (size_t i = 0; i < elemsPerThread; i += vec * packed) { Value rmwPtr = ptrElements[i]; Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask; std::string sTy; PTXBuilder ptxBuilderAtomicRMW; - std::string tyId = valueElemNBits * vec == 64 - ? "l" - : (valueElemNBits * vec == 32 ? "r" : "h"); - auto *dstOpr = ptxBuilderAtomicRMW.newOperand("=" + tyId, /*init=*/true); + // 16-bit -> "h", 32-bit -> "r", 64-bit -> "l" + std::string tyId = + getRegisterSizeCode(valueElemNBits * packed, /*is_float=*/false); + + PTXBuilder::Operand *dstOpr; + if (vec > 1) { + dstOpr = ptxBuilderAtomicRMW.newListOperand(); + for (unsigned ii = 0; ii < vec; ++ii) { + dstOpr->listAppend( + ptxBuilderAtomicRMW.newOperand("=" + tyId, /*init=*/true)); + } + } else { + dstOpr = ptxBuilderAtomicRMW.newOperand("=" + tyId, /*init=*/true); + } + auto *ptrOpr = ptxBuilderAtomicRMW.newAddrOperand(rmwPtr, "l"); - auto *valOpr = ptxBuilderAtomicRMW.newOperand(rmwVal, tyId); + + PTXBuilder::Operand *valOpr; + if (vec > 1) { + valOpr = ptxBuilderAtomicRMW.newListOperand(); + for (unsigned ii = 0; ii < vec; ++ii) { + valOpr->listAppend( + ptxBuilderAtomicRMW.newOperand(valElements[i + ii], tyId)); + } + } else if (packed > 1) { + Value rmwVal = undef(packedTy); + for (int ii = 0; ii < packed; ++ii) { + rmwVal = insert_element(packedTy, rmwVal, valElements[i + ii], + i32_val(ii)); + } + valOpr = ptxBuilderAtomicRMW.newOperand(rmwVal, tyId); + } else { + valOpr = ptxBuilderAtomicRMW.newOperand(valElements[i], tyId); + } auto scope = stringifyMemSyncScope(op.getScope()).str(); auto &atom = ptxBuilderAtomicRMW.create<>("atom")->global().o(scope); @@ -725,7 +793,7 @@ struct AtomicRMWOpConversion rmwOp = "add"; rmwOp += (valueElemNBits == 16 ? ".noftz" : ""); sTy = "f" + sBits; - sTy += (vec == 2 && valueElemNBits == 16) ? "x2" : ""; + sTy += (packed == 2 && valueElemNBits == 16) ? "x2" : ""; break; case RMWOp::MAX: sTy = "s" + sBits; @@ -750,15 +818,33 @@ struct AtomicRMWOpConversion std::string semStr; llvm::raw_string_ostream os(semStr); os << op.getSem(); - atom.o(semStr).o(rmwOp).o(sTy); + atom.o(semStr).o(rmwOp).v(vec).o(sTy); if (tensorTy) { atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask); - auto retType = vec == 1 ? valueElemTy : vecTy; + Type retType; + if (vec > 1) { + SmallVector retTys(vec, valueElemTy); + retType = struct_ty(retTys); + } else if (packed > 1) { + retType = packedTy; + } else { + retType = valueElemTy; + } + auto ret = ptxBuilderAtomicRMW.launch(rewriter, loc, retType); - for (int ii = 0; ii < vec; ++ii) { - resultVals[i + ii] = - vec == 1 ? ret : extract_element(valueElemTy, ret, i32_val(ii)); + + if (vec > 1) { + for (unsigned ii = 0; ii < vec; ++ii) { + resultVals[i + ii] = extract_val(valueElemTy, ret, ii); + } + } else if (packed > 1) { + for (unsigned ii = 0; ii < packed; ++ii) { + resultVals[i + ii] = extract_element(valueElemTy, ret, i32_val(ii)); + } + } else { + resultVals[i] = ret; } + } else { auto ASMReturnTy = void_ty(ctx); atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp index aeca44bb46..ad4e840ee0 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -27,6 +27,60 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {} + llvm::SmallVector + unpackFP4Elements(Location loc, ConversionPatternRewriter &rewriter, + const llvm::SmallVector &vals, Value laneId) const { + auto fp4x2ToBf16x2 = [&loc, &rewriter](Value v) -> Value { + auto em0 = and_(v, i8_val(0x70)); + auto em1 = and_(v, i8_val(0x7)); + Value v0 = or_(shl(zext(i16_ty, em0), i16_val(2)), + shl(zext(i16_ty, and_(v, i8_val(0x80))), i16_val(8))); + Value v1 = or_(shl(zext(i16_ty, em1), i16_val(6)), + shl(zext(i16_ty, and_(v, i8_val(0x8))), i16_val(12))); + + // Three cases: + // 1) x is normal and non-zero: Correct bias + v0 = select(icmp_ne(and_(em0, i8_val(0x60)), i8_val(0)), + add(v0, i16_val((127 - 1) << 7)), v0); + v1 = select(icmp_ne(and_(em1, i8_val(0x6)), i8_val(0)), + add(v1, i16_val((127 - 1) << 7)), v1); + + // 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in + // bf16 + v0 = select(icmp_eq(em0, i8_val(0x10)), + or_(i16_val(16128), and_(v0, i16_val(0x8000))), v0); + v1 = select(icmp_eq(em1, i8_val(0x1)), + or_(i16_val(16128), and_(v1, i16_val(0x8000))), v1); + // 3) x is zero, nothing to do + + // Swap as they come packed in big endian + return or_(zext(i32_ty, v0), shl(zext(i32_ty, v1), i32_val(16))); + }; + + auto fp4x8ToBf16x2 = [&loc, &rewriter, &fp4x2ToBf16x2]( + Value v) -> llvm::SmallVector { + llvm::SmallVector results(4); + for (int i = 0; i < 4; ++i) { + auto v_i = trunc(i8_ty, lshr(v, i32_val(8 * i))); + results[i] = fp4x2ToBf16x2(v_i); + } + return results; + }; + + // Split fp4x8 into 4 bf16x2 + llvm::SmallVector ret; + ret.reserve(vals.size() * 4); + for (int i = 0; i < vals.size(); ++i) { + auto vs = fp4x8ToBf16x2(vals[i]); + assert(vs.size() == 4); + for (auto v : vs) { + ret.push_back(v); + } + } + + return ret; + } + LogicalResult matchAndRewrite(UpcastMXFPOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { diff --git a/unittest/Dialect/TritonGPU/DialectTest.cpp b/unittest/Dialect/TritonGPU/DialectTest.cpp index e3f521f1b3..c27c63335e 100644 --- a/unittest/Dialect/TritonGPU/DialectTest.cpp +++ b/unittest/Dialect/TritonGPU/DialectTest.cpp @@ -543,6 +543,12 @@ class AMDMfmaLayoutTest : public ::testing::Test { /*isTransposed=*/true, ctaLayout); } + triton::gpu::DotOperandEncodingAttr + createDotOperand(int idx, triton::gpu::AMDMfmaEncodingAttr parent, + int kWidth) { + return triton::gpu::DotOperandEncodingAttr::get(&ctx, idx, parent, kWidth); + } + protected: MLIRContext ctx; const SmallVector ctaPerCGA{1, 1, 1}; @@ -588,6 +594,32 @@ TEST_F(AMDMfmaLayoutTest, mfma16) { ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u)); } +TEST_F(AMDMfmaLayoutTest, mfma_dot_op) { + auto mfma2d = createMFMA(32, 32, {2, 4}); + auto dot2dOp0 = createDotOperand(0, mfma2d, 4); + auto dot2dOp1 = createDotOperand(1, mfma2d, 4); + ASSERT_THAT(dot2dOp0.getWarpOrder(), mfma2d.getWarpOrder()); + ASSERT_THAT(dot2dOp1.getWarpOrder(), mfma2d.getWarpOrder()); + + auto tmfma2d = createTransposedMFMA(32, 32, {2, 4}); + auto tdot2dOp0 = createDotOperand(0, tmfma2d, 4); + auto tdot2dOp1 = createDotOperand(1, tmfma2d, 4); + ASSERT_THAT(tdot2dOp0.getWarpOrder(), tmfma2d.getWarpOrder()); + ASSERT_THAT(tdot2dOp1.getWarpOrder(), tmfma2d.getWarpOrder()); + + auto mfma3d = createMFMA(32, 32, {2, 4, 1}); + auto dot3dOp0 = createDotOperand(0, mfma3d, 4); + auto dot3dOp1 = createDotOperand(1, mfma3d, 4); + ASSERT_THAT(dot3dOp0.getWarpOrder(), mfma3d.getWarpOrder()); + ASSERT_THAT(dot3dOp1.getWarpOrder(), mfma3d.getWarpOrder()); + + auto tmfma3d = createTransposedMFMA(32, 32, {2, 4, 1}); + auto tdot3dOp0 = createDotOperand(0, tmfma3d, 4); + auto tdot3dOp1 = createDotOperand(1, tmfma3d, 4); + ASSERT_THAT(tdot3dOp0.getWarpOrder(), tmfma3d.getWarpOrder()); + ASSERT_THAT(tdot3dOp1.getWarpOrder(), tmfma3d.getWarpOrder()); +} + } // anonymous namespace } // namespace mlir::triton::gpu diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index 894d78e1b4..554d507015 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -58,8 +58,8 @@ class LinearLayoutConversionsTest : public ::testing::Test { isTransposed, CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd)); } - DotOperandEncodingAttr amdDot(AMDMfmaEncodingAttr mfma, unsigned opIdx, - unsigned kWidth) { + DotOperandEncodingAttr mfmaDotOp(AMDMfmaEncodingAttr mfma, unsigned opIdx, + unsigned kWidth) { return DotOperandEncodingAttr::get(&ctx, opIdx, mfma, kWidth); } @@ -738,12 +738,84 @@ TEST_F(LinearLayoutConversionsTest, MFMA32_2x4x1Warps) { {S("dim0"), S("dim1"), S("dim2")})); } +TEST_F(LinearLayoutConversionsTest, warp1onK_mfma32_lhs_kwidth8) { + auto parentMfma_1_8 = mfma(/*warps=*/{1, 8}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/false); + auto mfmaDot_1_8 = mfmaDotOp(parentMfma_1_8, /*opIdx=*/0, /*kWidth=*/8); + EXPECT_EQ(toLinearLayout({128, 128}, mfmaDot_1_8), + LinearLayout( + {{S("register"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 16}, + {0, 32}, + {0, 64}, + {32, 0}, + {64, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({128, 256}, mfmaDot_1_8), + LinearLayout( + {{S("register"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 16}, + {0, 32}, + {0, 64}, + {0, 128}, + {32, 0}, + {64, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({32, 64}, mfmaDot_1_8), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 4}, {0, 16}, {0, 32}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({256, 256}, mfmaDot_1_8), + LinearLayout( + {{S("register"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 16}, + {0, 32}, + {0, 64}, + {0, 128}, + {32, 0}, + {64, 0}, + {128, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDot_1_8), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 4}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 0}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + TEST_F(LinearLayoutConversionsTest, warp1onK_mfma32_rhs_kwidth8) { auto parentMfma_1_8 = mfma(/*warps=*/{1, 8}, /*mDim=*/32, /*nDim=*/32, /*isTransposed=*/false); - auto amdDot_1_8 = amdDot(parentMfma_1_8, /*opIdx=*/1, /*kWidth=*/8); + auto mfmaDot_1_8 = mfmaDotOp(parentMfma_1_8, /*opIdx=*/1, /*kWidth=*/8); EXPECT_EQ( - toLinearLayout({128, 128}, amdDot_1_8), + toLinearLayout({128, 128}, mfmaDot_1_8), LinearLayout( {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {16, 0}, {32, 0}, {64, 0}}}, {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}}, @@ -752,7 +824,7 @@ TEST_F(LinearLayoutConversionsTest, warp1onK_mfma32_rhs_kwidth8) { {S("dim0"), S("dim1")})); EXPECT_EQ( - toLinearLayout({128, 256}, amdDot_1_8), + toLinearLayout({128, 256}, mfmaDot_1_8), LinearLayout( {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {16, 0}, {32, 0}, {64, 0}}}, {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}}, @@ -760,7 +832,7 @@ TEST_F(LinearLayoutConversionsTest, warp1onK_mfma32_rhs_kwidth8) { {S("block"), {}}}, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({32, 64}, amdDot_1_8), + EXPECT_EQ(toLinearLayout({32, 64}, mfmaDot_1_8), LinearLayout( {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {16, 0}}}, {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}}, @@ -769,7 +841,7 @@ TEST_F(LinearLayoutConversionsTest, warp1onK_mfma32_rhs_kwidth8) { {S("dim0"), S("dim1")})); EXPECT_EQ( - toLinearLayout({256, 256}, amdDot_1_8), + toLinearLayout({256, 256}, mfmaDot_1_8), LinearLayout( {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {16, 0}, {32, 0}, {64, 0}, {128, 0}}}, @@ -778,10 +850,18 @@ TEST_F(LinearLayoutConversionsTest, warp1onK_mfma32_rhs_kwidth8) { {S("block"), {}}}, {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDot_1_8), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {4, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 0}, {8, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + auto parentMfma_1_4 = mfma(/*warps=*/{1, 4}, /*mDim=*/32, /*nDim=*/32, /*isTransposed=*/false); - auto amdDot_1_4 = amdDot(parentMfma_1_4, /*opIdx=*/1, /*kWidth=*/8); - EXPECT_EQ(toLinearLayout({256, 256}, amdDot_1_4), + auto mfmaDot_1_4 = mfmaDotOp(parentMfma_1_4, /*opIdx=*/1, /*kWidth=*/8); + EXPECT_EQ(toLinearLayout({256, 256}, mfmaDot_1_4), LinearLayout( {{S("register"), {{1, 0}, @@ -798,12 +878,131 @@ TEST_F(LinearLayoutConversionsTest, warp1onK_mfma32_rhs_kwidth8) { {S("dim0"), S("dim1")})); } +TEST_F(LinearLayoutConversionsTest, warp1onK_mfma16_lhs_kwidth8) { + auto parentMfma_1_4 = mfma(/*warps=*/{1, 4}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + auto mfmaDot_1_4 = mfmaDotOp(parentMfma_1_4, /*opIdx=*/0, /*kWidth=*/8); + EXPECT_EQ(toLinearLayout({128, 128}, mfmaDot_1_4), + LinearLayout( + {{S("register"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 32}, + {0, 64}, + {16, 0}, + {32, 0}, + {64, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 8}, {0, 16}}}, + {S("warp"), {{0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({1, 128}, mfmaDot_1_4), + LinearLayout( + {{S("register"), + { + {0, 1}, + {0, 2}, + {0, 4}, + {0, 32}, + {0, 64}, + }}, + {S("lane"), {{0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 8}, {0, 16}}}, + {S("warp"), {{0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ( + toLinearLayout({128, 1}, mfmaDot_1_4), + LinearLayout( + {{S("register"), {{0, 0}, {0, 0}, {0, 0}, {16, 0}, {32, 0}, {64, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 0}, {0, 0}}}, + {S("warp"), {{0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({256, 256}, mfmaDot_1_4), + LinearLayout( + {{S("register"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 32}, + {0, 64}, + {0, 128}, + {16, 0}, + {32, 0}, + {64, 0}, + {128, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 8}, {0, 16}}}, + {S("warp"), {{0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDot_1_4), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 4}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 8}, {0, 0}}}, + {S("warp"), {{0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + auto parentMfma_1_8 = mfma(/*warps=*/{1, 8}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + auto mfmaDot_1_8 = mfmaDotOp(parentMfma_1_8, /*opIdx=*/0, /*kWidth=*/8); + EXPECT_EQ(toLinearLayout({256, 256}, mfmaDot_1_8), + LinearLayout( + {{S("register"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 32}, + {0, 64}, + {0, 128}, + {16, 0}, + {32, 0}, + {64, 0}, + {128, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 8}, {0, 16}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + auto parentMfma_1_8_1 = mfma(/*warps=*/{1, 1, 8}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + auto mfmaDot_1_8_1 = mfmaDotOp(parentMfma_1_8_1, /*opIdx=*/0, /*kWidth=*/8); + + EXPECT_EQ(toLinearLayout({1, 256, 256}, mfmaDot_1_8_1), + LinearLayout({{S("register"), + {{0, 0, 1}, + {0, 0, 2}, + {0, 0, 4}, + {0, 0, 32}, + {0, 0, 64}, + {0, 0, 128}, + {0, 16, 0}, + {0, 32, 0}, + {0, 64, 0}, + {0, 128, 0}}}, + {S("lane"), + {{0, 1, 0}, + {0, 2, 0}, + {0, 4, 0}, + {0, 8, 0}, + {0, 0, 8}, + {0, 0, 16}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); +} + TEST_F(LinearLayoutConversionsTest, warp1onK_mfma16_rhs_kwidth8) { auto parentMfma_1_4 = mfma(/*warps=*/{1, 4}, /*mDim=*/16, /*nDim=*/16, /*isTransposed=*/false); - auto amdDot_1_4 = amdDot(parentMfma_1_4, /*opIdx=*/1, /*kWidth=*/8); + auto mfmaDot_1_4 = mfmaDotOp(parentMfma_1_4, /*opIdx=*/1, /*kWidth=*/8); EXPECT_EQ( - toLinearLayout({128, 128}, amdDot_1_4), + toLinearLayout({128, 128}, mfmaDot_1_4), LinearLayout( {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}, {64, 0}, {0, 64}}}, {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}, {16, 0}}}, @@ -811,7 +1010,7 @@ TEST_F(LinearLayoutConversionsTest, warp1onK_mfma16_rhs_kwidth8) { {S("block"), {}}}, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({1, 128}, amdDot_1_4), + EXPECT_EQ(toLinearLayout({1, 128}, mfmaDot_1_4), LinearLayout( {{S("register"), {{0, 0}, {0, 0}, {0, 0}, {0, 64}}}, {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 0}, {0, 0}}}, @@ -819,7 +1018,7 @@ TEST_F(LinearLayoutConversionsTest, warp1onK_mfma16_rhs_kwidth8) { {S("block"), {}}}, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({128, 1}, amdDot_1_4), + EXPECT_EQ(toLinearLayout({128, 1}, mfmaDot_1_4), LinearLayout( {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}, {64, 0}}}, {S("lane"), {{0, 0}, {0, 0}, {0, 0}, {0, 0}, {8, 0}, {16, 0}}}, @@ -827,7 +1026,7 @@ TEST_F(LinearLayoutConversionsTest, warp1onK_mfma16_rhs_kwidth8) { {S("block"), {}}}, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({256, 256}, amdDot_1_4), + EXPECT_EQ(toLinearLayout({256, 256}, mfmaDot_1_4), LinearLayout( {{S("register"), {{1, 0}, @@ -843,11 +1042,19 @@ TEST_F(LinearLayoutConversionsTest, warp1onK_mfma16_rhs_kwidth8) { {S("block"), {}}}, {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDot_1_4), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {4, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}, {0, 0}}}, + {S("warp"), {{0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + auto parentMfma_1_8 = mfma(/*warps=*/{1, 8}, /*mDim=*/16, /*nDim=*/16, /*isTransposed=*/false); - auto amdDot_1_8 = amdDot(parentMfma_1_8, /*opIdx=*/1, /*kWidth=*/8); + auto mfmaDot_1_8 = mfmaDotOp(parentMfma_1_8, /*opIdx=*/1, /*kWidth=*/8); EXPECT_EQ( - toLinearLayout({256, 256}, amdDot_1_8), + toLinearLayout({256, 256}, mfmaDot_1_8), LinearLayout( {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}, {64, 0}, {128, 0}, {0, 128}}}, @@ -858,9 +1065,9 @@ TEST_F(LinearLayoutConversionsTest, warp1onK_mfma16_rhs_kwidth8) { auto parentMfma_1_8_1 = mfma(/*warps=*/{1, 1, 8}, /*mDim=*/16, /*nDim=*/16, /*isTransposed=*/false); - auto amdDot_1_8_1 = amdDot(parentMfma_1_8_1, /*opIdx=*/1, /*kWidth=*/8); + auto mfmaDot_1_8_1 = mfmaDotOp(parentMfma_1_8_1, /*opIdx=*/1, /*kWidth=*/8); - EXPECT_EQ(toLinearLayout({1, 256, 256}, amdDot_1_8_1), + EXPECT_EQ(toLinearLayout({1, 256, 256}, mfmaDot_1_8_1), LinearLayout({{S("register"), {{0, 1, 0}, {0, 2, 0}, @@ -881,6 +1088,167 @@ TEST_F(LinearLayoutConversionsTest, warp1onK_mfma16_rhs_kwidth8) { {S("dim0"), S("dim1"), S("dim2")})); } +TEST_F(LinearLayoutConversionsTest, mfma32_dot_op_lhs_kwidth4) { + auto parentMfma32 = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/false); + auto mfmaDotOp0_32 = mfmaDotOp(parentMfma32, /*opIdx=*/0, /*kWidth=*/4); + EXPECT_EQ(toLinearLayout({128, 128}, mfmaDotOp0_32), + LinearLayout( + {{S("register"), + {{0, 1}, {0, 2}, {0, 8}, {0, 16}, {0, 32}, {0, 64}, {64, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}}, + {S("warp"), {{0, 0}, {0, 0}, {32, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 32}, mfmaDotOp0_32), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 8}, {0, 16}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}}, + {S("warp"), {{0, 0}, {0, 0}, {32, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDotOp0_32), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 8}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 0}, {0, 4}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + // Dot operand based on transposed mfma layout has same layout as ordinary + auto parentTMfma32 = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/true); + auto tmfmaDotOp0_32 = mfmaDotOp(parentTMfma32, /*opIdx=*/0, /*kWidth=*/4); + + EXPECT_EQ(toLinearLayout({128, 128}, tmfmaDotOp0_32), + toLinearLayout({128, 128}, mfmaDotOp0_32)); + EXPECT_EQ(toLinearLayout({64, 32}, tmfmaDotOp0_32), + toLinearLayout({64, 32}, mfmaDotOp0_32)); + EXPECT_EQ(toLinearLayout({16, 16}, tmfmaDotOp0_32), + toLinearLayout({16, 16}, mfmaDotOp0_32)); +} + +TEST_F(LinearLayoutConversionsTest, mfma16_dot_op_lhs_kwidth4) { + auto parentMfma16 = mfma(/*warps=*/{2, 4}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + auto mfmaDotOp0_16 = mfmaDotOp(parentMfma16, /*opIdx=*/0, /*kWidth=*/4); + EXPECT_EQ( + toLinearLayout({128, 128}, mfmaDotOp0_16), + LinearLayout( + {{S("register"), + {{0, 1}, {0, 2}, {0, 16}, {0, 32}, {0, 64}, {32, 0}, {64, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 4}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 32}, mfmaDotOp0_16), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 16}, {32, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 4}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDotOp0_16), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 4}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + // Dot operand based on transposed mfma layout has same layout as ordinary + auto parentTMfma16 = mfma(/*warps=*/{2, 4}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/true); + auto tmfmaDotOp0_16 = mfmaDotOp(parentTMfma16, /*opIdx=*/0, /*kWidth=*/4); + + EXPECT_EQ(toLinearLayout({128, 128}, tmfmaDotOp0_16), + toLinearLayout({128, 128}, mfmaDotOp0_16)); + EXPECT_EQ(toLinearLayout({64, 32}, tmfmaDotOp0_16), + toLinearLayout({64, 32}, mfmaDotOp0_16)); + EXPECT_EQ(toLinearLayout({16, 16}, tmfmaDotOp0_16), + toLinearLayout({16, 16}, mfmaDotOp0_16)); +} + +TEST_F(LinearLayoutConversionsTest, mfma32_dot_op_rhs_kwidth4) { + auto parentMfma32 = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/false); + auto mfmaDotOp1_32 = mfmaDotOp(parentMfma32, /*opIdx=*/1, /*kWidth=*/4); + EXPECT_EQ( + toLinearLayout({128, 128}, mfmaDotOp1_32), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {8, 0}, {16, 0}, {32, 0}, {64, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {4, 0}}}, + {S("warp"), {{0, 32}, {0, 64}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({32, 64}, mfmaDotOp1_32), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {8, 0}, {16, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {4, 0}}}, + {S("warp"), {{0, 32}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDotOp1_32), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {8, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 0}, {4, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + // Dot operand based on transposed mfma layout has same layout as ordinary + auto parentTMfma32 = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/true); + auto tmfmaDotOp1_32 = mfmaDotOp(parentTMfma32, /*opIdx=*/1, /*kWidth=*/4); + + EXPECT_EQ(toLinearLayout({128, 128}, tmfmaDotOp1_32), + toLinearLayout({128, 128}, mfmaDotOp1_32)); + EXPECT_EQ(toLinearLayout({64, 32}, tmfmaDotOp1_32), + toLinearLayout({64, 32}, mfmaDotOp1_32)); + EXPECT_EQ(toLinearLayout({16, 16}, tmfmaDotOp1_32), + toLinearLayout({16, 16}, mfmaDotOp1_32)); +} + +TEST_F(LinearLayoutConversionsTest, mfma16_dot_op_rhs_kwidth4) { + auto parentMfma16 = mfma(/*warps=*/{2, 4}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + auto mfmaDotOp1_16 = mfmaDotOp(parentMfma16, /*opIdx=*/1, /*kWidth=*/4); + EXPECT_EQ(toLinearLayout({128, 128}, mfmaDotOp1_16), + LinearLayout( + {{S("register"), + {{1, 0}, {2, 0}, {16, 0}, {32, 0}, {64, 0}, {0, 64}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {4, 0}, {8, 0}}}, + {S("warp"), {{0, 16}, {0, 32}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({32, 64}, mfmaDotOp1_16), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {16, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {4, 0}, {8, 0}}}, + {S("warp"), {{0, 16}, {0, 32}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDotOp1_16), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {4, 0}, {8, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + // Dot operand based on transposed mfma layout has same layout as ordinary + auto parentTMfma16 = mfma(/*warps=*/{2, 4}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/true); + auto tmfmaDotOp1_16 = mfmaDotOp(parentTMfma16, /*opIdx=*/1, /*kWidth=*/4); + + EXPECT_EQ(toLinearLayout({128, 128}, tmfmaDotOp1_16), + toLinearLayout({128, 128}, mfmaDotOp1_16)); + EXPECT_EQ(toLinearLayout({64, 32}, tmfmaDotOp1_16), + toLinearLayout({64, 32}, mfmaDotOp1_16)); + EXPECT_EQ(toLinearLayout({16, 16}, tmfmaDotOp1_16), + toLinearLayout({16, 16}, mfmaDotOp1_16)); +} + TEST_F(LinearLayoutConversionsTest, WMMA_2x4Warps) { auto legacy = wmma(/*warps=*/{2, 4});