Skip to content

Commit da5498c

Browse files
Added GradStore::insert_id(id, grad)
1 parent 42bd33e commit da5498c

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

candle-core/src/backprop.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,11 @@ impl GradStore {
754754
self.0.insert(tensor.id(), grad)
755755
}
756756

757+
/// Insert a gradient tensor associated with the given tensor id, returning the previous gradient tensor if it existed
758+
pub fn insert_id(&mut self, id: TensorId, grad: Tensor) -> Option<Tensor> {
759+
self.0.insert(id, grad)
760+
}
761+
757762
/// Get the gradient tensor associated with the given tensor, or, if it does not exist,
758763
/// insert a tensor of zeroes, with the same shape and type as the given tensors and return it
759764
fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> {

0 commit comments

Comments
 (0)