1- from typing import Callable , NamedTuple
1+ from typing import Callable , NamedTuple , Union , Iterable
22from timeit import Timer
33
44import torch
@@ -29,7 +29,11 @@ def benchmark(fn: Callable, *args, num_iterations: int = 10, **kwargs) -> Benchm
2929
3030
3131def benchmark_conv (
32- signal : Tensor , kernel : Tensor , bias : Tensor , padding : int = 0 , stride : int = 1
32+ signal : Tensor ,
33+ kernel : Tensor ,
34+ bias : Tensor ,
35+ padding : Union [int , Iterable [int ]] = 0 ,
36+ stride : Union [int , Iterable [int ]] = 1 ,
3337):
3438 print (f"Signal size: { signal .shape } " )
3539 print (f"Kernel size: { kernel .shape } " )
@@ -54,24 +58,27 @@ def benchmark_conv(
5458
5559print ("\n --- 1D Convolution ---" )
5660benchmark_conv (
57- signal = torch .randn (3 , 3 , 4096 ),
61+ signal = torch .randn (3 , 3 , 4091 ),
5862 kernel = torch .randn (2 , 3 , 1025 ),
5963 bias = torch .randn (2 ),
6064 padding = 512 ,
65+ stride = 3 ,
6166)
6267
6368print ("\n --- 2D Convolution ---" )
6469benchmark_conv (
65- signal = torch .randn (3 , 3 , 256 , 256 ),
66- kernel = torch .randn (2 , 3 , 21 , 21 ),
70+ signal = torch .randn (3 , 3 , 256 , 235 ),
71+ kernel = torch .randn (2 , 3 , 19 , 21 ),
6772 bias = torch .randn (2 ),
68- padding = 10 ,
73+ padding = (9 , 10 ),
74+ stride = (2 , 3 ),
6975)
7076
7177print ("\n --- 3D Convolution ---" )
7278benchmark_conv (
73- signal = torch .randn (3 , 3 , 64 , 64 , 64 ),
74- kernel = torch .randn (2 , 3 , 9 , 9 , 9 ),
79+ signal = torch .randn (3 , 3 , 64 , 72 , 61 ),
80+ kernel = torch .randn (2 , 3 , 5 , 7 , 9 ),
7581 bias = torch .randn (2 ),
76- padding = 4 ,
82+ padding = (2 , 3 , 4 ),
83+ stride = (1 , 2 , 3 )
7784)
0 commit comments