Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/llvm-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1f20eee6dc367bd202895e3eedb03974a628ef16
86b69c31642e98f8357df62c09d118ad1da4e16a
6 changes: 6 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1123,6 +1123,12 @@ emitBaseIndexForLayout(Location loc, RewriterBase &rewriter,
return idx;
}

// Emit code to compute the (blockId, warpId, laneId) for the current thread.
std::tuple</*blockId=*/Value, /*warpId=*/Value, /*laneId=*/Value>
emitHardwareTuple(Location loc, RewriterBase &rewriter,
const TargetInfoBase &target, bool withCTAOffset,
unsigned threadsPerWarp);

// Emit indices calculation within each ConversionPattern, and returns a
// [elemsPerThread X rank] index matrix.
//
Expand Down
35 changes: 21 additions & 14 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,20 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
return outIndices;
}

std::tuple<Value, Value, Value> emitHardwareTuple(Location loc,
RewriterBase &rewriter,
const TargetInfoBase &target,
bool withCTAOffset,
unsigned threadsPerWarpCst) {
Value threadId = getThreadId(rewriter, loc);
Value threadsPerWarp = i32_val(threadsPerWarpCst);
Value laneId = urem(threadId, threadsPerWarp);
Value warpId = udiv(threadId, threadsPerWarp);
Value blockId =
withCTAOffset ? target.getClusterCTAId(rewriter, loc) : i32_val(0);
return {blockId, warpId, laneId};
}

SmallVector<SmallVector<Value>>
emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
Attribute layout, RankedTensorType type, bool withCTAOffset) {
Expand All @@ -116,12 +130,8 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
StringAttr kWarp = str_attr("warp");
StringAttr kBlock = str_attr("block");

Value threadId = getThreadId(rewriter, loc);
Value threadsPerWarp = i32_val(ll->getInDimSize(kLane));
Value laneId = urem(threadId, threadsPerWarp);
Value warpId = udiv(threadId, threadsPerWarp);
Value blockId =
withCTAOffset ? target.getClusterCTAId(rewriter, loc) : i32_val(0);
auto [blockId, warpId, laneId] = emitHardwareTuple(
loc, rewriter, target, withCTAOffset, ll->getInDimSize(kLane));
unsigned rank = shape.size();
SmallVector<SmallVector<Value>> ret;
// Linear layout function is split in two parts below:
Expand Down Expand Up @@ -214,10 +224,9 @@ bool emitTransferBetweenRegistersAndShared(
std::min(regToSharedLayout->getNumConsecutiveInOut(),
maxVecElems.value_or(std::numeric_limits<int>::max()));

Value threadId = getThreadId(rewriter, loc);
Value threadsPerWarp = i32_val(regToSharedLayout->getInDimSize(kLane));
Value laneId = urem(threadId, threadsPerWarp);
Value warpId = udiv(threadId, threadsPerWarp);
auto [blockId, warpId, laneId] =
emitHardwareTuple(loc, rewriter, target, /*withCTAOffset=*/false,
regToSharedLayout->getInDimSize(kLane));

int numElems = regToSharedLayout->getInDimSize(kRegister);
auto vecTy = vec_ty(elemLlvmTy, vecElems);
Expand Down Expand Up @@ -625,10 +634,8 @@ SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
auto instrShape = mmaLayout.getInstrShape();
SmallVector<Value> mmaColIdx(2);
SmallVector<Value> mmaRowIdx(2);
Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(32);
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
auto [blockId, warpId, laneId] = emitHardwareTuple(
loc, rewriter, targetInfo, /*withCTAOffset=*/false, 32);
// TODO: fix the bug in MMAEncodingAttr document
SmallVector<Value> multiDimWarpId(2);
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/Triton/Transforms/Combine.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def CombineDotAddIPattern : Pat<
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;
def CombineDotAddFPattern : Pat<
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath, $denorm),
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath),
(TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)),
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"::llvm::cast<::mlir::IntegerAttr>($0).getInt() == 0">> $maxNumImpreciseAcc),
Expand All @@ -29,7 +29,7 @@ def CombineDotAddIRevPattern : Pat<
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;
def CombineDotAddFRevPattern : Pat<
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath, $denorm),
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath),
(TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)),
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"::llvm::cast<::mlir::IntegerAttr>($0).getInt() == 0">> $maxNumImpreciseAcc),
Expand Down
3 changes: 0 additions & 3 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -525,10 +525,7 @@ AMDWmmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
std::optional<LinearLayout>
BlockedEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
assert(shape.size() == getOrder().size());

