|
20 | 20 | #include "mlir/IR/Builders.h" |
21 | 21 | #include "mlir/IR/BuiltinAttributes.h" |
22 | 22 | #include "mlir/IR/BuiltinTypes.h" |
| 23 | +#include <cassert> |
23 | 24 | #include <utility> |
24 | 25 |
|
25 | 26 | namespace mlir { |
@@ -124,6 +125,10 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) { |
124 | 125 | auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType()); |
125 | 126 | auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType()); |
126 | 127 |
|
| 128 | + if (!convOp.hasPureTensorSemantics()) |
| 129 | + return rewriter.notifyMatchFailure( |
| 130 | + convOp, "expected op to have pure tensor semantics"); |
| 131 | + |
127 | 132 | if (!filterType.hasStaticShape()) |
128 | 133 | return rewriter.notifyMatchFailure( |
129 | 134 | convOp, "expected a static shape for the filter"); |
@@ -155,10 +160,15 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) { |
155 | 160 |
|
156 | 161 | Location loc = convOp.getLoc(); |
157 | 162 |
|
| 163 | + assert(isa<RankedTensorType>(filterType) && |
| 164 | + "expected filter type to be a ranked tensor"); |
| 165 | + auto tensorFilterType = cast<RankedTensorType>(filterType); |
| 166 | + |
158 | 167 | // Reshape output and filter to the LHS and result of a (B)MNK matmul. |
159 | 168 | SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}}; |
160 | 169 | auto reshapedFilterType = |
161 | | - RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType()); |
| 170 | + RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType(), |
| 171 | + tensorFilterType.getEncoding()); |
162 | 172 | Value reshapedFilter = tensor::CollapseShapeOp::create( |
163 | 173 | rewriter, loc, reshapedFilterType, filter, filterReassocIndices); |
164 | 174 |
|
@@ -253,6 +263,10 @@ rewriteInIm2Col(RewriterBase &rewriter, |
253 | 263 | auto filterType = cast<RankedTensorType>(convOp.getInputs()[1].getType()); |
254 | 264 | auto outputType = cast<RankedTensorType>(convOp.getOutputs()[0].getType()); |
255 | 265 |
|
| 266 | + if (!convOp.hasPureTensorSemantics()) |
| 267 | + return rewriter.notifyMatchFailure( |
| 268 | + convOp, "expected op to have pure tensor semantics"); |
| 269 | + |
256 | 270 | if (!filterType.hasStaticShape()) |
257 | 271 | return rewriter.notifyMatchFailure( |
258 | 272 | convOp, "expected a static shape for the filter"); |
@@ -404,6 +418,10 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) { |
404 | 418 | auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType()); |
405 | 419 | auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType()); |
406 | 420 |
|
| 421 | + if (!convOp.hasPureTensorSemantics()) |
| 422 | + return rewriter.notifyMatchFailure( |
| 423 | + convOp, "expected op to have pure tensor semantics"); |
| 424 | + |
407 | 425 | if (!filterType.hasStaticShape()) |
408 | 426 | return rewriter.notifyMatchFailure( |
409 | 427 | convOp, "expected a static shape for the filter"); |
@@ -435,9 +453,14 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) { |
435 | 453 | auto loc = convOp.getLoc(); |
436 | 454 | MLIRContext *context = rewriter.getContext(); |
437 | 455 |
|
| 456 | + assert(isa<RankedTensorType>(filterType) && |
| 457 | + "expected filter type to be a ranked tensor"); |
| 458 | + auto tensorFilterType = cast<RankedTensorType>(filterType); |
| 459 | + |
438 | 460 | SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}}; |
439 | 461 | auto reshapedFilterType = |
440 | | - RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType()); |
| 462 | + RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType(), |
| 463 | + tensorFilterType.getEncoding()); |
441 | 464 | Value reshapedFilter = tensor::CollapseShapeOp::create( |
442 | 465 | rewriter, loc, reshapedFilterType, filter, filterReassocIndices); |
443 | 466 |
|
@@ -529,6 +552,10 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) { |
529 | 552 | auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType()); |
530 | 553 | auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType()); |
531 | 554 |
|
| 555 | + if (!convOp.hasPureTensorSemantics()) |
| 556 | + return rewriter.notifyMatchFailure( |
| 557 | + convOp, "expected op to have pure tensor semantics"); |
| 558 | + |
532 | 559 | if (!filterType.hasStaticShape()) |
533 | 560 | return rewriter.notifyMatchFailure( |
534 | 561 | convOp, "expected a static shape for the filter"); |
@@ -560,11 +587,16 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) { |
560 | 587 |
|
561 | 588 | Location loc = convOp.getLoc(); |
562 | 589 |
|
| 590 | + assert(isa<RankedTensorType>(filterType) && |
| 591 | + "expected filter type to be a ranked tensor"); |
| 592 | + auto tensorFilterType = cast<RankedTensorType>(filterType); |
| 593 | + |
563 | 594 | // Reshape output and filter to the LHS and result of a "row-wise" matrix |
564 | 595 | // multiplication. |
565 | 596 | SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}}; |
566 | 597 | auto reshapedFilterType = |
567 | | - RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType()); |
| 598 | + RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType(), |
| 599 | + tensorFilterType.getEncoding()); |
568 | 600 | Value reshapedFilter = tensor::CollapseShapeOp::create( |
569 | 601 | rewriter, loc, reshapedFilterType, filter, filterReassocIndices); |
570 | 602 |
|
|
0 commit comments