Skip to content

Commit 0ce5d77

Browse files
authored
Restore maxNumImpreciseAcc guard for AddFOp (triton-lang#8056)
This restores the legacy guard from Combine.td: we only fold `addf(dot, bias)` into `dot(..., C=bias)` when `maxNumImpreciseAcc == 0`. Without this, `use_fast_accum=False` kernels were silently rewritten into the fast-accum form, causing accuracy drift in FP8 tests. This change ensures precise accumulation semantics are preserved while keeping the optimization enabled when imprecise accumulation is allowed. <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because existing tests should cover it. - Select one of the following. - [ ] I have not added any `lit` tests. - [x] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 6f06595 commit 0ce5d77

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

lib/Dialect/Triton/Transforms/Combine.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,11 @@ class CombineDotAddPattern : public mlir::OpRewritePattern<OpTy> {
252252
}
253253
if (!isZero(dotOp.getC()))
254254
return failure();
255+
if constexpr (std::is_same_v<OpTy, arith::AddFOp>) {
256+
if (dotOp.getMaxNumImpreciseAcc() != 0) {
257+
return failure();
258+
}
259+
}
255260
rewriter.modifyOpInPlace(dotOp, [&] {
256261
dotOp.getCMutable().assign(isDotLHS ? addOp.getRhs() : addOp.getLhs());
257262
dotOp->moveBefore(addOp);

test/Triton/combine.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,3 +413,39 @@ tt.func @test_rank_reduce_desc_load(%0: !tt.tensordesc<tensor<1x128x64xf16>>) ->
413413
%r = tt.reshape %l : tensor<1x128x64xf16> -> tensor<128x64xf16>
414414
tt.return %r : tensor<128x64xf16>
415415
}
416+
417+
// CHECK-LABEL: @test_combine_dot_add_no_fold_when_imprecise_allowed
418+
tt.func @test_combine_dot_add_no_fold_when_imprecise_allowed() -> (tensor<128x128xf32>) {
419+
// CHECK-DAG: %[[D:.*]] = arith.constant dense<3.000000e+00> : tensor<128x128xf32>
420+
%a = arith.constant dense<1.0> : tensor<128x128xf32>
421+
%b = arith.constant dense<2.0> : tensor<128x128xf32>
422+
%zero = arith.constant dense<0.0> : tensor<128x128xf32>
423+
%d = arith.constant dense<3.0> : tensor<128x128xf32>
424+
425+
%dot_out = tt.dot %a, %b, %zero {maxNumImpreciseAcc = 1 : i32}
426+
: tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
427+
428+
// CHECK: arith.addf %{{.*}}, %[[D]] : tensor<128x128xf32>
429+
// CHECK-NEXT: tt.return %{{.*}} : tensor<128x128xf32>
430+
%res = arith.addf %dot_out, %d : tensor<128x128xf32>
431+
tt.return %res : tensor<128x128xf32>
432+
}
433+
434+
// CHECK-LABEL: @test_combine_dot_add_fold_when_precise_required
435+
tt.func @test_combine_dot_add_fold_when_precise_required() -> (tensor<128x128xf32>) {
436+
// CHECK-DAG: %[[D:.*]] = arith.constant dense<3.000000e+00> : tensor<128x128xf32>
437+
// CHECK-DAG: %[[B:.*]] = arith.constant dense<2.000000e+00> : tensor<128x128xf32>
438+
// CHECK-DAG: %[[A:.*]] = arith.constant dense<1.000000e+00> : tensor<128x128xf32>
439+
%a = arith.constant dense<1.0> : tensor<128x128xf32>
440+
%b = arith.constant dense<2.0> : tensor<128x128xf32>
441+
%zero = arith.constant dense<0.0> : tensor<128x128xf32>
442+
%d = arith.constant dense<3.0> : tensor<128x128xf32>
443+
444+
%dot_out = tt.dot %a, %b, %zero {maxNumImpreciseAcc = 0 : i32}
445+
: tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
446+
447+
// CHECK-NEXT: %[[RES:.*]] = tt.dot %[[A]], %[[B]], %[[D]] : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
448+
// CHECK-NEXT: tt.return %[[RES]] : tensor<128x128xf32>
449+
%res = arith.addf %dot_out, %d : tensor<128x128xf32>
450+
tt.return %res : tensor<128x128xf32>
451+
}

0 commit comments

Comments
 (0)