@@ -489,19 +489,12 @@ def get_packed_shape(strides, shape):
489489
490490class WGMMALayoutTest (TestCase ):
491491
492- @parameterized .product (dtype = [jnp .float16 , jnp .float32 ],
493- transposed_smem = [False , True ])
494- def test_store_untiled (self , dtype , transposed_smem ):
492+ @parameterized .product (dtype = [jnp .float16 , jnp .float32 ])
493+ def test_store_untiled (self , dtype ):
495494 def kernel (ctx , out , _ ):
496495 del ctx
497- if transposed_smem :
498- out = memref_transpose (out , (1 , 0 ))
499- iota_tensor (64 , 64 , dtype ).store_untiled (
500- out , vector_store = not transposed_smem
501- )
496+ iota_tensor (64 , 64 , dtype ).store_untiled (out , optimized = False )
502497 expected = np .arange (64 * 64 , dtype = dtype ).reshape (64 , 64 )
503- if transposed_smem :
504- expected = expected .T
505498 iota = mgpu .as_gpu_kernel (
506499 kernel , (1 , 1 , 1 ), (128 , 1 , 1 ), (), expected , ()
507500 )()
@@ -749,7 +742,7 @@ def kernel(ctx, lhs, rhs, out, scratch):
749742 acc = mgpu .wgmma (init_acc , lhs_smem , rhs_smem , swizzle = swizzle )
750743 nvvm .wgmma_commit_group_sync_aligned ()
751744 nvvm .wgmma_wait_group_sync_aligned (0 )
752- acc .value .store_untiled (out )
745+ acc .value .store_untiled (out , optimized = False )
753746
754747 def quantize (x ):
755748 # Quantize the input to avoid rounding when feeding the WGMMA
@@ -821,7 +814,7 @@ def kernel(ctx, rhs, out, rhs_smem):
821814 acc = mgpu .wgmma (init_acc , lhs_regs , rhs_smem , swizzle = swizzle )
822815 nvvm .wgmma_commit_group_sync_aligned ()
823816 nvvm .wgmma_wait_group_sync_aligned (0 )
824- acc .value .store_untiled (out )
817+ acc .value .store_untiled (out , optimized = False )
825818
826819 y_shape = (n , k ) if rhs_transpose else (k , n )
827820 y = self .prng .uniform (- 1 , 1 , y_shape ).astype (dtype )
@@ -881,7 +874,7 @@ def kernel(ctx, rhs, out, smem):
881874 acc = mgpu .wgmma (init_acc , lhs_regs , rhs_smem , swizzle = swizzle )
882875 nvvm .wgmma_commit_group_sync_aligned ()
883876 nvvm .wgmma_wait_group_sync_aligned (0 )
884- acc .value .store_untiled (out )
877+ acc .value .store_untiled (out , optimized = False )
885878
886879 jax_dtype = jnp .float16
887880 y_shape = (n , k ) if rhs_transpose else (k , n )
@@ -1042,7 +1035,7 @@ def kernel(ctx, lhs, rhs, out, scratch):
10421035 )
10431036 tcgen05 .commit_arrive (barriers [2 ])
10441037 barriers [2 ].wait (for_tensor_core = True )
1045- acc [:].store_untiled (out )
1038+ acc [:].store_untiled (out , optimized = False )
10461039
10471040 x_shape = (k , m ) if lhs_transpose else (m , k )
10481041 x = self .prng .uniform (- 1 , 1 , x_shape ).astype (in_jax_dtype )
@@ -1145,7 +1138,7 @@ def kernel(ctx, lhs, rhs, out, scratch):
11451138 tcgen05 .commit_arrive (barriers [2 ], collective = True , ctx = ctx )
11461139 barriers [2 ].wait (for_tensor_core = True )
11471140 m_slice = ds (arith .muli (block_id , c (m_block_tile , index )), m_block_tile )
1148- acc [:].store_untiled (memref_slice (out , m_slice ))
1141+ acc [:].store_untiled (memref_slice (out , m_slice ), optimized = False )
11491142
11501143 in_finfo = jnp .finfo (in_jax_dtype )
11511144 exponent_bits , mantissa_bits = in_finfo .nexp , in_finfo .nmant
@@ -1198,7 +1191,7 @@ def kernel(ctx, dst, scratch):
11981191 final_arr = arr + mgpu .FragmentedArray .load_strided (
11991192 tmp , is_signed = False
12001193 )
1201- final_arr .store_untiled (memref_slice (dst , 0 ))
1194+ final_arr .store_untiled (memref_slice (dst , 0 ), optimized = False )
12021195 scf .yield_ ([])
12031196 with ir .InsertionPoint (scf .IfOp (is_second_wg ).then_block ):
12041197 barriers [0 ].wait ()
@@ -1209,7 +1202,7 @@ def kernel(ctx, dst, scratch):
12091202 barriers [2 ].wait () # Synchronize this warpgroup before we overwrite tmp.
12101203 arr .store_untiled (tmp )
12111204 barriers [1 ].arrive () # Signal that tmp is ready.
1212- final_arr .store_untiled (memref_slice (dst , 1 ))
1205+ final_arr .store_untiled (memref_slice (dst , 1 ), optimized = False )
12131206 scf .yield_ ([])
12141207 out_shape = jax .ShapeDtypeStruct ((2 , 128 ), jnp .int32 )
12151208 y = mgpu .as_gpu_kernel (
@@ -1670,7 +1663,7 @@ def kernel(ctx, dst, _):
16701663 mlir_dtype = utils .dtype_to_ir_type (dtype )
16711664 iota = iota_tensor (m , n , dtype )
16721665 rhs = iota if scalar_rhs is None else c (scalar_rhs , mlir_dtype )
1673- op (iota , rhs ).store_untiled (dst )
1666+ op (iota , rhs ).store_untiled (dst , optimized = False )
16741667 out_shape = jax .ShapeDtypeStruct ((m , n ), dtype )
16751668 result = mgpu .as_gpu_kernel (
16761669 kernel , (1 , 1 , 1 ), (128 , 1 , 1 ), (), out_shape , ()
@@ -1716,7 +1709,7 @@ def test_division(self, op, dtype, m=64, n=32):
17161709
17171710 def kernel (ctx , dst , _ ):
17181711 iota = iota_tensor (m , n , dtype )
1719- op (dtype (4.2 ).item () * iota , iota + 1 ).store_untiled (dst )
1712+ op (dtype (4.2 ).item () * iota , iota + 1 ).store_untiled (dst , optimized = False )
17201713
17211714 out_shape = jax .ShapeDtypeStruct ((m , n ), dtype )
17221715 result = mgpu .as_gpu_kernel (
@@ -1746,14 +1739,14 @@ def kernel(ctx, dst, _):
17461739 rhs = 0 if rhs_is_literal else iota + 1
17471740 res = op (iota , rhs )
17481741 assert not res .is_signed
1749- res .astype (i8 , is_signed = False ).store_untiled (dst )
1742+ res .astype (i8 , is_signed = False ).store_untiled (dst , optimized = False )
17501743
17511744 out_shape = jax .ShapeDtypeStruct ((m , n ), jnp .int8 )
17521745 result = mgpu .as_gpu_kernel (
17531746 kernel , (1 , 1 , 1 ), (128 , 1 , 1 ), (), out_shape , ()
17541747 )()
17551748 iota = np .arange (m * n , dtype = dtype ).reshape (m , n )
1756- rhs = rhs = 0 if rhs_is_literal else iota + 1
1749+ rhs = 0 if rhs_is_literal else iota + 1
17571750 np .testing .assert_array_equal (result , op (iota , rhs ).astype (jnp .int8 ))
17581751
17591752 def test_foreach_wgmma_row_array (self ):
@@ -1784,22 +1777,25 @@ def _(v, idx):
17841777 def test_foreach (self ):
17851778 dtype = jnp .int32
17861779 swizzle = 128
1787- tile = 64 , swizzle // jnp .dtype (dtype ).itemsize
1780+ tiling = ( 8 , swizzle // jnp .dtype (dtype ).itemsize )
17881781 shape = 128 , 192
1789- tiled_shape = mgpu .tile_shape (shape , tile )
17901782 mlir_dtype = utils .dtype_to_ir_type (dtype )
17911783 cst = 9999
17921784 def causal (val , idx ):
17931785 row , col = idx
17941786 mask = arith .cmpi (arith .CmpIPredicate .uge , row , col )
17951787 return arith .select (mask , val , c (cst , mlir_dtype ))
17961788
1797- tiling = mgpu .TileTransform (tile )
17981789 def kernel (ctx , dst , smem ):
17991790 x = iota_tensor (shape [0 ], shape [1 ], dtype )
1800- x .foreach (causal , create_array = True , is_signed = False ).store_untiled (smem )
1791+ x .foreach (causal , create_array = True , is_signed = False ).store_tiled (smem , swizzle = 128 )
18011792 mgpu .commit_shared ()
1802- ctx .async_copy (src_ref = smem , dst_ref = dst )
1793+ ctx .async_copy (
1794+ src_ref = smem ,
1795+ dst_ref = dst ,
1796+ gmem_transform = mgpu .TileTransform (tiling ),
1797+ swizzle = 128 ,
1798+ )
18031799 ctx .await_async_copy (0 )
18041800
18051801 iota = np .arange (np .prod (shape ), dtype = dtype ).reshape (* shape )
@@ -1809,7 +1805,7 @@ def kernel(ctx, dst, smem):
18091805 (128 , 1 , 1 ),
18101806 (),
18111807 jax .ShapeDtypeStruct (shape = shape , dtype = dtype ),
1812- jax .ShapeDtypeStruct (shape = shape , dtype = dtype ),
1808+ jax .ShapeDtypeStruct (shape = mgpu . tile_shape ( shape , tiling ) , dtype = dtype ),
18131809 )()
18141810 expected = jnp .tril (iota ) + jnp .triu (jnp .ones (shape ), k = 1 ) * cst
18151811 np .testing .assert_array_equal (result , expected )
@@ -1821,7 +1817,7 @@ def kernel(ctx, dst, smem):
18211817 def test_bitwise (self , op , dtype , m = 64 , n = 8 ):
18221818 def kernel (ctx , dst , _ ):
18231819 iota = iota_tensor (m , n , dtype )
1824- op (iota , iota + 1 ).store_untiled (dst )
1820+ op (iota , iota + 1 ).store_untiled (dst , optimized = False )
18251821
18261822 out_shape = jax .ShapeDtypeStruct ((m , n ), dtype )
18271823 result = mgpu .as_gpu_kernel (
@@ -1845,7 +1841,7 @@ def test_unary(self, ops, dtype, m=64, n=32):
18451841
18461842 def kernel (ctx , dst , _ ):
18471843 iota = iota_tensor (m , n , dtype )
1848- op (iota ).store_untiled (dst )
1844+ op (iota ).store_untiled (dst , optimized = False )
18491845
18501846 out_shape = jax .ShapeDtypeStruct ((m , n ), dtype )
18511847 result = mgpu .as_gpu_kernel (
@@ -1858,7 +1854,7 @@ def test_select(self, m=64, n=32):
18581854
18591855 def kernel (ctx , dst , _ ):
18601856 iota = iota_tensor (m , n , jnp .int32 )
1861- (iota < 16 ).select (iota * 2 , iota * 3 ).store_untiled (dst )
1857+ (iota < 16 ).select (iota * 2 , iota * 3 ).store_untiled (dst , optimized = False )
18621858
18631859 out_shape = jax .ShapeDtypeStruct ((m , n ), jnp .int32 )
18641860 result = mgpu .as_gpu_kernel (
@@ -1881,7 +1877,7 @@ def test_math(self, ops, approx, m=64, n=32):
18811877 op , np_op = ops
18821878 def kernel (ctx , dst , _ ):
18831879 iota = iota_tensor (m , n , jnp .float32 )
1884- op (iota ).store_untiled (dst )
1880+ op (iota ).store_untiled (dst , optimized = False )
18851881 out_shape = jax .ShapeDtypeStruct ((m , n ), jnp .float32 )
18861882 result = mgpu .as_gpu_kernel (
18871883 kernel , (1 , 1 , 1 ), (128 , 1 , 1 ), (), out_shape , ()
@@ -1902,7 +1898,7 @@ def kernel(ctx, src, dst, scratch):
19021898 src , is_signed = utils .is_signed (dtype )
19031899 )
19041900 acc = src .reduce_sum (scratch ).broadcast ((m ,))
1905- acc .store_untiled (dst )
1901+ acc .store_untiled (dst , optimized = False )
19061902
19071903 in_shape = jax .ShapeDtypeStruct ((m , n ), dtype )
19081904 out_shape = jax .ShapeDtypeStruct ((m ,), dtype )
@@ -1930,7 +1926,7 @@ def kernel(ctx, dst, _):
19301926 is_signed = utils .is_signed (dtype ),
19311927 )
19321928 acc = src .reduce_sum ().broadcast ((m ,))
1933- acc .store_untiled (dst )
1929+ acc .store_untiled (dst , optimized = False )
19341930
19351931 kernel_fn = mgpu .as_gpu_kernel (
19361932 kernel ,
@@ -1950,7 +1946,7 @@ def kernel(ctx, dst, _):
19501946 def test_reduce (self , op , m = 64 , n = 32 ):
19511947 def kernel (ctx , dst , _ ):
19521948 iota = iota_tensor (m , n , jnp .float32 )
1953- iota .reduce (op , axis = 1 ).broadcast_minor (n ).store_untiled (dst )
1949+ iota .reduce (op , axis = 1 ).broadcast_minor (n ).store_untiled (dst , optimized = False )
19541950 out_shape = jax .ShapeDtypeStruct ((m , n ), jnp .float32 )
19551951 result = mgpu .as_gpu_kernel (
19561952 kernel , (1 , 1 , 1 ), (128 , 1 , 1 ), (), out_shape , ()
@@ -1971,7 +1967,7 @@ def kernel(ctx, dst, _):
19711967 cte = c (1 , iota .mlir_dtype )
19721968 cte_arr = mgpu .FragmentedArray .splat (cte , ())
19731969 cte_arr = cte_arr .reshape ((1 , 1 )).broadcast ((m , n ))
1974- (iota + cte_arr ).store_untiled (dst )
1970+ (iota + cte_arr ).store_untiled (dst , optimized = False )
19751971 out_shape = jax .ShapeDtypeStruct ((m , n ), jnp .float32 )
19761972 result = mgpu .as_gpu_kernel (
19771973 kernel , (1 , 1 , 1 ), (128 , 1 , 1 ), (), out_shape , ()
@@ -1986,7 +1982,7 @@ def kernel(ctx, dst, _):
19861982 t = mgpu .FragmentedArray .splat (
19871983 v , (128 ,), mgpu .WGMMA_ROW_LAYOUT
19881984 )
1989- t .broadcast_minor (32 ).store_untiled (dst )
1985+ t .broadcast_minor (32 ).store_untiled (dst , optimized = False )
19901986 out_shape = jax .ShapeDtypeStruct ((128 , 32 ), jnp .float32 )
19911987 result = mgpu .as_gpu_kernel (
19921988 kernel , (1 , 1 , 1 ), (128 , 1 , 1 ), (), out_shape , ()
@@ -2005,7 +2001,7 @@ def kernel(ctx, src, dst, _):
20052001 assert isinstance (pi_arr_sq .layout , mgpu .WGStridedFragLayout )
20062002 pi_arr_cube = pi_splat .broadcast (pi_arr .shape ) * pi_arr_sq
20072003 assert isinstance (pi_arr_cube .layout , mgpu .WGStridedFragLayout )
2008- (pi_arr == pi_arr ).select (pi_splat , pi_arr_cube ).store_untiled (dst )
2004+ (pi_arr == pi_arr ).select (pi_splat , pi_arr_cube ).store_untiled (dst , optimized = False )
20092005
20102006 out_shape = jax .ShapeDtypeStruct ((128 , 32 ), jnp .float32 )
20112007 inp = jnp .ones_like (out_shape ) * 3.14
@@ -2077,7 +2073,7 @@ def kernel(ctx, gmem_input, gmem_output, _):
20772073 t = mgpu .FragmentedArray .load_untiled (
20782074 gmem_input , layout = mgpu .WGMMA_COL_LAYOUT , optimized = False
20792075 )
2080- t .broadcast_major (m ).store_untiled (gmem_output )
2076+ t .broadcast_major (m ).store_untiled (gmem_output , optimized = False )
20812077
20822078 inp = self .prng .uniform (- 1 , 1 , (n ,)).astype (jnp .float16 )
20832079 out_shape = jax .ShapeDtypeStruct ((m , n ), jnp .float16 )
@@ -2114,7 +2110,7 @@ def kernel(ctx, inp, out, smem):
21142110 del ctx , smem
21152111 arr = mgpu .FragmentedArray .load_strided (inp , is_signed = True )
21162112 assert ir .VectorType (arr .registers .flat [0 ].type ).shape == [reg_length ]
2117- arr .astype (mlir_dtype_to ).store_untiled (out )
2113+ arr .astype (mlir_dtype_to ).store_untiled (out , optimized = False )
21182114
21192115 x = jnp .arange (- 128 , 128 , dtype = jax_dtype_from )
21202116 x = jnp .tile (x , reg_length // 2 )
@@ -2190,7 +2186,7 @@ def test_convert_bool_to_u8(self):
21902186 def kernel (ctx , dst , _ ):
21912187 i8 = ir .IntegerType .get_signless (8 )
21922188 iota = iota_tensor (m , n , jnp .uint8 )
2193- (iota > 10 ).astype (i8 , is_signed = False ).store_untiled (dst )
2189+ (iota > 10 ).astype (i8 , is_signed = False ).store_untiled (dst , optimized = False )
21942190
21952191 out_shape = jax .ShapeDtypeStruct ((m , n ), jnp .int8 )
21962192 result = mgpu .as_gpu_kernel (
@@ -2318,7 +2314,7 @@ def kernel(ctx, dst, _):
23182314 )
23192315 self .assertEqual (tiled .shape , shape )
23202316 self .assertEqual (tiled .mlir_dtype , iota .mlir_dtype )
2321- tiled .store_untiled (dst )
2317+ tiled .store_untiled (dst , optimized = False )
23222318 ty = jax .ShapeDtypeStruct (shape , dtype )
23232319 f = mgpu .as_gpu_kernel (kernel , (1 , 1 , 1 ), (128 , 1 , 1 ), (), ty , ())
23242320 expected = np .arange (math .prod (shape ), dtype = dtype ).reshape (shape )
0 commit comments