Skip to content

Commit 2c7071e

Browse files
committed
chore: update burn to 0.20.0-pre.4 and refactor dims pattern usage
- Bumped `burn` and `burn-import` dependencies to `0.20.0-pre.4`. - Replaced direct tensor references with `.dims()` for consistent shape contract checks. - Updated examples `resnet_finetune` and `swin_tiny` to align with the new `burn` API. - Modified `Cargo.lock` to include updated dependencies and checksums.
1 parent 607fe7b commit 2c7071e

File tree

22 files changed

+1399
-437
lines changed

22 files changed

+1399
-437
lines changed

Cargo.lock

Lines changed: 1305 additions & 363 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ doc_markdown = "warn"
2424

2525

2626
[workspace.dependencies]
27-
burn = "^0.19.1"
28-
burn-import = "^0.19.1"
27+
burn = "^0.20.0-pre.4"
28+
burn-import = "^0.20.0-pre.4"
2929
dirs = "^6.0.0"
3030

3131
bimm-contracts = "^0.19.2"

crates/bimm-firehose-image/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ mod tests {
100100
"name": "data",
101101
"description": "TensorData representation of the image.",
102102
"data_type": {
103-
"type_name": "burn_tensor::tensor::data::TensorData"
103+
"type_name": "burn_backend::data::tensor::TensorData"
104104
}
105105
}
106106
],

