@@ -50,66 +50,40 @@ bool linalg::detail::canOpOperandsBeDroppedImpl(
5050 return inversePermutation (concatAffineMaps (indexingMaps)) != AffineMap ();
5151}
5252
53- // Returns true if all loops of the linalgOp are parallel
54- static bool isAllParallel (LinalgOp op) {
55- return op.getNumParallelLoops () == op.getNumLoops ();
56- }
57-
58- // Returns true if and only if linalgOp takes one input and one init.
59- static bool isSingleInputOutput (LinalgOp op) {
60- return op.getNumDpsInputs () == 1 && op.getNumDpsInits () == 1 ;
61- }
62- // Returns true if genericOp body is just a yieldOp that yields
63- // input operand as result.
64- static bool isSingleYieldOp (GenericOp op) {
65- if (op.getNumDpsInputs () != 1 || op.getNumDpsInits () != 1 )
66- return false ;
67-
68- Block *body = op.getBody ();
69- if (body->getOperations ().size () != 1 )
70- return false ;
71-
72- auto yieldOp = dyn_cast<linalg::YieldOp>(body->back ());
73- if (!yieldOp || yieldOp.getNumOperands () != 1 ||
74- yieldOp->getOperand (0 ) != body->getArgument (0 ))
75- return false ;
76- return true ;
77- }
78-
7953// ===----------------------------------------------------------------------===//
8054// CopyOpInterface implementation
8155// ===----------------------------------------------------------------------===//
8256
83- bool linalg::isaCopyOpInterface (LinalgOp linalgOp ) {
84- // Structural and operands
85- if (!isAllParallel (linalgOp ) || !isSingleInputOutput (linalgOp ))
57+ bool linalg::isaCopyOpInterface (LinalgOp op ) {
58+ // Check all loops are parallel and linalgOp is single input and output.
59+ if (!op. isAllParallelLoops ( ) || !op. isSingleInputOutput ())
8660 return false ;
8761
88- auto mapRange = linalgOp .getIndexingMapsArray ();
62+ auto mapRange = op .getIndexingMapsArray ();
8963 if (mapRange.size () != 2 || !mapRange.front ().isIdentity () ||
9064 !mapRange.back ().isIdentity ()) {
9165 return false ;
9266 }
9367 // Region.
94- return llvm::hasSingleElement (linalgOp .getBlock ()->getOperations ());
68+ return llvm::hasSingleElement (op .getBlock ()->getOperations ());
9569}
9670
9771// ===----------------------------------------------------------------------===//
9872// FillOpInterface implementation
9973// ===----------------------------------------------------------------------===//
100- std::optional<Value> linalg::isaFillOpInterface (GenericOp genericOp ) {
74+ std::optional<Value> linalg::isaFillOpInterface (GenericOp op ) {
10175 // Structural.
102- if (!isAllParallel (genericOp ) || !isSingleInputOutput (genericOp ) ||
103- !isSingleYieldOp (genericOp ))
76+ if (!op. isAllParallelLoops ( ) || !op. isSingleInputOutput () ||
77+ !op. isSingleYieldOp ())
10478 return std::nullopt ;
10579
10680 // Input should be referenced and init should not.
107- if (!genericOp .payloadUsesValueFromOperand (genericOp .getDpsInputOperand (0 )) ||
108- genericOp .payloadUsesValueFromOperand (genericOp .getDpsInitOperand (0 )))
81+ if (!op .payloadUsesValueFromOperand (op .getDpsInputOperand (0 )) ||
82+ op .payloadUsesValueFromOperand (op .getDpsInitOperand (0 )))
10983 return std::nullopt ;
11084
111- OpOperand *value = genericOp .getDpsInputOperand (0 );
112- if (!genericOp .isScalar (value))
85+ OpOperand *value = op .getDpsInputOperand (0 );
86+ if (!op .isScalar (value))
11387 return std::nullopt ;
11488 return value->get ();
11589}
@@ -118,27 +92,30 @@ std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
11892// BroadcastOpInterface implementation
11993// ===----------------------------------------------------------------------===//
12094std::optional<SmallVector<int64_t >>
121- linalg::isaBroadcastOpInterface (GenericOp genericOp ) {
95+ linalg::isaBroadcastOpInterface (GenericOp op ) {
12296 // Structural.
123- if (!isAllParallel (genericOp ) || !isSingleInputOutput (genericOp ) ||
124- !isSingleYieldOp (genericOp ))
97+ if (!op. isAllParallelLoops ( ) || !op. isSingleInputOutput () ||
98+ !op. isSingleYieldOp ())
12599 return std::nullopt ;
126100
127- auto t0 = genericOp .getDpsInputOperand (0 )->get ().getType ();
128- auto t1 = genericOp .getDpsInitOperand (0 )->get ().getType ();
129- if (!isa<MemRefType, RankedTensorType>(t0 ) ||
130- !isa<MemRefType, RankedTensorType>(t1 ))
101+ auto srcTy = op .getDpsInputOperand (0 )->get ().getType ();
102+ auto dstTy = op .getDpsInitOperand (0 )->get ().getType ();
103+ if (!isa<MemRefType, RankedTensorType>(srcTy ) ||
104+ !isa<MemRefType, RankedTensorType>(dstTy ))
131105 return std::nullopt ;
132106
133- // Check output is identity map. Injective function could also be
134- // a permutation of indices and expressible in linalg.generic but
135- // is not expressible for named broadcast op.
136- auto dstMap = genericOp .getIndexingMapsArray ()[1 ];
107+ // Check output is identity map. Broadcast could additionally be
108+ // employing permutation of indices and that would be expressible
109+ // in linalg.generic but is not expressible for named broadcast op.
110+ auto dstMap = op .getIndexingMapsArray ()[1 ];
137111 if (!dstMap.isIdentity ())
138112 return std::nullopt ;
139113
140114 SmallVector<int64_t > position;
141- auto srcMap = genericOp.getIndexingMapsArray ()[0 ];
115+ auto srcMap = op.getIndexingMapsArray ()[0 ];
116+
117+ if (srcMap.getResults ().size () >= dstMap.getResults ().size ())
118+ return std::nullopt ;
142119
143120 // Check input map is monotonically increasing DimIds.
144121 for (unsigned i = 0 ; i < srcMap.getNumResults (); ++i) {
@@ -153,6 +130,7 @@ linalg::isaBroadcastOpInterface(GenericOp genericOp) {
153130
154131 SmallVector<int64_t > broadcastedDims;
155132 auto numDims = srcMap.getNumDims ();
133+ // This is quadratic but number of items is generally small.
156134 for (auto dim : llvm::seq<int64_t >(0 , numDims)) {
157135 if (!llvm::is_contained (position, dim))
158136 broadcastedDims.push_back (dim);
@@ -164,86 +142,92 @@ linalg::isaBroadcastOpInterface(GenericOp genericOp) {
164142// TranposeOpInterface implementation
165143// ===----------------------------------------------------------------------===//
166144std::optional<SmallVector<int64_t >>
167- linalg::isaTransposeOpInterface (GenericOp genericOp) {
168- // Structural.
169- if (!isAllParallel (genericOp) || !isSingleInputOutput (genericOp) ||
170- !isSingleYieldOp (genericOp))
145+ linalg::isaTransposeOpInterface (GenericOp op) {
146+ // To specialize as a transpose op, the genericOp must be
147+ // all parallel loops, single input, single output, and its body
148+ // should be just a yield op, yielding input as output as is (no compute).
149+ if (!op.isAllParallelLoops () || !op.isSingleInputOutput () ||
150+ !op.isSingleYieldOp ())
171151 return std::nullopt ;
172152
173- // mapping checks.
174- auto mapRange = genericOp.getIndexingMapsArray ();
175- if (mapRange.size () != 2 || !mapRange.back ().isIdentity () ||
176- !mapRange.front ().isPermutation ())
153+ auto mapRange = op.getIndexingMapsArray ();
154+ if (mapRange.size () != 2 )
177155 return std::nullopt ;
178156
179- SmallVector<int64_t > permutation;
180- auto map = mapRange.front ();
181- for (unsigned i = 0 ; i < map.getNumResults (); ++i) {
182- auto expr = llvm::cast<AffineDimExpr>(map.getResults ()[i]);
183- permutation.push_back (expr.getPosition ());
157+ auto mapOfInput = mapRange.front ();
158+ auto mapOfResult = mapRange.back ();
159+
160+ // linalg.transpose permutes the dimensions of input using this
161+ // rule: dim(result, i) = dim(input, permutation[i])
162+ if (!mapOfResult.isIdentity () || !mapOfInput.isPermutation ())
163+ return std::nullopt ;
164+
165+ SmallVector<int64_t > permutation (mapOfInput.getNumDims ());
166+ for (unsigned i = 0 ; i < mapOfInput.getNumDims (); ++i) {
167+ auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults ()[i]);
168+ permutation[expr.getPosition ()] = i;
184169 }
185170 return permutation;
186171}
187172
188173// ===----------------------------------------------------------------------===//
189174// Elementwise Single Unary/Binary-OpInterface implementation
190175// ===----------------------------------------------------------------------===//
191- static bool
192- isaElemwiseSingleUnaryOrBinaryOpInterface (linalg::GenericOp genericOp,
193- unsigned arity) {
176+ static bool isaElemwiseSingleUnaryOrBinaryOpInterface (linalg::GenericOp op,
177+ unsigned arity) {
194178 // Check all loops are parallel.
195- if (!isAllParallel (genericOp ) || genericOp .getNumLoops () < 1 )
179+ if (!op. isAllParallelLoops ( ) || op .getNumLoops () < 1 )
196180 return false ;
197181
198182 // Check there are arity-inputs, 1-output and all are identity-maps.
199- if (genericOp .getNumDpsInputs () != arity || genericOp .getNumDpsInits () != 1 ||
200- !llvm::all_of (genericOp .getIndexingMapsArray (),
183+ if (op .getNumDpsInputs () != arity || op .getNumDpsInits () != 1 ||
184+ !llvm::all_of (op .getIndexingMapsArray (),
201185 [](AffineMap map) { return map.isIdentity (); }))
202186 return false ;
203187
204188 // Init should not be referenced for elementwise operations.
205- if (genericOp .payloadUsesValueFromOperand (genericOp .getDpsInitOperand (0 )))
189+ if (op .payloadUsesValueFromOperand (op .getDpsInitOperand (0 )))
206190 return false ;
207191
208192 // A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such
209193 // as resulting from producer-consumer fusion. Here, we restrict to two ops in
210194 // the body, where the first is the elementwise single op and the second a
211195 // yield.
212- Block *body = genericOp .getBody ();
196+ Block *body = op .getBody ();
213197 if (body->getOperations ().size () != 2 )
214198 return false ;
215199
216- Operation *op = &body->front ();
217- if (op ->getNumOperands () != arity || op ->getNumResults () != 1 )
200+ Operation *oper = &body->front ();
201+ if (oper ->getNumOperands () != arity || oper ->getNumResults () != 1 )
218202 return false ;
219203
220204 auto yieldOp = dyn_cast<linalg::YieldOp>(body->back ());
221205 if (!yieldOp || yieldOp.getNumOperands () != 1 ||
222- yieldOp->getOperand (0 ).getDefiningOp () != op )
206+ yieldOp->getOperand (0 ).getDefiningOp () != oper )
223207 return false ;
224208 return true ;
225209}
226210
227- bool linalg::isaElemwiseSingleUnaryOpInterface (linalg::GenericOp genericOp ) {
211+ bool linalg::isaElemwiseSingleUnaryOpInterface (linalg::GenericOp op ) {
228212 // All basic elemwise checks.
229- if (!isaElemwiseSingleUnaryOrBinaryOpInterface (genericOp , 1 ))
213+ if (!isaElemwiseSingleUnaryOrBinaryOpInterface (op , 1 ))
230214 return false ;
231215
232216 // Check input is actully used.
233- if (!genericOp .payloadUsesValueFromOperand (genericOp .getDpsInputOperand (0 )))
217+ if (!op .payloadUsesValueFromOperand (op .getDpsInputOperand (0 )))
234218 return false ;
235219 return true ;
236220}
237221
238- bool linalg::isaElemwiseSingleBinaryOpInterface (linalg::GenericOp genericOp ) {
239- if (!isaElemwiseSingleUnaryOrBinaryOpInterface (genericOp , 2 ))
222+ bool linalg::isaElemwiseSingleBinaryOpInterface (linalg::GenericOp op ) {
223+ if (!isaElemwiseSingleUnaryOrBinaryOpInterface (op , 2 ))
240224 return false ;
241225
242226 // Check both inputs are used (elementwise).
243- OpOperand *inputOpOperand0 = genericOp .getDpsInputOperand (0 );
244- OpOperand *inputOpOperand1 = genericOp .getDpsInputOperand (1 );
245- if (!genericOp .payloadUsesValueFromOperand (inputOpOperand0) ||
246- !genericOp .payloadUsesValueFromOperand (inputOpOperand1))
227+ OpOperand *inputOpOperand0 = op .getDpsInputOperand (0 );
228+ OpOperand *inputOpOperand1 = op .getDpsInputOperand (1 );
229+ if (!op .payloadUsesValueFromOperand (inputOpOperand0) ||
230+ !op .payloadUsesValueFromOperand (inputOpOperand1))
247231 return false ;
248232 return true ;
249233}
0 commit comments