Skip to content

Commit 9678002

Browse files
aartbikmemfrob
authored andcommitted
[mlir] [VectorOps] Correctly account for rank-0 affine-map result in vector.contract
Summary: Now that, thanks to ntv, we have the ability to parse and represent an affine map with rank-0 results, viz. (i,j) -> (), we can pay off some engineering debt in special casing the verification of such affine maps in dot-product flavored vector.contract operations. Reviewers: nicolasvasilache, andydavis1, rriddle Reviewed By: nicolasvasilache Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D76028
1 parent c31c0ec commit 9678002

File tree

2 files changed

+5
-11
lines changed

2 files changed

+5
-11
lines changed

mlir/lib/Dialect/VectorOps/VectorOps.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -302,13 +302,8 @@ static LogicalResult verify(ContractionOp op) {
302302
<< index << " to have no symbols";
303303
auto vectorType = op.getOperand(index).getType().dyn_cast<VectorType>();
304304
unsigned rank = vectorType ? vectorType.getShape().size() : 0;
305-
// Since (...) -> () is parsed into an empty map, we need to add
306-
// a special case for this situation: continue the verification
307-
// of an empty map if the resulting rank is indeed zero, i.e. this
308-
// is a reduction into a scalar.
309-
if (map.getNumDims() == 0 && map.getNumResults() == 0 && rank == 0)
310-
continue;
311305
// Verify that the map has the right number of inputs, outputs, and indices.
306+
// This also correctly accounts for (..) -> () for rank-0 results.
312307
if (map.getNumDims() != numIterators)
313308
return op.emitOpError("expected indexing map ")
314309
<< index << " to have " << numIterators << " number of inputs";

mlir/lib/Dialect/VectorOps/VectorTransforms.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,20 +1077,19 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
10771077
// Helper to construct an affine map with one index removed.
10781078
static AffineMap adjustMap(AffineMap map, int64_t index,
10791079
PatternRewriter &rewriter) {
1080+
auto *ctx = rewriter.getContext();
10801081
SmallVector<AffineExpr, 4> results;
10811082
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
10821083
int64_t idx = map.getResult(i).cast<AffineDimExpr>().getPosition();
10831084
if (idx == index)
10841085
continue;
10851086
// Re-insert remaining indices, but renamed when occurring
10861087
// after the removed index.
1087-
auto targetExpr =
1088-
getAffineDimExpr(idx < index ? idx : idx - 1, rewriter.getContext());
1088+
auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
10891089
results.push_back(targetExpr);
10901090
}
1091-
// Since (...) -> () cannot be represented properly,
1092-
// we resort to an empty map when this situation happens.
1093-
return results.empty() ? AffineMap::get(rewriter.getContext())
1091+
// The (...) -> () affine map has its own factory method.
1092+
return results.empty() ? AffineMap::get(map.getNumDims() - 1, 0, ctx)
10941093
: AffineMap::get(map.getNumDims() - 1, 0, results);
10951094
}
10961095

0 commit comments

Comments
 (0)