crates/bimm/src/layers/blocks/cna.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ impl<B: Backend> CNA2d<B> {
232232
"in_height" = "out_height" * "height_stride",
233233
"in_width" = "out_width" * "width_stride"
234234
],
235-
&input,
235+
&input.dims(),
236236
&["batch", "out_height", "out_width"],
237237
&[
238238
("in_channels", self.in_channels()),
@@ -244,7 +244,7 @@ impl<B: Backend> CNA2d<B> {
244244

245245
assert_shape_contract_periodically!(
246246
["batch", "out_channels", "out_height", "out_width"],
247-
&x,
247+
&x.dims(),
248248
&[
249249
("batch", batch),
250250
("out_channels", self.out_channels()),
@@ -261,7 +261,7 @@ impl<B: Backend> CNA2d<B> {
261261

262262
assert_shape_contract_periodically!(
263263
["batch", "out_channels", "out_height", "out_width"],
264-
&x,
264+
&x.dims(),
265265
&[
266266
("batch", batch),
267267
("out_channels", self.out_channels()),

crates/bimm/src/layers/blocks/conv_norm.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ impl<B: Backend> ConvNorm2d<B> {
128128
"in_height" = "out_height" * "height_stride",
129129
"in_width" = "out_width" * "width_stride"
130130
],
131-
&input,
131+
&input.dims(),
132132
&["batch", "out_height", "out_width"],
133133
&[
134134
("in_channels", self.in_channels()),
@@ -142,7 +142,7 @@ impl<B: Backend> ConvNorm2d<B> {
142142

143143
assert_shape_contract_periodically!(
144144
["batch", "out_channels", "out_height", "out_width"],
145-
&x,
145+
&x.dims(),
146146
&[
147147
("batch", batch),
148148
("out_channels", self.out_channels()),

crates/bimm/src/layers/drop/drop_block.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ fn drop_block_2d_drop_filter_<B: Backend>(
285285
kernel_shape: [usize; 2],
286286
partial_edge_blocks: bool,
287287
) -> Tensor<B, 4> {
288-
let [_, _, h, w] = unpack_shape_contract!(["b", "c", "h", "w"], &selected_blocks);
288+
let [_, _, h, w] = unpack_shape_contract!(["b", "c", "h", "w"], &selected_blocks.dims());
289289
let [kh, kw] = kernel_shape;
290290

291291
assert!(

crates/bimm/src/layers/patching/patch_embed.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ impl<B: Backend> PatchEmbed<B> {
189189
) -> Tensor<B, 3> {
190190
assert_shape_contract_periodically!(
191191
["batch", "d_input", "height", "width"],
192-
&x,
192+
&x.dims(),
193193
&[
194194
("d_input", self.d_input()),
195195
("height", self.input_height()),
@@ -202,7 +202,7 @@ impl<B: Backend> PatchEmbed<B> {
202202
let x = self.projection.forward(x);
203203
assert_shape_contract_periodically!(
204204
["batch", "d_output", "patches_height", "patches_width"],
205-
&x,
205+
&x.dims(),
206206
&[
207207
("batch", batch),
208208
("d_output", self.d_output()),
@@ -215,7 +215,7 @@ impl<B: Backend> PatchEmbed<B> {
215215
let x = x.swap_dims(1, 2);
216216
assert_shape_contract_periodically!(
217217
["batch", "num_patches", "d_output"],
218-
&x,
218+
&x.dims(),
219219
&[
220220
("batch", batch),
221221
("num_patches", self.num_patches()),
@@ -229,7 +229,7 @@ impl<B: Backend> PatchEmbed<B> {
229229
};
230230
assert_shape_contract_periodically!(
231231
["batch", "num_patches", "d_output"],
232-
&x,
232+
&x.dims(),
233233
&[
234234
("batch", batch),
235235
("num_patches", self.num_patches()),

crates/bimm/src/models/resnet/basic_block.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ impl<B: Backend> BasicBlock<B> {
321321
"in_height" = "out_height" * "stride",
322322
"in_width" = "out_width" * "stride"
323323
],
324-
&input,
324+
&input.dims(),
325325
&["batch", "out_height", "out_width"],
326326
&[("in_planes", self.in_planes()), ("stride", self.stride())],
327327
);
@@ -354,7 +354,7 @@ impl<B: Backend> BasicBlock<B> {
354354
#[cfg(debug_assertions)]
355355
bimm_contracts::assert_shape_contract_periodically!(
356356
["batch", "first_planes", "out_height", "out_width"],
357-
&x,
357+
&x.dims(),
358358
&[
359359
("batch", batch),
360360
("first_planes", self.first_planes()),
@@ -377,7 +377,7 @@ impl<B: Backend> BasicBlock<B> {
377377
});
378378

379379
#[cfg(debug_assertions)]
380-
bimm_contracts::assert_shape_contract_periodically!(OUT_CONTRACT, &x, &out_bindings);
380+
bimm_contracts::assert_shape_contract_periodically!(OUT_CONTRACT, &x.dims(), &out_bindings);
381381

382382
x
383383
}
@@ -482,7 +482,7 @@ mod tests {
482482

483483
assert_shape_contract!(
484484
["batch", "out_channels", "out_height", "out_width"],
485-
&output,
485+
&output.dims(),
486486
&[
487487
("batch", batch_size),
488488
("out_channels", out_planes),
@@ -520,7 +520,7 @@ mod tests {
520520

521521
assert_shape_contract!(
522522
["batch", "out_channels", "out_height", "out_width"],
523-
&output,
523+
&output.dims(),
524524
&[
525525
("batch", batch_size),
526526
("out_channels", out_planes),

crates/bimm/src/models/resnet/bottleneck_block.rs

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ impl<B: Backend> BottleneckBlock<B> {
421421
"in_height" = "out_height" * "stride",
422422
"in_width" = "out_width" * "stride"
423423
],
424-
&input,
424+
&input.dims(),
425425
&["batch", "in_height", "out_height", "in_width", "out_width"],
426426
&[("in_planes", self.in_planes()), ("stride", self.stride())],
427427
);
@@ -441,13 +441,17 @@ impl<B: Backend> BottleneckBlock<B> {
441441
("out_height", out_height),
442442
("out_width", out_width),
443443
];
444-
bimm_contracts::assert_shape_contract_periodically!(OUT_CONTRACT, &identity, &out_bindings);
444+
bimm_contracts::assert_shape_contract_periodically!(
445+
OUT_CONTRACT,
446+
&identity.dims(),
447+
&out_bindings
448+
);
445449

446450
let x = self.cna1.forward(input);
447451

448452
bimm_contracts::assert_shape_contract_periodically!(
449453
["batch", "pinch_planes", "in_height", "in_width"],
450-
&x,
454+
&x.dims(),
451455
&[
452456
("batch", batch),
453457
("pinch_planes", self.planes()),
@@ -463,7 +467,7 @@ impl<B: Backend> BottleneckBlock<B> {
463467

464468
bimm_contracts::assert_shape_contract_periodically!(
465469
["batch", "width", "out_height", "out_width"],
466-
&x,
470+
&x.dims(),
467471
&[
468472
("batch", batch),
469473
("width", self.width()),
@@ -475,7 +479,11 @@ impl<B: Backend> BottleneckBlock<B> {
475479
// TODO: anti-aliasing
476480

477481
self.cna3.map_forward(x, |x| {
478-
bimm_contracts::assert_shape_contract_periodically!(OUT_CONTRACT, &x, &out_bindings);
482+
bimm_contracts::assert_shape_contract_periodically!(
483+
OUT_CONTRACT,
484+
&x.dims(),
485+
&out_bindings
486+
);
479487

480488
// TODO: attention
481489

@@ -591,7 +599,7 @@ mod tests {
591599

592600
assert_shape_contract!(
593601
["batch", "out_channels", "out_height", "out_width"],
594-
&output,
602+
&output.dims(),
595603
&[
596604
("batch", batch_size),
597605
("out_channels", out_planes),
@@ -631,7 +639,7 @@ mod tests {
631639

632640
assert_shape_contract!(
633641
["batch", "out_channels", "out_height", "out_width"],
634-
&output,
642+
&output.dims(),
635643
&[
636644
("batch", batch_size),
637645
("out_channels", out_planes),

crates/bimm/src/models/resnet/downsample.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ impl<B: Backend> ResNetDownsample<B> {
208208
) -> Tensor<B, 4> {
209209
let [batch, in_height, in_width] = unpack_shape_contract!(
210210
["batch", "in_channels", "in_height", "in_width",],
211-
&input,
211+
&input.dims(),
212212
&["batch", "in_height", "in_width"],
213213
&[("in_channels", self.in_channels()),]
214214
);
@@ -220,7 +220,7 @@ impl<B: Backend> ResNetDownsample<B> {
220220

221221
assert_shape_contract_periodically!(
222222
["batch", "out_channels", "out_height", "out_width"],
223-
&out,
223+
&out.dims(),
224224
&[
225225
("batch", batch),
226226
("out_channels", self.out_channels()),
@@ -275,7 +275,7 @@ mod tests {
275275

276276
assert_shape_contract!(
277277
["batch", "out_channels", "out_height", "out_width"],
278-
&out,
278+
&out.dims(),
279279
&[
280280
("batch", batch_size),
281281
("out_channels", out_channels),

0 commit comments

Comments
 (0)