@@ -6,23 +6,33 @@ use std::time::Instant;
66
77const B : usize = 1 ;
88const C : usize = 1 ;
9- const M : usize = 128 ;
10- const K : usize = 128 ;
11- const K_SIZE : usize = 3 ;
129
13- fn run ( input : Tensor , weight : Tensor , bias : Tensor , config : Conv2dConfig ) {
14- Conv2d :: new ( weight, Some ( bias) , config)
15- . forward ( & input)
16- . unwrap ( ) ;
10+ fn run ( input : Tensor , weight : Tensor , bias : Option < Tensor > , config : Conv2dConfig ) {
11+ Conv2d :: new ( weight, bias, config) . forward ( & input) . unwrap ( ) ;
1712}
1813
19- fn run_conv2d_benchmark ( c : & mut Criterion , device : & Device , dtype : DType , name : & str ) {
20- let weight = Tensor :: ones ( ( 1 , 1 , K_SIZE , K_SIZE ) , dtype, device)
14+ fn run_conv2d_benchmark (
15+ c : & mut Criterion ,
16+ device : & Device ,
17+ dtype : DType ,
18+ k_size : usize ,
19+ m : usize ,
20+ bias : bool ,
21+ ) {
22+ let weight = Tensor :: ones ( ( 1 , C , k_size, k_size) , dtype, device)
2123 . unwrap ( )
2224 . to_dtype ( dtype)
2325 . unwrap ( ) ;
24- let bias = Tensor :: zeros ( K , dtype, device) . unwrap ( ) ;
25- let input = Tensor :: ones ( ( B , C , M , K ) , dtype, device) . unwrap ( ) ;
26+ let bias_t = if bias {
27+ Some ( Tensor :: zeros ( m, dtype, device) . unwrap ( ) )
28+ } else {
29+ None
30+ } ;
31+ let input = Tensor :: ones ( ( B , C , m, m) , dtype, device) . unwrap ( ) ;
32+ let name = format ! (
33+ "conv2d_{dtype:?}_i{m}_k{k_size}x{k_size}_{}" ,
34+ if bias { "b" } else { "nb" }
35+ ) ;
2636
2737 let mut group = c. benchmark_group ( device. bench_name ( name) ) ;
2838 group. bench_function ( "iter" , move |b| {
@@ -32,7 +42,7 @@ fn run_conv2d_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name:
3242 run (
3343 black_box ( input. clone ( ) ) ,
3444 black_box ( weight. clone ( ) ) ,
35- black_box ( bias . clone ( ) ) ,
45+ black_box ( bias_t . clone ( ) ) ,
3646 Default :: default ( ) ,
3747 ) ;
3848 }
@@ -46,8 +56,14 @@ fn run_conv2d_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name:
4656fn criterion_benchmark ( c : & mut Criterion ) {
4757 let device = BenchDeviceHandler :: new ( ) . unwrap ( ) ;
4858 for d in device. devices {
49- run_conv2d_benchmark ( c, & d, DType :: F32 , "conv2d_f32" ) ;
50- run_conv2d_benchmark ( c, & d, DType :: F16 , "conv2d_f16" ) ;
59+ run_conv2d_benchmark ( c, & d, DType :: F32 , 3 , 128 , true ) ;
60+ run_conv2d_benchmark ( c, & d, DType :: F32 , 1 , 128 , false ) ;
61+ run_conv2d_benchmark ( c, & d, DType :: F32 , 5 , 128 , false ) ;
62+ run_conv2d_benchmark ( c, & d, DType :: F32 , 3 , 512 , false ) ;
63+ run_conv2d_benchmark ( c, & d, DType :: F16 , 3 , 128 , true ) ;
64+ run_conv2d_benchmark ( c, & d, DType :: F16 , 1 , 128 , false ) ;
65+ run_conv2d_benchmark ( c, & d, DType :: F16 , 5 , 128 , false ) ;
66+ run_conv2d_benchmark ( c, & d, DType :: F16 , 5 , 512 , false ) ;
5167 }
5268}
5369
0 commit comments