Skip to content

Commit a964bb4

Browse files
[WIP] Generic to named Conv op support
Signed-off-by: Abhishek Varma <[email protected]>
1 parent e9972de commit a964bb4

File tree

1 file changed

+158
-0
lines changed

1 file changed

+158
-0
lines changed

mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,159 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
237237
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
238238
}
239239

240+
static bool matchingIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
241+
ArrayRef<mlir::utils::IteratorType> expectedIteratorTypes) {
242+
if (iteratorTypes.size() != expectedIteratorTypes.size()) return false;
243+
for (auto [orig, expected] : llvm::zip_equal(iteratorTypes, expectedIteratorTypes)) {
244+
if (orig != expected) return false;
245+
}
246+
return true;
247+
}
248+
249+
static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps,
250+
uint32_t mapIndex, uint32_t dimIndex) {
251+
auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
252+
// uint32_t nResults = affineMap.getNumResults();
253+
// llvm::outs()<<affineMap<<"\n";
254+
// llvm::outs()<<"Total result = "<<affineMap.getNumResults()<<"\n";
255+
// llvm::outs()<<"N = "<<nResults<<", dimIndex = "<<dimIndex<<"\n";
256+
// llvm::outs().flush();
257+
return affineMap.getResult(dimIndex);
258+
}
259+
260+
static std::string inferBasedOnRank2ConvIteratorTypes(GenericOp genericOp) {
261+
SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
262+
SmallVector<utils::IteratorType> expectedIteratorTypes = {
263+
utils::IteratorType::parallel, utils::IteratorType::reduction
264+
};
265+
266+
if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes))
267+
return "linalg.conv_1d";
268+
return "";
269+
}
270+
271+
static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
272+
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
273+
if (indexingMaps.size() != 3) return "";
274+
SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
275+
// Conv 1D
276+
// depthwise_conv_1d_ncw_cw
277+
// depthwise_conv_1d_nwc_wc
278+
// ["parallel", "parallel", "parallel", "reduction"]
279+
SmallVector<utils::IteratorType> expectedIteratorTypes = {
280+
utils::IteratorType::parallel, utils::IteratorType::parallel,
281+
utils::IteratorType::parallel, utils::IteratorType::reduction
282+
};
283+
// inputMapIndex = 0, filterMapIndex = 1, outputMapIndex = 2;
284+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
285+
if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes)) {
286+
if (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1))
287+
return "linalg.depthwise_conv_1d_ncw_cw";
288+
else if (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 2))
289+
return "linalg.depthwise_conv_1d_nwc_wc";
290+
}
291+
292+
//
293+
expectedIteratorTypes[2] = utils::IteratorType::reduction;
294+
if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes)) {
295+
return "linalg.conv_2d";
296+
}
297+
return "";
298+
}
299+
300+
static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) {
301+
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
302+
if (indexingMaps.size() != 3) return "";
303+
SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
304+
// "parallel", "parallel", "parallel", "reduction", "reduction"]
305+
SmallVector<utils::IteratorType> expectedIteratorTypes = {
306+
utils::IteratorType::parallel, utils::IteratorType::parallel,
307+
utils::IteratorType::parallel, utils::IteratorType::parallel,
308+
utils::IteratorType::reduction
309+
};
310+
if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes))
311+
return "linalg.depthwise_conv_1d_nwc_wcm";
312+
313+
expectedIteratorTypes[3] = utils::IteratorType::reduction;
314+
// inputMapIndex = 0, filterMapIndex = 1, outputMapIndex = 2;
315+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
316+
if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes)) {
317+
if (getAffineMapDim(indexingMaps, fIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2))
318+
return "linalg.conv_1d_nwc_wcf";
319+
else if (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1))
320+
return "linalg.conv_1d_ncw_fcw";
321+
}
322+
return "";
323+
}
324+
325+
static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
326+
SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
327+
SmallVector<utils::IteratorType> expectedIteratorTypes = {
328+
utils::IteratorType::parallel, utils::IteratorType::reduction
329+
};
330+
if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes))
331+
return "linalg.conv_1d";
332+
return "";
333+
}
334+
335+
static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
336+
SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
337+
SmallVector<utils::IteratorType> expectedIteratorTypes = {
338+
utils::IteratorType::parallel, utils::IteratorType::reduction
339+
};
340+
if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes))
341+
return "linalg.conv_1d";
342+
return "";
343+
}
344+
345+
static std::string inferConvolutionKind(GenericOp genericOp) {
346+
SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
347+
unsigned totalIterators = iteratorTypes.size();
348+
switch(totalIterators) {
349+
case 2:
350+
return inferBasedOnRank2ConvIteratorTypes(genericOp);
351+
case 4:
352+
return inferBasedOnRank4ConvIteratorTypes(genericOp);
353+
case 5:
354+
return inferBasedOnRank5ConvIteratorTypes(genericOp);
355+
case 7:
356+
return inferBasedOnRank7ConvIteratorTypes(genericOp);
357+
case 8:
358+
return inferBasedOnRank8ConvIteratorTypes(genericOp);
359+
}
360+
return "";
361+
}
362+
363+
// Converts linalg.generic to named linalg.*conv* where possible.
364+
static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
365+
GenericOp genericOp) {
366+
std::string convKind = inferConvolutionKind(genericOp);
367+
if (convKind == "") return failure();
368+
SmallVector<Value> inputs = genericOp.getDpsInputs();
369+
ValueRange outputs = genericOp.getDpsInits();
370+
SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
371+
SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
372+
? TypeRange(ValueRange(outputs))
373+
: TypeRange{};
374+
LinalgOp namedOp;
375+
if (convKind == "linalg.conv_1d") {
376+
namedOp = rewriter.replaceOpWithNewOp<linalg::Conv1DOp>(genericOp, resultTypes, inputs, outputs);
377+
} else if (convKind == "linalg.conv_1d_nwc_wcf") {
378+
namedOp = rewriter.replaceOpWithNewOp<linalg::Conv1DNwcWcfOp>(genericOp, resultTypes, inputs, outputs);
379+
} else if (convKind == "linalg.conv_1d_ncw_fcw") {
380+
namedOp = rewriter.replaceOpWithNewOp<linalg::Conv1DNcwFcwOp>(genericOp, resultTypes, inputs, outputs);
381+
} else if (convKind == "linalg.depthwise_conv_1d_ncw_cw") {
382+
namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv1DNcwCwOp>(genericOp, resultTypes, inputs, outputs);
383+
} else if (convKind == "linalg.depthwise_conv_1d_nwc_wc") {
384+
namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv1DNwcWcOp>(genericOp, resultTypes, inputs, outputs);
385+
} else if (convKind == "linalg.conv_2d") {
386+
namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DOp>(genericOp, resultTypes, inputs, outputs);
387+
}
388+
return namedOp;
389+
390+
return failure();
391+
}
392+
240393
} // namespace
241394

242395
//===----------------------------------------------------------------------===//
@@ -316,6 +469,11 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
316469
if (isaContractionOpInterface(genericOp)) {
317470
return specializeLinalgContractions(rewriter, genericOp);
318471
}
472+
473+
// Convolution - e.g. *conv*
474+
if (isaConvolutionOpInterface(genericOp)) {
475+
return specializeLinalgConvolutions(rewriter, genericOp);
476+
}
319477
return failure();
320478
}
321479

0 commit comments

Comments
 (0)