diff --git a/cmake/llvm-hash.txt b/cmake/llvm-hash.txt index 0952ab984c..50d0247946 100644 --- a/cmake/llvm-hash.txt +++ b/cmake/llvm-hash.txt @@ -1 +1 @@ -1f20eee6dc367bd202895e3eedb03974a628ef16 +86b69c31642e98f8357df62c09d118ad1da4e16a diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 244bc6181a..a1c37efb52 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -1123,6 +1123,12 @@ emitBaseIndexForLayout(Location loc, RewriterBase &rewriter, return idx; } +// Emit code to compute the (blockId, warpId, laneId) for the current thread. +std::tuple +emitHardwareTuple(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, bool withCTAOffset, + unsigned threadsPerWarp); + // Emit indices calculation within each ConversionPattern, and returns a // [elemsPerThread X rank] index matrix. // diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 49f05a758e..a310cdba5f 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -99,6 +99,20 @@ applyLinearLayout(Location loc, RewriterBase &rewriter, return outIndices; } +std::tuple emitHardwareTuple(Location loc, + RewriterBase &rewriter, + const TargetInfoBase &target, + bool withCTAOffset, + unsigned threadsPerWarpCst) { + Value threadId = getThreadId(rewriter, loc); + Value threadsPerWarp = i32_val(threadsPerWarpCst); + Value laneId = urem(threadId, threadsPerWarp); + Value warpId = udiv(threadId, threadsPerWarp); + Value blockId = + withCTAOffset ? target.getClusterCTAId(rewriter, loc) : i32_val(0); + return {blockId, warpId, laneId}; +} + SmallVector> emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, Attribute layout, RankedTensorType type, bool withCTAOffset) { @@ -116,12 +130,8 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, StringAttr kWarp = str_attr("warp"); StringAttr kBlock = str_attr("block"); - Value threadId = getThreadId(rewriter, loc); - Value threadsPerWarp = i32_val(ll->getInDimSize(kLane)); - Value laneId = urem(threadId, threadsPerWarp); - Value warpId = udiv(threadId, threadsPerWarp); - Value blockId = - withCTAOffset ? target.getClusterCTAId(rewriter, loc) : i32_val(0); + auto [blockId, warpId, laneId] = emitHardwareTuple( + loc, rewriter, target, withCTAOffset, ll->getInDimSize(kLane)); unsigned rank = shape.size(); SmallVector> ret; // Linear layout function is split in two parts below: @@ -214,10 +224,9 @@ bool emitTransferBetweenRegistersAndShared( std::min(regToSharedLayout->getNumConsecutiveInOut(), maxVecElems.value_or(std::numeric_limits::max())); - Value threadId = getThreadId(rewriter, loc); - Value threadsPerWarp = i32_val(regToSharedLayout->getInDimSize(kLane)); - Value laneId = urem(threadId, threadsPerWarp); - Value warpId = udiv(threadId, threadsPerWarp); + auto [blockId, warpId, laneId] = + emitHardwareTuple(loc, rewriter, target, /*withCTAOffset=*/false, + regToSharedLayout->getInDimSize(kLane)); int numElems = regToSharedLayout->getInDimSize(kRegister); auto vecTy = vec_ty(elemLlvmTy, vecElems); @@ -625,10 +634,8 @@ SmallVector getMultiDimOffset(Attribute layout, Location loc, auto instrShape = mmaLayout.getInstrShape(); SmallVector mmaColIdx(2); SmallVector mmaRowIdx(2); - Value threadId = getThreadId(rewriter, loc); - Value warpSize = i32_val(32); - Value laneId = urem(threadId, warpSize); - Value warpId = udiv(threadId, warpSize); + auto [blockId, warpId, laneId] = emitHardwareTuple( + loc, rewriter, targetInfo, /*withCTAOffset=*/false, 32); // TODO: fix the bug in MMAEncodingAttr document SmallVector multiDimWarpId(2); auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); diff --git a/lib/Dialect/Triton/Transforms/Combine.td b/lib/Dialect/Triton/Transforms/Combine.td index 1f1de2c717..e3588f5877 100644 --- a/lib/Dialect/Triton/Transforms/Combine.td +++ b/lib/Dialect/Triton/Transforms/Combine.td @@ -17,7 +17,7 @@ def CombineDotAddIPattern : Pat< [(Constraint> $c), (ConstrainthasOneUse()">, "dot result has a single use">)]>; def CombineDotAddFPattern : Pat< - (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath, $denorm), + (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath), (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), [(Constraint> $c), (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc), @@ -29,7 +29,7 @@ def CombineDotAddIRevPattern : Pat< [(Constraint> $c), (ConstrainthasOneUse()">, "dot result has a single use">)]>; def CombineDotAddFRevPattern : Pat< - (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath, $denorm), + (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath), (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), [(Constraint> $c), (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc), diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 3a3519f7c0..1025ab0990 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -525,10 +525,7 @@ AMDWmmaEncodingAttr::toLinearLayout(ArrayRef shape) const { std::optional BlockedEncodingAttr::toLinearLayout(ArrayRef shape) const { assert(shape.size() == getOrder().size()); - - int rank = shape.size(); MLIRContext *ctx = getContext(); - SmallVector outDimNames = standardOutDimNames(ctx, rank); const auto &order = getOrder(); LinearLayout ctaLayout = diff --git a/python/test/regression/test_cast_matmul.py b/python/test/regression/test_cast_matmul.py index 5a5d3b446f..2ea563b9d2 100644 --- a/python/test/regression/test_cast_matmul.py +++ b/python/test/regression/test_cast_matmul.py @@ -15,7 +15,7 @@ import triton.language as tl from triton._internal_testing import is_hip_mi300, is_cuda, is_hip -input_dtypes = ["float16", "float32", "float64"] +input_dtypes = ["bfloat16", "float16", "float32", "float64"] if is_cuda(): input_dtypes += ["int8", "float8_e5m2"] cc = torch.cuda.get_device_capability(0) diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 32dff6e49f..61911356af 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1999,3 +1999,49 @@ tt.func @gather_in_shared_dot_input(%arg0: tensor<16x4xi32, #blocked>, %arg1: te } } + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 3072 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + + tt.func public @ampere_s8_to_fp16_conversion_opIdx1(%1 : tensor<16x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>) attributes {noinline = false} { + // CHECK-LABEL: ampere_s8_to_fp16_conversion_opIdx1 + // CHECK: llvm.sitofp %{{.*}} : i8 to f16 + %2 = arith.sitofp %1 : tensor<16x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> to tensor<16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + tt.return +} + +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 3072 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @ampere_s8_to_fp16_conversion_opIdx0(%1 : tensor<32x16xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>) attributes {noinline = false} { + // CHECK-LABEL: @ampere_s8_to_fp16_conversion_opIdx0 + // CHECK: llvm.sitofp %{{.*}} : i8 to f16 + %2 = arith.sitofp %1 : tensor<32x16xi8, #ttg.dot_op<{opIdx = 0 , parent = #mma, kWidth = 4}>> to tensor<32x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + tt.return +} + +} + +// ----- + +#linear = #ttg.linear<{register = [], lane = [[0, 1], [1, 0], [2, 0], [4, 0], [8, 0]], warp = [[0, 0], [16, 0]], block = []}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + +tt.func @upcast_mxfp(%arg0: tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: tensor<32x2xi8, #linear>) { + // CHECK-LABEL: upcast_mxfp + // CHECK-COUNT-4: llvm.inline_asm + // CHECK-COUNT-2: nvvm.shfl.sync + // CHECK-COUNT-32: llvm.fmul + %0 = ttg.upcast_mxfp %arg0, %arg1 fp_type = e2m1 : tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, tensor<32x2xi8, #linear> -> tensor<32x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + tt.return +} + +} diff --git a/test/TritonIntelGPU/prefetch-block.mlir b/test/TritonIntelGPU/prefetch-block.mlir index 31d6ed992c..ddcf0ee142 100644 --- a/test/TritonIntelGPU/prefetch-block.mlir +++ b/test/TritonIntelGPU/prefetch-block.mlir @@ -33,7 +33,7 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 1 : i32} // CHECK-NEXT: [[B3:%.*]] = tt.advance [[B2]], {{.*}} : > // CHECK-NEXT: [[B4:%.*]] = tt.make_tensor_ptr %arg1, {{.*}} : >> - // CHECK: spirv.INTEL.ControlBarrierArrive + // CHECK: spirv.INTEL.ControlBarrierArrive , , // CHECK-NEXT: scf.for [[IV:%.*]] = [[CST_ZERO]] to [[CST_4096]] step [[CST_32]] // CHECK-SAME: iter_args([[CST:%.*]] = {{.*}}, [[A6:%.*]] = [[A4]], [[B6:%.*]] = [[B4]], [[A5:%.*]] = [[A3]], [[B5:%.*]] = [[B3]]) // CHECK-NEXT: [[LD_A:%.*]] = tt.load [[A6]] @@ -45,11 +45,11 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 1 : i32} // CHECK-DAG: tt.advance [[A6]], {{.*}} : >> // CHECK-NEXT: tt.advance [[B5]], {{.*}} : > // CHECK-DAG: tt.advance [[B6]], {{.*}} : >> - // CHECK: spirv.INTEL.ControlBarrierWait - // CHECK-NEXT: spirv.INTEL.ControlBarrierArrive + // CHECK: spirv.INTEL.ControlBarrierWait , , + // CHECK-NEXT: spirv.INTEL.ControlBarrierArrive , , // CHECK-NEXT: scf.yield {{.*}} // CHECK-NEXT: } - // CHECK-NEXT: spirv.INTEL.ControlBarrierWait + // CHECK-NEXT: spirv.INTEL.ControlBarrierWait , , %c64_i32 = arith.constant 64 : i32 %c16_i32 = arith.constant 16 : i32 diff --git a/third_party/amd/backend/include/hsa/amd_hsa_elf.h b/third_party/amd/backend/include/hsa/amd_hsa_elf.h index 65a77f041b..74f15d7d7a 100644 --- a/third_party/amd/backend/include/hsa/amd_hsa_elf.h +++ b/third_party/amd/backend/include/hsa/amd_hsa_elf.h @@ -136,7 +136,7 @@ enum : unsigned { EF_AMDGPU_MACH_AMDGCN_GFX942 = 0x04c, EF_AMDGPU_MACH_AMDGCN_RESERVED_0X4D = 0x04d, EF_AMDGPU_MACH_AMDGCN_GFX1201 = 0x04e, - EF_AMDGPU_MACH_AMDGCN_GFX950 = 0x04f, + EF_AMDGPU_MACH_AMDGCN_RESERVED_0X4F = 0x04f, EF_AMDGPU_MACH_AMDGCN_RESERVED_0X50 = 0x050, EF_AMDGPU_MACH_AMDGCN_GFX9_GENERIC = 0x051, EF_AMDGPU_MACH_AMDGCN_GFX10_1_GENERIC = 0x052, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp index 7ab6fd68a5..63fb972f79 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp @@ -11,7 +11,6 @@ ISAFamily deduceISAFamily(llvm::StringRef arch) { // CDNA ISA cases switch (kind) { - case llvm::AMDGPU::GK_GFX950: case llvm::AMDGPU::GK_GFX942: case llvm::AMDGPU::GK_GFX941: case llvm::AMDGPU::GK_GFX940: diff --git a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td index bc83a32fdf..d637a87366 100644 --- a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td +++ b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td @@ -35,7 +35,7 @@ class TritonGEN_Op traits = []> : def TritonGEN_MatrixElemType : AnyTypeOf<[AnyI8, AnyI16, AnyI32, F32, F16, BF16]>; def TritonGEN_MatrixDPASOp : TritonGEN_Op<"dpas">, - Results<(outs FixedVectorOfAnyRank<[TritonGEN_MatrixElemType]>:$d)>, + Results<(outs FixedVectorOf<[TritonGEN_MatrixElemType]>:$d)>, Arguments<(ins FixedVectorOfRankAndType<[1], [TritonGEN_MatrixElemType]>:$c, FixedVectorOfRankAndType<[1], [TritonGEN_MatrixElemType]>:$a, @@ -82,7 +82,7 @@ def TritonGEN_MatrixDPASOp : TritonGEN_Op<"dpas">, } def TritonGEN_Matrix2DBlockLoadOp : TritonGEN_Op<"2Dblockload">, - Results<(outs FixedVectorOfAnyRank<[TritonGEN_MatrixElemType]>:$res)>, + Results<(outs FixedVectorOf<[TritonGEN_MatrixElemType]>:$res)>, Arguments<(ins Arg:$ptr, I32:$base_width, @@ -145,7 +145,7 @@ def TritonGEN_Matrix2DBlockStoreOp : TritonGEN_Op<"2Dblockstore">, I32Attr:$tile_width, I32Attr:$tile_height, I32Attr:$v_blocks, - FixedVectorOfAnyRank<[TritonGEN_MatrixElemType]>:$stored_val, + FixedVectorOf<[TritonGEN_MatrixElemType]>:$stored_val, DefaultValuedAttr:$cache_control )> { diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp index e5c8973f7a..47c7fcc063 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -6,6 +6,7 @@ #include "PatternTritonGPUOpToLLVM.h" +#include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Transforms/DialectConversion.h" @@ -19,6 +20,73 @@ using namespace mlir; using namespace mlir::triton; using namespace mlir::triton::gpu; +// Convert 8 fp4 elements packed into a 32bit reg into 8 bf16 elements packed +// into 4 32bits regs. +static constexpr const char *ptxAsm = + "{\n" + ".reg .b32 a<14>;\n" + "and.b32 a0, $4, -2004318072;\n\t" + "shr.u32 a1, a0, 3;\n\t" + "and.b32 a2, $4, 2004318071;\n\t" + "shr.u32 a3, a2, 16;\n\t" + "shr.u32 a4, a0, 19;\n\t" + "prmt.b32 a5, -1065353216, -1065336832, a2;\n\t" + "prmt.b32 a6, -1065353216, -1065336832, a3;\n\t" + "prmt.b32 a7, 1061109504, 1077952576, a2;\n\t" + "prmt.b32 a8, 1061109504, 1077952576, a3;\n\t" + "prmt.b32 a9, 32768, 0, a1;\n\t" + "prmt.b32 a10, 32768, 0, a4;\n\t" + "or.b32 a11, a7, a9;\n\t" + "or.b32 a12, a8, a10;\n\t" + "prmt.b32 $0, a5, a11, 20800;\n\t" + "prmt.b32 $1, a5, a11, 29538;\n\t" + "prmt.b32 $2, a6, a12, 20800;\n\t" + "prmt.b32 $3, a6, a12, 29538;\n\t" + "}"; + +static Value createInlineAsmUpcast(Location loc, RewriterBase &rewriter, + Type retType, Value packedVec) { + PTXBuilder builder; + SmallVector operands; + for (int i = 0; i < 4; i++) { + operands.push_back(builder.newOperand("=r")); + } + operands.push_back(builder.newOperand(packedVec, "r")); + auto &ptxOp = *builder.create(ptxAsm); + ptxOp(operands, /*onlyAttachMLIRArgs=*/true); + Value result = builder.launch(rewriter, loc, retType, false); + return result; +} + +static SmallVector convertMxfp4x2ToBf16x2PTX(RewriterBase &rewriter, + Location loc, + ArrayRef values) { + SmallVector results; + MLIRContext *ctx = rewriter.getContext(); + assert(values.size() % 4 == 0); + for (int i = 0; i < values.size(); i += 4) { + Value v0 = values[i]; + Value v1 = values[i + 1]; + Value v2 = values[i + 2]; + Value v3 = values[i + 3]; + Value packedVec = undef(vec_ty(i8_ty, 4)); + packedVec = insert_element(packedVec, v0, i32_val(0)); + packedVec = insert_element(packedVec, v1, i32_val(1)); + packedVec = insert_element(packedVec, v2, i32_val(2)); + packedVec = insert_element(packedVec, v3, i32_val(3)); + SmallVector rets(4, i32_ty); + Type retType = struct_ty(rets); + Value ret = createInlineAsmUpcast(loc, rewriter, retType, packedVec); + for (int i = 0; i < 4; i++) { + Value extractI32 = extract_val(ret, i); + Value vecbf16 = bitcast(extractI32, vec_ty(bf16_ty, 2)); + results.push_back(extract_element(vecbf16, i32_val(0))); + results.push_back(extract_element(vecbf16, i32_val(1))); + } + } + return results; +} + namespace { class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { private: @@ -53,7 +121,7 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { cast(op.getType().getEncoding()).getKWidth(); if (fpType == ScaleDotElemType::E2M1) - xVals = LLVM::convertMxfp4x2ToBf16x2(rewriter, loc, xVals); + xVals = convertMxfp4x2ToBf16x2PTX(rewriter, loc, xVals); // Each thread owns elements of 4 mxfp vectors so we need 4 scales // Since we go from a threadShape of 8x4 to 16x2, we let c = tid / 4 * 2