Skip to content

Commit 0555f42

Browse files
[Linalg] Add *Conv1D* matchers (#168050)
-- This commit is the second in the series of adding matchers for linalg.*conv*/*pool*. Refer: #163724 -- In this commit all variants of Conv1D convolution ops have been added. -- For sake of completion for a specific infra required for those ops which don't require dilations/strides information during their creation, this commit also includes a basic Conv2D and Conv3D op as part of the lit test. Signed-off-by: Abhishek Varma <[email protected]>
1 parent 0602678 commit 0555f42

File tree

3 files changed

+418
-7
lines changed

3 files changed

+418
-7
lines changed

mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -245,14 +245,22 @@ specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
245245
ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
246246
SmallVector<Value> inputs = genericOp.getDpsInputs();
247247
ValueRange outputs = genericOp.getDpsInits();
248-
SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
249248
SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
250249
? TypeRange(ValueRange(outputs))
251250
: TypeRange{};
252-
Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
253-
Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
254-
LinalgOp namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
255-
genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
251+
LinalgOp namedOp;
252+
// Ops with no dilations and no strides.
253+
if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
254+
std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
255+
std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
256+
namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes,
257+
inputs, outputs);
258+
} else {
259+
Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
260+
Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
261+
namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
262+
genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
263+
}
256264
return namedOp;
257265
}
258266

@@ -265,9 +273,19 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
265273
return specializeToConvOp<ConvOpTy>(rewriter, genericOp, dilations, \
266274
strides); \
267275
// -----------------------------
276+
// Convolution ops.
277+
// -----------------------------
278+
CONV_OP_SPECIALIZER(linalg::Conv1DOp);
279+
CONV_OP_SPECIALIZER(linalg::Conv1DNwcWcfOp);
280+
CONV_OP_SPECIALIZER(linalg::Conv1DNcwFcwOp);
281+
CONV_OP_SPECIALIZER(linalg::Conv2DOp);
282+
CONV_OP_SPECIALIZER(linalg::Conv3DOp);
283+
// -----------------------------
268284
// Depthwise Convolution ops.
269285
// -----------------------------
286+
CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNcwCwOp);
270287
CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcOp);
288+
CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcmOp);
271289
CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNchwChwOp);
272290
CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcmOp);
273291
// -----------------------------

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 299 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim,
390390
unsigned inputMapIdx = 0, filterMapIdx = 1,
391391
outputMapIdx = indexingMaps.size() - 1;
392392
AffineExpr inpExpr = getAffineMapDim(indexingMaps, inputMapIdx, iDim);
393-
auto addExpr = dyn_cast<AffineBinaryOpExpr>(inpExpr);
393+
auto addExpr = dyn_cast_or_null<AffineBinaryOpExpr>(inpExpr);
394394
if (!addExpr || addExpr.getKind() != AffineExprKind::Add)
395395
return false;
396396

@@ -434,6 +434,263 @@ static bool convLayoutMatches(ArrayRef<ArrayRef<AffineExpr>> mapListExpected,
434434
})));
435435
}
436436

437+
// #inputMap = affine_map<(W, w) -> (W + w)>
438+
// #filterMap = affine_map<(W, w) -> (w)>
439+
// #outputMap = affine_map<(W, w) -> (W)>
440+
template <>
441+
bool isaConvolutionOpOfType<linalg::Conv1DOp>(LinalgOp op,
442+
SmallVector<int64_t> *dilations,
443+
SmallVector<int64_t> *strides) {
444+
if (isa<linalg::Conv1DOp>(op))
445+
return true;
446+
447+
assert(isaConvolutionOpInterface(op) &&
448+
"expected op to implement ConvolutionOpInterface");
449+
450+
*dilations = SmallVector<int64_t>(1, 1);
451+
*strides = SmallVector<int64_t>(1, 1);
452+
MLIRContext *context = op->getContext();
453+
AffineExpr W = getAffineDimExpr(0, context);
454+
AffineExpr w = getAffineDimExpr(1, context);
455+
ArrayAttr indexingMaps = op.getIndexingMaps();
456+
// First fetch dilations/strides :-
457+
// Match: W * stride + w * dilation
458+
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
459+
/*oDim=*/0, (*dilations)[0], (*strides)[0]))
460+
return false;
461+
// Match expected indexing maps
462+
if (!convLayoutMatches(
463+
{/*inputMap=*/{W * (*strides)[0] + w * (*dilations)[0]},
464+
/*filterMap=*/{w},
465+
/*outputMap=*/{W}},
466+
indexingMaps, context))
467+
return false;
468+
// Match body
469+
Block *body = op.getBlock();
470+
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
471+
Value yieldVal = yieldOp.getOperand(0);
472+
return bodyMatcherForConvolutionOps(yieldVal, body);
473+
}
474+
475+
// #inputMap = affine_map<(N, W, F, w, c) -> (N, W + w, c)>
476+
// #filterMap = affine_map<(N, W, F, w, c) -> (w, c, F)>
477+
// #outputMap = affine_map<(N, W, F, w, c) -> (N, W, F)>
478+
template <>
479+
bool isaConvolutionOpOfType<linalg::Conv1DNwcWcfOp>(
480+
LinalgOp op, SmallVector<int64_t> *dilations,
481+
SmallVector<int64_t> *strides) {
482+
if (isa<linalg::Conv1DNwcWcfOp>(op))
483+
return true;
484+
485+
assert(isaConvolutionOpInterface(op) &&
486+
"expected op to implement ConvolutionOpInterface");
487+
488+
*dilations = SmallVector<int64_t>(1, 1);
489+
*strides = SmallVector<int64_t>(1, 1);
490+
MLIRContext *context = op->getContext();
491+
AffineExpr N = getAffineDimExpr(0, context);
492+
AffineExpr W = getAffineDimExpr(1, context);
493+
AffineExpr F = getAffineDimExpr(2, context);
494+
AffineExpr w = getAffineDimExpr(3, context);
495+
AffineExpr c = getAffineDimExpr(4, context);
496+
ArrayAttr indexingMaps = op.getIndexingMaps();
497+
// First fetch dilations/strides :-
498+
// Match: W * stride + w * dilation
499+
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
500+
/*oDim=*/1, (*dilations)[0], (*strides)[0]))
501+
return false;
502+
// Match expected indexing maps
503+
if (!convLayoutMatches(
504+
{/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], c},
505+
/*filterMap=*/{w, c, F},
506+
/*outputMap=*/{N, W, F}},
507+
indexingMaps, context))
508+
return false;
509+
// Match body
510+
Block *body = op.getBlock();
511+
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
512+
Value yieldVal = yieldOp.getOperand(0);
513+
return bodyMatcherForConvolutionOps(yieldVal, body);
514+
}
515+
516+
// #inputMap = affine_map<(N, F, W, c, w) -> (N, c, W + w)>
517+
// #filterMap = affine_map<(N, F, W, c, w) -> (F, c, w)>
518+
// #outputMap = affine_map<(N, F, W, c, w) -> (N, F, W)>
519+
template <>
520+
bool isaConvolutionOpOfType<linalg::Conv1DNcwFcwOp>(
521+
LinalgOp op, SmallVector<int64_t> *dilations,
522+
SmallVector<int64_t> *strides) {
523+
if (isa<linalg::Conv1DNcwFcwOp>(op))
524+
return true;
525+
526+
assert(isaConvolutionOpInterface(op) &&
527+
"expected op to implement ConvolutionOpInterface");
528+
529+
*dilations = SmallVector<int64_t>(1, 1);
530+
*strides = SmallVector<int64_t>(1, 1);
531+
MLIRContext *context = op->getContext();
532+
AffineExpr N = getAffineDimExpr(0, context);
533+
AffineExpr F = getAffineDimExpr(1, context);
534+
AffineExpr W = getAffineDimExpr(2, context);
535+
AffineExpr c = getAffineDimExpr(3, context);
536+
AffineExpr w = getAffineDimExpr(4, context);
537+
ArrayAttr indexingMaps = op.getIndexingMaps();
538+
// First fetch dilations/strides :-
539+
// Match: W * stride + w * dilation
540+
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
541+
/*oDim=*/2, (*dilations)[0], (*strides)[0]))
542+
return false;
543+
// Match expected indexing maps
544+
if (!convLayoutMatches(
545+
{/*inputMap=*/{N, c, W * (*strides)[0] + w * (*dilations)[0]},
546+
/*filterMap=*/{F, c, w},
547+
/*outputMap=*/{N, F, W}},
548+
indexingMaps, context))
549+
return false;
550+
// Match body
551+
Block *body = op.getBlock();
552+
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
553+
Value yieldVal = yieldOp.getOperand(0);
554+
return bodyMatcherForConvolutionOps(yieldVal, body);
555+
}
556+
557+
// #inputMap = affine_map<(H, W, h, w) -> (H + h, W + w)>
558+
// #filterMap = affine_map<(H, W, h, w) -> (h, w)>
559+
// #outputMap = affine_map<(H, W, h, w) -> (H, W)>
560+
template <>
561+
bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
562+
SmallVector<int64_t> *dilations,
563+
SmallVector<int64_t> *strides) {
564+
if (isa<linalg::Conv2DOp>(op))
565+
return true;
566+
567+
assert(isaConvolutionOpInterface(op) &&
568+
"expected op to implement ConvolutionOpInterface");
569+
570+
*dilations = SmallVector<int64_t>(2, 1);
571+
*strides = SmallVector<int64_t>(2, 1);
572+
MLIRContext *context = op->getContext();
573+
AffineExpr H = getAffineDimExpr(0, context);
574+
AffineExpr W = getAffineDimExpr(1, context);
575+
AffineExpr h = getAffineDimExpr(2, context);
576+
AffineExpr w = getAffineDimExpr(3, context);
577+
ArrayAttr indexingMaps = op.getIndexingMaps();
578+
// First fetch dilations/strides :-
579+
// Match: H * stride + h * dilation
580+
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
581+
/*oDim=*/0, (*dilations)[0], (*strides)[0]))
582+
return false;
583+
// Match: W * stride + w * dilation
584+
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
585+
/*oDim=*/1, (*dilations)[1], (*strides)[1]))
586+
return false;
587+
// Match expected indexing maps
588+
if (!convLayoutMatches(
589+
{/*inputMap=*/{H * (*strides)[0] + h * (*dilations)[0],
590+
W * (*strides)[1] + w * (*dilations)[1]},
591+
/*filterMap=*/{h, w},
592+
/*outputMap=*/{H, W}},
593+
indexingMaps, context))
594+
return false;
595+
// Match body
596+
Block *body = op.getBlock();
597+
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
598+
Value yieldVal = yieldOp.getOperand(0);
599+
return bodyMatcherForConvolutionOps(yieldVal, body);
600+
}
601+
602+
// #inputMap = affine_map<(D, H, W, d, h, w) -> (D + d, H + h, W + w)>
603+
// #filterMap = affine_map<(D, H, W, d, h, w) -> (d, h, w)>
604+
// #outputMap = affine_map<(D, H, W, d, h, w) -> (D, H, W)>
605+
template <>
606+
bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
607+
SmallVector<int64_t> *dilations,
608+
SmallVector<int64_t> *strides) {
609+
if (isa<linalg::Conv3DOp>(op))
610+
return true;
611+
612+
assert(isaConvolutionOpInterface(op) &&
613+
"expected op to implement ConvolutionOpInterface");
614+
615+
*dilations = SmallVector<int64_t>(3, 1);
616+
*strides = SmallVector<int64_t>(3, 1);
617+
MLIRContext *context = op->getContext();
618+
AffineExpr D = getAffineDimExpr(0, context);
619+
AffineExpr H = getAffineDimExpr(1, context);
620+
AffineExpr W = getAffineDimExpr(2, context);
621+
AffineExpr d = getAffineDimExpr(3, context);
622+
AffineExpr h = getAffineDimExpr(4, context);
623+
AffineExpr w = getAffineDimExpr(5, context);
624+
ArrayAttr indexingMaps = op.getIndexingMaps();
625+
// First fetch dilations/strides :-
626+
// Match: D * stride + d * dilation
627+
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
628+
/*oDim=*/0, (*dilations)[0], (*strides)[0]))
629+
return false;
630+
// Match: H * stride + h * dilation
631+
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
632+
/*oDim=*/1, (*dilations)[1], (*strides)[1]))
633+
return false;
634+
// Match: W * stride + w * dilation
635+
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
636+
/*oDim=*/2, (*dilations)[2], (*strides)[2]))
637+
return false;
638+
// Match expected indexing maps
639+
if (!convLayoutMatches(
640+
{/*inputMap=*/{D * (*strides)[0] + d * (*dilations)[0],
641+
H * (*strides)[1] + h * (*dilations)[1],
642+
W * (*strides)[2] + w * (*dilations)[2]},
643+
/*filterMap=*/{d, h, w},
644+
/*outputMap=*/{D, H, W}},
645+
indexingMaps, context))
646+
return false;
647+
// Match body
648+
Block *body = op.getBlock();
649+
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
650+
Value yieldVal = yieldOp.getOperand(0);
651+
return bodyMatcherForConvolutionOps(yieldVal, body);
652+
}
653+
654+
// #inputMap = affine_map<(N, W, C, w) -> (N, C, W + w)>
655+
// #filterMap = affine_map<(N, W, C, w) -> (C, w)>
656+
// #outputMap = affine_map<(N, W, C, w) -> (N, C, W)>
657+
template <>
658+
bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
659+
LinalgOp op, SmallVector<int64_t> *dilations,
660+
SmallVector<int64_t> *strides) {
661+
if (isa<linalg::DepthwiseConv1DNcwCwOp>(op))
662+
return true;
663+
664+
assert(isaConvolutionOpInterface(op) &&
665+
"expected op to implement ConvolutionOpInterface");
666+
667+
*dilations = SmallVector<int64_t>(1, 1);
668+
*strides = SmallVector<int64_t>(1, 1);
669+
MLIRContext *context = op->getContext();
670+
AffineExpr N = getAffineDimExpr(0, context);
671+
AffineExpr W = getAffineDimExpr(1, context);
672+
AffineExpr C = getAffineDimExpr(2, context);
673+
AffineExpr w = getAffineDimExpr(3, context);
674+
ArrayAttr indexingMaps = op.getIndexingMaps();
675+
// First fetch dilations/strides :-
676+
// Match: W * stride + w * dilation
677+
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
678+
/*oDim=*/2, (*dilations)[0], (*strides)[0]))
679+
return false;
680+
// Match expected indexing maps
681+
if (!convLayoutMatches(
682+
{/*inputMap=*/{N, C, W * (*strides)[0] + w * (*dilations)[0]},
683+
/*filterMap=*/{C, w},
684+
/*outputMap=*/{N, C, W}},
685+
indexingMaps, context))
686+
return false;
687+
// Match body
688+
Block *body = op.getBlock();
689+
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
690+
Value yieldVal = yieldOp.getOperand(0);
691+
return bodyMatcherForConvolutionOps(yieldVal, body);
692+
}
693+
437694
// #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)>
438695
// #filterMap = affine_map<(N, W, C, w) -> (w, C)>
439696
// #outputMap = affine_map<(N, W, C, w) -> (N, W, C)>
@@ -474,6 +731,47 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
474731
return bodyMatcherForConvolutionOps(yieldVal, body);
475732
}
476733

