Skip to content

Commit b6c4d4d

Browse files
committed
add fastmath propagation
1 parent 22e02b3 commit b6c4d4d

File tree

2 files changed

+62
-5
lines changed

2 files changed

+62
-5
lines changed

mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,11 @@ rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
650650
rewriter.getMultiDimIdentityMap(rank));
651651
SmallVector<utils::IteratorType> iteratorTypes(rank,
652652
utils::IteratorType::parallel);
653+
654+
// Check 'fast-math'. If present, propagate it.
655+
auto fmfOpInterface =
656+
llvm::dyn_cast<arith::ArithFastMathInterface>(op.getOperation());
657+
653658
auto genericOp = linalg::GenericOp::create(
654659
rewriter, loc, tensorType,
655660
op->getOperands(), // inputs
@@ -658,12 +663,32 @@ rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
658663
[&](OpBuilder &builder, Location loc, ValueRange args) {
659664
Value res;
660665
if (args.size() == 2) {
661-
res =
662-
builder.create<OpTy>(loc, args[1].getType(), ValueRange{args[0]})
663-
.getResult();
666+
if (fmfOpInterface) {
667+
auto attr = fmfOpInterface.getFastMathFlagsAttr();
668+
auto fmf = rewriter.getNamedAttr("fastmath", attr);
669+
res = builder
670+
.create<OpTy>(loc, args[1].getType(), ValueRange{args[0]},
671+
fmf)
672+
.getResult();
673+
} else {
674+
res = builder
675+
.create<OpTy>(loc, args[1].getType(), ValueRange{args[0]})
676+
.getResult();
677+
}
664678
} else if (args.size() == 3) {
665-
res = builder.create<OpTy>(loc, args[2].getType(),
666-
ValueRange{args[0], args[1]});
679+
if (fmfOpInterface) {
680+
auto attr = fmfOpInterface.getFastMathFlagsAttr();
681+
auto fmf = rewriter.getNamedAttr("fastmath", attr);
682+
res = builder
683+
.create<OpTy>(loc, args[2].getType(),
684+
ValueRange{args[0], args[1]}, fmf)
685+
.getResult();
686+
} else {
687+
res = builder
688+
.create<OpTy>(loc, args[2].getType(),
689+
ValueRange{args[0], args[1]})
690+
.getResult();
691+
}
667692
} else
668693
llvm_unreachable("did not expect ops other than nary and binary");
669694
linalg::YieldOp::create(builder, loc, res);

mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,3 +313,35 @@ module attributes {transform.with_named_sequence} {
313313
transform.yield
314314
}
315315
}
316+
317+
// -----
318+
319+
// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)>
320+
// CHECK-LABEL: func @arith_binop_fastmath(
321+
// CHECK-SAME: %[[X:.+]]: tensor<?xf32>, %[[Y:.+]]: tensor<?xf32>
322+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
323+
// CHECK: %[[DIM:.+]] = tensor.dim %[[X]], %[[C0]] : tensor<?xf32>
324+
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?xf32>
325+
// CHECK: %[[GENERIC:.+]] = linalg.generic
326+
// CHECK-SAME: {indexing_maps = [#[[$map]], #[[$map]], #[[$map]]], iterator_types = ["parallel"]}
327+
// CHECK-SAME: ins(%[[X:.+]], %[[Y:.+]] : tensor<?xf32>, tensor<?xf32>) outs(%[[EMPTY]] : tensor<?xf32>) {
328+
// CHECK: ^bb0(%[[x:.+]]: f32, %[[y:.+]]: f32, %[[Out:.+]]: f32):
329+
// CHECK: %[[z:.+]] = arith.addf %[[x]], %[[y]] fastmath<fast> : f32
330+
// CHECK: linalg.yield %[[z]] : f32
331+
// CHECK: return %[[GENERIC]] : tensor<?xf32>
332+
333+
func.func @arith_binop_fastmath(%x : tensor<?xf32>, %y : tensor<?xf32>)
334+
-> tensor<?xf32> {
335+
%z = arith.addf %x, %y fastmath<fast> : tensor<?xf32>
336+
return %z : tensor<?xf32>
337+
}
338+
339+
module attributes {transform.with_named_sequence} {
340+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
341+
%0 = transform.structured.match ops{["arith.addf"]} in %arg1
342+
: (!transform.any_op) -> !transform.any_op
343+
transform.structured.rewrite_in_destination_passing_style %0
344+
: (!transform.any_op) -> !transform.any_op
345+
transform.yield
346+
}
347+
}

0 commit comments

Comments
 (0)