Skip to content

Commit 8a34c21

Browse files
authored
[AMD] Accumulate into offset instead of pointer for jit specialized tensor (#7939)
We canonicalize pointer pointing to small-tensor in different way. For example, given input like this: %p1 = tt.addptr %p0, %ofst ... %p2 = tt.addptr %p1, %ofst2 it will be canonicalized into following: %p1 = tt.addptr %p0, %ofst ... %p2 = tt.addptr %p0, (%ofst2 + %ofst) The rationale is two-fold: 1) is to fix bug like the one reported in the issue-830 2) to reveal the information to buffer-op optimization. buffer-op pass will easily spot the a global memory op whose base pointer is pointing to a small-tensor and hence can safely convert them into buffer-op.
1 parent 6ec5e0c commit 8a34c21

File tree

3 files changed

+552
-239
lines changed

3 files changed

+552
-239
lines changed

test/TritonGPU/amd/amd-canonicalize-pointers-dont-run-mlir-canonicalizer.mlir

Lines changed: 85 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -24,41 +24,33 @@ module attributes {"ttg.num-warps" = 4 : i32} {
2424
}
2525
}
2626

27-
// CHECK-LABEL: tt.func @ifOpTwoYields(
28-
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>,
29-
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1024xf32>,
30-
// CHECK-SAME: %[[VAL_2:.*]]: i1) -> (tensor<1024xf32>, tensor<1024xf32>) {
31-
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : i64
32-
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1024 : i32
33-
// CHECK: %[[VAL_5:.*]] = tt.get_program_id x : i32
34-
// CHECK: %[[VAL_6:.*]] = arith.muli %[[VAL_5]], %[[VAL_4]] : i32
35-
// CHECK: %[[VAL_7:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
36-
// CHECK: %[[VAL_8:.*]] = tt.splat %[[VAL_6]] : i32 -> tensor<1024xi32>
37-
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_7]] : tensor<1024xi32>
38-
// CHECK: %[[VAL_10:.*]] = tt.splat %[[VAL_3]] : i64 -> tensor<1024xi64>
39-
// CHECK: %[[VAL_11:.*]]:4 = scf.if %[[VAL_2]] -> (!tt.ptr<f32>, tensor<1024xi64>, !tt.ptr<f32>, tensor<1024xi64>) {
40-
// CHECK-DAG: %[[VAL_12:.*]] = arith.constant dense<0> : tensor<1024xi32>
41-
// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 0 : i32
42-
// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_6]], %[[VAL_13]] : i32
43-
// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_12]], %[[VAL_7]] : tensor<1024xi32>
44-
// CHECK: %[[VAL_16:.*]] = tt.addptr %[[VAL_0]], %[[VAL_14]] : !tt.ptr<f32>, i32
45-
// CHECK: %[[VAL_17:.*]] = arith.extsi %[[VAL_15]] : tensor<1024xi32> to tensor<1024xi64>
46-
// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_10]] : tensor<1024xi64>
47-
// CHECK: scf.yield %[[VAL_16]], %[[VAL_18]], %[[VAL_16]], %[[VAL_18]] : !tt.ptr<f32>, tensor<1024xi64>, !tt.ptr<f32>, tensor<1024xi64>
48-
// CHECK: } else {
49-
// CHECK: %[[VAL_19:.*]] = tt.addptr %[[VAL_0]], %[[VAL_6]] : !tt.ptr<f32>, i32
50-
// CHECK: scf.yield %[[VAL_19]], %[[VAL_10]], %[[VAL_19]], %[[VAL_10]] : !tt.ptr<f32>, tensor<1024xi64>, !tt.ptr<f32>, tensor<1024xi64>
51-
// CHECK: }
52-
// CHECK: %[[VAL_20:.*]] = arith.trunci %[[VAL_21:.*]]#1 : tensor<1024xi64> to tensor<1024xi32>
53-
// CHECK: %[[VAL_22:.*]] = tt.splat %[[VAL_21]]#0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
54-
// CHECK: %[[VAL_23:.*]] = tt.addptr %[[VAL_22]], %[[VAL_20]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
55-
// CHECK: %[[VAL_24:.*]] = tt.load %[[VAL_23]] : tensor<1024x!tt.ptr<f32>>
56-
// CHECK: %[[VAL_25:.*]] = arith.trunci %[[VAL_21]]#3 : tensor<1024xi64> to tensor<1024xi32>
57-
// CHECK: %[[VAL_26:.*]] = tt.splat %[[VAL_21]]#2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
58-
// CHECK: %[[VAL_27:.*]] = tt.addptr %[[VAL_26]], %[[VAL_25]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
59-
// CHECK: %[[VAL_28:.*]] = tt.load %[[VAL_27]] : tensor<1024x!tt.ptr<f32>>
60-
// CHECK: tt.return %[[VAL_24]], %[[VAL_28]] : tensor<1024xf32>, tensor<1024xf32>
61-
// CHECK: }
27+
// CHECK-LABEL: tt.func @ifOpTwoYields(
28+
// CHECK-SAME: %arg0: !tt.ptr<f32>,
29+
// CHECK-SAME: %arg1: tensor<1024xf32>,
30+
// CHECK-SAME: %arg2: i1) -> (tensor<1024xf32>, tensor<1024xf32>) {
31+
// CHECK: %[[const0:.*]] = arith.constant 0 : i64
32+
// CHECK: %[[C1024:.*]] = arith.constant 1024 : i32
33+
// CHECK: %[[PID:.*]] = tt.get_program_id x : i32
34+
// CHECK: %[[PID_time_1024:.*]] = arith.muli %[[PID]], %[[C1024]] : i32
35+
// CHECK: %[[MAKE_RANGE_1024:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
36+
// CHECK: %[[CONST_ZERO_SPLAT:.*]] = tt.splat %[[const0]] : i64 -> tensor<1024xi64>
37+
// CHECK: %[[SCF:.*]]:4 = scf.if %arg2 -> (!tt.ptr<f32>, tensor<1024xi64>, !tt.ptr<f32>, tensor<1024xi64>) {
38+
// CHECK: %[[ADDPTR1:.*]] = tt.addptr %arg0, %[[PID_time_1024]] : !tt.ptr<f32>, i32
39+
// CHECK: %[[EXT_RANGE:.*]] = arith.extsi %[[MAKE_RANGE_1024]] : tensor<1024xi32> to tensor<1024xi64>
40+
// CHECK: scf.yield %[[ADDPTR1]], %[[EXT_RANGE]], %[[ADDPTR1]], %[[EXT_RANGE]] : !tt.ptr<f32>, tensor<1024xi64>, !tt.ptr<f32>, tensor<1024xi64>
41+
// } else {
42+
// CHECK: %[[ADDPTR2:.*]] = tt.addptr %arg0, %[[PID_time_1024]] : !tt.ptr<f32>, i32
43+
// CHECK: scf.yield %[[ADDPTR2]], %[[CONST_ZERO_SPLAT]], %[[ADDPTR2]], %[[CONST_ZERO_SPLAT]] : !tt.ptr<f32>, tensor<1024xi64>, !tt.ptr<f32>, tensor<1024xi64>
44+
// }
45+
// CHECK: %[[dont_care_5:.*]] = arith.trunci %[[SCF]]#1 : tensor<1024xi64> to tensor<1024xi32>
46+
// CHECK: %[[dont_care_6:.*]] = tt.splat %[[SCF]]#0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
47+
// CHECK: %[[dont_care_7:.*]] = tt.addptr %[[dont_care_6]], %[[dont_care_5]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
48+
// CHECK: %[[dont_care_8:.*]] = tt.load %[[dont_care_7]] : tensor<1024x!tt.ptr<f32>>
49+
// CHECK: %[[dont_care_9:.*]] = arith.trunci %[[SCF]]#3 : tensor<1024xi64> to tensor<1024xi32>
50+
// CHECK: %[[dont_care_10:.*]] = tt.splat %[[SCF]]#2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
51+
// CHECK: %[[dont_care_11:.*]] = tt.addptr %[[dont_care_10]], %[[dont_care_9]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
52+
// CHECK: %[[dont_care_12:.*]] = tt.load %[[dont_care_11]] : tensor<1024x!tt.ptr<f32>>
53+
// CHECK: tt.return %[[dont_care_8]], %[[dont_care_12]] : tensor<1024xf32>, tensor<1024xf32>
6254

