2222#include " llvm/ADT/SmallBitVector.h"
2323#include " llvm/ADT/SmallVector.h"
2424#include < algorithm>
25+ #include < numeric>
2526
2627using namespace mlir ;
2728using namespace mlir ::linalg;
@@ -49,18 +50,41 @@ bool linalg::detail::canOpOperandsBeDroppedImpl(
4950 return inversePermutation (concatAffineMaps (indexingMaps)) != AffineMap ();
5051}
5152
53+ // Returns true if all loops of the linalgOp are parallel
54+ static bool isAllParallel (LinalgOp op) {
55+ return op.getNumParallelLoops () == op.getNumLoops ();
56+ }
57+
58+ // Returns true if and only if linalgOp takes one input and one init.
59+ static bool isSingleInputOutput (LinalgOp op) {
60+ return op.getNumDpsInputs () == 1 && op.getNumDpsInits () == 1 ;
61+ }
62+ // Returns true if genericOp body is just a yieldOp that yields
63+ // input operand as result.
64+ static bool isSingleYieldOp (GenericOp op) {
65+ if (op.getNumDpsInputs () != 1 || op.getNumDpsInits () != 1 )
66+ return false ;
67+
68+ Block *body = op.getBody ();
69+ if (body->getOperations ().size () != 1 )
70+ return false ;
71+
72+ auto yieldOp = dyn_cast<linalg::YieldOp>(body->back ());
73+ if (!yieldOp || yieldOp.getNumOperands () != 1 ||
74+ yieldOp->getOperand (0 ) != body->getArgument (0 ))
75+ return false ;
76+ return true ;
77+ }
78+
5279// ===----------------------------------------------------------------------===//
5380// CopyOpInterface implementation
5481// ===----------------------------------------------------------------------===//
5582
5683bool linalg::isaCopyOpInterface (LinalgOp linalgOp) {
57- // Structural.
58- if (linalgOp. getNumParallelLoops () != linalgOp. getNumLoops ( ))
84+ // Structural and operands
85+ if (! isAllParallel (linalgOp) || ! isSingleInputOutput (linalgOp ))
5986 return false ;
6087
61- // Operands and maps.
62- if (linalgOp.getNumDpsInputs () != 1 || linalgOp.getNumDpsInits () != 1 )
63- return false ;
6488 auto mapRange = linalgOp.getIndexingMapsArray ();
6589 if (mapRange.size () != 2 || !mapRange.front ().isIdentity () ||
6690 !mapRange.back ().isIdentity ()) {
@@ -75,8 +99,8 @@ bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
7599// ===----------------------------------------------------------------------===//
76100std::optional<Value> linalg::isaFillOpInterface (GenericOp genericOp) {
77101 // Structural.
78- if (genericOp. getNumParallelLoops () != genericOp. getNumLoops ( ) ||
79- genericOp. getNumDpsInputs () != 1 || genericOp. getNumDpsInits () != 1 )
102+ if (! isAllParallel (genericOp) || ! isSingleInputOutput (genericOp ) ||
103+ ! isSingleYieldOp (genericOp) )
80104 return std::nullopt ;
81105
82106 // Input should be referenced and init should not.
@@ -87,16 +111,78 @@ std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
87111 OpOperand *value = genericOp.getDpsInputOperand (0 );
88112 if (!genericOp.isScalar (value))
89113 return std::nullopt ;
114+ return value->get ();
115+ }
90116
91- Block *body = genericOp.getBody ();
92- if (body->getOperations ().size () != 1 )
117+ // ===----------------------------------------------------------------------===//
118+ // BroadcastOpInterface implementation
119+ // ===----------------------------------------------------------------------===//
120+ std::optional<SmallVector<int64_t >>
121+ linalg::isaBroadcastOpInterface (GenericOp genericOp) {
122+ // Structural.
123+ if (!isAllParallel (genericOp) || !isSingleInputOutput (genericOp) ||
124+ !isSingleYieldOp (genericOp))
93125 return std::nullopt ;
94126
95- auto yieldOp = dyn_cast<linalg::YieldOp>(body->back ());
96- if (!yieldOp || yieldOp.getNumOperands () != 1 ||
97- yieldOp->getOperand (0 ) != body->getArgument (0 ))
127+ auto t0 = genericOp.getDpsInputOperand (0 )->get ().getType ();
128+ auto t1 = genericOp.getDpsInitOperand (0 )->get ().getType ();
129+ if (!isa<MemRefType, RankedTensorType>(t0) ||
130+ !isa<MemRefType, RankedTensorType>(t1))
98131 return std::nullopt ;
99- return value->get ();
132+
133+ // Check output is identity map. Injective function could also be
134+ // a permutation of indices and expressible in linalg.generic but
135+ // is not expressible for named broadcast op.
136+ auto dstMap = genericOp.getIndexingMapsArray ()[1 ];
137+ if (!dstMap.isIdentity ())
138+ return std::nullopt ;
139+
140+ SmallVector<int64_t > position;
141+ auto srcMap = genericOp.getIndexingMapsArray ()[0 ];
142+
143+ // Check input map is monotonically increasing DimIds.
144+ for (unsigned i = 0 ; i < srcMap.getNumResults (); ++i) {
145+ auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults ()[i]);
146+ if (!expr)
147+ return std::nullopt ;
148+ int64_t pos = expr.getPosition ();
149+ if (i > 0 && pos <= position[i - 1 ])
150+ return std::nullopt ;
151+ position.push_back (expr.getPosition ());
152+ }
153+
154+ SmallVector<int64_t > broadcastedDims;
155+ auto numDims = srcMap.getNumDims ();
156+ for (auto dim : llvm::seq<int64_t >(0 , numDims)) {
157+ if (!llvm::is_contained (position, dim))
158+ broadcastedDims.push_back (dim);
159+ }
160+ return broadcastedDims;
161+ }
162+
163+ // ===----------------------------------------------------------------------===//
164+ // TranposeOpInterface implementation
165+ // ===----------------------------------------------------------------------===//
166+ std::optional<SmallVector<int64_t >>
167+ linalg::isaTransposeOpInterface (GenericOp genericOp) {
168+ // Structural.
169+ if (!isAllParallel (genericOp) || !isSingleInputOutput (genericOp) ||
170+ !isSingleYieldOp (genericOp))
171+ return std::nullopt ;
172+
173+ // mapping checks.
174+ auto mapRange = genericOp.getIndexingMapsArray ();
175+ if (mapRange.size () != 2 || !mapRange.back ().isIdentity () ||
176+ !mapRange.front ().isPermutation ())
177+ return std::nullopt ;
178+
179+ SmallVector<int64_t > permutation;
180+ auto map = mapRange.front ();
181+ for (unsigned i = 0 ; i < map.getNumResults (); ++i) {
182+ auto expr = llvm::cast<AffineDimExpr>(map.getResults ()[i]);
183+ permutation.push_back (expr.getPosition ());
184+ }
185+ return permutation;
100186}
101187
102188// ===----------------------------------------------------------------------===//
@@ -106,8 +192,7 @@ static bool
106192isaElemwiseSingleUnaryOrBinaryOpInterface (linalg::GenericOp genericOp,
107193 unsigned arity) {
108194 // Check all loops are parallel.
109- if (genericOp.getNumParallelLoops () != genericOp.getNumLoops () ||
110- genericOp.getNumLoops () < 1 )
195+ if (!isAllParallel (genericOp) || genericOp.getNumLoops () < 1 )
111196 return false ;
112197
113198 // Check there are arity-inputs, 1-output and all are identity-maps.
0 commit comments