Skip to content

Commit 963ba2b

Browse files
Merge commit '390e27f4813799c242d4e6b2f8d79eda3b51cd92'
2 parents 5c0e236 + 390e27f commit 963ba2b

File tree

7 files changed

+114
-20
lines changed

7 files changed

+114
-20
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,6 +1123,12 @@ emitBaseIndexForLayout(Location loc, RewriterBase &rewriter,
11231123
return idx;
11241124
}
11251125

1126+
// Emit code to compute the (blockId, warpId, laneId) for the current thread.
1127+
std::tuple</*blockId=*/Value, /*warpId=*/Value, /*laneId=*/Value>
1128+
emitHardwareTuple(Location loc, RewriterBase &rewriter,
1129+
const TargetInfoBase &target, bool withCTAOffset,
1130+
unsigned threadsPerWarp);
1131+
11261132
// Emit indices calculation within each ConversionPattern, and returns a
11271133
// [elemsPerThread X rank] index matrix.
11281134
//

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,20 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
9999
return outIndices;
100100
}
101101

102+
std::tuple<Value, Value, Value> emitHardwareTuple(Location loc,
103+
RewriterBase &rewriter,
104+
const TargetInfoBase &target,
105+
bool withCTAOffset,
106+
unsigned threadsPerWarpCst) {
107+
Value threadId = getThreadId(rewriter, loc);
108+
Value threadsPerWarp = i32_val(threadsPerWarpCst);
109+
Value laneId = urem(threadId, threadsPerWarp);
110+
Value warpId = udiv(threadId, threadsPerWarp);
111+
Value blockId =
112+
withCTAOffset ? target.getClusterCTAId(rewriter, loc) : i32_val(0);
113+
return {blockId, warpId, laneId};
114+
}
115+
102116
SmallVector<SmallVector<Value>>
103117
emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
104118
Attribute layout, RankedTensorType type, bool withCTAOffset) {
@@ -116,12 +130,8 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
116130
StringAttr kWarp = str_attr("warp");
117131
StringAttr kBlock = str_attr("block");
118132

119-
Value threadId = getThreadId(rewriter, loc);
120-
Value threadsPerWarp = i32_val(ll->getInDimSize(kLane));
121-
Value laneId = urem(threadId, threadsPerWarp);
122-
Value warpId = udiv(threadId, threadsPerWarp);
123-
Value blockId =
124-
withCTAOffset ? target.getClusterCTAId(rewriter, loc) : i32_val(0);
133+
auto [blockId, warpId, laneId] = emitHardwareTuple(
134+
loc, rewriter, target, withCTAOffset, ll->getInDimSize(kLane));
125135
unsigned rank = shape.size();
126136
SmallVector<SmallVector<Value>> ret;
127137
// Linear layout function is split in two parts below:
@@ -214,10 +224,9 @@ bool emitTransferBetweenRegistersAndShared(
214224
std::min(regToSharedLayout->getNumConsecutiveInOut(),
215225
maxVecElems.value_or(std::numeric_limits<int>::max()));
216226

217-
Value threadId = getThreadId(rewriter, loc);
218-
Value threadsPerWarp = i32_val(regToSharedLayout->getInDimSize(kLane));
219-
Value laneId = urem(threadId, threadsPerWarp);
220-
Value warpId = udiv(threadId, threadsPerWarp);
227+
auto [blockId, warpId, laneId] =
228+
emitHardwareTuple(loc, rewriter, target, /*withCTAOffset=*/false,
229+
regToSharedLayout->getInDimSize(kLane));
221230

222231
int numElems = regToSharedLayout->getInDimSize(kRegister);
223232
auto vecTy = vec_ty(elemLlvmTy, vecElems);
@@ -625,10 +634,8 @@ SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
625634
auto instrShape = mmaLayout.getInstrShape();
626635
SmallVector<Value> mmaColIdx(2);
627636
SmallVector<Value> mmaRowIdx(2);
628-
Value threadId = getThreadId(rewriter, loc);
629-
Value warpSize = i32_val(32);
630-
Value laneId = urem(threadId, warpSize);
631-
Value warpId = udiv(threadId, warpSize);
637+
auto [blockId, warpId, laneId] = emitHardwareTuple(
638+
loc, rewriter, targetInfo, /*withCTAOffset=*/false, 32);
632639
// TODO: fix the bug in MMAEncodingAttr document
633640
SmallVector<Value> multiDimWarpId(2);
634641
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -525,10 +525,7 @@ AMDWmmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
525525
std::optional<LinearLayout>
526526
BlockedEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
527527
assert(shape.size() == getOrder().size());
528-
529-
int rank = shape.size();
530528
MLIRContext *ctx = getContext();
531-
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);
532529

533530
const auto &order = getOrder();
534531
LinearLayout ctaLayout =

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2028,3 +2028,20 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
20282028
}
20292029

