@@ -2634,6 +2634,80 @@ class CoreMapWGTest(
26342634 ...
26352635
26362636
2637+ class PrettyPrintingTest (PallasTest ):
2638+
2639+ def test_load (self ):
2640+ @functools .partial (
2641+ self .pallas_call ,
2642+ out_shape = jax .ShapeDtypeStruct ([2 , 128 ], jnp .float32 ),
2643+ in_specs = [pl .BlockSpec (memory_space = plgpu .GMEM )],
2644+ out_specs = plgpu .GPUBlockSpec (memory_space = plgpu .SMEM ),
2645+ )
2646+ def kernel (x_ref , o_ref ):
2647+ for i in range (2 ):
2648+ x = plgpu .load (x_ref , (i ,))
2649+ o_ref [i , ...] = x
2650+
2651+ _ = str (jax .make_jaxpr (kernel )(jax .ShapeDtypeStruct ((2 , 128 ), jnp .float32 )))
2652+
2653+ def test_copy_primitives (self ):
2654+ num_steps = 4
2655+
2656+ @functools .partial (
2657+ self .pallas_call ,
2658+ out_shape = jax .ShapeDtypeStruct ((64 , 64 ), jnp .float32 ),
2659+ in_specs = [pl .BlockSpec (memory_space = plgpu .GMEM )],
2660+ out_specs = pl .BlockSpec (memory_space = plgpu .GMEM ),
2661+ )
2662+ def kernel (x_gmem , o_gmem ):
2663+ # ``plgpu.emit_pipeline`` is implemented in terms of async copy and
2664+ # synchronization primitives.
2665+ plgpu .emit_pipeline (
2666+ kernel_body ,
2667+ in_specs = [pl .BlockSpec ((64 , 64 ), lambda i : (0 , i ))],
2668+ out_specs = [
2669+ pl .BlockSpec (
2670+ (64 , 64 ),
2671+ lambda i : (0 , i ),
2672+ )
2673+ ],
2674+ grid = (num_steps ,),
2675+ max_concurrent_steps = 2 ,
2676+ )(x_gmem , o_gmem )
2677+
2678+ def kernel_body (_ , x_smem , o_smem ):
2679+ o_smem [...] = x_smem [...] + 1.0
2680+
2681+ _ = str (jax .make_jaxpr (kernel )(jax .ShapeDtypeStruct ((64 , 64 ), jnp .float32 )))
2682+
2683+ def test_wgmma (self ):
2684+ transforms = ()
2685+ if self .LOWERING_SEMANTICS == plgpu .LoweringSemantics .Lane :
2686+ transforms = (plgpu .TilingTransform ((8 , 64 )), plgpu .SwizzleTransform (128 ))
2687+
2688+ @functools .partial (
2689+ self .pallas_call ,
2690+ out_shape = jax .ShapeDtypeStruct ((64 , 192 ), jnp .float32 ),
2691+ in_specs = [
2692+ plgpu .GPUBlockSpec (transforms = transforms ),
2693+ plgpu .GPUBlockSpec (transforms = transforms ),
2694+ ],
2695+ )
2696+ def kernel (a_ref , b_ref , o_ref ):
2697+ def scope (acc_ref ):
2698+ plgpu .wgmma (acc_ref , a_ref [...], b_ref )
2699+ return acc_ref [...]
2700+
2701+ o_ref [...] = pl .run_scoped (scope , plgpu .ACC ((64 , 192 ), jnp .float32 ))
2702+
2703+ _ = str (
2704+ jax .make_jaxpr (kernel )(
2705+ jax .ShapeDtypeStruct ((64 , 128 ), jnp .float16 ),
2706+ jax .ShapeDtypeStruct ((128 , 192 ), jnp .float16 ),
2707+ )
2708+ )
2709+
2710+
26372711class ExamplesTest (PallasTest ):
26382712
26392713 # Basic
0 commit comments