@@ -1740,47 +1740,34 @@ def kernel(X, Y, Z, N: tl.constexpr):
17401740
17411741@pytest .mark .interpreter
17421742@pytest .mark .parametrize ("dtype_str" , list (torch_dtypes ))
1743+ @pytest .mark .parametrize ("constant_field" , ["value" , "mask" ])
17431744@pytest .mark .parametrize ("num_ctas" , num_ctas_list )
1744- def test_store_constant (dtype_str , num_ctas , device ):
1745+ def test_store_constant (num_ctas , dtype_str , constant_field , device ):
17451746 check_type_supported (dtype_str , device )
1746- """Tests that boolean True is stored as 1"""
17471747
17481748 @triton .jit
1749- def kernel (output_ptr , n_elements , BLOCK_SIZE : tl .constexpr ):
1749+ def kernel (output_ptr , n_elements , BLOCK_SIZE : tl .constexpr , CONSTANT_FIELD : tl . constexpr ):
17501750 offsets = tl .program_id (axis = 0 ) * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
1751- mask = offsets < n_elements
1752- output = GENERATE_TEST_HERE
1751+ if CONSTANT_FIELD == "value" :
1752+ value = 1
1753+ output = tl .full ([BLOCK_SIZE ], value = value , dtype = value .dtype )
1754+ mask = offsets < n_elements
1755+ elif CONSTANT_FIELD == "mask" :
1756+ output = offsets < n_elements
1757+ mask = False
17531758 tl .store (output_ptr + offsets , output , mask = mask )
17541759
1755- triton_dtype_str = 'uint8' if dtype_str == 'bool' else dtype_str
1756- kernel = patch_kernel (kernel , {'GENERATE_TEST_HERE' : f'tl.zeros([BLOCK_SIZE], dtype=tl.{ triton_dtype_str } ) + 1' })
17571760 block_size = 128
17581761 ref = torch .ones ([block_size ], dtype = getattr (torch , dtype_str ), device = device )
17591762 output = torch .zeros ([block_size ], dtype = getattr (torch , dtype_str ), device = device )
1760- kernel [(1 , )](output , block_size , BLOCK_SIZE = block_size , num_ctas = num_ctas )
1761-
1762- assert torch .all (output == ref )
17631763
1764+ kernel [(1 , )](output , block_size , BLOCK_SIZE = block_size , num_ctas = num_ctas , CONSTANT_FIELD = constant_field )
17641765
1765- @pytest .mark .interpreter
1766- @pytest .mark .parametrize ("num_ctas" , num_ctas_list )
1767- def test_store_constant_default_dtype (num_ctas , device ):
1768- """Tests that boolean True is stored as 1"""
1769-
1770- @triton .jit
1771- def kernel (output_ptr , n_elements , BLOCK_SIZE : tl .constexpr ):
1772- offsets = tl .program_id (axis = 0 ) * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
1773- mask = offsets < n_elements
1774- value = 1
1775- output = tl .full ([BLOCK_SIZE ], value = value , dtype = value .dtype )
1776- tl .store (output_ptr + offsets , output , mask = mask )
1777-
1778- block_size = 128
1779- ref = torch .ones ([block_size ], dtype = getattr (torch , 'int32' ), device = device )
1780- output = torch .zeros ([block_size ], dtype = getattr (torch , 'int32' ), device = device )
1781- kernel [(1 , )](output , block_size , BLOCK_SIZE = block_size , num_ctas = num_ctas )
1782-
1783- assert torch .all (output == ref )
1766+ if constant_field == "value" :
1767+ print (output , ref )
1768+ assert torch .all (output == ref )
1769+ else :
1770+ assert torch .all (output == 0 )
17841771
17851772
17861773def test_load_store_same_ptr (device ):
0 commit comments