Skip to content

Commit b1dbce0

Browse files
authored
Merge pull request #3062 from davenpi/fix/core-basics-example
Fix broken slice_scatter example in basics.rs
2 parents 65055f6 + 5d6407f commit b1dbce0

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

candle-core/examples/basics.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@ use candle_core::{Device, Tensor};
99

1010
fn main() -> Result<()> {
1111
let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &Device::Cpu)?;
12-
let b = Tensor::new(&[[88.0f32, 99.0]], &Device::Cpu)?;
12+
let b = Tensor::new(&[[88.0f32], [99.0]], &Device::Cpu)?;
1313
let new_a = a.slice_scatter(&b, 1, 2)?;
1414
assert_eq!(a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
15-
assert_eq!(new_a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
15+
assert_eq!(
16+
new_a.to_vec2::<f32>()?,
17+
[[0.0, 1.0, 88.0], [3.0, 4.0, 99.0]]
18+
);
1619
Ok(())
1720
}

0 commit comments

Comments
 (0)