@@ -24,41 +24,33 @@ module attributes {"ttg.num-warps" = 4 : i32} {
2424 }
2525}
2626
27- // CHECK-LABEL: tt.func @ifOpTwoYields(
28- // CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>,
29- // CHECK-SAME: %[[VAL_1:.*]]: tensor<1024xf32>,
30- // CHECK-SAME: %[[VAL_2:.*]]: i1) -> (tensor<1024xf32>, tensor<1024xf32>) {
31- // CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : i64
32- // CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1024 : i32
33- // CHECK: %[[VAL_5:.*]] = tt.get_program_id x : i32
34- // CHECK: %[[VAL_6:.*]] = arith.muli %[[VAL_5]], %[[VAL_4]] : i32
35- // CHECK: %[[VAL_7:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
36- // CHECK: %[[VAL_8:.*]] = tt.splat %[[VAL_6]] : i32 -> tensor<1024xi32>
37- // CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_7]] : tensor<1024xi32>
38- // CHECK: %[[VAL_10:.*]] = tt.splat %[[VAL_3]] : i64 -> tensor<1024xi64>
39- // CHECK: %[[VAL_11:.*]]:4 = scf.if %[[VAL_2]] -> (!tt.ptr<f32>, tensor<1024xi64>, !tt.ptr<f32>, tensor<1024xi64>) {
40- // CHECK-DAG: %[[VAL_12:.*]] = arith.constant dense<0> : tensor<1024xi32>
41- // CHECK-DAG: %[[VAL_13:.*]] = arith.constant 0 : i32
42- // CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_6]], %[[VAL_13]] : i32
43- // CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_12]], %[[VAL_7]] : tensor<1024xi32>
44- // CHECK: %[[VAL_16:.*]] = tt.addptr %[[VAL_0]], %[[VAL_14]] : !tt.ptr<f32>, i32
45- // CHECK: %[[VAL_17:.*]] = arith.extsi %[[VAL_15]] : tensor<1024xi32> to tensor<1024xi64>
46- // CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_10]] : tensor<1024xi64>
47- // CHECK: scf.yield %[[VAL_16]], %[[VAL_18]], %[[VAL_16]], %[[VAL_18]] : !tt.ptr<f32>, tensor<1024xi64>, !tt.ptr<f32>, tensor<1024xi64>
48- // CHECK: } else {
49- // CHECK: %[[VAL_19:.*]] = tt.addptr %[[VAL_0]], %[[VAL_6]] : !tt.ptr<f32>, i32
50- // CHECK: scf.yield %[[VAL_19]], %[[VAL_10]], %[[VAL_19]], %[[VAL_10]] : !tt.ptr<f32>, tensor<1024xi64>, !tt.ptr<f32>, tensor<1024xi64>
51- // CHECK: }
52- // CHECK: %[[VAL_20:.*]] = arith.trunci %[[VAL_21:.*]]#1 : tensor<1024xi64> to tensor<1024xi32>
53- // CHECK: %[[VAL_22:.*]] = tt.splat %[[VAL_21]]#0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
54- // CHECK: %[[VAL_23:.*]] = tt.addptr %[[VAL_22]], %[[VAL_20]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
55- // CHECK: %[[VAL_24:.*]] = tt.load %[[VAL_23]] : tensor<1024x!tt.ptr<f32>>
56- // CHECK: %[[VAL_25:.*]] = arith.trunci %[[VAL_21]]#3 : tensor<1024xi64> to tensor<1024xi32>
57- // CHECK: %[[VAL_26:.*]] = tt.splat %[[VAL_21]]#2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
58- // CHECK: %[[VAL_27:.*]] = tt.addptr %[[VAL_26]], %[[VAL_25]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
59- // CHECK: %[[VAL_28:.*]] = tt.load %[[VAL_27]] : tensor<1024x!tt.ptr<f32>>
60- // CHECK: tt.return %[[VAL_24]], %[[VAL_28]] : tensor<1024xf32>, tensor<1024xf32>
61- // CHECK: }
27+ // CHECK-LABEL: tt.func @ifOpTwoYields(
28+ // CHECK-SAME: %arg0: !tt.ptr<f32>,
29+ // CHECK-SAME: %arg1: tensor<1024xf32>,
30+ // CHECK-SAME: %arg2: i1) -> (tensor<1024xf32>, tensor<1024xf32>) {
31+ // CHECK: %[[const0:.*]] = arith.constant 0 : i64
32+ // CHECK: %[[C1024:.*]] = arith.constant 1024 : i32
33+ // CHECK: %[[PID:.*]] = tt.get_program_id x : i32
34+ // CHECK: %[[PID_time_1024:.*]] = arith.muli %[[PID]], %[[C1024]] : i32
35+ // CHECK: %[[MAKE_RANGE_1024:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
36+ // CHECK: %[[CONST_ZERO_SPLAT:.*]] = tt.splat %[[const0]] : i64 -> tensor<1024xi64>
37+ // CHECK: %[[SCF:.*]]:4 = scf.if %arg2 -> (!tt.ptr<f32>, tensor<1024xi64>, !tt.ptr<f32>, tensor<1024xi64>) {
38+ // CHECK: %[[ADDPTR1:.*]] = tt.addptr %arg0, %[[PID_time_1024]] : !tt.ptr<f32>, i32
39+ // CHECK: %[[EXT_RANGE:.*]] = arith.extsi %[[MAKE_RANGE_1024]] : tensor<1024xi32> to tensor<1024xi64>
40+ // CHECK: scf.yield %[[ADDPTR1]], %[[EXT_RANGE]], %[[ADDPTR1]], %[[EXT_RANGE]] : !tt.ptr<f32>, tensor<1024xi64>, !tt.ptr<f32>, tensor<1024xi64>
41+ // } else {
42+ // CHECK: %[[ADDPTR2:.*]] = tt.addptr %arg0, %[[PID_time_1024]] : !tt.ptr<f32>, i32
43+ // CHECK: scf.yield %[[ADDPTR2]], %[[CONST_ZERO_SPLAT]], %[[ADDPTR2]], %[[CONST_ZERO_SPLAT]] : !tt.ptr<f32>, tensor<1024xi64>, !tt.ptr<f32>, tensor<1024xi64>
44+ // }
45+ // CHECK: %[[dont_care_5:.*]] = arith.trunci %[[SCF]]#1 : tensor<1024xi64> to tensor<1024xi32>
46+ // CHECK: %[[dont_care_6:.*]] = tt.splat %[[SCF]]#0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
47+ // CHECK: %[[dont_care_7:.*]] = tt.addptr %[[dont_care_6]], %[[dont_care_5]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
48+ // CHECK: %[[dont_care_8:.*]] = tt.load %[[dont_care_7]] : tensor<1024x!tt.ptr<f32>>
49+ // CHECK: %[[dont_care_9:.*]] = arith.trunci %[[SCF]]#3 : tensor<1024xi64> to tensor<1024xi32>
50+ // CHECK: %[[dont_care_10:.*]] = tt.splat %[[SCF]]#2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
51+ // CHECK: %[[dont_care_11:.*]] = tt.addptr %[[dont_care_10]], %[[dont_care_9]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
52+ // CHECK: %[[dont_care_12:.*]] = tt.load %[[dont_care_11]] : tensor<1024x!tt.ptr<f32>>
53+ // CHECK: tt.return %[[dont_care_8]], %[[dont_care_12]] : tensor<1024xf32>, tensor<1024xf32>
6254
6355// -----
6456
@@ -76,7 +68,8 @@ module attributes {"ttg.num-warps" = 4 : i32} {
7668 scf.yield %8 , %8 , %0 : tensor <1024 x!tt.ptr <f32 >>, tensor <1024 x!tt.ptr <f32 >>, i32
7769 } else {
7870 %8 = tt.addptr %5 , %3 : tensor <1024 x!tt.ptr <f32 >>, tensor <1024 xi32 >
79- scf.yield %8 , %8 , %0 : tensor <1024 x!tt.ptr <f32 >>, tensor <1024 x!tt.ptr <f32 >>, i32
71+ %9 = arith.muli %1 , %1 : i32
72+ scf.yield %8 , %8 , %9 : tensor <1024 x!tt.ptr <f32 >>, tensor <1024 x!tt.ptr <f32 >>, i32
8073 }
8174 %7 = tt.load %6#0 : tensor <1024 x!tt.ptr <f32 >>
8275 %8 = tt.load %6#1 : tensor <1024 x!tt.ptr <f32 >>
@@ -85,41 +78,33 @@ module attributes {"ttg.num-warps" = 4 : i32} {
8578}
8679
8780// CHECK-LABEL: tt.func @ifOpTwoYieldsAndNonPtr(
88- // CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>,
89- // CHECK-SAME: %[[VAL_1:.*]]: tensor<1024xf32>,
90- // CHECK-SAME: %[[VAL_2:.*]]: i1) -> (tensor<1024xf32>, tensor<1024xf32>, i32) {
91- // CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : i64
92- // CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1024 : i32
93- // CHECK: %[[VAL_5:.*]] = tt.get_program_id x : i32
94- // CHECK: %[[VAL_6:.*]] = arith.muli %[[VAL_5]], %[[VAL_4]] : i32
95- // CHECK: %[[VAL_7:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
96- // CHECK: %[[VAL_8:.*]] = tt.splat %[[VAL_6]] : i32 -> tensor<1024xi32>
97- // CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_7]] : tensor<1024xi32>
98- // CHECK: %[[VAL_10:.*]] = tt.splat %[[VAL_3]] : i64 -> tensor<1024xi64>
99- // CHECK: %[[VAL_11:.*]]:5 = scf.if %[[VAL_2]] -> (!tt.ptr<f32>, tensor<1024xi64>, !tt.ptr<f32>, tensor<1024xi64>, i32) {
100- // CHECK-DAG: %[[VAL_12:.*]] = arith.constant dense<0> : tensor<1024xi32>
101- // CHECK-DAG: %[[VAL_13:.*]] = arith.constant 0 : i32
102- // CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_6]], %[[VAL_13]] : i32
103- // CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_12]], %[[VAL_7]] : tensor<1024xi32>
104- // CHECK: %[[VAL_16:.*]] = tt.addptr %[[VAL_0]], %[[VAL_14]] : !tt.ptr<f32>, i32
105- // CHECK: %[[VAL_17:.*]] = arith.extsi %[[VAL_15]] : tensor<1024xi32> to tensor<1024xi64>
106- // CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_10]] : tensor<1024xi64>
107- // CHECK: scf.yield %[[VAL_16]], %[[VAL_18]], %[[VAL_16]], %[[VAL_18]], %[[VAL_5]] : !tt.ptr<f32>, tensor<1024xi64>, !tt.ptr<f32>, tensor<1024xi64>, i32
108- // CHECK: } else {
109- // CHECK: %[[VAL_19:.*]] = tt.addptr %[[VAL_0]], %[[VAL_6]] : !tt.ptr<f32>, i32
110- // CHECK: scf.yield %[[VAL_19]], %[[VAL_10]], %[[VAL_19]], %[[VAL_10]], %[[VAL_5]] : !tt.ptr<f32>, tensor<1024xi64>, !tt.ptr<f32>, tensor<1024xi64>, i32
111- // CHECK: }
112- // CHECK: %[[VAL_20:.*]] = arith.trunci %[[VAL_21:.*]]#1 : tensor<1024xi64> to tensor<1024xi32>
113- // CHECK: %[[VAL_22:.*]] = tt.splat %[[VAL_21]]#0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
114- // CHECK: %[[VAL_23:.*]] = tt.addptr %[[VAL_22]], %[[VAL_20]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
115- // CHECK: %[[VAL_24:.*]] = tt.load %[[VAL_23]] : tensor<1024x!tt.ptr<f32>>
116- // CHECK: %[[VAL_25:.*]] = arith.trunci %[[VAL_21]]#3 : tensor<1024xi64> to tensor<1024xi32>
117- // CHECK: %[[VAL_26:.*]] = tt.splat %[[VAL_21]]#2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
118- // CHECK: %[[VAL_27:.*]] = tt.addptr %[[VAL_26]], %[[VAL_25]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
119- // CHECK: %[[VAL_28:.*]] = tt.load %[[VAL_27]] : tensor<1024x!tt.ptr<f32>>
120- // CHECK: tt.return %[[VAL_24]], %[[VAL_28]], %[[VAL_21]]#4 : tensor<1024xf32>, tensor<1024xf32>, i32
121- // CHECK: }
122-
81+ // CHECK-SAME: %arg0: !tt.ptr<f32>,
82+ // CHECK-SAME: %arg1: tensor<1024xf32>,
83+ // CHECK-SAME: %arg2: i1) -> (tensor<1024xf32>, tensor<1024xf32>, i32) {
84+ // CHECK-DAG: %c0_i64 = arith.constant 0 : i64
85+ // CHECK: %[[C1024:.*]] = arith.constant 1024 : i32
86+ // CHECK: %[[PID:.*]] = tt.get_program_id x : i32
87+ // CHECK: %[[PID_TIME_1024:.*]] = arith.muli %[[PID]], %[[C1024]] : i32
88+ // CHECK: %[[MK_RANGE:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
89+ // CHECK: %[[CONST0_SPLAT:.*]] = tt.splat %c0_i64 : i64 -> tensor<1024xi64>
90+ // CHECK: %[[SCF_IF:.*]]:5 = scf.if %arg2 -> (!tt.ptr<f32>, tensor<1024xi64>, !tt.ptr<f32>, tensor<1024xi64>, i32) {
91+ // CHECK: %[[PTR_BASE_0:.*]] = tt.addptr %arg0, %[[PID_TIME_1024]] : !tt.ptr<f32>, i32
92+ // CHECK: %[[EXT_MK_RANGE:.*]] = arith.extsi %[[MK_RANGE]] : tensor<1024xi32> to tensor<1024xi64>
93+ // CHECK: scf.yield %[[PTR_BASE_0]], %[[EXT_MK_RANGE]], %[[PTR_BASE_0]], %[[EXT_MK_RANGE]], %[[PID]] : !tt.ptr<f32>, tensor<1024xi64>, !tt.ptr<f32>, tensor<1024xi64>, i32
94+ // } else {
95+ // CHECK: %[[BASE_PTR_1:.*]] = tt.addptr %arg0, %[[PID_TIME_1024]] : !tt.ptr<f32>, i32
96+ // CHECK: %[[OFST_2:.*]] = arith.muli %[[PID_TIME_1024]], %[[PID_TIME_1024]] : i32
97+ // scf.yield %[[BASE_PTR_1]], %[[CONST0_SPLAT]], %[[BASE_PTR_1]], %[[CONST0_SPLAT]], %[[OFST_2]] : !tt.ptr<f32>, tensor<1024xi64>, !tt.ptr<f32>, tensor<1024xi64>, i32
98+ // }
99+ // CHECK: %[[dont_care_5:.*]] = arith.trunci %[[SCF_IF]]#1 : tensor<1024xi64> to tensor<1024xi32>
100+ // CHECK: %[[dont_care_6:.*]] = tt.splat %[[SCF_IF]]#0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
101+ // CHECK: %[[dont_care_7:.*]] = tt.addptr %[[dont_care_6]], %[[dont_care_5]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
102+ // CHECK: %[[dont_care_8:.*]] = tt.load %[[dont_care_7]] : tensor<1024x!tt.ptr<f32>>
103+ // CHECK: %[[dont_care_9:.*]] = arith.trunci %[[SCF_IF]]#3 : tensor<1024xi64> to tensor<1024xi32>
104+ // CHECK: %[[dont_care_10:.*]] = tt.splat %[[SCF_IF]]#2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
105+ // CHECK: %[[dont_care_11:.*]] = tt.addptr %[[dont_care_10]], %[[dont_care_9]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
106+ // CHECK: %[[dont_care_12:.*]] = tt.load %[[dont_care_11]] : tensor<1024x!tt.ptr<f32>>
107+ // CHECK: tt.return %[[dont_care_8]], %[[dont_care_12]], %[[SCF_IF]]#4 : tensor<1024xf32>, tensor<1024xf32>, i32
123108
124109// -----
125110
@@ -137,7 +122,8 @@ module attributes {"ttg.num-warps" = 4 : i32} {
137122 scf.yield %8 , %0 , %8 : tensor <1024 x!tt.ptr <f32 >>, i32 , tensor <1024 x!tt.ptr <f32 >>
138123 } else {
139124 %8 = tt.addptr %5 , %3 : tensor <1024 x!tt.ptr <f32 >>, tensor <1024 xi32 >
140- scf.yield %8 , %0 , %8 : tensor <1024 x!tt.ptr <f32 >>, i32 , tensor <1024 x!tt.ptr <f32 >>
125+ %9 = arith.muli %1 , %1 : i32
126+ scf.yield %8 , %9 , %8 : tensor <1024 x!tt.ptr <f32 >>, i32 , tensor <1024 x!tt.ptr <f32 >>
141127 }
142128 %7 = tt.load %6#0 : tensor <1024 x!tt.ptr <f32 >>
143129 %8 = tt.load %6#2 : tensor <1024 x!tt.ptr <f32 >>
@@ -146,37 +132,30 @@ module attributes {"ttg.num-warps" = 4 : i32} {
146132}
147133
148134// CHECK-LABEL: tt.func @ifOpTwoYieldsAndNonPtrReordered(
149- // CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>,
150- // CHECK-SAME: %[[VAL_1:.*]]: tensor<1024xf32>,
151- // CHECK-SAME: %[[VAL_2:.*]]: i1) -> (tensor<1024xf32>, tensor<1024xf32>, i32) {
152- // CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : i64
153- // CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1024 : i32
154- // CHECK: %[[VAL_5:.*]] = tt.get_program_id x : i32
155- // CHECK: %[[VAL_6:.*]] = arith.muli %[[VAL_5]], %[[VAL_4]] : i32
156- // CHECK: %[[VAL_7:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
157- // CHECK: %[[VAL_8:.*]] = tt.splat %[[VAL_6]] : i32 -> tensor<1024xi32>
158- // CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_7]] : tensor<1024xi32>
159- // CHECK: %[[VAL_10:.*]] = tt.splat %[[VAL_3]] : i64 -> tensor<1024xi64>
160- // CHECK: %[[VAL_11:.*]]:5 = scf.if %[[VAL_2]] -> (!tt.ptr<f32>, tensor<1024xi64>, i32, !tt.ptr<f32>, tensor<1024xi64>) {
161- // CHECK-DAG: %[[VAL_12:.*]] = arith.constant dense<0> : tensor<1024xi32>
162- // CHECK-DAG: %[[VAL_13:.*]] = arith.constant 0 : i32
163- // CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_6]], %[[VAL_13]] : i32
164- // CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_12]], %[[VAL_7]] : tensor<1024xi32>
165- // CHECK: %[[VAL_16:.*]] = tt.addptr %[[VAL_0]], %[[VAL_14]] : !tt.ptr<f32>, i32
166- // CHECK: %[[VAL_17:.*]] = arith.extsi %[[VAL_15]] : tensor<1024xi32> to tensor<1024xi64>
167- // CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_10]] : tensor<1024xi64>
168- // CHECK: scf.yield %[[VAL_16]], %[[VAL_18]], %[[VAL_5]], %[[VAL_16]], %[[VAL_18]] : !tt.ptr<f32>, tensor<1024xi64>, i32, !tt.ptr<f32>, tensor<1024xi64>
169- // CHECK: } else {
170- // CHECK: %[[VAL_19:.*]] = tt.addptr %[[VAL_0]], %[[VAL_6]] : !tt.ptr<f32>, i32
171- // CHECK: scf.yield %[[VAL_19]], %[[VAL_10]], %[[VAL_5]], %[[VAL_19]], %[[VAL_10]] : !tt.ptr<f32>, tensor<1024xi64>, i32, !tt.ptr<f32>, tensor<1024xi64>
172- // CHECK: }
173- // CHECK: %[[VAL_20:.*]] = arith.trunci %[[VAL_21:.*]]#1 : tensor<1024xi64> to tensor<1024xi32>
174- // CHECK: %[[VAL_22:.*]] = tt.splat %[[VAL_21]]#0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
175- // CHECK: %[[VAL_23:.*]] = tt.addptr %[[VAL_22]], %[[VAL_20]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
176- // CHECK: %[[VAL_24:.*]] = tt.load %[[VAL_23]] : tensor<1024x!tt.ptr<f32>>
177- // CHECK: %[[VAL_25:.*]] = arith.trunci %[[VAL_21]]#4 : tensor<1024xi64> to tensor<1024xi32>
178- // CHECK: %[[VAL_26:.*]] = tt.splat %[[VAL_21]]#3 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
179- // CHECK: %[[VAL_27:.*]] = tt.addptr %[[VAL_26]], %[[VAL_25]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
180- // CHECK: %[[VAL_28:.*]] = tt.load %[[VAL_27]] : tensor<1024x!tt.ptr<f32>>
181- // CHECK: tt.return %[[VAL_24]], %[[VAL_28]], %[[VAL_21]]#2 : tensor<1024xf32>, tensor<1024xf32>, i32
182- // CHECK: }
135+ // CHECK-SAME: %arg0: !tt.ptr<f32>,
136+ // CHECK-SAME: %arg1: tensor<1024xf32>,
137+ // CHECK-SAME: %arg2: i1) -> (tensor<1024xf32>, tensor<1024xf32>, i32) {
138+ // CHECK: %[[C0:.*]] = arith.constant 0 : i64
139+ // CHECK: %[[C1024:.*]] = arith.constant 1024 : i32
140+ // CHECK: %[[PID:.*]] = tt.get_program_id x : i32
141+ // CHECK: %[[PID_TIME_1024:.*]] = arith.muli %[[PID]], %[[C1024]] : i32
142+ // CHECK: %[[MK_RANGE_1024:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
143+ // CHECK: %[[C0_SPLAT:.*]] = tt.splat %[[C0]] : i64 -> tensor<1024xi64>
144+ // CHECK: %[[SCF_IF:.*]]:5 = scf.if %arg2 -> (!tt.ptr<f32>, tensor<1024xi64>, i32, !tt.ptr<f32>, tensor<1024xi64>) {
145+ // CHECK: %[[PTR_BASE_1:.*]] = tt.addptr %arg0, %[[PID_TIME_1024]] : !tt.ptr<f32>, i32
146+ // CHECK: %[[EXT_MK_RANGE:.*]] = arith.extsi %[[MK_RANGE_1024]] : tensor<1024xi32> to tensor<1024xi64>
147+ // CHECK: scf.yield %[[PTR_BASE_1]], %[[EXT_MK_RANGE]], %[[PID]], %[[PTR_BASE_1]], %[[EXT_MK_RANGE]] : !tt.ptr<f32>, tensor<1024xi64>, i32, !tt.ptr<f32>, tensor<1024xi64>
148+ // } else {
149+ // CHECK: %[[PTR_BASE_2:.*]] = tt.addptr %arg0, %[[PID_TIME_1024]] : !tt.ptr<f32>, i32
150+ // CHECK: %[[EXT_MK_RANGE:.*]] = arith.muli %[[PID_TIME_1024]], %[[PID_TIME_1024]] : i32
151+ // CHECK: scf.yield %[[PTR_BASE_2]], %[[C0_SPLAT]], %[[EXT_MK_RANGE]], %[[PTR_BASE_2]], %[[C0_SPLAT]] : !tt.ptr<f32>, tensor<1024xi64>, i32, !tt.ptr<f32>, tensor<1024xi64>
152+ // }
153+ // CHECK: %[[dont_care_5:.*]] = arith.trunci %[[SCF_IF]]#1 : tensor<1024xi64> to tensor<1024xi32>
154+ // CHECK: %[[dont_care_6:.*]] = tt.splat %[[SCF_IF]]#0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
155+ // CHECK: %[[dont_care_7:.*]] = tt.addptr %[[dont_care_6]], %[[dont_care_5]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
156+ // CHECK: %[[dont_care_8:.*]] = tt.load %[[dont_care_7]] : tensor<1024x!tt.ptr<f32>>
157+ // CHECK: %[[dont_care_9:.*]] = arith.trunci %[[SCF_IF]]#4 : tensor<1024xi64> to tensor<1024xi32>
158+ // CHECK: %[[dont_care_10:.*]] = tt.splat %[[SCF_IF]]#3 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
159+ // CHECK: %[[dont_care_11:.*]] = tt.addptr %[[dont_care_10]], %[[dont_care_9]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
160+ // CHECK: %[[dont_care_12:.*]] = tt.load %[[dont_care_11]] : tensor<1024x!tt.ptr<f32>>
161+ // CHECK: tt.return %[[dont_care_8]], %[[dont_care_12]], %[[SCF_IF]]#2 : tensor<1024xf32>, tensor<1024xf32>, i32
0 commit comments