20302030
}
2031+
2032+
// -----
2033+
2034+
#linear = #ttg.linear<{register = [], lane = [[0, 1], [1, 0], [2, 0], [4, 0], [8, 0]], warp = [[0, 0], [16, 0]], block = []}>
2035+
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>
2036+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
2037+
2038+
tt.func @upcast_mxfp(%arg0: tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: tensor<32x2xi8, #linear>) {
2039+
// CHECK-LABEL: upcast_mxfp
2040+
// CHECK-COUNT-4: llvm.inline_asm
2041+
// CHECK-COUNT-2: nvvm.shfl.sync
2042+
// CHECK-COUNT-32: llvm.fmul
2043+
%0 = ttg.upcast_mxfp %arg0, %arg1 fp_type = e2m1 : tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, tensor<32x2xi8, #linear> -> tensor<32x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
2044+
tt.return
2045+
}
2046+
2047+
}

third_party/amd/backend/include/hsa/amd_hsa_elf.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ enum : unsigned {
136136
EF_AMDGPU_MACH_AMDGCN_GFX942 = 0x04c,
137137
EF_AMDGPU_MACH_AMDGCN_RESERVED_0X4D = 0x04d,
138138
EF_AMDGPU_MACH_AMDGCN_GFX1201 = 0x04e,
139-
EF_AMDGPU_MACH_AMDGCN_GFX950 = 0x04f,
139+
EF_AMDGPU_MACH_AMDGCN_RESERVED_0X4F = 0x04f,
140140
EF_AMDGPU_MACH_AMDGCN_RESERVED_0X50 = 0x050,
141141
EF_AMDGPU_MACH_AMDGCN_GFX9_GENERIC = 0x051,
142142
EF_AMDGPU_MACH_AMDGCN_GFX10_1_GENERIC = 0x052,

third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ ISAFamily deduceISAFamily(llvm::StringRef arch) {
1111

1212
// CDNA ISA cases
1313
switch (kind) {
14-
case llvm::AMDGPU::GK_GFX950:
1514
case llvm::AMDGPU::GK_GFX942:
1615
case llvm::AMDGPU::GK_GFX941:
1716
case llvm::AMDGPU::GK_GFX940:

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "PatternTritonGPUOpToLLVM.h"
88

9+
#include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h"
910
#include "mlir/IR/Value.h"
1011
#include "mlir/IR/ValueRange.h"
1112
#include "mlir/Transforms/DialectConversion.h"
@@ -19,6 +20,73 @@ using namespace mlir;
1920
using namespace mlir::triton;
2021
using namespace mlir::triton::gpu;
2122

23+
// Convert 8 fp4 elements packed into a 32bit reg into 8 bf16 elements packed
24+
// into 4 32bits regs.
25+
static constexpr const char *ptxAsm =
26+
"{\n"
27+
".reg .b32 a<14>;\n"
28+
"and.b32 a0, $4, -2004318072;\n\t"
29+
"shr.u32 a1, a0, 3;\n\t"
30+
"and.b32 a2, $4, 2004318071;\n\t"
31+
"shr.u32 a3, a2, 16;\n\t"
32+
"shr.u32 a4, a0, 19;\n\t"
33+
"prmt.b32 a5, -1065353216, -1065336832, a2;\n\t"
34+
"prmt.b32 a6, -1065353216, -1065336832, a3;\n\t"
35+
"prmt.b32 a7, 1061109504, 1077952576, a2;\n\t"
36+
"prmt.b32 a8, 1061109504, 1077952576, a3;\n\t"
37+
"prmt.b32 a9, 32768, 0, a1;\n\t"
38+
"prmt.b32 a10, 32768, 0, a4;\n\t"
39+
"or.b32 a11, a7, a9;\n\t"
40+
"or.b32 a12, a8, a10;\n\t"
41+
"prmt.b32 $0, a5, a11, 20800;\n\t"
42+
"prmt.b32 $1, a5, a11, 29538;\n\t"
43+
"prmt.b32 $2, a6, a12, 20800;\n\t"
44+
"prmt.b32 $3, a6, a12, 29538;\n\t"
45+
"}";
46+
47+
static Value createInlineAsmUpcast(Location loc, RewriterBase &rewriter,
48+
Type retType, Value packedVec) {
49+
PTXBuilder builder;
50+
SmallVector<PTXBuilder::Operand *> operands;
51+
for (int i = 0; i < 4; i++) {
52+
operands.push_back(builder.newOperand("=r"));
53+
}
54+
operands.push_back(builder.newOperand(packedVec, "r"));
55+
auto &ptxOp = *builder.create(ptxAsm);
56+
ptxOp(operands, /*onlyAttachMLIRArgs=*/true);
57+
Value result = builder.launch(rewriter, loc, retType, false);
58+
return result;
59+
}
60+
61+
static SmallVector<Value> convertMxfp4x2ToBf16x2PTX(RewriterBase &rewriter,
62+
Location loc,
63+
ArrayRef<Value> values) {
64+
SmallVector<Value> results;
65+
MLIRContext *ctx = rewriter.getContext();
66+
assert(values.size() % 4 == 0);
67+
for (int i = 0; i < values.size(); i += 4) {
68+
Value v0 = values[i];
69+
Value v1 = values[i + 1];
70+
Value v2 = values[i + 2];
71+
Value v3 = values[i + 3];
72+
Value packedVec = undef(vec_ty(i8_ty, 4));
73+
packedVec = insert_element(packedVec, v0, i32_val(0));
74+
packedVec = insert_element(packedVec, v1, i32_val(1));
75+
packedVec = insert_element(packedVec, v2, i32_val(2));
76+
packedVec = insert_element(packedVec, v3, i32_val(3));
77+
SmallVector<Type> rets(4, i32_ty);
78+
Type retType = struct_ty(rets);
79+
Value ret = createInlineAsmUpcast(loc, rewriter, retType, packedVec);
80+
for (int i = 0; i < 4; i++) {
81+
Value extractI32 = extract_val(ret, i);
82+
Value vecbf16 = bitcast(extractI32, vec_ty(bf16_ty, 2));
83+
results.push_back(extract_element(vecbf16, i32_val(0)));
84+
results.push_back(extract_element(vecbf16, i32_val(1)));
85+
}
86+
}
87+
return results;
88+
}
89+
2290
namespace {
2391
class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
2492
private:
@@ -53,7 +121,7 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
53121
cast<DotOperandEncodingAttr>(op.getType().getEncoding()).getKWidth();
54122

55123
if (fpType == ScaleDotElemType::E2M1)
56-
xVals = LLVM::convertMxfp4x2ToBf16x2(rewriter, loc, xVals);
124+
xVals = convertMxfp4x2ToBf16x2PTX(rewriter, loc, xVals);
57125

58126
// Each thread owns elements of 4 mxfp vectors so we need 4 scales
59127
// Since we go from a threadShape of 8x4 to 16x2, we let c = tid / 4 * 2

0 commit comments

Comments
 (0)