Skip to content

Commit fcd917c

Browse files
antimorawingertge
andauthored
Add step support to tensor slice operations (tracel-ai#3748)
* Add step support to tensor slice operations Extends tensor slicing to support custom step values, including negative steps for reversed slices, via a new SliceInfo struct. Updates all relevant backend, API, and macro code to use SliceInfo for slicing, and adapts documentation and trait signatures to reflect the new functionality. Backends that do not support steps panic if step != 1. This enables advanced slicing patterns such as strided and reversed selection. * Add support for slicing tensors with step in tch backend Refactored tensor slicing logic to handle arbitrary step values, including negative steps, in the tch backend. Updated bool, int, and float tensor ops to use the new slice_with_steps method, removing previous step=1 restriction and panic. * Add support for slicing tensors with steps Introduces the `slice_with_steps` function to enable slicing tensors with arbitrary step sizes, including negative steps, in `base.rs`. Updates `bool_tensor.rs`, `int_tensor.rs`, and `tensor.rs` to use the new slicing logic via `SliceInfo`, allowing more flexible and efficient tensor slicing operations. * Fix formatting * Add step slicing support in burn-cubecl * Fix format * Refactor tensor slicing to use SliceInfo Replaces usage of std::ops::Range with burn_tensor::SliceInfo for tensor slicing operations in qtensor and ring modules. This change improves consistency and prepares for more flexible slicing semantics across backends. * Fix format * Fix formatting * Refactor tensor slice ops to use SliceInfo Updated bool, float, int, and quantized tensor slice operations to accept and process SliceInfo instead of Range<usize>. This change enables support for slicing with steps and improves consistency across tensor types. * Refactor slice kernels to use LinearView Updated slice and slice_with_steps kernels to operate on LinearView types instead of raw Tensor references. This change improves memory access safety and consistency, and updates kernel launch calls to pass linear views for both input and output tensors. * Refactor tensor slice output shape calculation Centralizes and simplifies output shape calculation for tensor slicing with steps by introducing `calculate_slice_output_shape` in burn-tensor. Updates all relevant tensor ops to use this utility, improving code maintainability and consistency. Adds comprehensive unit tests for the new function. * Fix no-std * Add support for non-unit steps in ONNX Slice op This commit adds support for arbitrary step values (including negative steps) in the ONNX Slice operator, updating the parser, code generation, and tests. It introduces new test cases and models for slicing with steps, removes previous restrictions on steps, and ensures correct handling of steps in both tensor and shape slicing scenarios. * Fix lint warning * Refactor cubecl slice kernel * Fix no-std * Refactor tensor slice validation to use SliceInfo * Refactor slice_dim to use SliceInfo parameter Updated the slice_dim function to accept a SliceInfo parameter instead of a Range, allowing for more flexible slicing with custom step values. Adjusted documentation and internal usage to reflect this change. * Enforce unit step in slice_assign and slice_fill ops Updated tensor slicing APIs to require step=1 for slice_assign and slice_fill, panicking if non-unit steps are provided. Refactored internal usage to pass SliceInfo instead of Range, updated related ops (cat, repeat_dim), and added tests to verify panics on unsupported stepped slicing. * Add ignored test for matmul with stepped slice Introduces a new test 'should_diff_matmul_with_slice_stepped' to verify autodiff behavior when using stepped slices in matmul operations. The test is currently ignored due to lack of support for slice assignment with steps. * Refactor tensor slicing to use Slice instead of SliceInfo Replaces usage of SliceInfo with the new Slice type across all tensor slicing APIs, implementations, and internal logic. Updates function signatures, backend implementations, and related code to support the new Slice structure, improving consistency and flexibility for tensor slicing operations. * Add panic for unsupported slice step in Autodiff Added a check in float_slice to panic if any slice step is not 1, as Autodiff does not support slicing with step != 1. Also added a test to verify that slicing with a step triggers the expected panic. * Refactor slice range checks and update tests Improved slice range validation in TensorCheck by checking the raw end index before conversion, providing clearer error messages. Updated related tests to use the Slice API directly and replaced Slice::new with Slice::from_range_stepped for consistency. * Rename slice_infos to slices in tensor APIs Updated documentation and variable names from 'slice_infos' to 'slices' across tensor API traits and convolution module implementations for consistency and clarity. No functional changes were made. * Update tensor slicing docs to use 'slices' terminology * Refactor tensor slicing to use Slice::from with ranges Replaces multiple usages of Slice::new with Slice::from and Rust range syntax for tensor slicing in convolution weight gradient functions. This improves code readability and leverages more idiomatic Rust constructs. --------- Co-authored-by: Genna Wingert <[email protected]>
1 parent 8e0474c commit fcd917c

File tree

59 files changed

+1929
-446
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+1929
-446
lines changed

burn-book/src/building-blocks/tensor.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
166166
| `tensor.select(dim, indices)` | `tensor.index_select(dim, indices)` |
167167
| `tensor.select_assign(dim, indices, values)`| N/A |
168168
| `tensor.shape()` | `tensor.shape` |
169-
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
169+
| `tensor.slice(s![range;step])` | `tensor[(*ranges,)]` or `tensor[start:end:step]` |
170170
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
171171
| `tensor.slice_fill(ranges, value)` | `tensor[(*ranges,)] = value` |
172172
| `tensor.slice_dim(dim, range)` | N/A |
@@ -181,7 +181,6 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
181181
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
182182
| `tensor.unsqueeze_dims(dims)` | N/A |
183183

184-
185184
### Numeric Operations
186185

187186
Those operations are available for numeric tensor kinds: `Float` and `Int`.

crates/burn-autodiff/src/ops/bool_tensor.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
3232
B::bool_reshape(tensor, shape)
3333
}
3434

