@@ -15,7 +15,9 @@ def constructAndPrintInModule(f):
1515 module = Module .create ()
1616 with InsertionPoint (module .body ):
1717 f ()
18+
1819 print (module )
20+ module .operation .verify ()
1921 return f
2022
2123
@@ -89,3 +91,132 @@ def my_inline_ptx(a, b, c, d):
8991 arith .addf (a , b )
9092 arith .addi (c , d )
9193 arith .addf (wo0 , wo1 )
94+
95+ @constructAndPrintInModule
96+ def test_barriers ():
97+ i32 = T .i32 ()
98+ f32 = T .f32 ()
99+
100+ @func .FuncOp .from_py_func (i32 , i32 , f32 )
101+ def barriers (mask , vi32 , vf32 ):
102+ c0 = arith .constant (T .i32 (), 0 )
103+ cffff = arith .constant (T .i32 (), 0xFFFF )
104+ res = nvvm .barrier (
105+ res = i32 ,
106+ barrier_id = c0 ,
107+ number_of_threads = cffff ,
108+ )
109+
110+ for reduction in (
111+ nvvm .BarrierReduction .AND ,
112+ nvvm .BarrierReduction .OR ,
113+ nvvm .BarrierReduction .POPC ,
114+ ):
115+ res = nvvm .barrier (
116+ res = i32 ,
117+ reduction_op = reduction ,
118+ reduction_predicate = res ,
119+ )
120+
121+ nvvm .barrier0 ()
122+ nvvm .bar_warp_sync (mask )
123+ nvvm .cluster_arrive ()
124+ nvvm .cluster_arrive (aligned = True )
125+ nvvm .cluster_arrive_relaxed ()
126+ nvvm .cluster_arrive_relaxed (aligned = True )
127+ nvvm .cluster_wait ()
128+ nvvm .cluster_wait (aligned = True )
129+ nvvm .fence_mbarrier_init ()
130+ nvvm .bar_warp_sync (mask )
131+ return res
132+
133+
134+ # CHECK-LABEL: func.func @barriers(
135+ # CHECK: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: f32) -> i32 {
136+ # CHECK: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
137+ # CHECK: %[[CONSTANT_1:.*]] = arith.constant 65535 : i32
138+ # CHECK: %[[BARRIER_0:.*]] = nvvm.barrier id = %[[CONSTANT_0]] number_of_threads = %[[CONSTANT_1]] -> i32
139+ # CHECK: %[[BARRIER_1:.*]] = nvvm.barrier #nvvm.reduction<and> %[[BARRIER_0]] -> i32
140+ # CHECK: %[[BARRIER_2:.*]] = nvvm.barrier #nvvm.reduction<or> %[[BARRIER_1]] -> i32
141+ # CHECK: %[[BARRIER_3:.*]] = nvvm.barrier #nvvm.reduction<popc> %[[BARRIER_2]] -> i32
142+ # CHECK: nvvm.barrier0
143+ # CHECK: nvvm.bar.warp.sync %[[ARG0]] : i32
144+ # CHECK: nvvm.cluster.arrive
145+ # CHECK: nvvm.cluster.arrive {aligned}
146+ # CHECK: nvvm.cluster.arrive.relaxed
147+ # CHECK: nvvm.cluster.arrive.relaxed {aligned}
148+ # CHECK: nvvm.cluster.wait
149+ # CHECK: nvvm.cluster.wait {aligned}
150+ # CHECK: nvvm.fence.mbarrier.init
151+ # CHECK: nvvm.bar.warp.sync %[[ARG0]] : i32
152+ # CHECK: return %[[BARRIER_3]] : i32
153+ # CHECK: }
154+
155+
156+ @constructAndPrintInModule
157+ def test_reductions ():
158+ i32 = T .i32 ()
159+ f32 = T .f32 ()
160+
161+ @func .FuncOp .from_py_func (i32 , i32 , f32 )
162+ def reductions (mask , vi32 , vf32 ):
163+ for abs in (True , False ):
164+ for nan in (True , False ):
165+ for kind in (
166+ nvvm .ReduxKind .AND ,
167+ nvvm .ReduxKind .MAX ,
168+ nvvm .ReduxKind .MIN ,
169+ nvvm .ReduxKind .OR ,
170+ nvvm .ReduxKind .UMAX ,
171+ nvvm .ReduxKind .UMIN ,
172+ nvvm .ReduxKind .XOR ,
173+ ):
174+ nvvm .redux_sync (i32 , vi32 , kind , vi32 )
175+
176+ for kind in (
177+ nvvm .ReduxKind .FMIN ,
178+ nvvm .ReduxKind .FMAX ,
179+ ):
180+ nvvm .redux_sync (f32 , vf32 , kind , vi32 , abs = abs , nan = nan )
181+
182+
183+ # CHECK-LABEL: func.func @reductions(
184+ # CHECK: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: f32) {
185+ # CHECK: %[[REDUX_0:.*]] = nvvm.redux.sync and %[[ARG1]], %[[ARG1]] : i32 -> i32
186+ # CHECK: %[[REDUX_1:.*]] = nvvm.redux.sync max %[[ARG1]], %[[ARG1]] : i32 -> i32
187+ # CHECK: %[[REDUX_2:.*]] = nvvm.redux.sync min %[[ARG1]], %[[ARG1]] : i32 -> i32
188+ # CHECK: %[[REDUX_3:.*]] = nvvm.redux.sync or %[[ARG1]], %[[ARG1]] : i32 -> i32
189+ # CHECK: %[[REDUX_4:.*]] = nvvm.redux.sync umax %[[ARG1]], %[[ARG1]] : i32 -> i32
190+ # CHECK: %[[REDUX_5:.*]] = nvvm.redux.sync umin %[[ARG1]], %[[ARG1]] : i32 -> i32
191+ # CHECK: %[[REDUX_6:.*]] = nvvm.redux.sync xor %[[ARG1]], %[[ARG1]] : i32 -> i32
192+ # CHECK: %[[REDUX_7:.*]] = nvvm.redux.sync fmin %[[ARG2]], %[[ARG1]] {abs = true, nan = true} : f32 -> f32
193+ # CHECK: %[[REDUX_8:.*]] = nvvm.redux.sync fmax %[[ARG2]], %[[ARG1]] {abs = true, nan = true} : f32 -> f32
194+ # CHECK: %[[REDUX_9:.*]] = nvvm.redux.sync and %[[ARG1]], %[[ARG1]] : i32 -> i32
195+ # CHECK: %[[REDUX_10:.*]] = nvvm.redux.sync max %[[ARG1]], %[[ARG1]] : i32 -> i32
196+ # CHECK: %[[REDUX_11:.*]] = nvvm.redux.sync min %[[ARG1]], %[[ARG1]] : i32 -> i32
197+ # CHECK: %[[REDUX_12:.*]] = nvvm.redux.sync or %[[ARG1]], %[[ARG1]] : i32 -> i32
198+ # CHECK: %[[REDUX_13:.*]] = nvvm.redux.sync umax %[[ARG1]], %[[ARG1]] : i32 -> i32
199+ # CHECK: %[[REDUX_14:.*]] = nvvm.redux.sync umin %[[ARG1]], %[[ARG1]] : i32 -> i32
200+ # CHECK: %[[REDUX_15:.*]] = nvvm.redux.sync xor %[[ARG1]], %[[ARG1]] : i32 -> i32
201+ # CHECK: %[[REDUX_16:.*]] = nvvm.redux.sync fmin %[[ARG2]], %[[ARG1]] {abs = true} : f32 -> f32
202+ # CHECK: %[[REDUX_17:.*]] = nvvm.redux.sync fmax %[[ARG2]], %[[ARG1]] {abs = true} : f32 -> f32
203+ # CHECK: %[[REDUX_18:.*]] = nvvm.redux.sync and %[[ARG1]], %[[ARG1]] : i32 -> i32
204+ # CHECK: %[[REDUX_19:.*]] = nvvm.redux.sync max %[[ARG1]], %[[ARG1]] : i32 -> i32
205+ # CHECK: %[[REDUX_20:.*]] = nvvm.redux.sync min %[[ARG1]], %[[ARG1]] : i32 -> i32
206+ # CHECK: %[[REDUX_21:.*]] = nvvm.redux.sync or %[[ARG1]], %[[ARG1]] : i32 -> i32
207+ # CHECK: %[[REDUX_22:.*]] = nvvm.redux.sync umax %[[ARG1]], %[[ARG1]] : i32 -> i32
208+ # CHECK: %[[REDUX_23:.*]] = nvvm.redux.sync umin %[[ARG1]], %[[ARG1]] : i32 -> i32
209+ # CHECK: %[[REDUX_24:.*]] = nvvm.redux.sync xor %[[ARG1]], %[[ARG1]] : i32 -> i32
210+ # CHECK: %[[REDUX_25:.*]] = nvvm.redux.sync fmin %[[ARG2]], %[[ARG1]] {nan = true} : f32 -> f32
211+ # CHECK: %[[REDUX_26:.*]] = nvvm.redux.sync fmax %[[ARG2]], %[[ARG1]] {nan = true} : f32 -> f32
212+ # CHECK: %[[REDUX_27:.*]] = nvvm.redux.sync and %[[ARG1]], %[[ARG1]] : i32 -> i32
213+ # CHECK: %[[REDUX_28:.*]] = nvvm.redux.sync max %[[ARG1]], %[[ARG1]] : i32 -> i32
214+ # CHECK: %[[REDUX_29:.*]] = nvvm.redux.sync min %[[ARG1]], %[[ARG1]] : i32 -> i32
215+ # CHECK: %[[REDUX_30:.*]] = nvvm.redux.sync or %[[ARG1]], %[[ARG1]] : i32 -> i32
216+ # CHECK: %[[REDUX_31:.*]] = nvvm.redux.sync umax %[[ARG1]], %[[ARG1]] : i32 -> i32
217+ # CHECK: %[[REDUX_32:.*]] = nvvm.redux.sync umin %[[ARG1]], %[[ARG1]] : i32 -> i32
218+ # CHECK: %[[REDUX_33:.*]] = nvvm.redux.sync xor %[[ARG1]], %[[ARG1]] : i32 -> i32
219+ # CHECK: %[[REDUX_34:.*]] = nvvm.redux.sync fmin %[[ARG2]], %[[ARG1]] : f32 -> f32
220+ # CHECK: %[[REDUX_35:.*]] = nvvm.redux.sync fmax %[[ARG2]], %[[ARG1]] : f32 -> f32
221+ # CHECK: return
222+ # CHECK: }
0 commit comments