@@ -3528,32 +3528,29 @@ def kernel(In, Out, in_shape1: tl.constexpr, in_shape2: tl.constexpr, ou_shape1:
35283528
35293529@pytest .mark .interpreter
35303530@pytest .mark .parametrize ("dtype_str" , ["int32" , "int8" ])
3531- @pytest .mark .parametrize ("shape" , [(2 , 2 , 8 , 64 ), (4 , 4 , 4 , 4 )])
3531+ @pytest .mark .parametrize ("shape" , [(2 , 2 , 8 , 64 ), (4 , 4 , 4 , 16 )])
35323532@pytest .mark .parametrize ("perm" , list (itertools .permutations ([0 , 1 , 2 , 3 ])))
3533- def test_trans_4d (dtype_str , shape , perm , device ):
3533+ def test_trans_4d (dtype_str , shape , perm , device , with_allocator ):
35343534
35353535 @triton .jit
35363536 def kernel (In , Out , #
35373537 in_shape1 : tl .constexpr , in_shape2 : tl .constexpr , in_shape3 : tl .constexpr , in_shape4 : tl .constexpr ,
35383538 ou_shape1 : tl .constexpr , ou_shape2 : tl .constexpr , ou_shape3 : tl .constexpr , ou_shape4 : tl .constexpr ,
35393539 trans1 : tl .constexpr , trans2 : tl .constexpr , trans3 : tl .constexpr , trans4 : tl .constexpr ):
3540- in_ptr = tl .make_block_ptr (
3540+ in_desc = tl .make_tensor_descriptor (
35413541 base = In ,
3542- shape = (in_shape1 , in_shape2 , in_shape3 , in_shape4 ),
3543- strides = (in_shape4 * in_shape3 * in_shape2 , in_shape4 * in_shape3 , in_shape4 , 1 ),
3544- offsets = (0 , 0 , 0 , 0 ),
3545- block_shape = (in_shape1 , in_shape2 , in_shape3 , in_shape4 ),
3546- order = (3 , 2 , 1 , 0 ),
3542+ shape = [in_shape1 , in_shape2 , in_shape3 , in_shape4 ],
3543+ strides = [in_shape4 * in_shape3 * in_shape2 , in_shape4 * in_shape3 , in_shape4 , 1 ],
3544+ block_shape = [in_shape1 , in_shape2 , in_shape3 , in_shape4 ],
35473545 )
3548- out_ptr = tl .make_block_ptr (
3546+ out_desc = tl .make_tensor_descriptor (
35493547 base = Out ,
3550- shape = (ou_shape1 , ou_shape2 , ou_shape3 , ou_shape4 ),
3551- strides = (ou_shape4 * ou_shape3 * ou_shape2 , ou_shape4 * ou_shape3 , ou_shape4 , 1 ),
3552- offsets = (0 , 0 , 0 , 0 ),
3553- block_shape = (ou_shape1 , ou_shape2 , ou_shape3 , ou_shape4 ),
3554- order = (3 , 2 , 1 , 0 ),
3548+ shape = [ou_shape1 * ou_shape2 * ou_shape3 * ou_shape4 ],
3549+ strides = [1 ],
3550+ block_shape = [ou_shape1 * ou_shape2 * ou_shape3 * ou_shape4 ],
35553551 )
3556- tl .store (out_ptr , tl .load (in_ptr ).permute ((trans1 , trans2 , trans3 , trans4 )))
3552+ val = in_desc .load ([0 , 0 , 0 , 0 ]).permute ((trans1 , trans2 , trans3 , trans4 ))
3553+ out_desc .store ([0 ], val .reshape (out_desc .block_shape ))
35573554
35583555 input = torch .arange (math .prod (shape ), dtype = getattr (torch , dtype_str ), device = device ).reshape (shape )
35593556 expected = torch .permute (input , perm )
@@ -5145,7 +5142,7 @@ def kernel(ptr):
51455142 assert "Descriptor block shape must have at least 16 bytes" in str (e .value .__cause__ )
51465143
51475144
5148- def test_trans_reshape (device ):
5145+ def test_trans_reshape (device , with_allocator ):
51495146
51505147 @triton .jit
51515148 def kernel (in_base_ptr , out_base_ptr , IN_SHAPE0 : tl .constexpr , IN_SHAPE1 : tl .constexpr ):
0 commit comments