6355
// -----
6456

@@ -76,7 +68,8 @@ module attributes {"ttg.num-warps" = 4 : i32} {
7668
scf.yield %8, %8, %0 : tensor<1024x!tt.ptr<f32>>, tensor<1024x!tt.ptr<f32>>, i32
7769
} else {
7870
%8 = tt.addptr %5, %3 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
79-
scf.yield %8, %8, %0 : tensor<1024x!tt.ptr<f32>>, tensor<1024x!tt.ptr<f32>>, i32
71+
%9 = arith.muli %1, %1 : i32
72+
scf.yield %8, %8, %9 : tensor<1024x!tt.ptr<f32>>, tensor<1024x!tt.ptr<f32>>, i32
8073
}
8174
%7 = tt.load %6#0 : tensor<1024x!tt.ptr<f32>>
8275
%8 = tt.load %6#1 : tensor<1024x!tt.ptr<f32>>
@@ -85,41 +78,33 @@ module attributes {"ttg.num-warps" = 4 : i32} {
8578
}
8679

8780
// CHECK-LABEL: tt.func @ifOpTwoYieldsAndNonPtr(
88-
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>,
89-
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1024xf32>,
90-
// CHECK-SAME: %[[VAL_2:.*]]: i1) -> (tensor<1024xf32>, tensor<1024xf32>, i32) {
91-
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : i64
92-
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1024 : i32
93-
// CHECK: %[[VAL_5:.*]] = tt.get_program_id x : i32
94-
// CHECK: %[[VAL_6:.*]] = arith.muli %[[VAL_5]], %[[VAL_4]] : i32
95-
// CHECK: %[[VAL_7:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
96-
// CHECK: %[[VAL_8:.*]] = tt.splat %[[VAL_6]] : i32 -> tensor<1024xi32>
97-
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_7]] : tensor<1024xi32>
98-
// CHECK: %[[VAL_10:.*]] = tt.splat %[[VAL_3]] : i64 -> tensor<1024xi64>
99-
// CHECK: %[[VAL_11:.*]]:5 = scf.if %[[VAL_2]] -> (!tt.ptr<f32>, tensor<1024xi64>, !tt.ptr<f32>, tensor<1024xi64>, i32) {
100-
// CHECK-DAG: %[[VAL_12:.*]] = arith.constant dense<0> : tensor<1024xi32>
101-
// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 0 : i32
102-
// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_6]], %[[VAL_13]] : i32
103-
// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_12]], %[[VAL_7]] : tensor<1024xi32>
104-
// CHECK: %[[VAL_16:.*]] = tt.addptr %[[VAL_0]], %[[VAL_14]] : !tt.ptr<f32>, i32
105-
// CHECK: %[[VAL_17:.*]] = arith.extsi %[[VAL_15]] : tensor<1024xi32> to tensor<1024xi64>
106-
// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_10]] : tensor<1024xi64>
107-
// CHECK: scf.yield %[[VAL_16]], %[[VAL_18]], %[[VAL_16]], %[[VAL_18]], %[[VAL_5]] : !tt.ptr<f32>, tensor<1024xi64>, !tt.ptr<f32>, tensor<1024xi64>, i32
108-
// CHECK: } else {
109-
// CHECK: %[[VAL_19:.*]] = tt.addptr %[[VAL_0]], %[[VAL_6]] : !tt.ptr<f32>, i32
110-
// CHECK: scf.yield %[[VAL_19]], %[[VAL_10]], %[[VAL_19]], %[[VAL_10]], %[[VAL_5]] : !tt.ptr<f32>, tensor<1024xi64>, !tt.ptr<f32>, tensor<1024xi64>, i32
111-
// CHECK: }
112-
// CHECK: %[[VAL_20:.*]] = arith.trunci %[[VAL_21:.*]]#1 : tensor<1024xi64> to tensor<1024xi32>
113-
// CHECK: %[[VAL_22:.*]] = tt.splat %[[VAL_21]]#0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
114-
// CHECK: %[[VAL_23:.*]] = tt.addptr %[[VAL_22]], %[[VAL_20]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
115-
// CHECK: %[[VAL_24:.*]] = tt.load %[[VAL_23]] : tensor<1024x!tt.ptr<f32>>
116-
// CHECK: %[[VAL_25:.*]] = arith.trunci %[[VAL_21]]#3 : tensor<1024xi64> to tensor<1024xi32>
117-
// CHECK: %[[VAL_26:.*]] = tt.splat %[[VAL_21]]#2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
118-
// CHECK: %[[VAL_27:.*]] = tt.addptr %[[VAL_26]], %[[VAL_25]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
119-
// CHECK: %[[VAL_28:.*]] = tt.load %[[VAL_27]] : tensor<1024x!tt.ptr<f32>>
120-
// CHECK: tt.return %[[VAL_24]], %[[VAL_28]], %[[VAL_21]]#4 : tensor<1024xf32>, tensor<1024xf32>, i32
121-
// CHECK: }
122-
81+
// CHECK-SAME: %arg0: !tt.ptr<f32>,
82+
// CHECK-SAME: %arg1: tensor<1024xf32>,
83+
// CHECK-SAME: %arg2: i1) -> (tensor<1024xf32>, tensor<1024xf32>, i32) {
84+
// CHECK-DAG: %c0_i64 = arith.constant 0 : i64
85+
// CHECK: %[[C1024:.*]] = arith.constant 1024 : i32
86+
// CHECK: %[[PID:.*]] = tt.get_program_id x : i32
87+
// CHECK: %[[PID_TIME_1024:.*]] = arith.muli %[[PID]], %[[C1024]] : i32
88+
// CHECK: %[[MK_RANGE:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
89+
// CHECK: %[[CONST0_SPLAT:.*]] = tt.splat %c0_i64 : i64 -> tensor<1024xi64>
90+
// CHECK: %[[SCF_IF:.*]]:5 = scf.if %arg2 -> (!tt.ptr<f32>, tensor<1024xi64>, !tt.ptr<f32>, tensor<1024xi64>, i32) {
91+
// CHECK: %[[PTR_BASE_0:.*]] = tt.addptr %arg0, %[[PID_TIME_1024]] : !tt.ptr<f32>, i32
92+
// CHECK: %[[EXT_MK_RANGE:.*]] = arith.extsi %[[MK_RANGE]] : tensor<1024xi32> to tensor<1024xi64>
93+
// CHECK: scf.yield %[[PTR_BASE_0]], %[[EXT_MK_RANGE]], %[[PTR_BASE_0]], %[[EXT_MK_RANGE]], %[[PID]] : !tt.ptr<f32>, tensor<1024xi64>, !tt.ptr<f32>, tensor<1024xi64>, i32
94+
// } else {
95+
// CHECK: %[[BASE_PTR_1:.*]] = tt.addptr %arg0, %[[PID_TIME_1024]] : !tt.ptr<f32>, i32
96+
// CHECK: %[[OFST_2:.*]] = arith.muli %[[PID_TIME_1024]], %[[PID_TIME_1024]] : i32
97+
// scf.yield %[[BASE_PTR_1]], %[[CONST0_SPLAT]], %[[BASE_PTR_1]], %[[CONST0_SPLAT]], %[[OFST_2]] : !tt.ptr<f32>, tensor<1024xi64>, !tt.ptr<f32>, tensor<1024xi64>, i32
98+
// }
99+
// CHECK: %[[dont_care_5:.*]] = arith.trunci %[[SCF_IF]]#1 : tensor<1024xi64> to tensor<1024xi32>
100+
// CHECK: %[[dont_care_6:.*]] = tt.splat %[[SCF_IF]]#0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
101+
// CHECK: %[[dont_care_7:.*]] = tt.addptr %[[dont_care_6]], %[[dont_care_5]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
102+
// CHECK: %[[dont_care_8:.*]] = tt.load %[[dont_care_7]] : tensor<1024x!tt.ptr<f32>>
103+
// CHECK: %[[dont_care_9:.*]] = arith.trunci %[[SCF_IF]]#3 : tensor<1024xi64> to tensor<1024xi32>
104+
// CHECK: %[[dont_care_10:.*]] = tt.splat %[[SCF_IF]]#2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
105+
// CHECK: %[[dont_care_11:.*]] = tt.addptr %[[dont_care_10]], %[[dont_care_9]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
106+
// CHECK: %[[dont_care_12:.*]] = tt.load %[[dont_care_11]] : tensor<1024x!tt.ptr<f32>>
107+
// CHECK: tt.return %[[dont_care_8]], %[[dont_care_12]], %[[SCF_IF]]#4 : tensor<1024xf32>, tensor<1024xf32>, i32
123108

