@@ -31,162 +31,6 @@ namespace {
3131// because LinearLayout seems to have some performance issues.
3232constexpr bool useLegacyMMAConversion = false ;
3333
34- struct ConvertLayoutOpConversion
35- : public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ConvertLayoutOp> {
36- public:
37- ConvertLayoutOpConversion (const LLVMTypeConverter &typeConverter,
38- const triton::intel::TargetInfo &targetInfo,
39- PatternBenefit benefit = 1 )
40- : ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit) {}
41-
42- LogicalResult
43- matchAndRewrite (triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
44- ConversionPatternRewriter &rewriter) const override {
45- RankedTensorType srcTy = op.getSrc ().getType ();
46- RankedTensorType dstTy = op.getType ();
47- Attribute srcLayout = srcTy.getEncoding ();
48- Attribute dstLayout = dstTy.getEncoding ();
49- if (isa<DpasEncodingAttr>(srcLayout) &&
50- isa<DotOperandEncodingAttr>(dstLayout)) {
51- return lowerDpasToDotOperand (op, adaptor, rewriter);
52- }
53- return failure ();
54- }
55-
56- private:
57- using ValueTable = std::map<std::array<unsigned , 3 >, Value>;
58-
59- ValueTable getValuesFromDpasLayoutStruct (Location loc,
60- ConversionPatternRewriter &rewriter,
61- Value vals,
62- RankedTensorType srcType) const {
63- SmallVector<Value> elems = unpackLLElements (loc, vals, rewriter);
64- auto dpasLayout = dyn_cast<DpasEncodingAttr>(srcType.getEncoding ());
65-
66- size_t totalElems = elems.size ();
67- auto numElemsPerOperand =
68- product<unsigned >(dpasLayout.getDPASInstShapeC ()) /
69- product<unsigned >(dpasLayout.getThreadsPerWarp ());
70- Type elemTy =
71- this ->getTypeConverter ()->convertType (srcType.getElementType ());
72- VectorType dotOpTy = vec_ty (elemTy, numElemsPerOperand);
73- SmallVector<int64_t > repetitions =
74- dpasLayout.getDPASRepetitions (srcType.getShape (), 2 /* operand C*/ );
75- ArrayRef<unsigned > repCluster = dpasLayout.getRepCluster ();
76- size_t rank = repCluster.size ();
77- size_t outerDim = rank - 2 ;
78- size_t innerDim = rank - 1 ;
79-
80- auto tb = TritonLLVMOpBuilder (loc, rewriter);
81- int offset = 0 ;
82- ValueTable result;
83- for (unsigned b = 0 ; b < repetitions[0 ]; ++b) {
84- for (int i = 0 ; i < repetitions[1 ]; ++i) {
85- for (int j = 0 ; j < repetitions[2 ]; ++j) {
86- for (int repOuter = 0 ; repOuter < repCluster[outerDim]; ++repOuter) {
87- for (int repInner = 0 ; repInner < repCluster[innerDim];
88- ++repInner) {
89- Value matVal = rewriter.create <LLVM::UndefOp>(loc, dotOpTy);
90- for (int k = 0 ; k < numElemsPerOperand; ++k) {
91- matVal = tb.insert_element (dotOpTy, matVal, elems[offset++],
92- tb.i32_val (k));
93- }
94- result[{b, i * repCluster[outerDim] + repOuter,
95- j * repCluster[innerDim] + repInner}] = matVal;
96- }
97- }
98- }
99- }
100- }
101-
102- return result;
103- }
104-
105- Value composeValuesToDotOperandLayoutStruct (
106- Location loc, ConversionPatternRewriter &rewriter, const ValueTable &vals,
107- RankedTensorType dstType) const {
108- auto tb = TritonLLVMOpBuilder (loc, rewriter);
109- auto dotLayout = dyn_cast<DotOperandEncodingAttr>(dstType.getEncoding ());
110- auto dpasLayout = dyn_cast<DpasEncodingAttr>(dotLayout.getParent ());
111-
112- auto opIdx = static_cast <DpasEncodingAttr::OpIdx>(dotLayout.getOpIdx ());
113- SmallVector<int64_t > repetitions =
114- dpasLayout.getDPASRepetitions (dstType.getShape (), opIdx);
115- ArrayRef<unsigned > repCluster = dpasLayout.getRepCluster ();
116- size_t rank = repCluster.size ();
117- unsigned repBatch = repetitions[0 ];
118- unsigned repOuter = 0u ;
119- unsigned repInner = 0u ;
120- unsigned repClusterOuter = 0u ;
121-
122- switch (opIdx) {
123- case DpasEncodingAttr::OpIdx::OperandA: {
124- // operand A
125- repOuter = repetitions[1 ];
126- repInner = repetitions[2 ];
127- repClusterOuter = repCluster[rank - 2 ];
128- } break ;
129- case DpasEncodingAttr::OpIdx::OperandB: {
130- // operand B
131- repOuter = repetitions[2 ];
132- repInner = repetitions[1 ];
133- repClusterOuter = repCluster[rank - 1 ];
134- } break ;
135- case DpasEncodingAttr::OpIdx::OperandC: {
136- llvm_unreachable (" unexpected OpIdx::OperandC" );
137- } break ;
138- }
139-
140- // TODO: Operands B requires extra steps to combine [8, 16] to [16, 16].
141- SmallVector<Value> elems;
142- for (unsigned b = 0 ; b < repBatch; ++b) {
143- for (int m = 0 ; m < repOuter; ++m) {
144- for (int k = 0 ; k < repInner; ++k) {
145- for (int repOuterIdx = 0 ; repOuterIdx < repClusterOuter;
146- ++repOuterIdx) {
147- unsigned offsetM = m * repClusterOuter + repOuterIdx;
148- unsigned offsetN = k;
149- Value matVal = vals.at ({b, offsetM, offsetN});
150- auto vecType = cast<VectorType>(matVal.getType ());
151- Type valTy = vecType.getElementType ();
152- for (int i = 0 ; i < vecType.getNumElements (); ++i) {
153- Value val = tb.extract_element (valTy, matVal, tb.i32_val (i));
154- elems.push_back (val);
155- }
156- }
157- }
158- }
159- }
160-
161- Type elemTy = getTypeConverter ()->convertType (dstType.getElementType ());
162- Type structTy = LLVM::LLVMStructType::getLiteral (
163- getContext (), SmallVector<Type>(elems.size (), elemTy));
164- return packLLElements (loc, this ->getTypeConverter (), elems, rewriter,
165- structTy);
166- }
167-
168- // dpas -> dot_operand
169- LogicalResult
170- lowerDpasToDotOperand (triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
171- ConversionPatternRewriter &rewriter) const {
172- Location loc = op.getLoc ();
173- RankedTensorType srcTy = op.getSrc ().getType ();
174- RankedTensorType dstTy = op.getType ();
175-
176- if (!intel::isDpasToDotShortcut (srcTy, dstTy))
177- return failure ();
178-
179- // reorder the elements to match the dot_operand layout.
180- ValueTable values =
181- getValuesFromDpasLayoutStruct (loc, rewriter, adaptor.getSrc (), srcTy);
182- Value view =
183- composeValuesToDotOperandLayoutStruct (loc, rewriter, values, dstTy);
184-
185- rewriter.replaceOp (op, view);
186- return success ();
187- }
188- };
189-
19034struct ConvertLayoutOpUsingLinearLayoutsConversion
19135 : public ConvertOpToLLVMPattern<ConvertLayoutOp> {
19236 const TargetInfoBase &targetInfo;
@@ -1006,8 +850,6 @@ void mlir::triton::intel::populateConvertLayoutOpToLLVMPatterns(
1006850 // and be the only one left.
1007851 patterns.add <gpu::ConvertLayoutOpUsingLinearLayoutsConversion>(
1008852 typeConverter, targetInfo, benefit.getBenefit () + 2 );
1009- patterns.add <gpu::ConvertLayoutOpConversion>(typeConverter, targetInfo,
1010- benefit.getBenefit () + 1 );
1011853 mlir::triton::populateConvertLayoutOpToLLVMPatterns (typeConverter, targetInfo,
1012854 patterns, benefit);
1013855}
0 commit comments