@@ -662,56 +662,6 @@ def kernel(ctx, inp, out, smem):
662662 f = mgpu .as_gpu_kernel (kernel , (1 , 1 , 1 ), (128 , 1 , 1 ), x , y , (x , y ))
663663 np .testing .assert_array_equal (f (x ), y )
664664
665- @parameterized .parameters (
666- (jnp .float32 , jnp .float8_e4m3fn ),
667- (jnp .bfloat16 , jnp .float8_e4m3fn )
668- )
669- def test_f8_conversions (self , jax_dtype_from , jax_dtype_to ):
670- mlir_dtype_to = utils .dtype_to_ir_type (jax_dtype_to )
671- def kernel (ctx , inp , out , smem ):
672- del ctx
673- smem_from , smem_to = smem
674- copy (inp , smem_from , swizzle = 128 )
675- t = mgpu .FragmentedArray .load_tiled (
676- smem_from ,
677- swizzle = 128 ,
678- is_signed = None ,
679- layout = fa .WGMMA_LAYOUT ,
680- )
681- t = t .astype (mlir_dtype_to , is_signed = utils .is_signed (jax_dtype_to ))
682- t .store_tiled (smem_to , swizzle = 128 )
683- copy (smem_to , out , swizzle = 128 )
684-
685- # These generative shenanigans are to ensure that we don't generate values
686- # that are too large for the target type. That is because the saturation
687- # behavior of the conversion is different between XLA and Mosaic GPU here
688- # (to use the NVIDIA internal, we allow Mosaic GPU to use the .satfinite
689- # modifier, which saturates to the largest finite value---while XLA would
690- # give us NaNs in this case).
691- max_finite_val = 0b111_1110
692-
693- expected = jax .lax .bitcast_convert_type (
694- jax .random .randint (
695- jax .random .key (42 ),
696- (1 , 1 , 64 , 128 ),
697- - max_finite_val ,
698- max_finite_val + 1 ,
699- dtype = jnp .uint8 ,
700- ),
701- jax_dtype_to ,
702- )
703- x = expected .astype (jax_dtype_from )
704-
705- res = mgpu .as_gpu_kernel (
706- kernel ,
707- (1 , 1 , 1 ),
708- (128 , 1 , 1 ),
709- x ,
710- expected ,
711- (x , expected ),
712- )(x )
713- np .testing .assert_array_equal (res , expected )
714-
715665 @parameterized .product (
716666 jax_dtype_from_to = (
717667 (jnp .int8 , jnp .bfloat16 ),
@@ -3473,30 +3423,33 @@ def kernel(ctx, dst, _):
34733423 np .testing .assert_array_equal (result , op (iota , rhs ).astype (jnp .int8 ))
34743424
34753425 @parameterized .product (
3476- # TODO(apaszke): Add float16, float8_e5m2
3477- jax_dtype_from = (jnp .float32 , jnp .bfloat16 , jnp .float8_e4m3fn , jnp .float8_e8m0fnu ),
3478- jax_dtype_to = (jnp .float32 , jnp .bfloat16 , jnp .float8_e4m3fn , jnp .float8_e8m0fnu ),
3479- # Test different vector lengths.
3426+ # TODO(apaszke): Add float16
3427+ jax_dtype_from = (jnp .float32 , jnp .bfloat16 , jnp .float8_e5m2 , jnp .float8_e4m3fn , jnp .float8_e8m0fnu ),
3428+ jax_dtype_to = (jnp .float32 , jnp .bfloat16 , jnp .float8_e5m2 , jnp .float8_e4m3fn , jnp .float8_e8m0fnu ),
34803429 vec_len = (1 , 2 , 4 , 8 ),
34813430 )
34823431 def test_conversion_f8_ (self , jax_dtype_from , jax_dtype_to , vec_len ):
34833432 from_bitwidth = jnp .finfo (jax_dtype_from ).bits
34843433 to_bitwidth = jnp .finfo (jax_dtype_to ).bits
34853434 if from_bitwidth > 8 and to_bitwidth > 8 :
34863435 self .skipTest ("At least one of the types should be 8-bit" )
3487- if from_bitwidth == to_bitwidth == 8 :
3488- self .skipTest ("f8 <-> f8 conversions unimplemented " )
3436+ if jax_dtype_from == jax_dtype_to :
3437+ self .skipTest ("Identical types, so nothing to test " )
34893438 if jnp .float8_e8m0fnu in {
34903439 jax_dtype_from ,
34913440 jax_dtype_to ,
34923441 } and not jtu .is_cuda_compute_capability_at_least ("10.0" ):
34933442 self .skipTest ("f8e8m0fnu not supported on pre-Blackwell GPUs" )
3494- unimplemented = [
3495- (jnp .float8_e4m3fn , jnp .bfloat16 ),
3496- (jnp .float8_e4m3fn , jnp .float32 ),
3497- (jnp .bfloat16 , jnp .float8_e8m0fnu ),
3498- ]
3499- if (jax_dtype_from , jax_dtype_to ) in unimplemented :
3443+ if from_bitwidth == to_bitwidth == 8 and {jax_dtype_from , jax_dtype_to } != {
3444+ jnp .float8_e4m3fn , jnp .float8_e5m2 ,
3445+ }:
3446+ self .skipTest ("An unimplemented f8 <-> f8 conversion" )
3447+ unimplemented = {
3448+ frozenset ((jnp .float8_e4m3fn , jnp .bfloat16 )),
3449+ frozenset ((jnp .float8_e5m2 , jnp .bfloat16 )),
3450+ frozenset ((jnp .float8_e8m0fnu , jnp .float16 )),
3451+ }
3452+ if {jax_dtype_from , jax_dtype_to } in unimplemented :
35003453 self .skipTest ("Unimplemented" )
35013454 layout = fa .tmem_native_layout (vec_len )
35023455 mlir_dtype_to = utils .dtype_to_ir_type (jax_dtype_to )
@@ -3516,7 +3469,10 @@ def kernel(ctx, inp, out, smem):
35163469 bits = self .prng .integers (
35173470 low = sample_iinfo .min , high = sample_iinfo .max , size = (m , n ), dtype = np .int32
35183471 ).astype (int_sample_dtype )
3519- values = jax .lax .bitcast_convert_type (bits , narrow_type ).astype (jax_dtype_from )
3472+ values = jax .lax .bitcast_convert_type (bits , narrow_type )
3473+ # A bunch of conversions are only supported for finite values.
3474+ values = values .at [jnp .isinf (values )].set (jnp .finfo (narrow_type ).max )
3475+ values = values .astype (jax_dtype_from )
35203476
35213477 expected = values .astype (jax_dtype_to )
35223478 res = mgpu .as_gpu_kernel (
@@ -3860,44 +3816,44 @@ def kernel(ctx, dst, _):
38603816 )
38613817 @jtu .thread_unsafe_test ()
38623818 def test_max (self , vec_size , dtype ):
3863- def kernel (ctx , src , src2 , dst , _ ):
3864- is_signed = utils .is_signed (dtype )
3865- src = fa .FragmentedArray .load_strided (src , vec_size = vec_size , is_signed = is_signed )
3866- src2 = fa .FragmentedArray .load_strided (src2 , vec_size = vec_size , is_signed = is_signed )
3867- src .max (src2 ).store_untiled (dst )
3868- x = self .prng .uniform (- 1 , 1 , (12 * 128 ,)).astype (dtype )
3869- y = self .prng .uniform (- 1 , 1 , (12 * 128 ,)).astype (dtype )
3870- f = mgpu .as_gpu_kernel (
3819+ def kernel (ctx , src , src2 , dst , _ ):
3820+ is_signed = utils .is_signed (dtype )
3821+ src = fa .FragmentedArray .load_strided (src , vec_size = vec_size , is_signed = is_signed )
3822+ src2 = fa .FragmentedArray .load_strided (src2 , vec_size = vec_size , is_signed = is_signed )
3823+ src .max (src2 ).store_untiled (dst )
3824+ x = self .prng .uniform (- 1 , 1 , (12 * 128 ,)).astype (dtype )
3825+ y = self .prng .uniform (- 1 , 1 , (12 * 128 ,)).astype (dtype )
3826+ f = mgpu .as_gpu_kernel (
38713827 kernel , (1 , 1 , 1 ), (128 , 1 , 1 ), (x , y ), x , ()
38723828 )
3873- with jtu .set_env (MOSAIC_GPU_DUMP_PTX = "1" ), self .capture_stdout () as ptx :
3874- z = f (x , y ).block_until_ready ()
3875- if dtype == jnp .float32 :
3876- dtype_short = "f32"
3877- elif dtype == jnp .float16 :
3878- dtype_short = "f16"
3879- elif dtype == jnp .bfloat16 :
3880- dtype_short = "bf16"
3881- elif jnp .issubdtype (dtype , jnp .signedinteger ):
3882- dtype_short = f"s{ dtypes .itemsize_bits (dtype )} "
3883- elif jnp .issubdtype (dtype , jnp .unsignedinteger ):
3884- dtype_short = f"u{ dtypes .itemsize_bits (dtype )} "
3885- else :
3886- raise NotImplementedError (f"Unsupported dtype: { dtype } " )
3887- ptx = ptx ()
3888- nan_modifier = ".NaN" if jnp .issubdtype (dtype , jnp .floating ) else ""
3889- instr = f"max{ nan_modifier } .{ dtype_short } "
3890- instr_double = f"max{ nan_modifier } .{ dtype_short } x2 "
3891- single_converts = ptx .count (instr )
3892- double_converts = ptx .count (instr_double )
3893- self .assertEqual (128 * (single_converts + 2 * double_converts ), 12 * 128 )
3894- if vec_size % 2 :
3895- self .assertGreater (single_converts , 0 )
3896- elif dtypes .itemsize_bits (dtype ) < 32 :
3897- # This, together with the assertion above, implies that all converts
3898- # happened through doubled operations.
3899- self .assertEqual (single_converts , 0 )
3900- np .testing .assert_array_equal (z , np .maximum (x , y ))
3829+ with jtu .set_env (MOSAIC_GPU_DUMP_PTX = "1" ), self .capture_stdout () as ptx :
3830+ z = f (x , y ).block_until_ready ()
3831+ if dtype == jnp .float32 :
3832+ dtype_short = "f32"
3833+ elif dtype == jnp .float16 :
3834+ dtype_short = "f16"
3835+ elif dtype == jnp .bfloat16 :
3836+ dtype_short = "bf16"
3837+ elif jnp .issubdtype (dtype , jnp .signedinteger ):
3838+ dtype_short = f"s{ dtypes .itemsize_bits (dtype )} "
3839+ elif jnp .issubdtype (dtype , jnp .unsignedinteger ):
3840+ dtype_short = f"u{ dtypes .itemsize_bits (dtype )} "
3841+ else :
3842+ raise NotImplementedError (f"Unsupported dtype: { dtype } " )
3843+ ptx = ptx ()
3844+ nan_modifier = ".NaN" if jnp .issubdtype (dtype , jnp .floating ) else ""
3845+ instr = f"max{ nan_modifier } .{ dtype_short } "
3846+ instr_double = f"max{ nan_modifier } .{ dtype_short } x2 "
3847+ single_converts = ptx .count (instr )
3848+ double_converts = ptx .count (instr_double )
3849+ self .assertEqual (128 * (single_converts + 2 * double_converts ), 12 * 128 )
3850+ if vec_size % 2 :
3851+ self .assertGreater (single_converts , 0 )
3852+ elif dtypes .itemsize_bits (dtype ) < 32 :
3853+ # This, together with the assertion above, implies that all converts
3854+ # happened through doubled operations.
3855+ self .assertEqual (single_converts , 0 )
3856+ np .testing .assert_array_equal (z , np .maximum (x , y ))
39013857
39023858 def test_splat_layout (self ):
39033859 m , n = 64 , 8
0 commit comments