Skip to content

Commit 91fdc4d

Browse files
committed
torch
1 parent cce2960 commit 91fdc4d

File tree

23 files changed

+147
-55
lines changed

23 files changed

+147
-55
lines changed

crates/burn-autodiff/src/ops/bool_tensor.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,12 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
108108
B::bool_repeat_dim(tensor, dim, times)
109109
}
110110

111-
fn bool_unfold(tensor: BoolTensor<Self>, dim: usize, size: usize, step: usize) -> BoolTensor<Self> {
111+
fn bool_unfold(
112+
tensor: BoolTensor<Self>,
113+
dim: usize,
114+
size: usize,
115+
step: usize,
116+
) -> BoolTensor<Self> {
112117
B::bool_unfold(tensor, dim, size, step)
113118
}
114119
}

crates/burn-autodiff/src/ops/int_tensor.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,12 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
378378
B::int_cast(tensor, dtype)
379379
}
380380

381-
fn int_unfold(tensor: IntTensor<Self>, dim: usize, size: usize, step: usize) -> IntTensor<Self> {
381+
fn int_unfold(
382+
tensor: IntTensor<Self>,
383+
dim: usize,
384+
size: usize,
385+
step: usize,
386+
) -> IntTensor<Self> {
382387
B::int_unfold(tensor, dim, size, step)
383388
}
384389
}

crates/burn-autodiff/src/ops/tensor.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2563,7 +2563,12 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
25632563
// TODO: Implement float_prod and float_sum
25642564
// https://github.com/tracel-ai/burn/issues/1458
25652565

2566-
fn float_unfold(tensor: FloatTensor<Self>, dim: usize, size: usize, step: usize) -> FloatTensor<Self> {
2566+
fn float_unfold(
2567+
tensor: FloatTensor<Self>,
2568+
dim: usize,
2569+
size: usize,
2570+
step: usize,
2571+
) -> FloatTensor<Self> {
25672572
AutodiffTensor::new(B::float_unfold(tensor.primitive, dim, size, step))
25682573
}
25692574
}

crates/burn-cubecl/src/ops/base.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
use std::cmp::max;
21
use crate::{CubeRuntime, element::CubeElement, kernel, tensor::CubeTensor};
32
use burn_common::tensor::{ReshapeAction, reshape_action};
4-
use burn_tensor::{Shape, TensorData, quantization::{QTensorPrimitive, QuantLevel}};
3+
use burn_tensor::{
4+
Shape, TensorData,
5+
quantization::{QTensorPrimitive, QuantLevel},
6+
};
57
use cubecl::{server::CopyDescriptor, tensor_vectorization_factor};
8+
use std::cmp::max;
69

710
pub(crate) fn from_data<R: CubeRuntime>(data: TensorData, device: &R::Device) -> CubeTensor<R> {
811
let shape: Shape = (&data.shape).into();

crates/burn-cubecl/src/ops/bool_ops.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,12 @@ where
127127
kernel::flip::<R, BT, BT>(tensor, axes)
128128
}
129129

130-
fn bool_unfold(tensor: FloatTensor<Self>, dim: usize, size: usize, step: usize) -> FloatTensor<Self> {
130+
fn bool_unfold(
131+
tensor: FloatTensor<Self>,
132+
dim: usize,
133+
size: usize,
134+
step: usize,
135+
) -> FloatTensor<Self> {
131136
unfold(tensor, dim, size, step)
132137
}
133138
}

crates/burn-cubecl/src/ops/float_ops.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,12 @@ where
684684
}
685685
}
686686

687-
fn float_unfold(tensor: FloatTensor<Self>, dim: usize, size: usize, step: usize) -> FloatTensor<Self> {
687+
fn float_unfold(
688+
tensor: FloatTensor<Self>,
689+
dim: usize,
690+
size: usize,
691+
step: usize,
692+
) -> FloatTensor<Self> {
688693
unfold(tensor, dim, size, step)
689694
}
690695
}

crates/burn-cubecl/src/ops/int_ops.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,12 @@ where
662662
)
663663
}
664664

665-
fn int_unfold(tensor: FloatTensor<Self>, dim: usize, size: usize, step: usize) -> FloatTensor<Self> {
665+
fn int_unfold(
666+
tensor: FloatTensor<Self>,
667+
dim: usize,
668+
size: usize,
669+
step: usize,
670+
) -> FloatTensor<Self> {
666671
unfold(tensor, dim, size, step)
667672
}
668673
}

