@@ -15,7 +15,9 @@ def constructAndPrintInModule(f):
1515 module = Module .create ()
1616 with InsertionPoint (module .body ):
1717 f ()
18+
1819 print (module )
20+ module .operation .verify ()
1921 return f
2022
2123
@@ -89,3 +91,133 @@ def my_inline_ptx(a, b, c, d):
8991 arith .addf (a , b )
9092 arith .addi (c , d )
9193 arith .addf (wo0 , wo1 )
94+
95+
96+ @constructAndPrintInModule
97+ def test_barriers ():
98+ i32 = T .i32 ()
99+ f32 = T .f32 ()
100+
101+ @func .FuncOp .from_py_func (i32 , i32 , f32 )
102+ def barriers (mask , vi32 , vf32 ):
103+ c0 = arith .constant (T .i32 (), 0 )
104+ cffff = arith .constant (T .i32 (), 0xFFFF )
105+ res = nvvm .barrier (
106+ res = i32 ,
107+ barrier_id = c0 ,
108+ number_of_threads = cffff ,
109+ )
110+
111+ for reduction in (
112+ nvvm .BarrierReduction .AND ,
113+ nvvm .BarrierReduction .OR ,
114+ nvvm .BarrierReduction .POPC ,
115+ ):
116+ res = nvvm .barrier (
117+ res = i32 ,
118+ reduction_op = reduction ,
119+ reduction_predicate = res ,
120+ )
121+
122+ nvvm .barrier0 ()
123+ nvvm .bar_warp_sync (mask )
124+ nvvm .cluster_arrive ()
125+ nvvm .cluster_arrive (aligned = True )
126+ nvvm .cluster_arrive_relaxed ()
127+ nvvm .cluster_arrive_relaxed (aligned = True )
128+ nvvm .cluster_wait ()
129+ nvvm .cluster_wait (aligned = True )
130+ nvvm .fence_mbarrier_init ()
131+ nvvm .bar_warp_sync (mask )
132+ return res
133+
134+
135+ # CHECK-LABEL: func.func @barriers(
136+ # CHECK: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: f32) -> i32 {
137+ # CHECK: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
138+ # CHECK: %[[CONSTANT_1:.*]] = arith.constant 65535 : i32
139+ # CHECK: %[[BARRIER_0:.*]] = nvvm.barrier id = %[[CONSTANT_0]] number_of_threads = %[[CONSTANT_1]] -> i32
140+ # CHECK: %[[BARRIER_1:.*]] = nvvm.barrier #nvvm.reduction<and> %[[BARRIER_0]] -> i32
141+ # CHECK: %[[BARRIER_2:.*]] = nvvm.barrier #nvvm.reduction<or> %[[BARRIER_1]] -> i32
142+ # CHECK: %[[BARRIER_3:.*]] = nvvm.barrier #nvvm.reduction<popc> %[[BARRIER_2]] -> i32
143+ # CHECK: nvvm.barrier0
144+ # CHECK: nvvm.bar.warp.sync %[[ARG0]] : i32
145+ # CHECK: nvvm.cluster.arrive
146+ # CHECK: nvvm.cluster.arrive {aligned}
147+ # CHECK: nvvm.cluster.arrive.relaxed
148+ # CHECK: nvvm.cluster.arrive.relaxed {aligned}
149+ # CHECK: nvvm.cluster.wait
150+ # CHECK: nvvm.cluster.wait {aligned}
151+ # CHECK: nvvm.fence.mbarrier.init
152+ # CHECK: nvvm.bar.warp.sync %[[ARG0]] : i32
153+ # CHECK: return %[[BARRIER_3]] : i32
154+ # CHECK: }
155+
156+
157+ @constructAndPrintInModule
158+ def test_reductions ():
159+ i32 = T .i32 ()
160+ f32 = T .f32 ()
161+
162+ @func .FuncOp .from_py_func (i32 , i32 , f32 )
163+ def reductions (mask , vi32 , vf32 ):
164+ for abs in (True , False ):
165+ for nan in (True , False ):
166+ for kind in (
167+ nvvm .ReduxKind .AND ,
168+ nvvm .ReduxKind .MAX ,
169+ nvvm .ReduxKind .MIN ,
170+ nvvm .ReduxKind .OR ,
171+ nvvm .ReduxKind .UMAX ,
172+ nvvm .ReduxKind .UMIN ,
173+ nvvm .ReduxKind .XOR ,
174+ ):
175+ nvvm .redux_sync (i32 , vi32 , kind , vi32 )
176+
177+ for kind in (
178+ nvvm .ReduxKind .FMIN ,
179+ nvvm .ReduxKind .FMAX ,
180+ ):
181+ nvvm .redux_sync (f32 , vf32 , kind , vi32 , abs = abs , nan = nan )
182+
183+
184+ # CHECK-LABEL: func.func @reductions(
185+ # CHECK: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: f32) {
186+ # CHECK: %[[REDUX_0:.*]] = nvvm.redux.sync and %[[ARG1]], %[[ARG1]] : i32 -> i32
187+ # CHECK: %[[REDUX_1:.*]] = nvvm.redux.sync max %[[ARG1]], %[[ARG1]] : i32 -> i32
188+ # CHECK: %[[REDUX_2:.*]] = nvvm.redux.sync min %[[ARG1]], %[[ARG1]] : i32 -> i32
189+ # CHECK: %[[REDUX_3:.*]] = nvvm.redux.sync or %[[ARG1]], %[[ARG1]] : i32 -> i32
190+ # CHECK: %[[REDUX_4:.*]] = nvvm.redux.sync umax %[[ARG1]], %[[ARG1]] : i32 -> i32
191+ # CHECK: %[[REDUX_5:.*]] = nvvm.redux.sync umin %[[ARG1]], %[[ARG1]] : i32 -> i32
192+ # CHECK: %[[REDUX_6:.*]] = nvvm.redux.sync xor %[[ARG1]], %[[ARG1]] : i32 -> i32
193+ # CHECK: %[[REDUX_7:.*]] = nvvm.redux.sync fmin %[[ARG2]], %[[ARG1]] {abs = true, nan = true} : f32 -> f32
194+ # CHECK: %[[REDUX_8:.*]] = nvvm.redux.sync fmax %[[ARG2]], %[[ARG1]] {abs = true, nan = true} : f32 -> f32
195+ # CHECK: %[[REDUX_9:.*]] = nvvm.redux.sync and %[[ARG1]], %[[ARG1]] : i32 -> i32
196+ # CHECK: %[[REDUX_10:.*]] = nvvm.redux.sync max %[[ARG1]], %[[ARG1]] : i32 -> i32
197+ # CHECK: %[[REDUX_11:.*]] = nvvm.redux.sync min %[[ARG1]], %[[ARG1]] : i32 -> i32
198+ # CHECK: %[[REDUX_12:.*]] = nvvm.redux.sync or %[[ARG1]], %[[ARG1]] : i32 -> i32
199+ # CHECK: %[[REDUX_13:.*]] = nvvm.redux.sync umax %[[ARG1]], %[[ARG1]] : i32 -> i32
200+ # CHECK: %[[REDUX_14:.*]] = nvvm.redux.sync umin %[[ARG1]], %[[ARG1]] : i32 -> i32
201+ # CHECK: %[[REDUX_15:.*]] = nvvm.redux.sync xor %[[ARG1]], %[[ARG1]] : i32 -> i32
202+ # CHECK: %[[REDUX_16:.*]] = nvvm.redux.sync fmin %[[ARG2]], %[[ARG1]] {abs = true} : f32 -> f32
203+ # CHECK: %[[REDUX_17:.*]] = nvvm.redux.sync fmax %[[ARG2]], %[[ARG1]] {abs = true} : f32 -> f32
204+ # CHECK: %[[REDUX_18:.*]] = nvvm.redux.sync and %[[ARG1]], %[[ARG1]] : i32 -> i32
205+ # CHECK: %[[REDUX_19:.*]] = nvvm.redux.sync max %[[ARG1]], %[[ARG1]] : i32 -> i32
206+ # CHECK: %[[REDUX_20:.*]] = nvvm.redux.sync min %[[ARG1]], %[[ARG1]] : i32 -> i32
207+ # CHECK: %[[REDUX_21:.*]] = nvvm.redux.sync or %[[ARG1]], %[[ARG1]] : i32 -> i32
208+ # CHECK: %[[REDUX_22:.*]] = nvvm.redux.sync umax %[[ARG1]], %[[ARG1]] : i32 -> i32
209+ # CHECK: %[[REDUX_23:.*]] = nvvm.redux.sync umin %[[ARG1]], %[[ARG1]] : i32 -> i32
210+ # CHECK: %[[REDUX_24:.*]] = nvvm.redux.sync xor %[[ARG1]], %[[ARG1]] : i32 -> i32
211+ # CHECK: %[[REDUX_25:.*]] = nvvm.redux.sync fmin %[[ARG2]], %[[ARG1]] {nan = true} : f32 -> f32
212+ # CHECK: %[[REDUX_26:.*]] = nvvm.redux.sync fmax %[[ARG2]], %[[ARG1]] {nan = true} : f32 -> f32
213+ # CHECK: %[[REDUX_27:.*]] = nvvm.redux.sync and %[[ARG1]], %[[ARG1]] : i32 -> i32
214+ # CHECK: %[[REDUX_28:.*]] = nvvm.redux.sync max %[[ARG1]], %[[ARG1]] : i32 -> i32
215+ # CHECK: %[[REDUX_29:.*]] = nvvm.redux.sync min %[[ARG1]], %[[ARG1]] : i32 -> i32
216+ # CHECK: %[[REDUX_30:.*]] = nvvm.redux.sync or %[[ARG1]], %[[ARG1]] : i32 -> i32
217+ # CHECK: %[[REDUX_31:.*]] = nvvm.redux.sync umax %[[ARG1]], %[[ARG1]] : i32 -> i32
218+ # CHECK: %[[REDUX_32:.*]] = nvvm.redux.sync umin %[[ARG1]], %[[ARG1]] : i32 -> i32
219+ # CHECK: %[[REDUX_33:.*]] = nvvm.redux.sync xor %[[ARG1]], %[[ARG1]] : i32 -> i32
220+ # CHECK: %[[REDUX_34:.*]] = nvvm.redux.sync fmin %[[ARG2]], %[[ARG1]] : f32 -> f32
221+ # CHECK: %[[REDUX_35:.*]] = nvvm.redux.sync fmax %[[ARG2]], %[[ARG1]] : f32 -> f32
222+ # CHECK: return
223+ # CHECK: }
0 commit comments