1414#define LINALG_IR_LINALGINTERFACES
1515
1616include "mlir/Interfaces/DestinationStyleOpInterface.td"
17+ include "mlir/Interfaces/IndexingMapOpInterface.td"
1718include "mlir/IR/OpBase.td"
1819
1920// The 'LinalgContractionOpInterface' provides access to the
@@ -222,59 +223,11 @@ def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
222223 ];
223224}
224225
225- def IndexingMapOpInterface : OpInterface<"IndexingMapOpInterface"> {
226- let description = [{
227- Interface for operations that connect an iteration domain to operands via
228- affine maps. Provides methods to access indexing maps between iteration
229- domain and operand index spaces.
230- }];
231- let cppNamespace = "::mlir::linalg";
232- let methods = [
233- InterfaceMethod<
234- /*desc=*/[{
235- Return the indexing maps attribute within the current operation.
236- }],
237- /*retTy=*/"ArrayAttr",
238- /*methodName=*/"getIndexingMaps"
239- >,
240- InterfaceMethod<
241- /*desc=*/[{
242- Return the indexing maps within the current operation.
243- }],
244- /*retTy=*/"SmallVector<AffineMap>",
245- /*methodName=*/"getIndexingMapsArray",
246- /*args=*/(ins),
247- /*methodBody=*/"",
248- /*defaultImplementation=*/[{
249- auto range = $_op.getIndexingMaps()
250- .template getAsValueRange<AffineMapAttr>();
251- return {range.begin(), range.end()};
252- }]
253- >,
254- InterfaceMethod<
255- /*desc=*/[{
256- Return the input or output indexing map for `opOperand`.
257- }],
258- /*retTy=*/"AffineMap",
259- /*methodName=*/"getMatchingIndexingMap",
260- /*args=*/(ins "OpOperand*":$opOperand),
261- /*methodBody=*/"",
262- /*defaultImplementation=*/[{
263- assert(opOperand->getOwner() == this->getOperation());
264- auto indexingMaps =
265- $_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
266- return *(indexingMaps.begin() + opOperand->getOperandNumber());
267- }]
268- >,
269- ];
270- }
271-
272226// The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface.
273227def LinalgStructuredInterface
274- : OpInterface<"LinalgOp", [
275- DestinationStyleOpInterface,
276- IndexingMapOpInterface
277- ]> {
228+ : OpInterface<"LinalgOp",
229+ [DestinationStyleOpInterface, IndexingMapOpInterface]
230+ > {
278231 let cppNamespace = "::mlir::linalg";
279232 let methods = [
280233 //===------------------------------------------------------------------===//
@@ -464,30 +417,6 @@ def LinalgStructuredInterface
464417 return getBlock()->getArguments().take_back($_op.getNumDpsInits());
465418 }]
466419 >,
467- InterfaceMethod<
468- /*desc=*/[{
469- Return the `opOperand` shape or an empty vector for scalars or vectors
470- not wrapped within a tensor or a memref.
471- }],
472- /*retTy=*/"ArrayRef<int64_t>",
473- /*methodName=*/"getShape",
474- /*args=*/(ins "OpOperand*":$opOperand),
475- /*methodBody=*/"",
476- /*defaultImplementation=*/[{
477- assert(opOperand->getOwner() == this->getOperation());
478- Type t = opOperand->get().getType();
479- // A VectorType is an elemental type, do not consider its rank for the operand.
480- if (isa<VectorType>(t))
481- return {};
482- if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
483- // Failsafe.
484- assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
485- "expected a ranked tensor or memref in LinalgInterface::getRank");
486- return shapedType.getShape();
487- }
488- return {};
489- }]
490- >,
491420 InterfaceMethod<
492421 /*desc=*/[{
493422 Return the block argument for an `opOperand`.
@@ -620,7 +549,12 @@ def LinalgStructuredInterface
620549 /*args=*/(ins),
621550 /*methodBody=*/"",
622551 /*defaultImplementation=*/[{
623- return llvm::any_of(getStaticShape(), ShapedType::isDynamic);
552+ for (OpOperand &opOperand : this->getOperation()->getOpOperands()) {
553+ if (auto shapedType = dyn_cast<ShapedType>(opOperand.get().getType())) {
554+ if (ShapedType::isDynamicShape(shapedType.getShape())) return true;
555+ }
556+ }
557+ return false;
624558 }]
625559 >,
626560 InterfaceMethod<
@@ -738,53 +672,6 @@ def LinalgStructuredInterface
738672 //===------------------------------------------------------------------===//
739673 // Linalg generalization hooks.
740674 //===------------------------------------------------------------------===//
741- InterfaceMethod<
742- /*desc=*/[{
743- Hook to provide a custom AffineMap used to compute all the operand
744- subshapes given loop bounds. This is used to answer the question: "given
745- an iteration space over the codomain, what are the subshapes of the
746- operands involved in the computation".
747- The default behavior is to just concatenate all the indexing maps.
748- A custom AffineMap allows providing a map that can be used to
749- compute subshapes even in cases where the concatenation of indexing maps
750- (i.e. the data traversal order) is not a simple permutation of the loop
751- traversal order. It is then possible to define ops with skewed data
752- traversal order for which we can still easily compute hyperrectangular
753- loop bounds and subviews.
754- }],
755- /*retTy=*/"AffineMap",
756- /*methodName=*/"getLoopsToShapesMap",
757- /*args=*/(ins),
758- /*methodBody=*/"",
759- /*defaultImplementation=*/[{
760- auto maps = $_op.getIndexingMapsArray();
761- return concatAffineMaps(maps, $_op.getContext());
762- }]
763- >,
764- InterfaceMethod<
765- /*desc=*/[{
766- Hook to provide a custom AffineMap used to construct the
767- hyperrectangular loop iteration space given all the operand subshapes.
768- This is used to answer the question:
769- "Given a list of operand ranges, what is the subportion of the iteration
770- space involved in the computation".
771- This is the inverse problem of `getLoopsToShapesMap`.
772- Return the empty AffineMap when such an AffineMap cannot be constructed.
773- The default behavior is based on a very simple inference procedure that
774- only works with permutation affine maps.
775- A more advanced Tensor-Comprehension like inference is possible but has
776- proven to be ambiguous in unfavorable case.
777- A safer and more robust alternative is to allow each op to define
778- its own AffineMap.
779- }],
780- /*retTy=*/"AffineMap",
781- /*methodName=*/"getShapesToLoopsMap",
782- /*args=*/(ins),
783- /*methodBody=*/"",
784- /*defaultImplementation=*/[{
785- return inversePermutation(getLoopsToShapesMap());
786- }]
787- >,
788675 InterfaceMethod<
789676 /*desc=*/[{
790677 Checks if the given operands can be dropped, and the remaining
@@ -798,39 +685,30 @@ def LinalgStructuredInterface
798685 return detail::canOpOperandsBeDroppedImpl($_op, droppedOperands);
799686 }]
800687 >,
688+ //===------------------------------------------------------------------===//
689+ // IndexingMapOpInterface interface methods implementation.
690+ //===------------------------------------------------------------------===//
801691 InterfaceMethod<
802692 /*desc=*/[{
803- Like `getShape`, but only returns statically-known information, without
804- generating any new IR. For each shape dimension, returns >=0 if that
805- dimension is statically known, or ShapedType::kDynamic otherwise.
806- }],
807- /*retTy=*/"SmallVector<int64_t>",
808- /*methodName=*/"getStaticShape",
809- /*args=*/(ins),
810- /*methodBody=*/"",
811- /*defaultImplementation=*/[{
812- SmallVector<int64_t> res;
813- for (OpOperand &opOperand : this->getOperation()->getOpOperands())
814- llvm::append_range(res, getShape(&opOperand));
815- return res;
816- }]
817- >,
818- InterfaceMethod<
819- /*desc=*/[{
820- Returns the statically-known loop ranges. Composes
821- `getShapesToLoopsMap()` with the result of `getStaticShape`.
822- Returns ShapedType::kDynamic for non-statically-known loop ranges.
823- This is expected to be called by a valid Linalg op
693+ Return the `opOperand` shape or an empty vector for scalars or vectors
694+ not wrapped within a tensor or a memref.
824695 }],
825- /*retTy=*/"SmallVector <int64_t, 4 >",
826- /*methodName=*/"getStaticLoopRanges ",
827- /*args=*/(ins),
696+ /*retTy=*/"ArrayRef <int64_t>",
697+ /*methodName=*/"getShape ",
698+ /*args=*/(ins "OpOperand*":$opOperand ),
828699 /*methodBody=*/"",
829700 /*defaultImplementation=*/[{
830- SmallVector<int64_t> viewSizes = getStaticShape();
831- AffineMap invertedMap = getShapesToLoopsMap();
832- assert(invertedMap && "expected a valid Linalg op to call the method");
833- return invertedMap.compose(viewSizes);
701+ Type t = opOperand->get().getType();
702+ // A VectorType is an elemental type, do not consider its rank for the operand.
703+ if (isa<VectorType>(t))
704+ return {};
705+ if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
706+ // Failsafe.
707+ assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
708+ "expected a ranked tensor or memref in LinalgInterface::getRank");
709+ return shapedType.getShape();
710+ }
711+ return {};
834712 }]
835713 >,
836714 //===------------------------------------------------------------------===//
0 commit comments