int rank = shape.size();
MLIRContext *ctx = getContext();
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);

const auto &order = getOrder();
LinearLayout ctaLayout =
Expand Down
2 changes: 1 addition & 1 deletion python/test/regression/test_cast_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import triton.language as tl
from triton._internal_testing import is_hip_mi300, is_cuda, is_hip

input_dtypes = ["float16", "float32", "float64"]
input_dtypes = ["bfloat16", "float16", "float32", "float64"]
if is_cuda():
input_dtypes += ["int8", "float8_e5m2"]
cc = torch.cuda.get_device_capability(0)
Expand Down
46 changes: 46 additions & 0 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1999,3 +1999,49 @@ tt.func @gather_in_shared_dot_input(%arg0: tensor<16x4xi32, #blocked>, %arg1: te
}

}

// -----

#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 3072 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} {

tt.func public @ampere_s8_to_fp16_conversion_opIdx1(%1 : tensor<16x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>) attributes {noinline = false} {
// CHECK-LABEL: ampere_s8_to_fp16_conversion_opIdx1
// CHECK: llvm.sitofp %{{.*}} : i8 to f16
%2 = arith.sitofp %1 : tensor<16x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> to tensor<16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
tt.return
}

}

// -----

#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 3072 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @ampere_s8_to_fp16_conversion_opIdx0(%1 : tensor<32x16xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>) attributes {noinline = false} {
// CHECK-LABEL: @ampere_s8_to_fp16_conversion_opIdx0
// CHECK: llvm.sitofp %{{.*}} : i8 to f16
%2 = arith.sitofp %1 : tensor<32x16xi8, #ttg.dot_op<{opIdx = 0 , parent = #mma, kWidth = 4}>> to tensor<32x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
tt.return
}

}

// -----

#linear = #ttg.linear<{register = [], lane = [[0, 1], [1, 0], [2, 0], [4, 0], [8, 0]], warp = [[0, 0], [16, 0]], block = []}>
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {

tt.func @upcast_mxfp(%arg0: tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: tensor<32x2xi8, #linear>) {
// CHECK-LABEL: upcast_mxfp
// CHECK-COUNT-4: llvm.inline_asm
// CHECK-COUNT-2: nvvm.shfl.sync
// CHECK-COUNT-32: llvm.fmul
%0 = ttg.upcast_mxfp %arg0, %arg1 fp_type = e2m1 : tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, tensor<32x2xi8, #linear> -> tensor<32x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
tt.return
}

}
8 changes: 4 additions & 4 deletions test/TritonIntelGPU/prefetch-block.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 1 : i32}
// CHECK-NEXT: [[B3:%.*]] = tt.advance [[B2]], {{.*}} : <tensor<32x256xf16, #blocked2>>
// CHECK-NEXT: [[B4:%.*]] = tt.make_tensor_ptr %arg1, {{.*}} : <tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>>

