@@ -17,8 +17,8 @@ def _silu_and_mul_kernel(
1717 BLOCK_M : tl .constexpr ,
1818 BLOCK_N : tl .constexpr ,
1919):
20- stride_input_m = stride_input_m . to ( tl .int64 )
21- stride_output_m = stride_output_m . to ( tl .int64 )
20+ stride_input_m = tl . cast ( stride_input_m , dtype = tl .int64 )
21+ stride_output_m = tl . cast ( stride_output_m , dtype = tl .int64 )
2222
2323 tid = tl .program_id (0 )
2424 input_m_offsets = tid * BLOCK_M + tl .arange (0 , BLOCK_M )
@@ -53,7 +53,7 @@ def _silu_and_mul_kernel(
5353 )
5454
5555
56- def silu_and_mul_fwd (input , output ):
56+ def silu_and_mul_fwd (input : torch . Tensor , output ):
5757 stride_input_m = input .stride (0 )
5858 stride_input_n = input .stride (1 )
5959 stride_output_m = output .stride (0 )
@@ -88,13 +88,13 @@ def torch_silu_and_mul(input: torch.Tensor):
8888def test_silu_and_mul (M , N , dtype , device = "cuda" ):
8989 # create data
9090 X = torch .randn ((M , N ), dtype = dtype , device = device )
91-
91+ y_tri = torch . empty (( M , N // 2 ), dtype = dtype , device = device )
9292 # run
93- y_tri = silu_and_mul_fwd (X )
93+ silu_and_mul_fwd (X , y_tri )
9494 y_ref = torch_silu_and_mul (X )
9595
9696 # compare
9797 print ("type:" , y_tri .dtype , y_ref .dtype )
9898 print ("max delta:" , torch .max (torch .abs (y_tri - y_ref )))
99- assert torch .allclose (y_tri , y_ref , atol = 1e-6 , rtol = 0 )
99+ assert torch .allclose (y_tri , y_ref , atol = 1e-5 , rtol = 0 )
100100 return
0 commit comments