@@ -1318,18 +1318,21 @@ def kernel(ctx, dst, _):
13181318 operator .ne ,
13191319 ],
13201320 dtype = [jnp .float32 , jnp .int32 , jnp .uint32 ],
1321+ rhs_is_literal = [False , True ]
13211322 )
1322- def test_comparison (self , op , dtype , m = 64 , n = 32 ):
1323+ def test_comparison (self , op , dtype , rhs_is_literal , m = 64 , n = 32 ):
13231324 def kernel (ctx , dst , _ ):
13241325 iota = iota_tensor (m , n , dtype )
1325- op (iota , iota + 1 ).store_untiled (dst )
1326+ rhs = 0 if rhs_is_literal else iota + 1
1327+ op (iota , rhs ).store_untiled (dst )
13261328
13271329 out_shape = jax .ShapeDtypeStruct ((m , n ), jnp .bool )
13281330 result = mgpu .as_gpu_kernel (
13291331 kernel , (1 , 1 , 1 ), (128 , 1 , 1 ), (), out_shape , ()
13301332 )()
13311333 iota = np .arange (m * n , dtype = dtype ).reshape (m , n )
1332- np .testing .assert_array_equal (result , op (iota , iota + 1 ))
1334+ rhs = rhs = 0 if rhs_is_literal else iota + 1
1335+ np .testing .assert_array_equal (result , op (iota , rhs ))
13331336
13341337 @parameterized .product (
13351338 op = [operator .and_ , operator .or_ , operator .xor ],
0 commit comments