Skip to content

Commit 5e021b9

Browse files
sjw36liuyunqi20
authored andcommitted
[Backend] Copy attributes to new loop in RewriteTensorPointer (#4848)
This fixes tl.num_stages lost in translation.
1 parent 5ddb8c1 commit 5e021b9

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,7 @@ class RewriteTensorPointerPass
413413
auto newForOp = builder.create<scf::ForOp>(op.getLoc(), op.getLowerBound(),
414414
op.getUpperBound(), op.getStep(),
415415
newIterOperands);
416+
newForOp->setAttrs(op->getAttrs());
416417

417418
// Create value mapping. Note that for tensor pointers, we use identity
418419
// mapping. It may refer to a value in the old loop, but we will rewrite it

test/Triton/rewrite-tensor-pointer.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ tt.func public @rewrite_for(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>) {
111111
%4 = arith.addf %arg3, %3 : tensor<128x32xf16>
112112
%5 = tt.advance %arg4, [%c32_i32, %c0_i32] : !tt.ptr<tensor<128x32xf16>>
113113
scf.yield %4, %5 : tensor<128x32xf16>, !tt.ptr<tensor<128x32xf16>>
114-
}
114+
} {tt.num_stages = 3 : i32}
115115
%2 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<128x32x!tt.ptr<f16>>
116116
tt.store %2, %1#0 : tensor<128x32x!tt.ptr<f16>>
117117
tt.return
@@ -138,6 +138,7 @@ tt.func public @rewrite_for(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>) {
138138
// CHECK: %[[EXTSI3:.*]] = arith.extsi %[[C0_I32]] : i32 to i64
139139
// CHECK: %[[ADDI1:.*]] = arith.addi %[[ARG5]], %[[EXTSI3]] : i64
140140
// CHECK: scf.yield %{{.*}}, %[[ADDI0]], %[[ADDI1]] : tensor<128x32xf16>, i64, i64
141+
// CHECK: tt.num_stages = 3
141142

142143
// -----
143144
tt.func public @rewrite_if(%arg0: !tt.ptr<f16>, %arg1: i1) -> tensor<128x32xf16> {

0 commit comments

Comments
 (0)