Skip to content

Commit d3281b5

Browse files
committed
fma workaround
1 parent a65f072 commit d3281b5

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1066,7 +1066,10 @@ class ExtractOpFromElementwise final
10661066
PatternRewriter &rewriter) const override {
10671067
Operation *eltwise = op.getVector().getDefiningOp();
10681068

1069-
if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise))
1069+
// TODO: vector::FMAOp is not ElemetwiseMappable eve if it claims to be, as
1070+
// it doesn't support scalars.
1071+
if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
1072+
isa<vector::FMAOp>(eltwise))
10701073
return rewriter.notifyMatchFailure(op, "not an elementwise op");
10711074

10721075
if (eltwise->getNumResults() != 1)

mlir/test/Dialect/Vector/vector-sink.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,3 +501,15 @@ func.func @negative_extract_not_elementwise(%arg0: vector<4xi64>) -> i64 {
501501
%1 = vector.extract %0[1] : i64 from vector<4xi64>
502502
return %1 : i64
503503
}
504+
505+
// CHECK-LABEL: @negative_extract_vec_fma
506+
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>, %[[ARG2:.*]]: vector<4xf32>)
507+
func.func @negative_extract_vec_fma(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> f32 {
508+
// `vector.fma` doesn't suppport scalars.
509+
// CHECK: %[[FMA:.*]] = vector.fma %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<4xf32>
510+
// CHECK: %[[RES:.*]] = vector.extract %[[FMA]][1] : f32 from vector<4xf32>
511+
// CHECK: return %[[RES]] : f32
512+
%0 = vector.fma %arg0, %arg1, %arg2: vector<4xf32>
513+
%1 = vector.extract %0[1] : f32 from vector<4xf32>
514+
return %1 : f32
515+
}

0 commit comments

Comments
 (0)