124109
// -----
125110

@@ -137,7 +122,8 @@ module attributes {"ttg.num-warps" = 4 : i32} {
137122
scf.yield %8, %0, %8 : tensor<1024x!tt.ptr<f32>>, i32, tensor<1024x!tt.ptr<f32>>
138123
} else {
139124
%8 = tt.addptr %5, %3 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
140-
scf.yield %8, %0, %8 : tensor<1024x!tt.ptr<f32>>, i32, tensor<1024x!tt.ptr<f32>>
125+
%9 = arith.muli %1, %1 : i32
126+
scf.yield %8, %9, %8 : tensor<1024x!tt.ptr<f32>>, i32, tensor<1024x!tt.ptr<f32>>
141127
}
142128
%7 = tt.load %6#0 : tensor<1024x!tt.ptr<f32>>
143129
%8 = tt.load %6#2 : tensor<1024x!tt.ptr<f32>>
@@ -146,37 +132,30 @@ module attributes {"ttg.num-warps" = 4 : i32} {
146132
}
147133

148134
// CHECK-LABEL: tt.func @ifOpTwoYieldsAndNonPtrReordered(
149-
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>,
150-
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1024xf32>,
151-
// CHECK-SAME: %[[VAL_2:.*]]: i1) -> (tensor<1024xf32>, tensor<1024xf32>, i32) {
152-
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : i64
153-
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1024 : i32
154-
// CHECK: %[[VAL_5:.*]] = tt.get_program_id x : i32
155-
// CHECK: %[[VAL_6:.*]] = arith.muli %[[VAL_5]], %[[VAL_4]] : i32
156-
// CHECK: %[[VAL_7:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
157-
// CHECK: %[[VAL_8:.*]] = tt.splat %[[VAL_6]] : i32 -> tensor<1024xi32>
158-
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_7]] : tensor<1024xi32>
159-
// CHECK: %[[VAL_10:.*]] = tt.splat %[[VAL_3]] : i64 -> tensor<1024xi64>
160-
// CHECK: %[[VAL_11:.*]]:5 = scf.if %[[VAL_2]] -> (!tt.ptr<f32>, tensor<1024xi64>, i32, !tt.ptr<f32>, tensor<1024xi64>) {
161-
// CHECK-DAG: %[[VAL_12:.*]] = arith.constant dense<0> : tensor<1024xi32>
162-
// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 0 : i32
163-
// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_6]], %[[VAL_13]] : i32
164-
// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_12]], %[[VAL_7]] : tensor<1024xi32>
165-
// CHECK: %[[VAL_16:.*]] = tt.addptr %[[VAL_0]], %[[VAL_14]] : !tt.ptr<f32>, i32
166-
// CHECK: %[[VAL_17:.*]] = arith.extsi %[[VAL_15]] : tensor<1024xi32> to tensor<1024xi64>
167-
// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_10]] : tensor<1024xi64>
168-
// CHECK: scf.yield %[[VAL_16]], %[[VAL_18]], %[[VAL_5]], %[[VAL_16]], %[[VAL_18]] : !tt.ptr<f32>, tensor<1024xi64>, i32, !tt.ptr<f32>, tensor<1024xi64>
169-
// CHECK: } else {
170-
// CHECK: %[[VAL_19:.*]] = tt.addptr %[[VAL_0]], %[[VAL_6]] : !tt.ptr<f32>, i32
171-
// CHECK: scf.yield %[[VAL_19]], %[[VAL_10]], %[[VAL_5]], %[[VAL_19]], %[[VAL_10]] : !tt.ptr<f32>, tensor<1024xi64>, i32, !tt.ptr<f32>, tensor<1024xi64>
172-
// CHECK: }
173-
// CHECK: %[[VAL_20:.*]] = arith.trunci %[[VAL_21:.*]]#1 : tensor<1024xi64> to tensor<1024xi32>
174-
// CHECK: %[[VAL_22:.*]] = tt.splat %[[VAL_21]]#0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
175-
// CHECK: %[[VAL_23:.*]] = tt.addptr %[[VAL_22]], %[[VAL_20]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
176-
// CHECK: %[[VAL_24:.*]] = tt.load %[[VAL_23]] : tensor<1024x!tt.ptr<f32>>
177-
// CHECK: %[[VAL_25:.*]] = arith.trunci %[[VAL_21]]#4 : tensor<1024xi64> to tensor<1024xi32>
178-
// CHECK: %[[VAL_26:.*]] = tt.splat %[[VAL_21]]#3 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
179-
// CHECK: %[[VAL_27:.*]] = tt.addptr %[[VAL_26]], %[[VAL_25]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
180-
// CHECK: %[[VAL_28:.*]] = tt.load %[[VAL_27]] : tensor<1024x!tt.ptr<f32>>
181-
// CHECK: tt.return %[[VAL_24]], %[[VAL_28]], %[[VAL_21]]#2 : tensor<1024xf32>, tensor<1024xf32>, i32
182-
// CHECK: }
135+
// CHECK-SAME: %arg0: !tt.ptr<f32>,
136+
// CHECK-SAME: %arg1: tensor<1024xf32>,
137+
// CHECK-SAME: %arg2: i1) -> (tensor<1024xf32>, tensor<1024xf32>, i32) {
138+
// CHECK: %[[C0:.*]] = arith.constant 0 : i64
139+
// CHECK: %[[C1024:.*]] = arith.constant 1024 : i32
140+
// CHECK: %[[PID:.*]] = tt.get_program_id x : i32
141+
// CHECK: %[[PID_TIME_1024:.*]] = arith.muli %[[PID]], %[[C1024]] : i32
142+
// CHECK: %[[MK_RANGE_1024:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
143+
// CHECK: %[[C0_SPLAT:.*]] = tt.splat %[[C0]] : i64 -> tensor<1024xi64>
144+
// CHECK: %[[SCF_IF:.*]]:5 = scf.if %arg2 -> (!tt.ptr<f32>, tensor<1024xi64>, i32, !tt.ptr<f32>, tensor<1024xi64>) {
145+
// CHECK: %[[PTR_BASE_1:.*]] = tt.addptr %arg0, %[[PID_TIME_1024]] : !tt.ptr<f32>, i32
146+
// CHECK: %[[EXT_MK_RANGE:.*]] = arith.extsi %[[MK_RANGE_1024]] : tensor<1024xi32> to tensor<1024xi64>
147+
// CHECK: scf.yield %[[PTR_BASE_1]], %[[EXT_MK_RANGE]], %[[PID]], %[[PTR_BASE_1]], %[[EXT_MK_RANGE]] : !tt.ptr<f32>, tensor<1024xi64>, i32, !tt.ptr<f32>, tensor<1024xi64>
148+
// } else {
149+
// CHECK: %[[PTR_BASE_2:.*]] = tt.addptr %arg0, %[[PID_TIME_1024]] : !tt.ptr<f32>, i32
150+
// CHECK: %[[EXT_MK_RANGE:.*]] = arith.muli %[[PID_TIME_1024]], %[[PID_TIME_1024]] : i32
151+
// CHECK: scf.yield %[[PTR_BASE_2]], %[[C0_SPLAT]], %[[EXT_MK_RANGE]], %[[PTR_BASE_2]], %[[C0_SPLAT]] : !tt.ptr<f32>, tensor<1024xi64>, i32, !tt.ptr<f32>, tensor<1024xi64>
152+
// }
153+
// CHECK: %[[dont_care_5:.*]] = arith.trunci %[[SCF_IF]]#1 : tensor<1024xi64> to tensor<1024xi32>
154+
// CHECK: %[[dont_care_6:.*]] = tt.splat %[[SCF_IF]]#0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
155+
// CHECK: %[[dont_care_7:.*]] = tt.addptr %[[dont_care_6]], %[[dont_care_5]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
156+
// CHECK: %[[dont_care_8:.*]] = tt.load %[[dont_care_7]] : tensor<1024x!tt.ptr<f32>>
157+
// CHECK: %[[dont_care_9:.*]] = arith.trunci %[[SCF_IF]]#4 : tensor<1024xi64> to tensor<1024xi32>
158+
// CHECK: %[[dont_care_10:.*]] = tt.splat %[[SCF_IF]]#3 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
159+
// CHECK: %[[dont_care_11:.*]] = tt.addptr %[[dont_care_10]], %[[dont_care_9]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
160+
// CHECK: %[[dont_care_12:.*]] = tt.load %[[dont_care_11]] : tensor<1024x!tt.ptr<f32>>
161+
// CHECK: tt.return %[[dont_care_8]], %[[dont_care_12]], %[[SCF_IF]]#2 : tensor<1024xf32>, tensor<1024xf32>, i32

0 commit comments

Comments
 (0)