crates/burn-fusion/src/ops/boolean.rs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1-
use std::cmp::max;
2-
use burn_ir::{BaseOperationIr, BinaryOpIr, BoolOperationIr, CatOpIr, ExpandOpIr, FlipOpIr, HandleContainer, InitOperationIr, OperationIr, PermuteOpIr, RepeatDimOpIr, SliceAssignOpIr, SliceOpIr, SwapDimsOpIr, TensorIr, UnaryOpIr, UnfoldOpIr};
1+
use burn_ir::{
2+
BaseOperationIr, BinaryOpIr, BoolOperationIr, CatOpIr, ExpandOpIr, FlipOpIr, HandleContainer,
3+
InitOperationIr, OperationIr, PermuteOpIr, RepeatDimOpIr, SliceAssignOpIr, SliceOpIr,
4+
SwapDimsOpIr, TensorIr, UnaryOpIr, UnfoldOpIr,
5+
};
36
use burn_tensor::{
47
Device, Element, Shape, TensorData, TensorMetadata,
58
ops::{BoolTensor, BoolTensorOps, FloatTensor, IntTensor, binary_ops_shape},
69
};
10+
use std::cmp::max;
711
use std::marker::PhantomData;
812

913
use crate::{
@@ -747,7 +751,12 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
747751
out
748752
}
749753

750-
fn bool_unfold(tensor: BoolTensor<Self>, dim: usize, size: usize, step: usize) -> BoolTensor<Self> {
754+
fn bool_unfold(
755+
tensor: BoolTensor<Self>,
756+
dim: usize,
757+
size: usize,
758+
step: usize,
759+
) -> BoolTensor<Self> {
751760
#[derive(new, Debug)]
752761
struct UnfoldOps<B: FusionBackend> {
753762
desc: UnfoldOpIr,
@@ -757,11 +766,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
757766
impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> {
758767
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
759768
let input = handles.get_bool_tensor::<B>(&self.desc.input);
760-
let output = B::bool_unfold(
761-
input,
762-
self.desc.dim,
763-
self.desc.size,
764-
self.desc.step);
769+
let output = B::bool_unfold(input, self.desc.dim, self.desc.size, self.desc.step);
765770

766771
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
767772
}

crates/burn-fusion/src/ops/float.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use super::NoOp;
12
use crate::{
23
Fusion, FusionBackend, binary_float_cmp_ops, binary_float_ops,
34
client::FusionClient,
@@ -12,9 +13,8 @@ use burn_tensor::{
1213
Device, Distribution, Element, FloatDType, Shape, TensorData, TensorMetadata,
1314
ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor, binary_ops_shape},
1415
};
15-
use std::{marker::PhantomData, ops::Range};
1616
use std::cmp::max;
17-
use super::NoOp;
17+
use std::{marker::PhantomData, ops::Range};
1818

1919
impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
2020
fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {
@@ -2265,7 +2265,12 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
22652265
out
22662266
}
22672267

2268-
fn float_unfold(tensor: FloatTensor<Self>, dim: usize, size: usize, step: usize) -> FloatTensor<Self> {
2268+
fn float_unfold(
2269+
tensor: FloatTensor<Self>,
2270+
dim: usize,
2271+
size: usize,
2272+
step: usize,
2273+
) -> FloatTensor<Self> {
22692274
#[derive(new, Debug)]
22702275
struct UnfoldOps<B: FusionBackend> {
22712276
desc: UnfoldOpIr,
@@ -2275,11 +2280,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
22752280
impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> {
22762281
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
22772282
let input = handles.get_float_tensor::<B>(&self.desc.input);
2278-
let output = B::float_unfold(
2279-
input,
2280-
self.desc.dim,
2281-
self.desc.size,
2282-
self.desc.step);
2283+
let output = B::float_unfold(input, self.desc.dim, self.desc.size, self.desc.step);
22832284

22842285
handles.register_float_tensor::<B>(&self.desc.out.id, output);
22852286
}

crates/burn-fusion/src/ops/int.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2178,7 +2178,12 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
21782178
out
21792179
}
21802180

2181-
fn int_unfold(tensor: IntTensor<Self>, dim: usize, size: usize, step: usize) -> IntTensor<Self> {
2181+
fn int_unfold(
2182+
tensor: IntTensor<Self>,
2183+
dim: usize,
2184+
size: usize,
2185+
step: usize,
2186+
) -> IntTensor<Self> {
21822187
#[derive(new, Debug)]
21832188
struct UnfoldOps<B: FusionBackend> {
21842189
desc: UnfoldOpIr,
@@ -2188,11 +2193,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
21882193
impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> {
21892194
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
21902195
let input = handles.get_int_tensor::<B>(&self.desc.input);
2191-
let output = B::int_unfold(
2192-
input,
2193-
self.desc.dim,
2194-
self.desc.size,
2195-
self.desc.step);
2196+
let output = B::int_unfold(input, self.desc.dim, self.desc.size, self.desc.step);
21962197

21972198
handles.register_int_tensor::<B>(&self.desc.out.id, output);
21982199
}

0 commit comments

Comments
 (0)