// CHECK: spirv.INTEL.ControlBarrierArrive <Workgroup> <Workgroup> <None>
// CHECK: spirv.INTEL.ControlBarrierArrive <Workgroup>, <Workgroup>, <None>
// CHECK-NEXT: scf.for [[IV:%.*]] = [[CST_ZERO]] to [[CST_4096]] step [[CST_32]]
// CHECK-SAME: iter_args([[CST:%.*]] = {{.*}}, [[A6:%.*]] = [[A4]], [[B6:%.*]] = [[B4]], [[A5:%.*]] = [[A3]], [[B5:%.*]] = [[B3]])
// CHECK-NEXT: [[LD_A:%.*]] = tt.load [[A6]]
Expand All @@ -45,11 +45,11 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 1 : i32}
// CHECK-DAG: tt.advance [[A6]], {{.*}} : <tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>>
// CHECK-NEXT: tt.advance [[B5]], {{.*}} : <tensor<32x256xf16, #blocked2>>
// CHECK-DAG: tt.advance [[B6]], {{.*}} : <tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>>
// CHECK: spirv.INTEL.ControlBarrierWait <Workgroup> <Workgroup> <None>
// CHECK-NEXT: spirv.INTEL.ControlBarrierArrive <Workgroup> <Workgroup> <None>
// CHECK: spirv.INTEL.ControlBarrierWait <Workgroup>, <Workgroup>, <None>
// CHECK-NEXT: spirv.INTEL.ControlBarrierArrive <Workgroup>, <Workgroup>, <None>
// CHECK-NEXT: scf.yield {{.*}}
// CHECK-NEXT: }
// CHECK-NEXT: spirv.INTEL.ControlBarrierWait <Workgroup> <Workgroup> <None>
// CHECK-NEXT: spirv.INTEL.ControlBarrierWait <Workgroup>, <Workgroup>, <None>

