Skip to content

Commit cce2960

Browse files
committed
[WIP] towards pytorch.unfold()
1 parent 47b5fe8 commit cce2960

File tree

23 files changed

+492
-14
lines changed

23 files changed

+492
-14
lines changed

crates/burn-autodiff/src/ops/bool_tensor.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,4 +107,8 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
107107
fn bool_repeat_dim(tensor: BoolTensor<B>, dim: usize, times: usize) -> BoolTensor<B> {
108108
B::bool_repeat_dim(tensor, dim, times)
109109
}
110+
111+
fn bool_unfold(tensor: BoolTensor<Self>, dim: usize, size: usize, step: usize) -> BoolTensor<Self> {
112+
B::bool_unfold(tensor, dim, size, step)
113+
}
110114
}

crates/burn-autodiff/src/ops/int_tensor.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,4 +377,8 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
377377
fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
378378
B::int_cast(tensor, dtype)
379379
}
380+
381+
fn int_unfold(tensor: IntTensor<Self>, dim: usize, size: usize, step: usize) -> IntTensor<Self> {
382+
B::int_unfold(tensor, dim, size, step)
383+
}
380384
}

crates/burn-autodiff/src/ops/tensor.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2562,6 +2562,10 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
25622562

25632563
// TODO: Implement float_prod and float_sum
25642564
// https://github.com/tracel-ai/burn/issues/1458
2565+
2566+
fn float_unfold(tensor: FloatTensor<Self>, dim: usize, size: usize, step: usize) -> FloatTensor<Self> {
2567+
AutodiffTensor::new(B::float_unfold(tensor.primitive, dim, size, step))
2568+
}
25652569
}
25662570

25672571
#[derive(Debug, Clone)]

crates/burn-cubecl/src/ops/base.rs

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1+
use std::cmp::max;
12
use crate::{CubeRuntime, element::CubeElement, kernel, tensor::CubeTensor};
23
use burn_common::tensor::{ReshapeAction, reshape_action};
3-
use burn_tensor::{
4-
Shape, TensorData,
5-
quantization::{QTensorPrimitive, QuantLevel},
6-
};
4+
use burn_tensor::{Shape, TensorData, quantization::{QTensorPrimitive, QuantLevel}};
75
use cubecl::{server::CopyDescriptor, tensor_vectorization_factor};
86

97
pub(crate) fn from_data<R: CubeRuntime>(data: TensorData, device: &R::Device) -> CubeTensor<R> {
@@ -213,3 +211,46 @@ pub(crate) fn max_line_size_many<R: CubeRuntime>(tensors: &[&CubeTensor<R>], dim
213211

214212
vec.unwrap_or(0)
215213
}
214+
215+
/// Unfold windows along a dimension.
216+
///
217+
/// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
218+
/// where windows are advanced by `step` at each index.
219+
///
220+
/// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
221+
///
222+
/// # Arguments
223+
///
224+
/// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
225+
/// * `dim` - the dimension to unfold.
226+
/// * `size` - the size of each unfolded window.
227+
/// * `stride` - the step between each window.
228+
///
229+
/// # Returns
230+
///
231+
/// A tensor view with shape ``[pre=..., windows, size, post=...]``.
232+
pub fn unfold<R: CubeRuntime>(
233+
tensor: CubeTensor<R>,
234+
dim: usize,
235+
size: usize,
236+
step: usize,
237+
) -> CubeTensor<R> {
238+
let d_shape = tensor.shape.dims[dim];
239+
let d_stride = tensor.strides[dim];
240+
241+
let windows = max(0, (d_shape - size).div_ceil(step));
242+
243+
let mut shape = tensor.shape.clone();
244+
shape.dims[dim] = windows;
245+
shape.dims.insert(dim + 1, size);
246+
247+
let mut strides = tensor.strides.clone();
248+
strides[dim] = step * d_stride;
249+
strides.insert(dim + 1, d_stride);
250+
251+
CubeTensor {
252+
shape,
253+
strides,
254+
..tensor
255+
}
256+
}

crates/burn-cubecl/src/ops/bool_ops.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use burn_tensor::ops::{BoolTensor, BoolTensorOps, Device, FloatTensor, IntTensor
77
use burn_tensor::{Shape, TensorData};
88
use std::ops::Range;
99

10-
use super::{expand, numeric, permute};
10+
use super::{expand, numeric, permute, unfold};
1111

1212
impl<R, F, I, BT> BoolTensorOps<Self> for CubeBackend<R, F, I, BT>
1313
where
@@ -126,4 +126,8 @@ where
126126
fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
127127
kernel::flip::<R, BT, BT>(tensor, axes)
128128
}
129+
130+
fn bool_unfold(tensor: FloatTensor<Self>, dim: usize, size: usize, step: usize) -> FloatTensor<Self> {
131+
unfold(tensor, dim, size, step)
132+
}
129133
}

