Skip to content

Commit 2e0de69

Browse files
[Intel] Remove legacy DPAS conversion (#3529)
This PR removes `ConvertLayoutOpConversion`, which is the legacy way of converting DPAS layouts. The pass now relies on Linear Layout to perform such conversions. Signed-off-by: Whitney Tsang <[email protected]>
1 parent 03e9323 commit 2e0de69

File tree

1 file changed

+0
-158
lines changed

1 file changed

+0
-158
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 0 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -31,162 +31,6 @@ namespace {
3131
// because LinearLayout seems to have some performance issues.
3232
constexpr bool useLegacyMMAConversion = false;
3333

34-
struct ConvertLayoutOpConversion
35-
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ConvertLayoutOp> {
36-
public:
37-
ConvertLayoutOpConversion(const LLVMTypeConverter &typeConverter,
38-
const triton::intel::TargetInfo &targetInfo,
39-
PatternBenefit benefit = 1)
40-
: ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit) {}
41-
42-
LogicalResult
43-
matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
44-
ConversionPatternRewriter &rewriter) const override {
45-
RankedTensorType srcTy = op.getSrc().getType();
46-
RankedTensorType dstTy = op.getType();
47-
Attribute srcLayout = srcTy.getEncoding();
48-
Attribute dstLayout = dstTy.getEncoding();
49-
if (isa<DpasEncodingAttr>(srcLayout) &&
50-
isa<DotOperandEncodingAttr>(dstLayout)) {
51-
return lowerDpasToDotOperand(op, adaptor, rewriter);
52-
}
53-
return failure();
54-
}
55-
56-
private:
57-
using ValueTable = std::map<std::array<unsigned, 3>, Value>;
58-
59-
ValueTable getValuesFromDpasLayoutStruct(Location loc,
60-
ConversionPatternRewriter &rewriter,
61-
Value vals,
62-
RankedTensorType srcType) const {
63-
SmallVector<Value> elems = unpackLLElements(loc, vals, rewriter);
64-
auto dpasLayout = dyn_cast<DpasEncodingAttr>(srcType.getEncoding());
65-
66-
size_t totalElems = elems.size();
67-
auto numElemsPerOperand =
68-
product<unsigned>(dpasLayout.getDPASInstShapeC()) /
69-
product<unsigned>(dpasLayout.getThreadsPerWarp());
70-
Type elemTy =
71-
this->getTypeConverter()->convertType(srcType.getElementType());
72-
VectorType dotOpTy = vec_ty(elemTy, numElemsPerOperand);
73-
SmallVector<int64_t> repetitions =
74-
dpasLayout.getDPASRepetitions(srcType.getShape(), 2 /*operand C*/);
75-
ArrayRef<unsigned> repCluster = dpasLayout.getRepCluster();
76-
size_t rank = repCluster.size();
77-
size_t outerDim = rank - 2;
78-
size_t innerDim = rank - 1;
79-
80-
auto tb = TritonLLVMOpBuilder(loc, rewriter);
81-
int offset = 0;
82-
ValueTable result;
83-
for (unsigned b = 0; b < repetitions[0]; ++b) {
84-
for (int i = 0; i < repetitions[1]; ++i) {
85-
for (int j = 0; j < repetitions[2]; ++j) {
86-
for (int repOuter = 0; repOuter < repCluster[outerDim]; ++repOuter) {
87-
for (int repInner = 0; repInner < repCluster[innerDim];
88-
++repInner) {
89-
Value matVal = rewriter.create<LLVM::UndefOp>(loc, dotOpTy);
90-
for (int k = 0; k < numElemsPerOperand; ++k) {
91-
matVal = tb.insert_element(dotOpTy, matVal, elems[offset++],
92-
tb.i32_val(k));
93-
}
94-
result[{b, i * repCluster[outerDim] + repOuter,
95-
j * repCluster[innerDim] + repInner}] = matVal;
96-
}
97-
}
98-
}
99-
}
100-
}
101-
102-
return result;
103-
}
104-
105-
Value composeValuesToDotOperandLayoutStruct(
106-
Location loc, ConversionPatternRewriter &rewriter, const ValueTable &vals,
107-
RankedTensorType dstType) const {
108-
auto tb = TritonLLVMOpBuilder(loc, rewriter);
109-
auto dotLayout = dyn_cast<DotOperandEncodingAttr>(dstType.getEncoding());
110-
auto dpasLayout = dyn_cast<DpasEncodingAttr>(dotLayout.getParent());
111-
112-
auto opIdx = static_cast<DpasEncodingAttr::OpIdx>(dotLayout.getOpIdx());
113-
SmallVector<int64_t> repetitions =
114-
dpasLayout.getDPASRepetitions(dstType.getShape(), opIdx);
115-
ArrayRef<unsigned> repCluster = dpasLayout.getRepCluster();
116-
size_t rank = repCluster.size();
117-
unsigned repBatch = repetitions[0];
118-
unsigned repOuter = 0u;
119-
unsigned repInner = 0u;
120-
unsigned repClusterOuter = 0u;
121-
122-
switch (opIdx) {
123-
case DpasEncodingAttr::OpIdx::OperandA: {
124-
// operand A
125-
repOuter = repetitions[1];
126-
repInner = repetitions[2];
127-
repClusterOuter = repCluster[rank - 2];
128-
} break;
129-
case DpasEncodingAttr::OpIdx::OperandB: {
130-
// operand B
131-
repOuter = repetitions[2];
132-
repInner = repetitions[1];
133-
repClusterOuter = repCluster[rank - 1];
134-
} break;
135-
case DpasEncodingAttr::OpIdx::OperandC: {
136-
llvm_unreachable("unexpected OpIdx::OperandC");
137-
} break;
138-
}
139-
140-
// TODO: Operands B requires extra steps to combine [8, 16] to [16, 16].
141-
SmallVector<Value> elems;
142-
for (unsigned b = 0; b < repBatch; ++b) {
143-
for (int m = 0; m < repOuter; ++m) {
144-
for (int k = 0; k < repInner; ++k) {
145-
for (int repOuterIdx = 0; repOuterIdx < repClusterOuter;
146-
++repOuterIdx) {
147-
unsigned offsetM = m * repClusterOuter + repOuterIdx;
148-
unsigned offsetN = k;
149-
Value matVal = vals.at({b, offsetM, offsetN});
150-
auto vecType = cast<VectorType>(matVal.getType());
151-
Type valTy = vecType.getElementType();
152-
for (int i = 0; i < vecType.getNumElements(); ++i) {
153-
Value val = tb.extract_element(valTy, matVal, tb.i32_val(i));
154-
elems.push_back(val);
155-
}
156-
}
157-
}
158-
}
159-
}
160-
161-
Type elemTy = getTypeConverter()->convertType(dstType.getElementType());
162-
Type structTy = LLVM::LLVMStructType::getLiteral(
163-
getContext(), SmallVector<Type>(elems.size(), elemTy));
164-
return packLLElements(loc, this->getTypeConverter(), elems, rewriter,
165-
structTy);
166-
}
167-
168-
// dpas -> dot_operand
169-
LogicalResult
170-
lowerDpasToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
171-
ConversionPatternRewriter &rewriter) const {
172-
Location loc = op.getLoc();
173-
RankedTensorType srcTy = op.getSrc().getType();
174-
RankedTensorType dstTy = op.getType();
175-
176-
if (!intel::isDpasToDotShortcut(srcTy, dstTy))
177-
return failure();
178-
179-
// reorder the elements to match the dot_operand layout.
180-
ValueTable values =
181-
getValuesFromDpasLayoutStruct(loc, rewriter, adaptor.getSrc(), srcTy);
182-
Value view =
183-
composeValuesToDotOperandLayoutStruct(loc, rewriter, values, dstTy);
184-
185-
rewriter.replaceOp(op, view);
186-
return success();
187-
}
188-
};
189-
19034
struct ConvertLayoutOpUsingLinearLayoutsConversion
19135
: public ConvertOpToLLVMPattern<ConvertLayoutOp> {
19236
const TargetInfoBase &targetInfo;
@@ -1006,8 +850,6 @@ void mlir::triton::intel::populateConvertLayoutOpToLLVMPatterns(
1006850
// and be the only one left.
1007851
patterns.add<gpu::ConvertLayoutOpUsingLinearLayoutsConversion>(
1008852
typeConverter, targetInfo, benefit.getBenefit() + 2);
1009-
patterns.add<gpu::ConvertLayoutOpConversion>(typeConverter, targetInfo,
1010-
benefit.getBenefit() + 1);
1011853
mlir::triton::populateConvertLayoutOpToLLVMPatterns(typeConverter, targetInfo,
1012854
patterns, benefit);
1013855
}

0 commit comments

Comments
 (0)