diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 36a177959a..b1b875b8f9 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2041,7 +2041,7 @@ impl Tensor { self.flatten_(None::, None::) } - /// Returns the sub-tensor fixing the index at `i` on the first dimension. + /// Returns the sub-tensor fixing the index at `index` on the first dimension. /// /// ```rust /// use candle_core::{Tensor, Device}; @@ -2052,12 +2052,12 @@ impl Tensor { /// assert_eq!(t.to_vec1::()?, &[2., 3.]); /// # Ok::<(), candle_core::Error>(()) /// ``` - pub fn get(&self, i: usize) -> Result { + pub fn get(&self, index: usize) -> Result { let dims = self.dims(); if dims.is_empty() { Ok(self.clone()) } else { - self.narrow(0, i, 1)?.reshape(&dims[1..]) + self.narrow(0, index, 1)?.squeeze(0) } }