Skip to content

Commit 47b5fe8

Browse files
authored
Fix various clippy lints (tracel-ai#3766)
1 parent 3a2a1a8 commit 47b5fe8

File tree

15 files changed

+96
-113
lines changed

15 files changed

+96
-113
lines changed

crates/burn-cubecl/src/kernel/conv/conv_transpose2d/col2im.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -233,11 +233,6 @@ struct Col2ImArgs {
233233
}
234234

235235
#[cube(launch_unchecked)]
236-
#[allow(unknown_lints, reason = "manual_is_multiple_of is from Rust 1.89.0")]
237-
#[expect(
238-
clippy::manual_is_multiple_of,
239-
reason = "cubecl cannot expand is_multiple_of"
240-
)]
241236
fn col2im_kernel<E: Numeric>(
242237
columns: &Tensor<E>,
243238
bias: &Tensor<E>,
@@ -277,7 +272,8 @@ fn col2im_kernel<E: Numeric>(
277272
for col_x in x_col_start..x_col_end {
278273
let kernel_x = im_x - col_x * args.stride_w;
279274

280-
if kernel_y % args.dilation_h == 0 && kernel_x % args.dilation_w == 0 {
275+
if kernel_y.is_multiple_of(args.dilation_h) && kernel_x.is_multiple_of(args.dilation_w)
276+
{
281277
let kernel_y = kernel_y / args.dilation_h;
282278
let kernel_x = kernel_x / args.dilation_w;
283279

crates/burn-cubecl/src/kernel/conv/conv_transpose2d/transpose_direct.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,6 @@ struct ConvArgs {
2424
}
2525

2626
#[cube(launch)]
27-
#[allow(unknown_lints, reason = "manual_is_multiple_of is from Rust 1.89.0")]
28-
#[expect(
29-
clippy::manual_is_multiple_of,
30-
reason = "cubecl cannot expand is_multiple_of"
31-
)]
3227
fn conv_transpose2d_direct_kernel<E: Numeric>(
3328
input: &Tensor<E>,
3429
weight: &Tensor<E>,
@@ -87,7 +82,7 @@ fn conv_transpose2d_direct_kernel<E: Numeric>(
8782
let numerator_tmp = in_y * args.conv_stride_0;
8883
let numerator_h = numerator_h_base - numerator_tmp;
8984

90-
if numerator_h_base >= numerator_tmp && numerator_h % args.dilation_0 == 0 {
85+
if numerator_h_base >= numerator_tmp && numerator_h.is_multiple_of(args.dilation_0) {
9186
let kernel_y = numerator_h / args.dilation_0;
9287
let idx_input_y = in_y * input.stride(2);
9388
let idx_weight_ky = kernel_y * weight.stride(2);
@@ -96,7 +91,9 @@ fn conv_transpose2d_direct_kernel<E: Numeric>(
9691
let numerator_tmp = in_x * args.conv_stride_1;
9792
let numerator_w = numerator_w_base - numerator_tmp;
9893

99-
if numerator_w_base >= numerator_tmp && numerator_w % args.dilation_1 == 0 {
94+
if numerator_w_base >= numerator_tmp
95+
&& numerator_w.is_multiple_of(args.dilation_1)
96+
{
10097
let kernel_x = numerator_w / args.dilation_1;
10198
let idx_input_x = in_x * input.stride(3);
10299
let idx_weight_kx = kernel_x * weight.stride(3);

crates/burn-cubecl/src/kernel/conv/conv_transpose3d.rs

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,6 @@ struct ConvArgs {
2727
}
2828

2929
#[cube(launch)]
30-
#[allow(unknown_lints, reason = "manual_is_multiple_of is from Rust 1.89.0")]
31-
#[expect(
32-
clippy::manual_is_multiple_of,
33-
reason = "cubecl cannot expand is_multiple_of"
34-
)]
3530
fn conv_transpose3d_kernel<E: Numeric>(
3631
input: &Tensor<E>,
3732
weight: &Tensor<E>,
@@ -98,7 +93,7 @@ fn conv_transpose3d_kernel<E: Numeric>(
9893
let numerator_tmp = in_z * args.conv_stride_0;
9994
let numerator_d = numerator_d_base - numerator_tmp;
10095

101-
if numerator_d_base >= numerator_tmp && numerator_d % args.dilation_0 == 0 {
96+
if numerator_d_base >= numerator_tmp && numerator_d.is_multiple_of(args.dilation_0) {
10297
let kernel_z = numerator_d / args.dilation_0;
10398
let index_input_z = in_z * input.stride(2);
10499
let index_weight_kz = kernel_z * weight.stride(2);
@@ -107,7 +102,9 @@ fn conv_transpose3d_kernel<E: Numeric>(
107102
let numerator_tmp = in_y * args.conv_stride_1;
108103
let numerator_h = numerator_h_base - numerator_tmp;
109104

110-
if numerator_h_base >= numerator_tmp && numerator_h % args.dilation_1 == 0 {
105+
if numerator_h_base >= numerator_tmp
106+
&& numerator_h.is_multiple_of(args.dilation_1)
107+
{
111108
let kernel_y = numerator_h / args.dilation_1;
112109
let index_input_y = in_y * input.stride(3);
113110
let index_weight_ky = kernel_y * weight.stride(3);
@@ -117,7 +114,7 @@ fn conv_transpose3d_kernel<E: Numeric>(
117114
let numerator_w = numerator_w_base - numerator_tmp;
118115

119116
if numerator_w_base >= numerator_tmp
120-
&& numerator_w % args.dilation_2 == 0
117+
&& numerator_w.is_multiple_of(args.dilation_2)
121118
{
122119
let kernel_x = numerator_w / args.dilation_2;
123120
let index_input_x = in_x * input.stride(4);

crates/burn-cubecl/src/kernel/conv/deform_conv_transpose2d.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -270,11 +270,6 @@ struct DeformConv2dCol2ImgCoordArgs<F: Float> {
270270

271271
#[expect(clippy::collapsible_if)]
272272
#[cube(launch_unchecked)]
273-
#[allow(unknown_lints, reason = "manual_is_multiple_of is from Rust 1.89.0")]
274-
#[expect(
275-
clippy::manual_is_multiple_of,
276-
reason = "cubecl cannot expand is_multiple_of"
277-
)]
278273
fn deform_col2img_coord_kernel<F: Float>(
279274
image: &Tensor<F>,
280275
offset: &Tensor<F>,
@@ -328,7 +323,7 @@ fn deform_col2img_coord_kernel<F: Float>(
328323
let mask_base_idx = (b * n_offset_groups + offset_group) * kernel_h * kernel_w * out_h * out_w;
329324

330325
let offset_c = c - offset_group * 2 * kernel_h * kernel_w;
331-
let is_y_direction = offset_c % 2 == 0;
326+
let is_y_direction = offset_c.is_multiple_of(2);
332327

333328
let c_bound = channels_per_offset_group * kernel_h * kernel_w;
334329

crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,10 @@ fn start_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32
6464
(output_size_index * input_size) / output_size
6565
}
6666

67-
#[allow(unknown_lints)] // `manual_div_ceil` only appeared in 1.83
68-
#[allow(clippy::manual_div_ceil)]
6967
#[cube]
7068
fn end_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 {
7169
let index = (output_size_index + 1) * input_size;
72-
let index = (index + output_size - 1) / output_size;
70+
let index = index.div_ceil(output_size);
7371

7472
if input_size < index {
7573
input_size

crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d_backward.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,10 @@ fn start_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32
7272
(output_size_index * input_size) / output_size
7373
}
7474

75-
#[allow(unknown_lints)] // `manual_div_ceil` only appeared in 1.83
76-
#[allow(clippy::manual_div_ceil)]
7775
#[cube]
7876
fn end_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 {
7977
let index = (output_size_index + 1) * input_size;
80-
let index = (index + output_size - 1) / output_size;
78+
let index = index.div_ceil(output_size);
8179

8280
if input_size < index {
8381
input_size

crates/burn-store/src/filter.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ mod tests {
465465
fn container_predicates() {
466466
// Filter that matches only Linear module weights
467467
let linear_weights = PathFilter::new().with_predicate(|path, container_path| {
468-
container_path.split('.').last() == Some("Linear") && path.ends_with(".weight")
468+
container_path.split('.').next_back() == Some("Linear") && path.ends_with(".weight")
469469
});
470470

471471
assert!(linear_weights.matches_with_container("layer1.weight", "Linear"));
@@ -474,7 +474,7 @@ mod tests {
474474

475475
// Filter for specific container types
476476
let conv_only = PathFilter::new().with_predicate(|_path, container_path| {
477-
let last = container_path.split('.').last();
477+
let last = container_path.split('.').next_back();
478478
last == Some("Conv2d") || last == Some("ConvTranspose2d")
479479
});
480480

@@ -486,7 +486,7 @@ mod tests {
486486
let combined = PathFilter::new()
487487
.with_predicate(|path, _container_path| path.starts_with("encoder."))
488488
.with_predicate(|_path, container_path| {
489-
container_path.split('.').last() == Some("BatchNorm2d")
489+
container_path.split('.').next_back() == Some("BatchNorm2d")
490490
});
491491

492492
// Should match either condition (OR logic)
@@ -503,7 +503,8 @@ mod tests {
503503
let filter = PathFilter::new()
504504
.with_regex(r"^encoder\..*")
505505
.with_predicate(|path, container_path| {
506-
container_path.split('.').last() == Some("Linear") && path.contains(".bias")
506+
container_path.split('.').next_back() == Some("Linear")
507+
&& path.contains(".bias")
507508
});
508509

509510
// Matches due to regex
@@ -607,7 +608,7 @@ mod tests {
607608
// Only weights in Linear layers that are inside blocks
608609
path.ends_with(".weight")
609610
&& container_path.contains("Block")
610-
&& container_path.split('.').last() == Some("Linear")
611+
&& container_path.split('.').next_back() == Some("Linear")
611612
});
612613

613614
assert!(

crates/burn-store/src/safetensors/tests/adapter.rs

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,17 @@ fn pytorch_to_burn_adapter_linear_transpose() {
3434

3535
// Load with PyTorchToBurn adapter (will transpose back)
3636
let mut load_store = SafetensorsStore::from_bytes(None).with_from_adapter(PyTorchToBurnAdapter);
37-
if let SafetensorsStore::Memory(ref mut p) = load_store {
38-
if let SafetensorsStore::Memory(ref p_save) = save_store {
39-
p.set_data(p_save.data().unwrap().as_ref().clone());
40-
}
37+
if let SafetensorsStore::Memory(ref mut p) = load_store
38+
&& let SafetensorsStore::Memory(ref p_save) = save_store
39+
{
40+
p.set_data(p_save.data().unwrap().as_ref().clone());
4141
}
4242

4343
let mut model2 = TestModel::<TestBackend>::new(&device);
4444
let result = model2.apply_from(&mut load_store).unwrap();
4545

4646
// Should successfully load all tensors
47-
assert!(result.applied.len() > 0);
47+
assert!(!result.applied.is_empty());
4848

4949
// Verify the linear weights are the same after round-trip
5050
let weight1 = model.linear.weight.val().to_data();
@@ -93,17 +93,17 @@ fn pytorch_to_burn_adapter_norm_rename() {
9393

9494
// Load with PyTorchToBurn adapter (will rename weight->gamma, bias->beta)
9595
let mut load_store = SafetensorsStore::from_bytes(None).with_from_adapter(PyTorchToBurnAdapter);
96-
if let SafetensorsStore::Memory(ref mut p) = load_store {
97-
if let SafetensorsStore::Memory(ref p_save) = save_store {
98-
p.set_data(p_save.data().unwrap().as_ref().clone());
99-
}
96+
if let SafetensorsStore::Memory(ref mut p) = load_store
97+
&& let SafetensorsStore::Memory(ref p_save) = save_store
98+
{
99+
p.set_data(p_save.data().unwrap().as_ref().clone());
100100
}
101101

102102
let mut model2 = NormModel::<TestBackend>::new(&device);
103103
let result = model2.apply_from(&mut load_store).unwrap();
104104

105105
// Should load successfully
106-
assert!(result.applied.len() > 0);
106+
assert!(!result.applied.is_empty());
107107

108108
// Verify data is preserved
109109
let gamma1 = model.norm_gamma.val().to_data().to_vec::<f32>().unwrap();
@@ -126,17 +126,17 @@ fn no_adapter_preserves_original() {
126126

127127
// Load without adapter
128128
let mut load_store = SafetensorsStore::from_bytes(None);
129-
if let SafetensorsStore::Memory(ref mut p) = load_store {
130-
if let SafetensorsStore::Memory(ref p_save) = save_store {
131-
p.set_data(p_save.data().unwrap().as_ref().clone());
132-
}
129+
if let SafetensorsStore::Memory(ref mut p) = load_store
130+
&& let SafetensorsStore::Memory(ref p_save) = save_store
131+
{
132+
p.set_data(p_save.data().unwrap().as_ref().clone());
133133
}
134134

135135
let mut model2 = TestModel::<TestBackend>::new(&device);
136136
let result = model2.apply_from(&mut load_store).unwrap();
137137

138138
assert!(result.is_success());
139-
assert!(result.applied.len() > 0);
139+
assert!(!result.applied.is_empty());
140140

141141
// Verify data is exactly the same
142142
let weight1 = model.linear.weight.val().to_data();
@@ -187,5 +187,5 @@ fn adapter_with_pytorch_import() {
187187

188188
// Should load some tensors (fc1 if it exists in the file)
189189
// This mainly tests that the adapter works with real PyTorch files
190-
assert!(result.applied.len() > 0 || result.missing.len() > 0);
190+
assert!(!result.applied.is_empty() || !result.missing.is_empty());
191191
}

crates/burn-store/src/safetensors/tests/error_handling.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ fn shape_mismatch_errors() {
2323

2424
// Load without validation - should return errors in the result
2525
let mut load_store = SafetensorsStore::from_bytes(None).validate(false); // Disable validation to get errors in result
26-
if let SafetensorsStore::Memory(ref mut p) = load_store {
27-
if let SafetensorsStore::Memory(ref p_save) = save_store {
28-
// Get Arc and extract data
29-
let data_arc = p_save.data().unwrap();
30-
p.set_data(data_arc.as_ref().clone());
31-
}
26+
if let SafetensorsStore::Memory(ref mut p) = load_store
27+
&& let SafetensorsStore::Memory(ref p_save) = save_store
28+
{
29+
// Get Arc and extract data
30+
let data_arc = p_save.data().unwrap();
31+
p.set_data(data_arc.as_ref().clone());
3232
}
3333

3434
let result = incompatible_module.apply_from(&mut load_store).unwrap();
@@ -38,12 +38,12 @@ fn shape_mismatch_errors() {
3838

3939
// Try again with validation enabled - should return Err
4040
let mut load_store_with_validation = SafetensorsStore::from_bytes(None).validate(true);
41-
if let SafetensorsStore::Memory(ref mut p) = load_store_with_validation {
42-
if let SafetensorsStore::Memory(ref p_save) = save_store {
43-
// Get Arc and extract data
44-
let data_arc = p_save.data().unwrap();
45-
p.set_data(data_arc.as_ref().clone());
46-
}
41+
if let SafetensorsStore::Memory(ref mut p) = load_store_with_validation
42+
&& let SafetensorsStore::Memory(ref p_save) = save_store
43+
{
44+
// Get Arc and extract data
45+
let data_arc = p_save.data().unwrap();
46+
p.set_data(data_arc.as_ref().clone());
4747
}
4848

4949
let validation_result = incompatible_module.apply_from(&mut load_store_with_validation);

crates/burn-store/src/safetensors/tests/filtering.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,18 @@ fn filtered_export_import() {
1717

1818
// Import filtered tensors - need to allow partial since we only saved encoder tensors
1919
let mut load_store = SafetensorsStore::from_bytes(None).allow_partial(true);
20-
if let SafetensorsStore::Memory(ref mut p) = load_store {
21-
if let SafetensorsStore::Memory(ref p_save) = save_store {
22-
// Get Arc and extract data
23-
let data_arc = p_save.data().unwrap();
24-
p.set_data(data_arc.as_ref().clone());
25-
}
20+
if let SafetensorsStore::Memory(ref mut p) = load_store
21+
&& let SafetensorsStore::Memory(ref p_save) = save_store
22+
{
23+
// Get Arc and extract data
24+
let data_arc = p_save.data().unwrap();
25+
p.set_data(data_arc.as_ref().clone());
2626
}
2727
let result = module2.apply_from(&mut load_store).unwrap();
2828

2929
assert!(result.is_success());
3030
assert_eq!(result.applied.len(), 3); // encoder.weight, encoder.bias, encoder.norm
31-
assert!(result.missing.len() > 0); // decoder and layers tensors are missing
31+
assert!(!result.missing.is_empty()); // decoder and layers tensors are missing
3232
}
3333

3434
#[test]

0 commit comments

Comments
 (0)