Skip to content

Commit 41fa5f1

Browse files
authored
Add more conv2d bench cases to candle-nn benches (#3131)
1 parent bffa5e1 commit 41fa5f1

File tree

1 file changed

+30
-14
lines changed
  • candle-nn/benches/benchmarks

1 file changed

+30
-14
lines changed

candle-nn/benches/benchmarks/conv.rs

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,33 @@ use std::time::Instant;
66

77
const B: usize = 1;
88
const C: usize = 1;
9-
const M: usize = 128;
10-
const K: usize = 128;
11-
const K_SIZE: usize = 3;
129

13-
fn run(input: Tensor, weight: Tensor, bias: Tensor, config: Conv2dConfig) {
14-
Conv2d::new(weight, Some(bias), config)
15-
.forward(&input)
16-
.unwrap();
10+
fn run(input: Tensor, weight: Tensor, bias: Option<Tensor>, config: Conv2dConfig) {
11+
Conv2d::new(weight, bias, config).forward(&input).unwrap();
1712
}
1813

19-
fn run_conv2d_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
20-
let weight = Tensor::ones((1, 1, K_SIZE, K_SIZE), dtype, device)
14+
fn run_conv2d_benchmark(
15+
c: &mut Criterion,
16+
device: &Device,
17+
dtype: DType,
18+
k_size: usize,
19+
m: usize,
20+
bias: bool,
21+
) {
22+
let weight = Tensor::ones((1, C, k_size, k_size), dtype, device)
2123
.unwrap()
2224
.to_dtype(dtype)
2325
.unwrap();
24-
let bias = Tensor::zeros(K, dtype, device).unwrap();
25-
let input = Tensor::ones((B, C, M, K), dtype, device).unwrap();
26+
let bias_t = if bias {
27+
Some(Tensor::zeros(m, dtype, device).unwrap())
28+
} else {
29+
None
30+
};
31+
let input = Tensor::ones((B, C, m, m), dtype, device).unwrap();
32+
let name = format!(
33+
"conv2d_{dtype:?}_i{m}_k{k_size}x{k_size}_{}",
34+
if bias { "b" } else { "nb" }
35+
);
2636

2737
let mut group = c.benchmark_group(device.bench_name(name));
2838
group.bench_function("iter", move |b| {
@@ -32,7 +42,7 @@ fn run_conv2d_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name:
3242
run(
3343
black_box(input.clone()),
3444
black_box(weight.clone()),
35-
black_box(bias.clone()),
45+
black_box(bias_t.clone()),
3646
Default::default(),
3747
);
3848
}
@@ -46,8 +56,14 @@ fn run_conv2d_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name:
4656
fn criterion_benchmark(c: &mut Criterion) {
4757
let device = BenchDeviceHandler::new().unwrap();
4858
for d in device.devices {
49-
run_conv2d_benchmark(c, &d, DType::F32, "conv2d_f32");
50-
run_conv2d_benchmark(c, &d, DType::F16, "conv2d_f16");
59+
run_conv2d_benchmark(c, &d, DType::F32, 3, 128, true);
60+
run_conv2d_benchmark(c, &d, DType::F32, 1, 128, false);
61+
run_conv2d_benchmark(c, &d, DType::F32, 5, 128, false);
62+
run_conv2d_benchmark(c, &d, DType::F32, 3, 512, false);
63+
run_conv2d_benchmark(c, &d, DType::F16, 3, 128, true);
64+
run_conv2d_benchmark(c, &d, DType::F16, 1, 128, false);
65+
run_conv2d_benchmark(c, &d, DType::F16, 5, 128, false);
66+
run_conv2d_benchmark(c, &d, DType::F16, 5, 512, false);
5167
}
5268
}
5369

0 commit comments

Comments
 (0)