@@ -6171,12 +6171,13 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
61716171 }
61726172
61736173 auto inputType = dyn_cast<VectorType>(broadcast.getSourceType ());
6174+ VectorType outputType = transpose.getResultVectorType ();
61746175
61756176 // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
61766177 bool inputIsScalar = !inputType;
61776178 if (inputIsScalar) {
6178- rewriter.replaceOpWithNewOp <vector::BroadcastOp>(
6179- transpose, transpose. getResultVectorType (), transpose.getVector ());
6179+ rewriter.replaceOpWithNewOp <vector::BroadcastOp>(transpose, outputType,
6180+ transpose.getVector ());
61806181 return success ();
61816182 }
61826183
@@ -6210,11 +6211,13 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
62106211 // the check (impossible to have just 1 non-locally bound group).
62116212
62126213 // The preceding logic also ensures that at this point, the output of the
6213- // transpose is definitely broadcastable from the input shape, so we don't
6214- // need to check vector::isBroadcastableTo now.
6214+ // transpose is definitely broadcastable from the input shape, assert so:
6215+ assert (vector::isBroadcastableTo (inputType, outputType) ==
6216+ vector::BroadcastableToResult::Success &&
6217+ " not broadcastable directly to transpose output" );
62156218
6216- rewriter.replaceOpWithNewOp <vector::BroadcastOp>(
6217- transpose, transpose. getResultVectorType (), transpose.getVector ());
6219+ rewriter.replaceOpWithNewOp <vector::BroadcastOp>(transpose, outputType,
6220+ transpose.getVector ());
62186221
62196222 return success ();
62206223 }
0 commit comments