734+
// #inputMap = affine_map<(N, W, C, CM, w) -> (N, W + w, C)>
735+
// #filterMap = affine_map<(N, W, C, CM, w) -> (w, C, CM)>
736+
// #outputMap = affine_map<(N, W, C, CM, w) -> (N, W, C, CM)>
737+
template <>
738+
bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
739+
LinalgOp op, SmallVector<int64_t> *dilations,
740+
SmallVector<int64_t> *strides) {
741+
if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op))
742+
return true;
743+
744+
assert(isaConvolutionOpInterface(op) &&
745+
"expected op to implement ConvolutionOpInterface");
746+
747+
*dilations = SmallVector<int64_t>(1, 1);
748+
*strides = SmallVector<int64_t>(1, 1);
749+
MLIRContext *context = op->getContext();
750+
AffineExpr N = getAffineDimExpr(0, context);
751+
AffineExpr W = getAffineDimExpr(1, context);
752+
AffineExpr C = getAffineDimExpr(2, context);
753+
AffineExpr CM = getAffineDimExpr(3, context);
754+
AffineExpr w = getAffineDimExpr(4, context);
755+
ArrayAttr indexingMaps = op.getIndexingMaps();
756+
// First fetch dilations/strides :-
757+
// Match: W * stride + w * dilation
758+
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
759+
/*oDim=*/1, (*dilations)[0], (*strides)[0]))
760+
return false;
761+
// Match expected indexing maps
762+
if (!convLayoutMatches(
763+
{/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], C},
764+
/*filterMap=*/{w, C, CM},
765+
/*outputMap=*/{N, W, C, CM}},
766+
indexingMaps, context))
767+
return false;
768+
// Match body
769+
Block *body = op.getBlock();
770+
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
771+
Value yieldVal = yieldOp.getOperand(0);
772+
return bodyMatcherForConvolutionOps(yieldVal, body);
773+
}
774+
477775
// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)>
478776
// #filterMap = affine_map<(N, H, W, C, h, w) -> (C, h, w)>
479777
// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H, W)>

0 commit comments

Comments
 (0)