crates/burn-cubecl/src/ops/float_ops.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use super::{expand, numeric, permute};
1+
use super::{expand, numeric, permute, unfold};
22
use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};
33
use crate::kernel::unary_basic::BasicFloatUnaryKind;
44
use crate::kernel::{
@@ -683,4 +683,8 @@ where
683683
_ => unimplemented!("Unsupported floating point type cast"),
684684
}
685685
}
686+
687+
fn float_unfold(tensor: FloatTensor<Self>, dim: usize, size: usize, step: usize) -> FloatTensor<Self> {
688+
unfold(tensor, dim, size, step)
689+
}
686690
}

crates/burn-cubecl/src/ops/int_ops.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use self::unary_basic_int::BasicIntUnaryKind;
22

3-
use super::{expand, numeric, permute};
3+
use super::{expand, numeric, permute, unfold};
44
use crate::{
55
CubeBackend, CubeRuntime, FloatElement, IntElement,
66
kernel::{
@@ -661,4 +661,8 @@ where
661661
}
662662
)
663663
}
664+
665+
fn int_unfold(tensor: FloatTensor<Self>, dim: usize, size: usize, step: usize) -> FloatTensor<Self> {
666+
unfold(tensor, dim, size, step)
667+
}
664668
}

crates/burn-fusion/src/ops/boolean.rs

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
use burn_ir::{
2-
BaseOperationIr, BinaryOpIr, BoolOperationIr, CatOpIr, ExpandOpIr, FlipOpIr, HandleContainer,
3-
InitOperationIr, OperationIr, PermuteOpIr, RepeatDimOpIr, SliceAssignOpIr, SliceOpIr,
4-
SwapDimsOpIr, TensorIr, UnaryOpIr,
5-
};
1+
use std::cmp::max;
2+
use burn_ir::{BaseOperationIr, BinaryOpIr, BoolOperationIr, CatOpIr, ExpandOpIr, FlipOpIr, HandleContainer, InitOperationIr, OperationIr, PermuteOpIr, RepeatDimOpIr, SliceAssignOpIr, SliceOpIr, SwapDimsOpIr, TensorIr, UnaryOpIr, UnfoldOpIr};
63
use burn_tensor::{
74
Device, Element, Shape, TensorData, TensorMetadata,
85
ops::{BoolTensor, BoolTensorOps, FloatTensor, IntTensor, binary_ops_shape},
@@ -749,4 +746,54 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
749746

750747
out
751748
}
749+
750+
fn bool_unfold(tensor: BoolTensor<Self>, dim: usize, size: usize, step: usize) -> BoolTensor<Self> {
751+
#[derive(new, Debug)]
752+
struct UnfoldOps<B: FusionBackend> {
753+
desc: UnfoldOpIr,
754+
_b: PhantomData<B>,
755+
}
756+
757+
impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> {
758+
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
759+
let input = handles.get_bool_tensor::<B>(&self.desc.input);
760+
let output = B::bool_unfold(
761+
input,
762+
self.desc.dim,
763+
self.desc.size,
764+
self.desc.step);
765+
766+
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
767+
}
768+
}
769+
770+
let mut streams = OperationStreams::default();
771+
streams.tensor(&tensor);
772+
773+
let mut shape = tensor.shape().dims.clone();
774+
let d_shape = shape[dim];
775+
let windows = max(0, (d_shape - size).div_ceil(step));
776+
shape[dim] = windows;
777+
shape.insert(dim + 1, size);
778+
779+
let out = tensor
780+
.client
781+
.tensor_uninitialized(shape.clone(), tensor.dtype);
782+
783+
let desc = UnfoldOpIr {
784+
input: tensor.into_ir(),
785+
out: out.to_ir_out(),
786+
dim: dim,
787+
size: size,
788+
step: step,
789+
};
790+
791+
out.client.register(
792+
streams,
793+
OperationIr::BaseBool(BaseOperationIr::Unfold(desc.clone())),
794+
UnfoldOps::<B>::new(desc),
795+
);
796+
797+
out
798+
}
752799
}

