@@ -200,8 +200,9 @@ def grid(meta):
200200 torch .manual_seed (0 )
201201 n , c , h , w = 4 , 3 , 224 , 224
202202 k , _ , r , s = 8 , c , 3 , 3
203- input = torch .randn (n , c , h , w , device = "cuda" )
204- filter = torch .randn (k , c , r , s , device = "cuda" )
203+ dtype = torch .float16
204+ input = torch .randn (n , c , h , w , dtype = dtype , device = "cuda" )
205+ filter = torch .randn (k , c , r , s , dtype = dtype , device = "cuda" )
205206 ninetoothed_output = conv2d (input , filter )
206207 torch_output = F .conv2d (input , filter )
207208 triton_output = triton_conv2d (input , filter )
@@ -233,8 +234,9 @@ def grid(meta):
233234 def benchmark (h , w , provider ):
234235 n , c , _ , _ = 64 , 3 , h , w
235236 k , _ , r , s = 64 , c , 3 , 3
236- input = torch .randn ((n , c , h , w ), device = "cuda" )
237- filter = torch .randn ((k , c , r , s ), device = "cuda" )
237+ dtype = torch .float16
238+ input = torch .randn ((n , c , h , w ), dtype = dtype , device = "cuda" )
239+ filter = torch .randn ((k , c , r , s ), dtype = dtype , device = "cuda" )
238240
239241 if provider == "ninetoothed" :
240242 ms = triton .testing .do_bench (lambda : conv2d (input , filter ))
0 commit comments