@@ -61,19 +61,6 @@ getSwizzledShape(ArrayRef<OpFoldResult> packedShape,
6161 return newShape;
6262}
6363
64- static Operation *dropEncodingAndCloneOp (OpBuilder &builder, Operation *op,
65- ValueRange convertedInputOperands,
66- ValueRange convertedOutputOperands) {
67- SmallVector<Value> operands;
68- operands.append (convertedInputOperands.begin (), convertedInputOperands.end ());
69- operands.append (convertedOutputOperands.begin (),
70- convertedOutputOperands.end ());
71- return mlir::clone (builder, op,
72- {dropEncoding (cast<RankedTensorType>(
73- convertedOutputOperands[0 ].getType ()))},
74- operands);
75- }
76-
7764static FailureOr<SmallVector<OpFoldResult>>
7865getInnerTileSizesOfr (OpBuilder &rewriter, Location loc,
7966 RankedTensorType tensorType,
@@ -111,91 +98,6 @@ getInnerTileSizesOfr(OpBuilder &rewriter, Location loc,
11198 return result;
11299}
113100
114- RankedTensorType getExpandedType (RankedTensorType type, bool isBatched,
115- bool isTransposed,
116- SmallVectorImpl<ReassociationIndices> &ri) {
117- if (!isBatched) {
118- ri.assign ({{0 , 1 }, {2 , 3 }});
119- if (!isTransposed) {
120- return RankedTensorType::get (
121- {1 , type.getDimSize (0 ), 1 , type.getDimSize (1 )},
122- type.getElementType ());
123- }
124- return RankedTensorType::get ({type.getDimSize (0 ), 1 , type.getDimSize (1 ), 1 },
125- type.getElementType ());
126- }
127-
128- ri.assign ({{0 }, {1 , 2 }, {3 , 4 }});
129- if (!isTransposed) {
130- return RankedTensorType::get (
131- {type.getDimSize (0 ), 1 , type.getDimSize (1 ), 1 , type.getDimSize (2 )},
132- type.getElementType ());
133- }
134- return RankedTensorType::get (
135- {type.getDimSize (0 ), type.getDimSize (1 ), 1 , type.getDimSize (2 ), 1 },
136- type.getElementType ());
137- }
138-
139- // / Given an input Value and a desired output element type, create and return
140- // / an element-wise linalg::GenericOp that extends the input Value to the
141- // / output element type.
142- static Value createElementWiseExtUIOp (RewriterBase &rewriter, Value input,
143- Location loc, Type outElemType) {
144- auto inputType = cast<RankedTensorType>(input.getType ());
145- SmallVector<AffineMap> maps (
146- 2 , rewriter.getMultiDimIdentityMap (inputType.getRank ()));
147- SmallVector<utils::IteratorType> iteratorTypes (inputType.getRank (),
148- utils::IteratorType::parallel);
149- auto castedType = inputType.clone (outElemType);
150- SmallVector<OpFoldResult> inputMixedSizes =
151- tensor::getMixedSizes (rewriter, loc, input);
152- Value init =
153- rewriter.create <tensor::EmptyOp>(loc, inputMixedSizes, outElemType);
154- return rewriter
155- .create <linalg::GenericOp>(
156- loc, castedType, input, init, maps, iteratorTypes,
157- [&](OpBuilder &b, Location nestedLoc, ValueRange args) {
158- Value castRes =
159- b.create <arith::ExtUIOp>(nestedLoc, outElemType, args[0 ])
160- ->getResult (0 );
161- b.create <linalg::YieldOp>(nestedLoc, castRes);
162- })
163- .getResult (0 );
164- }
165-
166- // / If needed, expand and the input Value, and return the resulting input with
167- // / the canonical mmt4d input shape. If the input element type is unsigned,
168- // / create a producer Linalg::GenericOp on the input that unsigned extends the
169- // / input to the output element type. This extension is required to keep the
170- // / unsignedness information on the input for ukernels. If `transpose` is true,
171- // / the `linalgOp`'s indexing maps are transposed.
172- static Value getMmt4dOperand (Value value, linalg::LinalgOp linalgOp,
173- bool transpose, RewriterBase &rewriter,
174- SmallVectorImpl<ReassociationIndices> &ri,
175- ArrayRef<Type> elemTypes, int operandIdx) {
176- assert (linalgOp.getNumDpsInputs () == 2 );
177- assert (linalgOp.getNumDpsInits () == 1 );
178- auto cDims = linalg::inferContractionDims (linalgOp);
179- Location loc = linalgOp->getLoc ();
180- Value expandedValue = value;
181- // If vecmat with non-rhs operandIdx or matvec with non-lhs operandIdx, the
182- // operand is a vector and must be extended
183- if ((cDims->m .empty () && operandIdx != 1 ) ||
184- (cDims->n .empty () && operandIdx != 0 )) {
185- auto type = cast<RankedTensorType>(value.getType ());
186- RankedTensorType newType = getExpandedType (
187- type, /* isBatched=*/ !cDims->batch .empty (),
188- /* isTransposed=*/ operandIdx == 2 && (transpose ^ cDims->n .empty ()), ri);
189- expandedValue =
190- rewriter.create <tensor::ExpandShapeOp>(loc, newType, value, ri);
191- }
192- if (elemTypes[operandIdx].isUnsignedInteger ()) {
193- return createElementWiseExtUIOp (rewriter, expandedValue, loc,
194- elemTypes.back ());
195- }
196- return expandedValue;
197- }
198-
199101static void transposeInPlace (MaterializeEncodingInfo &info) {
200102 // Vector cases: nothing to do.
201103 if (info.innerTileSizes .size () < 2 ) {
@@ -297,75 +199,6 @@ FailureOr<tensor::UnPackOp> lowerUnsetEncodingToUnpackOp(
297199 encodingInfo->outerDimsPerm );
298200}
299201
300- static FailureOr<Operation *> lowerContractionOpWithEncoding (
301- RewriterBase &rewriter, linalg::LinalgOp linalgOp, ValueRange operands,
302- const MaterializeEncodingTypeConverter &typeConverter) {
303- if (!linalgOp.hasPureTensorSemantics ())
304- return failure ();
305-
306- auto inputs = linalgOp.getDpsInputOperands ();
307- auto outputs = linalgOp.getDpsInits ();
308-
309- auto lhsType = cast<RankedTensorType>(inputs[0 ]->get ().getType ());
310- auto rhsType = cast<RankedTensorType>(inputs[1 ]->get ().getType ());
311- auto resultType = cast<RankedTensorType>(outputs[0 ].getType ());
312- auto lhsEncoding = IREE::Encoding::getEncodingAttr (lhsType);
313- auto rhsEncoding = IREE::Encoding::getEncodingAttr (rhsType);
314- auto resultEncoding = IREE::Encoding::getEncodingAttr (resultType);
315- if (!lhsEncoding || !rhsEncoding || !resultEncoding) {
316- return failure ();
317- }
318-
319- if (lhsEncoding.getOperandIndex ().getValue () != IREE::Encoding::MATMUL_LHS ||
320- rhsEncoding.getOperandIndex ().getValue () != IREE::Encoding::MATMUL_RHS ||
321- resultEncoding.getOperandIndex ().getValue () !=
322- IREE::Encoding::MATMUL_RESULT) {
323- return failure ();
324- }
325-
326- FailureOr<MaterializeEncodingInfo> encodingInfo =
327- typeConverter.getEncodingInfo (
328- cast<RankedTensorType>(linalgOp->getResultTypes ()[0 ]));
329-
330- Operation *result;
331- if (failed (encodingInfo)) {
332- result = dropEncodingAndCloneOp (rewriter, linalgOp,
333- operands.take_front (inputs.size ()),
334- operands.drop_front (inputs.size ()));
335- } else {
336- bool transpose =
337- typeConverter.getTransposeNarrowN () && isNarrowNResult (resultEncoding);
338- SmallVector<Type> elemTypes = lhsEncoding.getElementTypesArray ();
339- SmallVector<ReassociationIndices> ri;
340- Value newLhs = getMmt4dOperand (operands[0 ], linalgOp, transpose, rewriter,
341- ri, elemTypes, /* operandIdx=*/ 0 );
342- Value newRhs = getMmt4dOperand (operands[1 ], linalgOp, transpose, rewriter,
343- ri, elemTypes, /* operandIdx=*/ 1 );
344- Value newResult =
345- getMmt4dOperand (operands[2 ], linalgOp, transpose, rewriter, ri,
346- elemTypes, /* operandIdx=*/ 2 );
347- if (transpose) {
348- std::swap (newLhs, newRhs);
349- }
350- Type newResultType = newResult.getType ();
351- auto cDims = IREE::Encoding::getEncodingContractionDims (lhsEncoding);
352- if (cDims->batch .empty ()) {
353- result = rewriter.create <linalg::Mmt4DOp>(
354- linalgOp.getLoc (), newResultType, ValueRange{newLhs, newRhs},
355- ValueRange{newResult});
356- } else {
357- result = rewriter.create <linalg::BatchMmt4DOp>(
358- linalgOp.getLoc (), newResultType, ValueRange{newLhs, newRhs},
359- ValueRange{newResult});
360- }
361- if (!ri.empty ()) {
362- result = rewriter.create <tensor::CollapseShapeOp>(
363- linalgOp->getLoc (), operands[2 ].getType (), result->getResult (0 ), ri);
364- }
365- }
366- return result;
367- }
368-
369202// / Utility method to convert `tensor.empty` with encoding to a `tensor.empty`
370203// / of the materialized type.
371204static FailureOr<Operation *>
@@ -901,8 +734,17 @@ class MaterializeContractionOp
901734
902735 auto converter = static_cast <const MaterializeEncodingTypeConverter *>(
903736 this ->getTypeConverter ());
737+ // TODO(hanchung): This is a transition state for moving the implementation
738+ // details to backend attributes. We won't need the function type argument
739+ // after all the backends that support encodings implement the attribute.
740+ auto getEncodingInfoWrapper =
741+ [&](RankedTensorType type) -> FailureOr<MaterializeEncodingInfo> {
742+ return converter->getEncodingInfo (type);
743+ };
904744 FailureOr<Operation *> convertedOp =
905- lowerContractionOpWithEncoding (rewriter, op, operands, *converter);
745+ IREE::Codegen::lowerContractionOpWithEncoding (
746+ rewriter, op, operands, converter->getTransposeNarrowN (),
747+ getEncodingInfoWrapper);
906748 if (failed (convertedOp)) {
907749 return failure ();
908750 }
0 commit comments