35-
fn bool_slice(tensor: BoolTensor<B>, ranges: &[core::ops::Range<usize>]) -> BoolTensor<B> {
36-
B::bool_slice(tensor, ranges)
35+
fn bool_slice(tensor: BoolTensor<B>, slices: &[burn_tensor::Slice]) -> BoolTensor<B> {
36+
B::bool_slice(tensor, slices)
3737
}
3838

3939
fn bool_empty(shape: Shape, device: &Device<B>) -> BoolTensor<B> {

crates/burn-autodiff/src/ops/int_tensor.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
2828
B::int_reshape(tensor, shape)
2929
}
3030

31-
fn int_slice(tensor: IntTensor<B>, ranges: &[core::ops::Range<usize>]) -> IntTensor<B> {
32-
B::int_slice(tensor, ranges)
31+
fn int_slice(tensor: IntTensor<B>, slices: &[burn_tensor::Slice]) -> IntTensor<B> {
32+
B::int_slice(tensor, slices)
3333
}
3434

3535
fn int_empty(

crates/burn-autodiff/src/ops/qtensor.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
use core::ops::Range;
2-
31
use burn_tensor::{
42
Device, Shape, TensorData,
53
backend::Backend,
@@ -84,7 +82,10 @@ impl<B: Backend, C: CheckpointStrategy> QTensorOps<Self> for Autodiff<B, C> {
8482
unimplemented!()
8583
}
8684

87-
fn q_slice(_tensor: QuantizedTensor<Self>, _ranges: &[Range<usize>]) -> QuantizedTensor<Self> {
85+
fn q_slice(
86+
_tensor: QuantizedTensor<Self>,
87+
_slices: &[burn_tensor::Slice],
88+
) -> QuantizedTensor<Self> {
8889
unimplemented!()
8990
}
9091

crates/burn-autodiff/src/ops/tensor.rs

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,24 +1130,31 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
11301130
}
11311131
}
11321132

1133-
fn float_slice(
1134-
tensor: FloatTensor<Self>,
1135-
ranges: &[core::ops::Range<usize>],
1136-
) -> FloatTensor<Self> {
1133+
fn float_slice(tensor: FloatTensor<Self>, slices: &[burn_tensor::Slice]) -> FloatTensor<Self> {
1134+
// Check if any slice has step != 1
1135+
for (i, slice) in slices.iter().enumerate() {
1136+
if slice.step != 1 {
1137+
panic!(
1138+
"Autodiff does not support slice with step != 1. Dimension {} has step {}",
1139+
i, slice.step
1140+
);
1141+
}
1142+
}
1143+
11371144
#[derive(Debug)]
11381145
struct Index;
11391146

11401147
#[derive(new, Debug)]
11411148
struct RetroSlice<B: Backend> {
11421149
tensor_id: NodeID,
1143-
ranges: Vec<core::ops::Range<usize>>,
1150+
slices: Vec<burn_tensor::Slice>,
11441151
_backend: PhantomData<B>,
11451152
}
11461153

11471154
impl<B: Backend> RetroForward for RetroSlice<B> {
11481155
fn forward(&self, states: &mut BackwardStates, out_node: NodeID) {
11491156
let tensor = states.get_state::<B::FloatTensorPrimitive>(&self.tensor_id);
1150-
let out = B::float_slice(tensor, &self.ranges);
1157+
let out = B::float_slice(tensor, &self.slices);
11511158
states.save(out_node, out)
11521159
}
11531160
}
@@ -1170,22 +1177,30 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
11701177
}
11711178
}
11721179

1180+
// Convert slices to ranges for backward compatibility in State
1181+
let shape = tensor.primitive.shape();
1182+
let ranges: Vec<core::ops::Range<usize>> = slices
1183+
.iter()
1184+
.enumerate()
1185+
.map(|(i, s)| s.to_range(shape.dims[i]))
1186+
.collect();
1187+
11731188
match Index
11741189
.prepare::<C>([tensor.node.clone()])
11751190
.memory_bound()
1176-
.retro_forward(RetroSlice::<B>::new(tensor.node.id, ranges.to_vec()))
1191+
.retro_forward(RetroSlice::<B>::new(tensor.node.id, slices.to_vec()))
11771192
.parents([&tensor])
11781193
.stateful()
11791194
{
11801195
OpsKind::Tracked(prep) => prep.finish(
11811196
(
1182-
ranges.to_vec(),
1197+
ranges,
11831198
tensor.primitive.shape(),
11841199
B::float_device(&tensor.primitive),
11851200
),
1186-
B::float_slice(tensor.primitive, ranges),
1201+
B::float_slice(tensor.primitive, slices),
11871202
),
1188-
OpsKind::UnTracked(prep) => prep.finish(B::float_slice(tensor.primitive, ranges)),
1203+
OpsKind::UnTracked(prep) => prep.finish(B::float_slice(tensor.primitive, slices)),
11891204
}
11901205
}
11911206

@@ -1234,7 +1249,16 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
12341249
let zeros = B::float_zeros(shape_rhs, &device, grad.dtype().into());
12351250
B::float_slice_assign(grad, &ranges_4lhs.unwrap(), zeros)
12361251
},
1237-
|grad| B::float_slice(grad, &ranges_4rhs.unwrap()),
1252+
|grad| {
1253+
let slices: Vec<burn_tensor::Slice> = ranges_4rhs
1254+
.unwrap()
1255+
.iter()
1256+
.map(|r| {
1257+
burn_tensor::Slice::new(r.start as isize, Some(r.end as isize), 1)
1258+
})
1259+
.collect();
1260+
B::float_slice(grad, &slices)
1261+
},
12381262
);
12391263
}
12401264
}
@@ -2122,7 +2146,13 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
21222146
let mut ranges = ranges.clone();
21232147
ranges[self.dim] = current_index..dim_size + current_index;
21242148
current_index += dim_size;
2125-
grads.register::<B>(node.id, B::float_slice(grad.clone(), &ranges));
2149+
let slices: Vec<burn_tensor::Slice> = ranges
2150+
.iter()
2151+
.map(|r| {
2152+
burn_tensor::Slice::new(r.start as isize, Some(r.end as isize), 1)
2153+
})
2154+
.collect();
2155+
grads.register::<B>(node.id, B::float_slice(grad.clone(), &slices));
21262156
});
21272157
}
21282158

