Skip to content

Commit 3613bf4

Browse files
authored
[BACKEND] Fix the register accessing order of dot operands of mmav2 (#4979)
1 parent 3c13f09 commit 3613bf4

File tree

5 files changed

+72
-49
lines changed

5 files changed

+72
-49
lines changed

lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,36 +41,60 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
4141
if (inBitWidth == ouBitWidth)
4242
return values;
4343
if (inBitWidth == 16 && ouBitWidth == 32) {
44+
// Register layout conversion:
45+
//
46+
// [0, 1], [4, 5] ⟶ [0], [1], [4], [5]
47+
// [2, 3], [6, 7] [2], [3], [6], [7]
48+
//
49+
// Original access order:
50+
//
51+
// [0, 1], [2, 3], [4, 5], [6, 7]
52+
//
53+
// Transformed access order:
54+
//
55+
// [0], [2], [1], [3], [4], [6], [5], [7]
4456
SmallVector<Value> ret;
4557
for (unsigned i = 0; i < values.size(); i += 8) {
4658
ret.push_back(values[i]);
47-
ret.push_back(values[i + 1]);
48-
ret.push_back(values[i + 4]);
49-
ret.push_back(values[i + 5]);
5059
ret.push_back(values[i + 2]);
60+
ret.push_back(values[i + 1]);
5161
ret.push_back(values[i + 3]);
62+
ret.push_back(values[i + 4]);
5263
ret.push_back(values[i + 6]);
64+
ret.push_back(values[i + 5]);
5365
ret.push_back(values[i + 7]);
5466
}
5567
return ret;
5668
}
5769
if (inBitWidth == 8 && ouBitWidth == 16) {
70+
// Register layout conversion:
71+
//
72+
// [0, 1, 2, 3], [8, 9, 10, 11] ⟶ [0, 1], [2, 3], [8, 9], [10, 11]
73+
// [4, 5, 6, 7], [12, 13, 14, 15] [4, 5], [6, 7], [12, 13], [14, 15]
74+
//
75+
// Original access order:
76+
//
77+
// [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]
78+
//
79+
// Transformed access order:
80+
//
81+
// [0, 1], [4, 5], [2, 3], [6, 7], [8, 9], [12, 13], [10, 11], [14, 15]
5882
SmallVector<Value> ret;
5983
for (unsigned i = 0; i < values.size(); i += 16) {
60-
ret.push_back(values[i + 0]);
84+
ret.push_back(values[i]);
6185
ret.push_back(values[i + 1]);
62-
ret.push_back(values[i + 2]);
63-
ret.push_back(values[i + 3]);
64-
ret.push_back(values[i + 8]);
65-
ret.push_back(values[i + 9]);
66-
ret.push_back(values[i + 10]);
67-
ret.push_back(values[i + 11]);
6886
ret.push_back(values[i + 4]);
6987
ret.push_back(values[i + 5]);
88+
ret.push_back(values[i + 2]);
89+
ret.push_back(values[i + 3]);
7090
ret.push_back(values[i + 6]);
7191
ret.push_back(values[i + 7]);
92+
ret.push_back(values[i + 8]);
93+
ret.push_back(values[i + 9]);
7294
ret.push_back(values[i + 12]);
7395
ret.push_back(values[i + 13]);
96+
ret.push_back(values[i + 10]);
97+
ret.push_back(values[i + 11]);
7498
ret.push_back(values[i + 14]);
7599
ret.push_back(values[i + 15]);
76100
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,6 @@ struct ConvertLayoutOpConversion
641641
// for the destination type, we need to pack values together
642642
// so they can be consumed by tensor core operations
643643
SmallVector<Value> vecVals;
644-
SmallVector<Type> types;
645644
// For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer
646645
// instructions to pack & unpack sub-word integers. A workaround is to
647646
// store the results of ldmatrix in i32
@@ -655,37 +654,20 @@ struct ConvertLayoutOpConversion
655654
shl(i32_ty, zext(i32_ty, vals[i + j]), i32_val(elemSize * j));
656655
val = or_(i32_ty, val, ext);
657656
}
658-
vecVals.push_back(val);
657+
vecVals.push_back(bitcast(val, i32_ty));
659658
}
660-
elems = elems / (32 / elemSize);
661-
types = SmallVector<Type>(elems, i32_ty);
662659
} else {
663660
unsigned vecSize = std::max<unsigned>(32 / elemSize, 1);
664661
Type vecTy = vec_ty(elemTy, vecSize);
665-
types = SmallVector<Type>(elems / vecSize, vecTy);
666662
for (unsigned i = 0; i < elems; i += vecSize) {
667663
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
668664
for (unsigned j = 0; j < vecSize; j++)
669665
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
670-
vecVals.push_back(packed);
666+
vecVals.push_back(bitcast(packed, i32_ty));
671667
}
672668
}
673-
674-
// This needs to be ordered the same way that
675-
// ldmatrix.x4 would order it
676-
// TODO: this needs to be refactor so we don't
677-
// implicitly depends on how emitOffsetsForMMAV2
678-
// is implemented
679-
SmallVector<Value> reorderedVals;
680-
for (unsigned i = 0; i < vecVals.size(); i += 4) {
681-
reorderedVals.push_back(bitcast(vecVals[i], i32_ty));
682-
reorderedVals.push_back(bitcast(vecVals[i + 2], i32_ty));
683-
reorderedVals.push_back(bitcast(vecVals[i + 1], i32_ty));
684-
reorderedVals.push_back(bitcast(vecVals[i + 3], i32_ty));
685-
}
686-
687-
Value view = packLLElements(loc, getTypeConverter(), reorderedVals,
688-
rewriter, dstTy);
669+
Value view =
670+
packLLElements(loc, getTypeConverter(), vecVals, rewriter, dstTy);
689671
rewriter.replaceOp(op, view);
690672
return success();
691673
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,8 +513,8 @@ Value composeValuesToDotOperandLayoutStruct(
513513
for (int m = 0; m < n0; ++m)
514514
for (int k = 0; k < n1; ++k) {
515515
elems.push_back(vals.at({b, 2 * m, 2 * k}));
516-
elems.push_back(vals.at({b, 2 * m, 2 * k + 1}));
517516
elems.push_back(vals.at({b, 2 * m + 1, 2 * k}));
517+
elems.push_back(vals.at({b, 2 * m, 2 * k + 1}));
518518
elems.push_back(vals.at({b, 2 * m + 1, 2 * k + 1}));
519519
}
520520
assert(!elems.empty());

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,39 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct(
7575

7676
// For kWidth = 8, split the mma into 4 mmas with "stride 4" along K
7777
if (dot.getOpIdx() == 0) {
78-
si = llvm::SmallVector<unsigned>{0, 8, 4, 12, 1, 9, 5, 13,
79-
2, 10, 6, 14, 3, 11, 7, 15};
78+
// Original register layout:
79+
//
80+
// [0, 1, 2, 3], [8, 9, 10, 11]
81+
// [4, 5, 6, 7], [12, 13, 14, 15]
82+
//
83+
// Each element in the layout consists of two bf16 values.
84+
// For example, the row [0, 1, 2, 3] expands to:
85+
//
86+
// [[0/0, 0/1], [1/0, 1/1], [2/0, 2/1], [3/0, 3/1]]
87+
//
88+
// Here, 0/0 refers to the first half of element 0, and 0/1 refers to the
89+
// second half, matching kWidth = 8.
90+
//
91+
// To derive four independent MMA operations, a stride of 4 is applied to
92+
// the original register layout:
93+
//
94+
// 1st MMA: [0, 4, 8, 12]
95+
// 2nd MMA: [1, 5, 9, 13]
96+
// 3rd MMA: [2, 6, 10, 14]
97+
// 4th MMA: [3, 7, 11, 15]
98+
si = llvm::SmallVector<unsigned>{0, 4, 8, 12, 1, 5, 9, 13,
99+
2, 6, 10, 14, 3, 7, 11, 15};
80100
} else {
101+
// Original register layout:
102+
//
103+
// [0, 1, 2, 3]^T, [4, 5, 6, 7]^T
104+
//
105+
// A stride of 4 is applied to derive four independent MMA operations:
106+
//
107+
// 1st MMA: [0, 4]
108+
// 2nd MMA: [1, 5]
109+
// 3rd MMA: [2, 6]
110+
// 4th MMA: [3, 7]
81111
si = llvm::SmallVector<unsigned>{0, 4, 1, 5, 2, 6, 3, 7};
82112
}
83113

@@ -112,8 +142,8 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct(
112142
for (auto i = 0; i < n0; ++i) {
113143
for (auto j = 0; j < n1; j++) {
114144
vals[{b, 2 * i, 2 * j}] = elems[offset++];
115-
vals[{b, 2 * i, 2 * j + 1}] = elems[offset++];
116145
vals[{b, 2 * i + 1, 2 * j}] = elems[offset++];
146+
vals[{b, 2 * i, 2 * j + 1}] = elems[offset++];
117147
vals[{b, 2 * i + 1, 2 * j + 1}] = elems[offset++];
118148
}
119149
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -80,19 +80,6 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
8080
ret.push_back(v);
8181
}
8282
}
83-
// FIXME [Dot LL]
84-
// The DotOperandEncodingAttr without LLs encodes the
85-
// layout as
86-
// e0 e1
87-
// e2 e3
88-
// rather than transposed that, as the PTX docs say
89-
// We transpose every block of 4 elements (kWidth = 8 -> 4 bf16x2)
90-
assert(ret.size() % 16 == 0);
91-
for (int i = 0; i < ret.size() / 16; ++i) {
92-
for (int j = 0; j < 4; ++j) {
93-
std::swap(ret[16 * i + j + 4], ret[16 * i + j + 8]);
94-
}
95-
}
9683

9784
return ret;
9885
}

0 commit comments

Comments
 (0)