Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions candle-core/src/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,11 @@ impl Layout {
})
}

pub(crate) fn strided_index(&self) -> crate::StridedIndex {
pub(crate) fn strided_index(&self) -> crate::StridedIndex<'_> {
crate::StridedIndex::from_layout(self)
}

pub(crate) fn strided_blocks(&self) -> crate::StridedBlocks {
pub(crate) fn strided_blocks(&self) -> crate::StridedBlocks<'_> {
let mut block_len = 1;
let mut contiguous_dims = 0; // These are counted from the right.
for (&stride, &dim) in self.stride().iter().zip(self.dims().iter()).rev() {
Expand Down
2 changes: 1 addition & 1 deletion candle-core/src/quantized/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ impl QStorage {
}
}

fn data(&self) -> Result<Cow<[u8]>> {
fn data(&self) -> Result<Cow<'_, [u8]>> {
match self {
QStorage::Cpu(storage) => {
let data_ptr = storage.as_ptr();
Expand Down
4 changes: 2 additions & 2 deletions candle-core/src/safetensors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl st::View for Tensor {
self.shape().dims()
}

fn data(&self) -> Cow<[u8]> {
fn data(&self) -> Cow<'_, [u8]> {
// This copies data from GPU to CPU.
// TODO: Avoid the unwrap here.
Cow::Owned(convert_back(self).unwrap())
Expand All @@ -78,7 +78,7 @@ impl st::View for &Tensor {
self.dims()
}

fn data(&self) -> Cow<[u8]> {
fn data(&self) -> Cow<'_, [u8]> {
// This copies data from GPU to CPU.
// TODO: Avoid the unwrap here.
Cow::Owned(convert_back(self).unwrap())
Expand Down
4 changes: 2 additions & 2 deletions candle-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1742,15 +1742,15 @@ impl Tensor {

/// Returns an iterator over position of the elements in the storage when ranging over the
/// index tuples in lexicographic order.
pub fn strided_index(&self) -> crate::StridedIndex {
pub fn strided_index(&self) -> crate::StridedIndex<'_> {
self.layout.strided_index()
}

/// Similar to `strided_index` but returns the position of the start of each contiguous block
/// as well as the length of the contiguous blocks. For a contiguous tensor, the index iterator
/// will only return the start offset and the size would be the number of elements in the
/// tensor.
pub fn strided_blocks(&self) -> crate::StridedBlocks {
pub fn strided_blocks(&self) -> crate::StridedBlocks<'_> {
self.layout.strided_blocks()
}

Expand Down
2 changes: 1 addition & 1 deletion candle-examples/examples/chinese_clip/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ fn main() -> anyhow::Result<()> {
Ok(())
}

pub fn load_weights(model: Option<String>, device: &Device) -> anyhow::Result<nn::VarBuilder> {
pub fn load_weights(model: Option<String>, device: &Device) -> anyhow::Result<nn::VarBuilder<'_>> {
let model_file = match model {
None => {
let api = hf_hub::api::sync::Api::new()?;
Expand Down
5 changes: 3 additions & 2 deletions candle-examples/examples/clip/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,9 @@ pub fn main() -> anyhow::Result<()> {
],
};
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(std::slice::from_ref(&model_file), DType::F32, &device)?
};
let model = clip::ClipModel::new(vb, &config)?;
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
Expand Down
2 changes: 1 addition & 1 deletion candle-examples/examples/distilbert/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ impl Args {
Ok((config, tokenizer, weights))
}

fn load_variables(&self, weights_path: &PathBuf, device: &Device) -> Result<VarBuilder> {
fn load_variables(&self, weights_path: &PathBuf, device: &Device) -> Result<VarBuilder<'_>> {
if self.use_pth {
Ok(VarBuilder::from_pth(weights_path, DTYPE, device)?)
} else {
Expand Down
8 changes: 7 additions & 1 deletion candle-examples/examples/mobileclip/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,13 @@ pub fn main() -> anyhow::Result<()> {
let vb = if args.use_pth {
VarBuilder::from_pth(&model_file, DType::F32, &device)?
} else {
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? }
unsafe {
VarBuilder::from_mmaped_safetensors(
std::slice::from_ref(&model_file),
DType::F32,
&device,
)?
}
};

let model = mobileclip::MobileClipModel::new(vb, config)?;
Expand Down
5 changes: 4 additions & 1 deletion candle-examples/examples/segformer/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ enum Commands {
Classify(ClassificationArgs),
}

fn get_vb_and_config(model_name: String, device: &Device) -> anyhow::Result<(VarBuilder, Config)> {
fn get_vb_and_config(
model_name: String,
device: &Device,
) -> anyhow::Result<(VarBuilder<'_>, Config)> {
println!("loading model {model_name} via huggingface hub");
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name.clone());
Expand Down
5 changes: 3 additions & 2 deletions candle-examples/examples/siglip/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,9 @@ pub fn main() -> anyhow::Result<()> {
args.image_size.unwrap_or(config.vision_config.image_size),
)?
.to_device(&device)?;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(std::slice::from_ref(&model_file), DType::F32, &device)?
};
let model = siglip::Model::new(&config, vb)?;
let (input_ids, vec_seq) = tokenize_sequences(&config, args.sequences, &tokenizer, &device)?;
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
Expand Down
2 changes: 1 addition & 1 deletion candle-examples/examples/yolo-v3/darknet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ impl Darknet {
Ok(image_width)
}

pub fn build_model(&self, vb: VarBuilder) -> Result<Func> {
pub fn build_model(&self, vb: VarBuilder) -> Result<Func<'_>> {
let mut blocks: Vec<(usize, Bl)> = vec![];
let mut prev_channels: usize = 3;
for (index, block) in self.blocks.iter().enumerate() {
Expand Down
2 changes: 1 addition & 1 deletion candle-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ impl PyTensor {

compare(&self.0, &scalar_tensor)
} else {
return Err(PyTypeError::new_err("unsupported rhs for __richcmp__"));
Err(PyTypeError::new_err("unsupported rhs for __richcmp__"))
}
}

Expand Down
2 changes: 1 addition & 1 deletion candle-transformers/src/models/encodec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ impl<'a> Layer<'a> {
self.cnt += 1;
}

fn next(&mut self) -> VarBuilder {
fn next(&mut self) -> VarBuilder<'_> {
let vb = self.vb.pp(self.cnt.to_string());
self.cnt += 1;
vb
Expand Down
43 changes: 17 additions & 26 deletions candle-transformers/src/models/xlm_roberta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,34 +128,25 @@ impl XLMRobertaSelfAttention {
) -> Result<Tensor> {
let mixed_query_layer = self.query.forward(hidden_states)?;
let is_cross_attention = encoder_hidden_states.is_some();
let (key_layer, value_layer, attention_mask) = if is_cross_attention
&& past_key_value.is_some()
{
let key_layer = past_key_value.unwrap().0.clone();
let value_layer = past_key_value.unwrap().1.clone();
let attention_mask = encoder_attention_mask.unwrap().clone();
(key_layer, value_layer, Some(attention_mask))
} else if is_cross_attention {
let key_layer =
self.transpose_for_scores(&self.key.forward(encoder_hidden_states.unwrap())?)?;
let value_layer =
self.transpose_for_scores(&self.value.forward(encoder_hidden_states.unwrap())?)?;
let attention_mask = encoder_attention_mask.unwrap();
(key_layer, value_layer, Some(attention_mask.clone()))
} else if past_key_value.is_some() {
let (key_layer, value_layer, attention_mask) = if is_cross_attention {
if let Some((past_key, past_value)) = past_key_value {
let key_layer = past_key.clone();
let value_layer = past_value.clone();
let attention_mask = encoder_attention_mask.unwrap().clone();
(key_layer, value_layer, Some(attention_mask))
} else {
let key_layer =
self.transpose_for_scores(&self.key.forward(encoder_hidden_states.unwrap())?)?;
let value_layer = self
.transpose_for_scores(&self.value.forward(encoder_hidden_states.unwrap())?)?;
let attention_mask = encoder_attention_mask.unwrap();
(key_layer, value_layer, Some(attention_mask.clone()))
}
} else if let Some((past_key, past_value)) = past_key_value {
let mut key_layer = self.transpose_for_scores(&self.key.forward(hidden_states)?)?;
let mut value_layer = self.transpose_for_scores(&self.value.forward(hidden_states)?)?;
key_layer = Tensor::cat(
&[
past_key_value.clone().as_ref().unwrap().0.clone(),
key_layer,
],
2,
)?;
value_layer = Tensor::cat(
&[past_key_value.as_ref().unwrap().1.clone(), value_layer],
2,
)?;
key_layer = Tensor::cat(&[past_key.clone(), key_layer], 2)?;
value_layer = Tensor::cat(&[past_value.clone(), value_layer], 2)?;
(key_layer, value_layer, Some(attention_mask.clone()))
} else {
let key_layer = self.transpose_for_scores(&self.key.forward(hidden_states)?)?;
Expand Down
2 changes: 1 addition & 1 deletion candle-wasm-examples/llama2-c/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ impl TransformerWeights {
})
}

fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder> {
fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder<'_>> {
let mut ws = std::collections::HashMap::new();
let mut insert = |name: &str, t: Tensor| {
ws.insert(name.to_string(), t);
Expand Down
Loading