%c64_i32 = arith.constant 64 : i32
%c16_i32 = arith.constant 16 : i32
Expand Down
2 changes: 1 addition & 1 deletion third_party/amd/backend/include/hsa/amd_hsa_elf.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ enum : unsigned {
EF_AMDGPU_MACH_AMDGCN_GFX942 = 0x04c,
EF_AMDGPU_MACH_AMDGCN_RESERVED_0X4D = 0x04d,
EF_AMDGPU_MACH_AMDGCN_GFX1201 = 0x04e,
EF_AMDGPU_MACH_AMDGCN_GFX950 = 0x04f,
EF_AMDGPU_MACH_AMDGCN_RESERVED_0X4F = 0x04f,
EF_AMDGPU_MACH_AMDGCN_RESERVED_0X50 = 0x050,
EF_AMDGPU_MACH_AMDGCN_GFX9_GENERIC = 0x051,
EF_AMDGPU_MACH_AMDGCN_GFX10_1_GENERIC = 0x052,
Expand Down
1 change: 0 additions & 1 deletion third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ ISAFamily deduceISAFamily(llvm::StringRef arch) {

// CDNA ISA cases
switch (kind) {
case llvm::AMDGPU::GK_GFX950:
case llvm::AMDGPU::GK_GFX942:
case llvm::AMDGPU::GK_GFX941:
case llvm::AMDGPU::GK_GFX940:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class TritonGEN_Op<string mnemonic, list<Trait> traits = []> :
def TritonGEN_MatrixElemType : AnyTypeOf<[AnyI8, AnyI16, AnyI32, F32, F16, BF16]>;

def TritonGEN_MatrixDPASOp : TritonGEN_Op<"dpas">,
Results<(outs FixedVectorOfAnyRank<[TritonGEN_MatrixElemType]>:$d)>,
Results<(outs FixedVectorOf<[TritonGEN_MatrixElemType]>:$d)>,
Arguments<(ins
FixedVectorOfRankAndType<[1], [TritonGEN_MatrixElemType]>:$c,
FixedVectorOfRankAndType<[1], [TritonGEN_MatrixElemType]>:$a,
Expand Down Expand Up @@ -82,7 +82,7 @@ def TritonGEN_MatrixDPASOp : TritonGEN_Op<"dpas">,
}

def TritonGEN_Matrix2DBlockLoadOp : TritonGEN_Op<"2Dblockload">,
Results<(outs FixedVectorOfAnyRank<[TritonGEN_MatrixElemType]>:$res)>,
Results<(outs FixedVectorOf<[TritonGEN_MatrixElemType]>:$res)>,
Arguments<(ins
Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr,
I32:$base_width,
Expand Down Expand Up @@ -145,7 +145,7 @@ def TritonGEN_Matrix2DBlockStoreOp : TritonGEN_Op<"2Dblockstore">,
I32Attr:$tile_width,
I32Attr:$tile_height,
I32Attr:$v_blocks,
FixedVectorOfAnyRank<[TritonGEN_MatrixElemType]>:$stored_val,
FixedVectorOf<[TritonGEN_MatrixElemType]>:$stored_val,
DefaultValuedAttr<TritonGEN_StoreCacheControl, "::mlir::triton::TritonGEN::StoreCacheControl::DEFAULT">:$cache_control
)> {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "PatternTritonGPUOpToLLVM.h"

#include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h"
Expand All @@ -19,6 +20,73 @@ using namespace mlir;
using namespace mlir::triton;
using namespace mlir::triton::gpu;

// Convert 8 fp4 elements packed into a 32bit reg into 8 bf16 elements packed
// into 4 32bits regs.
static constexpr const char *ptxAsm =
"{\n"
".reg .b32 a<14>;\n"
"and.b32 a0, $4, -2004318072;\n\t"
"shr.u32 a1, a0, 3;\n\t"
"and.b32 a2, $4, 2004318071;\n\t"
"shr.u32 a3, a2, 16;\n\t"
"shr.u32 a4, a0, 19;\n\t"
"prmt.b32 a5, -1065353216, -1065336832, a2;\n\t"
"prmt.b32 a6, -1065353216, -1065336832, a3;\n\t"
"prmt.b32 a7, 1061109504, 1077952576, a2;\n\t"
"prmt.b32 a8, 1061109504, 1077952576, a3;\n\t"
"prmt.b32 a9, 32768, 0, a1;\n\t"
"prmt.b32 a10, 32768, 0, a4;\n\t"
"or.b32 a11, a7, a9;\n\t"
"or.b32 a12, a8, a10;\n\t"
"prmt.b32 $0, a5, a11, 20800;\n\t"
"prmt.b32 $1, a5, a11, 29538;\n\t"
"prmt.b32 $2, a6, a12, 20800;\n\t"
"prmt.b32 $3, a6, a12, 29538;\n\t"
"}";

static Value createInlineAsmUpcast(Location loc, RewriterBase &rewriter,
Type retType, Value packedVec) {
PTXBuilder builder;
SmallVector<PTXBuilder::Operand *> operands;
for (int i = 0; i < 4; i++) {
operands.push_back(builder.newOperand("=r"));
}
operands.push_back(builder.newOperand(packedVec, "r"));
auto &ptxOp = *builder.create(ptxAsm);
ptxOp(operands, /*onlyAttachMLIRArgs=*/true);
Value result = builder.launch(rewriter, loc, retType, false);
return result;
}

static SmallVector<Value> convertMxfp4x2ToBf16x2PTX(RewriterBase &rewriter,
Location loc,
ArrayRef<Value> values) {
SmallVector<Value> results;
MLIRContext *ctx = rewriter.getContext();
assert(values.size() % 4 == 0);
for (int i = 0; i < values.size(); i += 4) {
Value v0 = values[i];
Value v1 = values[i + 1];
Value v2 = values[i + 2];
Value v3 = values[i + 3];
Value packedVec = undef(vec_ty(i8_ty, 4));
packedVec = insert_element(packedVec, v0, i32_val(0));
packedVec = insert_element(packedVec, v1, i32_val(1));
packedVec = insert_element(packedVec, v2, i32_val(2));
packedVec = insert_element(packedVec, v3, i32_val(3));
SmallVector<Type> rets(4, i32_ty);
Type retType = struct_ty(rets);
Value ret = createInlineAsmUpcast(loc, rewriter, retType, packedVec);
for (int i = 0; i < 4; i++) {
Value extractI32 = extract_val(ret, i);
Value vecbf16 = bitcast(extractI32, vec_ty(bf16_ty, 2));
results.push_back(extract_element(vecbf16, i32_val(0)));
results.push_back(extract_element(vecbf16, i32_val(1)));
}
}
return results;
}

namespace {
class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
private:
Expand Down Expand Up @@ -53,7 +121,7 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
cast<DotOperandEncodingAttr>(op.getType().getEncoding()).getKWidth();

if (fpType == ScaleDotElemType::E2M1)
xVals = LLVM::convertMxfp4x2ToBf16x2(rewriter, loc, xVals);
xVals = convertMxfp4x2ToBf16x2PTX(rewriter, loc, xVals);

// Each thread owns elements of 4 mxfp vectors so we need 4 scales
// Since we go from a threadShape of 8x4 to 16x2, we let c = tid / 4 * 2
Expand Down
Loading