crates/burn-autodiff/src/tests/slice.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,4 +130,46 @@ mod tests {
130130
.to_data()
131131
.assert_approx_eq::<FT>(&cat_grad_2.to_data(), Tolerance::default());
132132
}
133+
134+
#[test]
135+
#[ignore = "slice_assign with steps are not supported currently"]
136+
fn should_diff_matmul_with_slice_stepped() {
137+
use burn_tensor::s;
138+
let data_1 = TensorData::from([[1.0, 7.0], [100.0, 100.0], [2.0, 3.0], [100.0, 100.0]]);
139+
let data_2 = TensorData::from([[4.0, 100.0, 7.0, 100.0], [2.0, 100.0, 3.0, 15.0]]);
140+
141+
let device = Default::default();
142+
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
143+
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
144+
145+
let tensor_3 = tensor_1.clone().slice(s![0..;2, 0..2]); // [[1., 7.], [2., 3.]]
146+
let tensor_4 = tensor_2.clone().slice(s![0..2, 0..;2]); // [[4., 7.], [2., 3.]]
147+
let tensor_5 = tensor_3.clone().matmul(tensor_4);
148+
let grads = tensor_5.backward();
149+
150+
let grad_1 = tensor_1.grad(&grads).unwrap();
151+
let grad_2 = tensor_2.grad(&grads).unwrap();
152+
153+
grad_1.to_data().assert_eq(
154+
&TensorData::from([[11., 5.], [0., 0.], [11., 5.], [0., 0.]]),
155+
false,
156+
);
157+
grad_2.to_data().assert_eq(
158+
&TensorData::from([[3., 0., 3., 0.], [10., 0., 10., 0.]]),
159+
false,
160+
);
161+
}
162+
163+
#[test]
164+
#[should_panic(expected = "Autodiff does not support slice with step != 1")]
165+
fn should_panic_on_slice_with_step() {
166+
use burn_tensor::s;
167+
168+
let data = TensorData::from([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]);
169+
let device = Default::default();
170+
let tensor = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();
171+
172+
// This should panic because step is 2
173+
let _sliced = tensor.slice(s![.., 0..4; 2]);
174+
}
133175
}

