Skip to content

Commit 384a5ca

Browse files
committed
add back assert
1 parent 7df6355 commit 384a5ca

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6171,12 +6171,13 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
61716171
}
61726172

61736173
auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
6174+
VectorType outputType = transpose.getResultVectorType();
61746175

61756176
// transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
61766177
bool inputIsScalar = !inputType;
61776178
if (inputIsScalar) {
6178-
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6179-
transpose, transpose.getResultVectorType(), transpose.getVector());
6179+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
6180+
transpose.getVector());
61806181
return success();
61816182
}
61826183

@@ -6210,11 +6211,13 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
62106211
// the check (impossible to have just 1 non-locally bound group).
62116212

62126213
// The preceding logic also ensures that at this point, the output of the
6213-
// transpose is definitely broadcastable from the input shape, so we don't
6214-
// need to check vector::isBroadcastableTo now.
6214+
// transpose is definitely broadcastable from the input shape, assert so:
6215+
assert(vector::isBroadcastableTo(inputType, outputType) ==
6216+
vector::BroadcastableToResult::Success &&
6217+
"not broadcastable directly to transpose output");
62156218

6216-
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6217-
transpose, transpose.getResultVectorType(), transpose.getVector());
6219+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
6220+
transpose.getVector());
62186221

62196222
return success();
62206223
}

0 commit comments

Comments
 (0)