@@ -42,10 +42,10 @@ def build_kernel(
4242 tile_m : int = 128 ,
4343 tile_n : int = 128 ,
4444 max_concurrent_steps : int = 2 ,
45+ collective : bool = False ,
4546):
4647 i1 = ir .IntegerType .get_signless (1 )
4748 i32 = ir .IntegerType .get_signless (32 )
48- f32 = ir .F32Type .get ()
4949 index = ir .IndexType .get ()
5050
5151 swizzle = 128
@@ -64,32 +64,46 @@ def build_kernel(
6464 tma_tile_m = 128
6565 tma_tile_kn = 64
6666
67+ block_tile_m = tile_m
68+ block_tile_n = tile_n
69+ if collective :
70+ tile_m *= 2
71+ tile_n *= 2
72+
6773 def kernel (ctx , a , b , d , smem ):
6874 a_smem , b_smem , d_smem , barriers , mma_done_barrier , acc = smem
6975 (ab_full_barriers , ab_empty_barriers ) = barriers
7076
7177 warp_idx = mgpu .warp_idx (sync = True )
72- warp_leader = nvvm .elect_sync (i1 )
78+ is_warp_leader = nvvm .elect_sync (i1 )
7379
74- is_warp = lambda i : arith .cmpi (arith .CmpIPredicate .eq , warp_idx , c (i , i32 ))
80+ is_leader_of = lambda i : arith .andi ( arith . cmpi (arith .CmpIPredicate .eq , warp_idx , c (i , i32 )), is_warp_leader )
7581
76- m_start = arith .muli (gpu .block_id (gpu .Dimension .y ), c (tile_m ,index ))
77- n_start = arith .muli (gpu .block_id (gpu .Dimension .x ), c (tile_n ,index ))
82+ m_start = arith .muli (gpu .cluster_id (gpu .Dimension .x ), c (tile_m ,index ))
83+ block_m_start = arith .muli (gpu .block_id (gpu .Dimension .x ), c (block_tile_m ,index ))
84+ n_start = arith .muli (gpu .block_id (gpu .Dimension .y ), c (tile_n ,index ))
85+ is_leader_block = arith .cmpi (arith .CmpIPredicate .eq , ctx .cluster_idx (gpu .Dimension .x ), c (0 , index ))
7886
79- with mgpu .when (arith . andi ( is_warp ( TMA_WARP ), warp_leader )):
87+ with mgpu .when (is_leader_of ( TMA_WARP )):
8088 @mgpu .fori (c (k_loop_iter , index ), None )
8189 def _tma_body (ki , _ ):
8290 slot = arith .remui (ki , c (max_concurrent_steps , index ))
8391 # TODO(apaszke): Use a predicate instead of a conditional.
8492 with mgpu .when (arith .cmpi (arith .CmpIPredicate .uge , ki , c (max_concurrent_steps , index ))):
8593 ab_empty_barriers [slot ].wait ()
8694 full_barrier = ab_full_barriers [slot ]
87- full_barrier .arrive_expect_tx (
88- bytecount ((tile_m , tile_k ), in_dtype ) + bytecount ((tile_n , tile_k ), in_dtype )
89- )
95+ with mgpu .when (is_leader_block ):
96+ full_barrier .arrive_expect_tx (
97+ bytecount ((tile_m , tile_k ), in_dtype ) + bytecount ((tile_n , tile_k ), in_dtype )
98+ )
9099 k_start = arith .muli (ki , c (tile_k , index ))
91100 common_args = dict (
92- swizzle = swizzle , barrier = full_barrier , arrive = False , uniform = False ,
101+ swizzle = swizzle ,
102+ barrier = full_barrier ,
103+ arrive = False ,
104+ uniform = False ,
105+ collective = gpu .Dimension .x ,
106+ partitioned = 0 , # Non-contracting dim is always 0.
93107 )
94108 ctx .async_copy (
95109 src_ref = a ,
@@ -109,66 +123,67 @@ def _tma_body(ki, _):
109123 ** common_args ,
110124 )
111125
112- with mgpu .when (arith .andi (is_warp (MMA_WARP ), warp_leader )):
113- with mgpu .when ( warp_leader ):
114- @ mgpu . fori ( c ( k_loop_iter , index ), arith . constant ( i1 , 0 ))
115- def _mma_body (ki , accumulate ):
116- slot = arith . remui ( ki , c ( max_concurrent_steps , index ) )
117- ab_full_barriers [ slot ]. wait ()
118- tcgen05 . mma (
119- acc ,
120- mgpu .memref_slice (a_smem , slot ),
121- mgpu . memref_transpose ( mgpu . memref_slice ( b_smem , slot ), ( 0 , 1 , 3 , 2 )) ,
122- a_swizzle = swizzle ,
123- b_swizzle = swizzle ,
124- accumulate = accumulate ,
125- )
126- accumulate = arith .constant (i1 , 1 )
127- is_last_iter = arith .cmpi (
128- arith .CmpIPredicate .eq , ki , c (k_loop_iter - 1 , index )
129- )
130- barrier_ptr = arith .select (
131- is_last_iter ,
132- mma_done_barrier .get_ptr (),
133- ab_empty_barriers [slot ].get_ptr (),
134- )
135- tcgen05 .commit_arrive (barrier_ptr )
136- return accumulate
126+ with mgpu .when (arith .andi (is_leader_of (MMA_WARP ), is_leader_block )):
127+ @ mgpu .fori ( c ( k_loop_iter , index ), arith . constant ( i1 , 0 ))
128+ def _mma_body ( ki , accumulate ):
129+ slot = arith . remui (ki , c ( max_concurrent_steps , index ))
130+ ab_full_barriers [ slot ]. wait ( )
131+ tcgen05 . mma (
132+ acc ,
133+ mgpu . memref_slice ( a_smem , slot ) ,
134+ mgpu .memref_transpose ( mgpu . memref_slice (b_smem , slot ), ( 0 , 1 , 3 , 2 ) ),
135+ a_swizzle = swizzle ,
136+ b_swizzle = swizzle ,
137+ accumulate = accumulate ,
138+ collective = collective ,
139+ )
140+ accumulate = arith .constant (i1 , 1 )
141+ is_last_iter = arith .cmpi (
142+ arith .CmpIPredicate .eq , ki , c (k_loop_iter - 1 , index )
143+ )
144+ barrier_ptr = arith .select (
145+ is_last_iter ,
146+ mma_done_barrier .get_ptr (),
147+ ab_empty_barriers [slot ].get_ptr (),
148+ )
149+ tcgen05 .commit_arrive (barrier_ptr , collective = collective , ctx = ctx )
150+ return accumulate
137151
138152 gpu .barrier ()
139153 mma_done_barrier .wait (for_tensor_core = True )
140154
141155 acc [:].astype (ir .F16Type .get ()).store_tiled (d_smem , swizzle = 128 )
142156 mgpu .commit_shared ()
143- # TODO(apaszke): Free up TMEM?
144157 ctx .async_copy (
145158 src_ref = d_smem ,
146159 dst_ref = d ,
147- gmem_slice = (ds (m_start , tile_m ), ds (n_start , tile_n )),
160+ gmem_slice = (ds (block_m_start , block_tile_m ), ds (n_start , tile_n )),
148161 gmem_transform = mgpu .TileTransform ((128 , 64 )),
149162 swizzle = swizzle ,
150163 )
164+ # TODO(apaszke): Free up TMEM?
151165 ctx .await_async_copy (0 )
152166
153167 # TODO(apaszke): Use a union for output SMEM.
154168 smem = (
155- jax .ShapeDtypeStruct ((max_concurrent_steps , * mgpu .tile_shape ((tile_m , tile_k ), (tma_tile_m , tma_tile_kn ))), jnp .float16 ),
156- jax .ShapeDtypeStruct ((max_concurrent_steps , * mgpu .tile_shape ((tile_k , tile_n ), (tma_tile_kn , tma_tile_kn ))), jnp .float16 ),
157- jax .ShapeDtypeStruct (mgpu .tile_shape ((tile_m , tile_n ), (tma_tile_m , tma_tile_kn )), jnp .float16 ),
169+ jax .ShapeDtypeStruct ((max_concurrent_steps , * mgpu .tile_shape ((block_tile_m , tile_k ), (tma_tile_m , tma_tile_kn ))), jnp .float16 ),
170+ jax .ShapeDtypeStruct ((max_concurrent_steps , * mgpu .tile_shape ((tile_k , block_tile_n ), (tma_tile_kn , tma_tile_kn ))), jnp .float16 ),
171+ jax .ShapeDtypeStruct (mgpu .tile_shape ((block_tile_m , tile_n ), (tma_tile_m , tma_tile_kn )), jnp .float16 ),
158172 [mgpu .Barrier (arrival_count = 1 , num_barriers = max_concurrent_steps )] * 2 ,
159173 mgpu .Barrier (arrival_count = 1 ),
160- mgpu .TMEM ((128 , tile_n ), jnp .float32 , tcgen05 .TMEMLayout .D ),
174+ mgpu .TMEM ((128 , tile_n ), jnp .float32 , tcgen05 .TMEMLayout .D , collective = collective ),
161175 )
162176 return mgpu .as_gpu_kernel (
163177 kernel ,
164- (n // tile_n , m // tile_m , 1 ),
178+ (m // block_tile_m , n // tile_n , 1 ),
165179 (128 , 1 , 1 ),
166180 (
167181 jax .ShapeDtypeStruct ((m , k ), jnp .float16 ),
168182 jax .ShapeDtypeStruct ((n , k ), jnp .float16 ),
169183 ),
170184 jax .ShapeDtypeStruct ((m , n ), jnp .float16 ),
171185 smem ,
186+ cluster = (2 if collective else 1 , 1 , 1 ),
172187 )
173188
174189
@@ -188,8 +203,8 @@ def main(unused_argv):
188203 f = build_kernel (m , n , k , tile_m = m_tile , tile_n = n_tile )
189204 y = f (a , b ).block_until_ready ()
190205
191- ref = np . asarray ( a ) @ np . asarray ( b ). T
192- np .testing .assert_allclose (y , ref , atol = 1e-3 , rtol = 1e-3 )
206+ y_ref = jax . jit ( lambda a , b : a @ b . T )( a , b )
207+ np .testing .assert_allclose (y , y_ref , atol = 1e-3 , rtol = 1e-3 )
193208 print ("OK!" )
194209
195210
0 commit comments