crates/burn-candle/src/ops/base.rs

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,66 @@ pub fn slice(tensor: CandleTensor, ranges: &[std::ops::Range<usize>]) -> CandleT
105105
CandleTensor::new(narrow_tensor)
106106
}
107107

108+
pub fn slice_with_steps(tensor: CandleTensor, slices: &[burn_tensor::Slice]) -> CandleTensor {
109+
let mut result_tensor = tensor.tensor;
110+
111+
for (dim, slice) in slices.iter().enumerate() {
112+
if slice.step == 1 {
113+
// Use narrow for step=1 (more efficient)
114+
// Convert slice to range using tensor shape
115+
let dim_size = result_tensor.dim(dim).unwrap();
116+
let range = slice.to_range(dim_size);
117+
let start = range.start;
118+
let length = range.end - range.start;
119+
result_tensor = result_tensor.narrow(dim, start, length).unwrap();
120+
} else {
121+
// Use index_select for step != 1
122+
let dim_size = result_tensor.dim(dim).unwrap();
123+
let range = slice.to_range(dim_size);
124+
let start = range.start;
125+
let end = range.end;
126+
let step = slice.step;
127+
128+
// Generate indices based on step direction
129+
let indices_vec = if step > 0 {
130+
// Forward stepping
131+
let step_usize = step as usize;
132+
(start..end).step_by(step_usize).collect::<Vec<_>>()
133+
} else {
134+
// Backward stepping (negative step)
135+
let step_usize = step.unsigned_abs();
136+
// Start from end-1 and go backwards
137+
let mut indices = Vec::new();
138+
let mut idx = end - 1;
139+
while idx >= start && idx < end {
140+
// Check for underflow
141+
indices.push(idx);
142+
if idx >= step_usize {
143+
idx -= step_usize;
144+
} else {
145+
break;
146+
}
147+
}
148+
indices
149+
};
150+
151+
// Convert indices to tensor and use index_select
152+
let indices_len = indices_vec.len();
153+
let device = result_tensor.device();
154+
let indices = candle_core::Tensor::from_vec(
155+
indices_vec.iter().map(|&x| x as u32).collect::<Vec<_>>(),
156+
indices_len,
157+
device,
158+
)
159+
.unwrap();
160+
161+
result_tensor = result_tensor.index_select(&indices, dim).unwrap();
162+
}
163+
}
164+
165+
CandleTensor::new(result_tensor)
166+
}
167+
108168
pub fn slice_assign(
109169
tensor: CandleTensor,
110170
ranges: &[std::ops::Range<usize>],

crates/burn-candle/src/ops/bool_tensor.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<
5757
super::base::reshape(tensor, shape)
5858
}
5959

60-
fn bool_slice(tensor: BoolTensor<Self>, ranges: &[std::ops::Range<usize>]) -> BoolTensor<Self> {
61-
super::base::slice(tensor, ranges)
60+
fn bool_slice(tensor: BoolTensor<Self>, slices: &[burn_tensor::Slice]) -> BoolTensor<Self> {
61+
super::base::slice_with_steps(tensor, slices)
6262
}
6363

6464
fn bool_slice_assign(

crates/burn-candle/src/ops/int_tensor.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
4141
super::base::reshape(tensor, shape)
4242
}
4343

44-
fn int_slice(tensor: IntTensor<Self>, indices: &[std::ops::Range<usize>]) -> IntTensor<Self> {
45-
super::base::slice(tensor, indices)
44+
fn int_slice(tensor: IntTensor<Self>, slices: &[burn_tensor::Slice]) -> IntTensor<Self> {
45+
super::base::slice_with_steps(tensor, slices)
4646
}
4747

4848
fn int_slice_assign(

crates/burn-candle/src/ops/qtensor.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
use std::ops::Range;
2-
31
use burn_tensor::{
42
DType, Device, Shape, TensorData,
53
backend::Backend,
@@ -80,7 +78,10 @@ impl<F: FloatCandleElement, I: IntCandleElement> QTensorOps<Self> for Candle<F,
8078
unimplemented!()
8179
}
8280

83-
fn q_slice(_tensor: QuantizedTensor<Self>, _ranges: &[Range<usize>]) -> QuantizedTensor<Self> {
81+
fn q_slice(
82+
_tensor: QuantizedTensor<Self>,
83+
_slices: &[burn_tensor::Slice],
84+
) -> QuantizedTensor<Self> {
8485
unimplemented!()
8586
}
8687

0 commit comments

Comments
 (0)