Skip to content

Commit 4370bba

Browse files
authored
Creating an affine if with results needs an else (#300)
1 parent d2e4c05 commit 4370bba

File tree

2 files changed

+215
-2
lines changed

2 files changed

+215
-2
lines changed

lib/polygeist/Ops.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3551,8 +3551,9 @@ struct AffineIfSimplification : public OpRewritePattern<AffineIfOp> {
35513551

35523552
auto newIf =
35533553
rewriter.create<AffineIfOp>(op.getLoc(), op.getResultTypes(), iset,
3554-
op.getOperands(), /*hasElse*/ false);
3554+
op.getOperands(), /*hasElse*/ true);
35553555
rewriter.eraseBlock(newIf.getThenBlock());
3556+
rewriter.eraseBlock(newIf.getElseBlock());
35563557
rewriter.inlineRegionBefore(op.getThenRegion(), newIf.getThenRegion(),
35573558
newIf.getThenRegion().begin());
35583559
rewriter.inlineRegionBefore(op.getElseRegion(), newIf.getElseRegion(),
@@ -3627,8 +3628,9 @@ struct CombineAffineIfs : public OpRewritePattern<AffineIfOp> {
36273628

36283629
AffineIfOp combinedIf = rewriter.create<AffineIfOp>(
36293630
nextIf.getLoc(), mergedTypes, prevIf.getIntegerSet(),
3630-
prevIf.getOperands(), /*hasElse=*/false);
3631+
prevIf.getOperands(), /*hasElse=*/true);
36313632
rewriter.eraseBlock(&combinedIf.getThenRegion().back());
3633+
rewriter.eraseBlock(&combinedIf.getElseRegion().back());
36323634

36333635
rewriter.inlineRegionBefore(prevIf.getThenRegion(),
36343636
combinedIf.getThenRegion(),

test/polygeist-opt/affifcombine.mlir

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
// RUN: polygeist-opt --canonicalize %s | FileCheck %s
2+
3+
#set0 = affine_set<(d0, d1) : (d0 + d1 * 512 == 0)>
4+
5+
module {
6+
func.func private @use(index)
7+
func.func @k(%636: index, %603: memref<?xf64>) {
8+
%c512_i32 = arith.constant 512 : i32
9+
affine.parallel (%arg7) = (0) to (symbol(%636)) {
10+
%706 = arith.index_cast %arg7 : index to i32
11+
%707 = arith.muli %706, %c512_i32 : i32
12+
affine.parallel (%arg8) = (0) to (512) {
13+
%708 = arith.index_cast %arg8 : index to i32
14+
%709 = arith.addi %707, %708 : i32
15+
%ifres = affine.if #set0(%arg8, %arg7) -> f64 {
16+
%712 = arith.sitofp %709 : i32 to f64
17+
func.call @use(%arg7) : (index) -> ()
18+
affine.yield %712 : f64
19+
} else {
20+
%712 = arith.sitofp %708 : i32 to f64
21+
func.call @use(%arg8) : (index) -> ()
22+
affine.yield %712 : f64
23+
}
24+
affine.if #set0(%arg8, %arg7) {
25+
func.call @use(%arg7) : (index) -> ()
26+
} else {
27+
func.call @use(%arg8) : (index) -> ()
28+
}
29+
affine.store %ifres, %603[0] : memref<?xf64>
30+
}
31+
}
32+
return
33+
}
34+
// CHECK-LABEL: func.func @k(
35+
// CHECK-SAME: %[[VAL_0:.*]]: index,
36+
// CHECK-SAME: %[[VAL_1:.*]]: memref<?xf64>) {
37+
// CHECK: %[[VAL_2:.*]] = arith.constant 512 : i32
38+
// CHECK: affine.parallel (%[[VAL_3:.*]], %[[VAL_4:.*]]) = (0, 0) to (symbol(%[[VAL_0]]), 512) {
39+
// CHECK: %[[VAL_5:.*]] = arith.index_cast %[[VAL_3]] : index to i32
40+
// CHECK: %[[VAL_6:.*]] = arith.muli %[[VAL_5]], %[[VAL_2]] : i32
41+
// CHECK: %[[VAL_7:.*]] = arith.index_cast %[[VAL_4]] : index to i32
42+
// CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_6]], %[[VAL_7]] : i32
43+
// CHECK: %[[VAL_9:.*]] = affine.if #set(%[[VAL_4]], %[[VAL_3]]) -> f64 {
44+
// CHECK: %[[VAL_10:.*]] = arith.sitofp %[[VAL_8]] : i32 to f64
45+
// CHECK: func.call @use(%[[VAL_3]]) : (index) -> ()
46+
// CHECK: func.call @use(%[[VAL_3]]) : (index) -> ()
47+
// CHECK: affine.yield %[[VAL_10]] : f64
48+
// CHECK: } else {
49+
// CHECK: %[[VAL_11:.*]] = arith.sitofp %[[VAL_7]] : i32 to f64
50+
// CHECK: func.call @use(%[[VAL_4]]) : (index) -> ()
51+
// CHECK: func.call @use(%[[VAL_4]]) : (index) -> ()
52+
// CHECK: affine.yield %[[VAL_11]] : f64
53+
// CHECK: }
54+
// CHECK: affine.store %[[VAL_12:.*]], %[[VAL_1]][0] : memref<?xf64>
55+
// CHECK: }
56+
// CHECK: return
57+
58+
func.func @h(%636: index, %603: memref<?xf64>) {
59+
%c512_i32 = arith.constant 512 : i32
60+
affine.parallel (%arg7) = (0) to (symbol(%636)) {
61+
%706 = arith.index_cast %arg7 : index to i32
62+
%707 = arith.muli %706, %c512_i32 : i32
63+
affine.parallel (%arg8) = (0) to (512) {
64+
%708 = arith.index_cast %arg8 : index to i32
65+
%709 = arith.addi %707, %708 : i32
66+
%ifres = affine.if #set0(%arg8, %arg7) -> f64 {
67+
%712 = arith.sitofp %709 : i32 to f64
68+
func.call @use(%arg7) : (index) -> ()
69+
affine.yield %712 : f64
70+
} else {
71+
%712 = arith.sitofp %708 : i32 to f64
72+
func.call @use(%arg8) : (index) -> ()
73+
affine.yield %712 : f64
74+
}
75+
affine.if #set0(%arg8, %arg7) {
76+
func.call @use(%arg7) : (index) -> ()
77+
}
78+
affine.store %ifres, %603[0] : memref<?xf64>
79+
}
80+
}
81+
return
82+
}
83+
// CHECK-LABEL: func.func @h(
84+
// CHECK-SAME: %[[VAL_0:.*]]: index,
85+
// CHECK-SAME: %[[VAL_1:.*]]: memref<?xf64>) {
86+
// CHECK: %[[VAL_2:.*]] = arith.constant 512 : i32
87+
// CHECK: affine.parallel (%[[VAL_3:.*]], %[[VAL_4:.*]]) = (0, 0) to (symbol(%[[VAL_0]]), 512) {
88+
// CHECK: %[[VAL_5:.*]] = arith.index_cast %[[VAL_3]] : index to i32
89+
// CHECK: %[[VAL_6:.*]] = arith.muli %[[VAL_5]], %[[VAL_2]] : i32
90+
// CHECK: %[[VAL_7:.*]] = arith.index_cast %[[VAL_4]] : index to i32
91+
// CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_6]], %[[VAL_7]] : i32
92+
// CHECK: %[[VAL_9:.*]] = affine.if #set(%[[VAL_4]], %[[VAL_3]]) -> f64 {
93+
// CHECK: %[[VAL_10:.*]] = arith.sitofp %[[VAL_8]] : i32 to f64
94+
// CHECK: func.call @use(%[[VAL_3]]) : (index) -> ()
95+
// CHECK: func.call @use(%[[VAL_3]]) : (index) -> ()
96+
// CHECK: affine.yield %[[VAL_10]] : f64
97+
// CHECK: } else {
98+
// CHECK: %[[VAL_11:.*]] = arith.sitofp %[[VAL_7]] : i32 to f64
99+
// CHECK: func.call @use(%[[VAL_4]]) : (index) -> ()
100+
// CHECK: affine.yield %[[VAL_11]] : f64
101+
// CHECK: }
102+
// CHECK: affine.store %[[VAL_12:.*]], %[[VAL_1]][0] : memref<?xf64>
103+
// CHECK: }
104+
// CHECK: return
105+
106+
func.func @g(%636: index, %603: memref<?xf64>) {
107+
%c512_i32 = arith.constant 512 : i32
108+
109+
affine.parallel (%arg7) = (0) to (symbol(%636)) {
110+
%706 = arith.index_cast %arg7 : index to i32
111+
%707 = arith.muli %706, %c512_i32 : i32
112+
affine.parallel (%arg8) = (0) to (512) {
113+
%708 = arith.index_cast %arg8 : index to i32
114+
%709 = arith.addi %707, %708 : i32
115+
affine.if #set0(%arg8, %arg7) {
116+
func.call @use(%arg7) : (index) -> ()
117+
}
118+
%ifres = affine.if #set0(%arg8, %arg7) -> f64 {
119+
%712 = arith.sitofp %709 : i32 to f64
120+
func.call @use(%arg7) : (index) -> ()
121+
affine.yield %712 : f64
122+
} else {
123+
%712 = arith.sitofp %708 : i32 to f64
124+
func.call @use(%arg8) : (index) -> ()
125+
affine.yield %712 : f64
126+
}
127+
affine.store %ifres, %603[0] : memref<?xf64>
128+
}
129+
}
130+
return
131+
}
132+
// CHECK-LABEL: func.func @g(
133+
// CHECK-SAME: %[[VAL_0:.*]]: index,
134+
// CHECK-SAME: %[[VAL_1:.*]]: memref<?xf64>) {
135+
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
136+
// CHECK: %[[VAL_3:.*]] = arith.constant 512 : i32
137+
// CHECK: affine.if #set1(){{\[}}%[[VAL_0]]] {
138+
// CHECK: func.call @use(%[[VAL_2]]) : (index) -> ()
139+
// CHECK: }
140+
// CHECK: affine.parallel (%[[VAL_4:.*]], %[[VAL_5:.*]]) = (0, 0) to (symbol(%[[VAL_0]]), 512) {
141+
// CHECK: %[[VAL_6:.*]] = arith.index_cast %[[VAL_4]] : index to i32
142+
// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_6]], %[[VAL_3]] : i32
143+
// CHECK: %[[VAL_8:.*]] = arith.index_cast %[[VAL_5]] : index to i32
144+
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_7]], %[[VAL_8]] : i32
145+
// CHECK: %[[VAL_10:.*]] = affine.if #set(%[[VAL_5]], %[[VAL_4]]) -> f64 {
146+
// CHECK: %[[VAL_11:.*]] = arith.sitofp %[[VAL_9]] : i32 to f64
147+
// CHECK: func.call @use(%[[VAL_4]]) : (index) -> ()
148+
// CHECK: affine.yield %[[VAL_11]] : f64
149+
// CHECK: } else {
150+
// CHECK: %[[VAL_12:.*]] = arith.sitofp %[[VAL_8]] : i32 to f64
151+
// CHECK: func.call @use(%[[VAL_5]]) : (index) -> ()
152+
// CHECK: affine.yield %[[VAL_12]] : f64
153+
// CHECK: }
154+
// CHECK: affine.store %[[VAL_13:.*]], %[[VAL_1]][0] : memref<?xf64>
155+
// CHECK: }
156+
// CHECK: return
157+
158+
func.func @f(%636: index, %603: memref<?xf64>) {
159+
%c512_i32 = arith.constant 512 : i32
160+
161+
affine.parallel (%arg7) = (0) to (symbol(%636)) {
162+
%706 = arith.index_cast %arg7 : index to i32
163+
%707 = arith.muli %706, %c512_i32 : i32
164+
affine.parallel (%arg8) = (0) to (512) {
165+
%708 = arith.index_cast %arg8 : index to i32
166+
%709 = arith.addi %707, %708 : i32
167+
affine.if #set0(%arg8, %arg7) {
168+
func.call @use(%arg7) : (index) -> ()
169+
} else {
170+
func.call @use(%arg8) : (index) -> ()
171+
}
172+
%ifres = affine.if #set0(%arg8, %arg7) -> f64 {
173+
%712 = arith.sitofp %709 : i32 to f64
174+
func.call @use(%arg7) : (index) -> ()
175+
affine.yield %712 : f64
176+
} else {
177+
%712 = arith.sitofp %708 : i32 to f64
178+
func.call @use(%arg8) : (index) -> ()
179+
affine.yield %712 : f64
180+
}
181+
affine.store %ifres, %603[0] : memref<?xf64>
182+
}
183+
}
184+
return
185+
}
186+
// CHECK-LABEL: func.func @f(
187+
// CHECK-SAME: %[[VAL_0:.*]]: index,
188+
// CHECK-SAME: %[[VAL_1:.*]]: memref<?xf64>) {
189+
// CHECK: %[[VAL_2:.*]] = arith.constant 512 : i32
190+
// CHECK: affine.parallel (%[[VAL_3:.*]], %[[VAL_4:.*]]) = (0, 0) to (symbol(%[[VAL_0]]), 512) {
191+
// CHECK: %[[VAL_5:.*]] = arith.index_cast %[[VAL_3]] : index to i32
192+
// CHECK: %[[VAL_6:.*]] = arith.muli %[[VAL_5]], %[[VAL_2]] : i32
193+
// CHECK: %[[VAL_7:.*]] = arith.index_cast %[[VAL_4]] : index to i32
194+
// CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_6]], %[[VAL_7]] : i32
195+
// CHECK: %[[VAL_9:.*]] = affine.if #set(%[[VAL_4]], %[[VAL_3]]) -> f64 {
196+
// CHECK: func.call @use(%[[VAL_3]]) : (index) -> ()
197+
// CHECK: %[[VAL_10:.*]] = arith.sitofp %[[VAL_8]] : i32 to f64
198+
// CHECK: func.call @use(%[[VAL_3]]) : (index) -> ()
199+
// CHECK: affine.yield %[[VAL_10]] : f64
200+
// CHECK: } else {
201+
// CHECK: func.call @use(%[[VAL_4]]) : (index) -> ()
202+
// CHECK: %[[VAL_11:.*]] = arith.sitofp %[[VAL_7]] : i32 to f64
203+
// CHECK: func.call @use(%[[VAL_4]]) : (index) -> ()
204+
// CHECK: affine.yield %[[VAL_11]] : f64
205+
// CHECK: }
206+
// CHECK: affine.store %[[VAL_12:.*]], %[[VAL_1]][0] : memref<?xf64>
207+
// CHECK: }
208+
// CHECK: return
209+
210+
211+
}

0 commit comments

Comments
 (0)