crates/burn-fusion/src/ops/float.rs

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use burn_tensor::{
1313
ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor, binary_ops_shape},
1414
};
1515
use std::{marker::PhantomData, ops::Range};
16-
16+
use std::cmp::max;
1717
use super::NoOp;
1818

1919
impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
@@ -2264,4 +2264,54 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
22642264

22652265
out
22662266
}
2267+
2268+
fn float_unfold(tensor: FloatTensor<Self>, dim: usize, size: usize, step: usize) -> FloatTensor<Self> {
2269+
#[derive(new, Debug)]
2270+
struct UnfoldOps<B: FusionBackend> {
2271+
desc: UnfoldOpIr,
2272+
_b: PhantomData<B>,
2273+
}
2274+
2275+
impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> {
2276+
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2277+
let input = handles.get_float_tensor::<B>(&self.desc.input);
2278+
let output = B::float_unfold(
2279+
input,
2280+
self.desc.dim,
2281+
self.desc.size,
2282+
self.desc.step);
2283+
2284+
handles.register_float_tensor::<B>(&self.desc.out.id, output);
2285+
}
2286+
}
2287+
2288+
let mut streams = OperationStreams::default();
2289+
streams.tensor(&tensor);
2290+
2291+
let mut shape = tensor.shape().dims.clone();
2292+
let d_shape = shape[dim];
2293+
let windows = max(0, (d_shape - size).div_ceil(step));
2294+
shape[dim] = windows;
2295+
shape.insert(dim + 1, size);
2296+
2297+
let out = tensor
2298+
.client
2299+
.tensor_uninitialized(shape.clone(), tensor.dtype);
2300+
2301+
let desc = UnfoldOpIr {
2302+
input: tensor.into_ir(),
2303+
out: out.to_ir_out(),
2304+
dim: dim,
2305+
size: size,
2306+
step: step,
2307+
};
2308+
2309+
out.client.register(
2310+
streams,
2311+
OperationIr::BaseFloat(BaseOperationIr::Unfold(desc.clone())),
2312+
UnfoldOps::<B>::new(desc),
2313+
);
2314+
2315+
out
2316+
}
22672317
}

crates/burn-fusion/src/ops/int.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use burn_tensor::{
1111
ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps, binary_ops_shape},
1212
};
1313
use core::ops::Range;
14+
use std::cmp::max;
1415
use std::marker::PhantomData;
1516

1617
use super::NoOp;
@@ -2176,4 +2177,54 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
21762177

21772178
out
21782179
}
2180+
2181+
fn int_unfold(tensor: IntTensor<Self>, dim: usize, size: usize, step: usize) -> IntTensor<Self> {
2182+
#[derive(new, Debug)]
2183+
struct UnfoldOps<B: FusionBackend> {
2184+
desc: UnfoldOpIr,
2185+
_b: PhantomData<B>,
2186+
}
2187+
2188+
impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> {
2189+
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2190+
let input = handles.get_int_tensor::<B>(&self.desc.input);
2191+
let output = B::int_unfold(
2192+
input,
2193+
self.desc.dim,
2194+
self.desc.size,
2195+
self.desc.step);
2196+
2197+
handles.register_int_tensor::<B>(&self.desc.out.id, output);
2198+
}
2199+
}
2200+
2201+
let mut streams = OperationStreams::default();
2202+
streams.tensor(&tensor);
2203+
2204+
let mut shape = tensor.shape().dims.clone();
2205+
let d_shape = shape[dim];
2206+
let windows = max(0, (d_shape - size).div_ceil(step));
2207+
shape[dim] = windows;
2208+
shape.insert(dim + 1, size);
2209+
2210+
let out = tensor
2211+
.client
2212+
.tensor_uninitialized(shape.clone(), tensor.dtype);
2213+
2214+
let desc = UnfoldOpIr {
2215+
input: tensor.into_ir(),
2216+
out: out.to_ir_out(),
2217+
dim: dim,
2218+
size: size,
2219+
step: step,
2220+
};
2221+
2222+
out.client.register(
2223+
streams,
2224+
OperationIr::BaseInt(BaseOperationIr::Unfold(desc.clone())),
2225+
UnfoldOps::<B>::new(desc),
2226+
);
2227+
2228+
out
2229+
}
21792230
}

0 commit comments

Comments
 (0)