Skip to content

Commit be47a27

Browse files
Merge commit '3613bf40d90a38766ec65a250aeadb391f9f7fc9'
2 parents c0c76b1 + 3613bf4 commit be47a27

File tree

5 files changed

+168
-34
lines changed

5 files changed

+168
-34
lines changed

lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,36 +41,60 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
4141
if (inBitWidth == ouBitWidth)
4242
return values;
4343
if (inBitWidth == 16 && ouBitWidth == 32) {
44+
// Register layout conversion:
45+
//
46+
// [0, 1], [4, 5] ⟶ [0], [1], [4], [5]
47+
// [2, 3], [6, 7] [2], [3], [6], [7]
48+
//
49+
// Original access order:
50+
//
51+
// [0, 1], [2, 3], [4, 5], [6, 7]
52+
//
53+
// Transformed access order:
54+
//
55+
// [0], [2], [1], [3], [4], [6], [5], [7]
4456
SmallVector<Value> ret;
4557
for (unsigned i = 0; i < values.size(); i += 8) {
4658
ret.push_back(values[i]);
47-
ret.push_back(values[i + 1]);
48-
ret.push_back(values[i + 4]);
49-
ret.push_back(values[i + 5]);
5059
ret.push_back(values[i + 2]);
60+
ret.push_back(values[i + 1]);
5161
ret.push_back(values[i + 3]);
62+
ret.push_back(values[i + 4]);
5263
ret.push_back(values[i + 6]);
64+
ret.push_back(values[i + 5]);
5365
ret.push_back(values[i + 7]);
5466
}
5567
return ret;
5668
}
5769
if (inBitWidth == 8 && ouBitWidth == 16) {
70+
// Register layout conversion:
71+
//
72+
// [0, 1, 2, 3], [8, 9, 10, 11] ⟶ [0, 1], [2, 3], [8, 9], [10, 11]
73+
// [4, 5, 6, 7], [12, 13, 14, 15] [4, 5], [6, 7], [12, 13], [14, 15]
74+
//
75+
// Original access order:
76+
//
77+
// [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]
78+
//
79+
// Transformed access order:
80+
//
81+
// [0, 1], [4, 5], [2, 3], [6, 7], [8, 9], [12, 13], [10, 11], [14, 15]
5882
SmallVector<Value> ret;
5983
for (unsigned i = 0; i < values.size(); i += 16) {
60-
ret.push_back(values[i + 0]);
84+
ret.push_back(values[i]);
6185
ret.push_back(values[i + 1]);
62-
ret.push_back(values[i + 2]);
63-
ret.push_back(values[i + 3]);
64-
ret.push_back(values[i + 8]);
65-
ret.push_back(values[i + 9]);
66-
ret.push_back(values[i + 10]);
67-
ret.push_back(values[i + 11]);
6886
ret.push_back(values[i + 4]);
6987
ret.push_back(values[i + 5]);
88+
ret.push_back(values[i + 2]);
89+
ret.push_back(values[i + 3]);
7090
ret.push_back(values[i + 6]);
7191
ret.push_back(values[i + 7]);
92+
ret.push_back(values[i + 8]);
93+
ret.push_back(values[i + 9]);
7294
ret.push_back(values[i + 12]);
7395
ret.push_back(values[i + 13]);
96+
ret.push_back(values[i + 10]);
97+
ret.push_back(values[i + 11]);
7498
ret.push_back(values[i + 14]);
7599
ret.push_back(values[i + 15]);
76100
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,6 @@ struct ConvertLayoutOpConversion
641641
// for the destination type, we need to pack values together
642642
// so they can be consumed by tensor core operations
643643
SmallVector<Value> vecVals;
644-
SmallVector<Type> types;
645644
// For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer
646645
// instructions to pack & unpack sub-word integers. A workaround is to
647646
// store the results of ldmatrix in i32
@@ -655,37 +654,20 @@ struct ConvertLayoutOpConversion
655654
shl(i32_ty, zext(i32_ty, vals[i + j]), i32_val(elemSize * j));
656655
val = or_(i32_ty, val, ext);
657656
}
658-
vecVals.push_back(val);
657+
vecVals.push_back(bitcast(val, i32_ty));
659658
}
660-
elems = elems / (32 / elemSize);
661-
types = SmallVector<Type>(elems, i32_ty);
662659
} else {
663660
unsigned vecSize = std::max<unsigned>(32 / elemSize, 1);
664661
Type vecTy = vec_ty(elemTy, vecSize);
665-
types = SmallVector<Type>(elems / vecSize, vecTy);
666662
for (unsigned i = 0; i < elems; i += vecSize) {
667663
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
668664
for (unsigned j = 0; j < vecSize; j++)
669665
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
670-
vecVals.push_back(packed);
666+
vecVals.push_back(bitcast(packed, i32_ty));
671667
}
672668
}
673-
674-
// This needs to be ordered the same way that
675-
// ldmatrix.x4 would order it
676-
// TODO: this needs to be refactor so we don't
677-
// implicitly depends on how emitOffsetsForMMAV2
678-
// is implemented
679-
SmallVector<Value> reorderedVals;
680-
for (unsigned i = 0; i < vecVals.size(); i += 4) {
681-
reorderedVals.push_back(bitcast(vecVals[i], i32_ty));
682-
reorderedVals.push_back(bitcast(vecVals[i + 2], i32_ty));
683-
reorderedVals.push_back(bitcast(vecVals[i + 1], i32_ty));
684-
reorderedVals.push_back(bitcast(vecVals[i + 3], i32_ty));
685-
}
686-
687-
Value view = packLLElements(loc, getTypeConverter(), reorderedVals,
688-
rewriter, dstTy);
669+
Value view =
670+
packLLElements(loc, getTypeConverter(), vecVals, rewriter, dstTy);
689671
rewriter.replaceOp(op, view);
690672
return success();
691673
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,8 +513,8 @@ Value composeValuesToDotOperandLayoutStruct(
513513
for (int m = 0; m < n0; ++m)
514514
for (int k = 0; k < n1; ++k) {
515515
elems.push_back(vals.at({b, 2 * m, 2 * k}));
516-
elems.push_back(vals.at({b, 2 * m, 2 * k + 1}));
517516
elems.push_back(vals.at({b, 2 * m + 1, 2 * k}));
517+
elems.push_back(vals.at({b, 2 * m, 2 * k + 1}));
518518
elems.push_back(vals.at({b, 2 * m + 1, 2 * k + 1}));
519519
}
520520
assert(!elems.empty());

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,86 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct(
6262
auto elems = unpackLLElements(loc, value, rewriter);
6363
int offset{};
6464
ValueTableV2 vals;
65+
66+
// FIXME [Dot LL]
67+
// [ez] Generalize the logic below for kWidth * elemBitWidth > 32
68+
auto dot = cast<DotOperandEncodingAttr>(type.getEncoding());
69+
auto largeK = dot.getKWidth() == 8 &&
70+
cast<NvidiaMmaEncodingAttr>(dot.getParent()).isAmpere();
71+
if (largeK) {
72+
llvm::SmallVector<unsigned> si;
73+
74+
// For kWidth = 8, split the mma into 4 mmas with "stride 4" along K
75+
if (dot.getOpIdx() == 0) {
76+
// Original register layout:
77+
//
78+
// [0, 1, 2, 3], [8, 9, 10, 11]
79+
// [4, 5, 6, 7], [12, 13, 14, 15]
80+
//
81+
// Each element in the layout consists of two bf16 values.
82+
// For example, the row [0, 1, 2, 3] expands to:
83+
//
84+
// [[0/0, 0/1], [1/0, 1/1], [2/0, 2/1], [3/0, 3/1]]
85+
//
86+
// Here, 0/0 refers to the first half of element 0, and 0/1 refers to the
87+
// second half, matching kWidth = 8.
88+
//
89+
// To derive four independent MMA operations, a stride of 4 is applied to
90+
// the original register layout:
91+
//
92+
// 1st MMA: [0, 4, 8, 12]
93+
// 2nd MMA: [1, 5, 9, 13]
94+
// 3rd MMA: [2, 6, 10, 14]
95+
// 4th MMA: [3, 7, 11, 15]
96+
si = llvm::SmallVector<unsigned>{0, 4, 8, 12, 1, 5, 9, 13,
97+
2, 6, 10, 14, 3, 7, 11, 15};
98+
} else {
99+
// Original register layout:
100+
//
101+
// [0, 1, 2, 3]^T, [4, 5, 6, 7]^T
102+
//
103+
// A stride of 4 is applied to derive four independent MMA operations:
104+
//
105+
// 1st MMA: [0, 4]
106+
// 2nd MMA: [1, 5]
107+
// 3rd MMA: [2, 6]
108+
// 4th MMA: [3, 7]
109+
si = llvm::SmallVector<unsigned>{0, 4, 1, 5, 2, 6, 3, 7};
110+
}
111+
112+
auto step = si.size();
113+
SmallVector<Value> perm(step);
114+
for (auto i = 0; i < elems.size() / step; ++i) {
115+
for (auto j = 0; j < step; ++j) {
116+
perm[j] = elems[i * step + si[j]];
117+
}
118+
std::copy(perm.begin(), perm.end(), elems.begin() + i * step);
119+
}
120+
121+
if (dot.getOpIdx() == 1) {
122+
// there are kWidth * 2 elems packed as bf16x2
123+
int elemsInTile = dot.getKWidth();
124+
// n0 and n1 are unrolled in the legacy path
125+
// Unrolling n1 makes some sense, but unrolling n0 makes absolutely no
126+
// sense IMO
127+
n0 *= 2;
128+
n1 *= 2;
129+
for (auto b = 0; b < batch; ++b)
130+
for (auto j = 0; j < n1 / elemsInTile; ++j)
131+
for (auto i = 0; i < n0; ++i)
132+
for (auto k = 0; k < elemsInTile; ++k) {
133+
vals[{b, i, elemsInTile * j + k}] = elems[offset++];
134+
}
135+
return vals;
136+
}
137+
}
138+
65139
for (auto b = 0; b < batch; ++b)
66140
for (auto i = 0; i < n0; ++i) {
67141
for (auto j = 0; j < n1; j++) {
68142
vals[{b, 2 * i, 2 * j}] = elems[offset++];
69-
vals[{b, 2 * i, 2 * j + 1}] = elems[offset++];
70143
vals[{b, 2 * i + 1, 2 * j}] = elems[offset++];
144+
vals[{b, 2 * i, 2 * j + 1}] = elems[offset++];
71145
vals[{b, 2 * i + 1, 2 * j + 1}] = elems[offset++];
72146
}
73147
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,60 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
2727
: ConvertOpToLLVMPattern<UpcastMXFPOp>(typeConverter, benefit),
2828
targetInfo(targetInfo) {}
2929

30+
llvm::SmallVector<Value>
31+
unpackFP4Elements(Location loc, ConversionPatternRewriter &rewriter,
32+
const llvm::SmallVector<Value> &vals, Value laneId) const {
33+
auto fp4x2ToBf16x2 = [&loc, &rewriter](Value v) -> Value {
34+
auto em0 = and_(v, i8_val(0x70));
35+
auto em1 = and_(v, i8_val(0x7));
36+
Value v0 = or_(shl(zext(i16_ty, em0), i16_val(2)),
37+
shl(zext(i16_ty, and_(v, i8_val(0x80))), i16_val(8)));
38+
Value v1 = or_(shl(zext(i16_ty, em1), i16_val(6)),
39+
shl(zext(i16_ty, and_(v, i8_val(0x8))), i16_val(12)));
40+
41+
// Three cases:
42+
// 1) x is normal and non-zero: Correct bias
43+
v0 = select(icmp_ne(and_(em0, i8_val(0x60)), i8_val(0)),
44+
add(v0, i16_val((127 - 1) << 7)), v0);
45+
v1 = select(icmp_ne(and_(em1, i8_val(0x6)), i8_val(0)),
46+
add(v1, i16_val((127 - 1) << 7)), v1);
47+
48+
// 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in
49+
// bf16
50+
v0 = select(icmp_eq(em0, i8_val(0x10)),
51+
or_(i16_val(16128), and_(v0, i16_val(0x8000))), v0);
52+
v1 = select(icmp_eq(em1, i8_val(0x1)),
53+
or_(i16_val(16128), and_(v1, i16_val(0x8000))), v1);
54+
// 3) x is zero, nothing to do
55+
56+
// Swap as they come packed in big endian
57+
return or_(zext(i32_ty, v0), shl(zext(i32_ty, v1), i32_val(16)));
58+
};
59+
60+
auto fp4x8ToBf16x2 = [&loc, &rewriter, &fp4x2ToBf16x2](
61+
Value v) -> llvm::SmallVector<Value, 4> {
62+
llvm::SmallVector<Value, 4> results(4);
63+
for (int i = 0; i < 4; ++i) {
64+
auto v_i = trunc(i8_ty, lshr(v, i32_val(8 * i)));
65+
results[i] = fp4x2ToBf16x2(v_i);
66+
}
67+
return results;
68+
};
69+
70+
// Split fp4x8 into 4 bf16x2
71+
llvm::SmallVector<Value> ret;
72+
ret.reserve(vals.size() * 4);
73+
for (int i = 0; i < vals.size(); ++i) {
74+
auto vs = fp4x8ToBf16x2(vals[i]);
75+
assert(vs.size() == 4);
76+
for (auto v : vs) {
77+
ret.push_back(v);
78+
}
79+
}
80+
81+
return ret;
82+
}
83+
3084
LogicalResult
3185
matchAndRewrite(UpcastMXFPOp op, OpAdaptor adaptor,
3286
ConversionPatternRewriter &rewriter) const override {

0 commit comments

Comments
 (0)