@@ -631,6 +631,174 @@ impl ScatteredCacheBuilder {
631631 }
632632}
633633
634+ /// KV-Cache using concatenation for append operations
635+ ///
636+ /// This implementation uses `Tensor::cat` instead of `slice_set` for updates,
637+ /// providing significant GPU performance improvements for autoregressive generation.
638+ ///
639+ /// # When to Use
640+ ///
641+ /// **Recommended for:**
642+ /// - GPU inference (CUDA, Metal)
643+ /// - Autoregressive generation (token-by-token decoding)
644+ ///
645+ /// **Use `KvCache` instead for:**
646+ /// - CPU-only inference
647+ /// - When you need fixed memory allocation upfront
648+ ///
649+ /// # Example
650+ ///
651+ /// ```ignore
652+ /// use candle_nn::kv_cache::ConcatKvCache;
653+ ///
654+ /// let mut cache = ConcatKvCache::new(2); // dim=2 for sequence dimension
655+ ///
656+ /// // First token (prefill)
657+ /// let k1 = Tensor::randn(0f32, 1., (1, 8, 10, 64), &device)?;
658+ /// let v1 = Tensor::randn(0f32, 1., (1, 8, 10, 64), &device)?;
659+ /// let (k, v) = cache.append(&k1, &v1)?;
660+ ///
661+ /// // Subsequent tokens (decode)
662+ /// let k_new = Tensor::randn(0f32, 1., (1, 8, 1, 64), &device)?;
663+ /// let v_new = Tensor::randn(0f32, 1., (1, 8, 1, 64), &device)?;
664+ /// let (k, v) = cache.append(&k_new, &v_new)?;
665+ /// ```
666+ #[ derive( Debug , Clone ) ]
667+ pub struct ConcatKvCache {
668+ k : Option < Tensor > ,
669+ v : Option < Tensor > ,
670+ dim : usize ,
671+ }
672+
673+ impl ConcatKvCache {
674+ /// Create a new empty concatenation-based KV-cache
675+ ///
676+ /// # Arguments
677+ /// * `dim` - The dimension along which to concatenate
678+ /// - For attention with shape `[batch, heads, seq, head_dim]`, use `dim=2`
679+ /// - For attention with shape `[batch, seq, heads, head_dim]`, use `dim=1`
680+ ///
681+ /// # Example
682+ /// ```ignore
683+ /// // For standard transformer attention: [B, H, S, D]
684+ /// let cache = ConcatKvCache::new(2);
685+ /// ```
686+ pub fn new ( dim : usize ) -> Self {
687+ Self {
688+ k : None ,
689+ v : None ,
690+ dim,
691+ }
692+ }
693+
694+ /// Get current sequence length in the cache
695+ ///
696+ /// Returns 0 if the cache is empty.
697+ pub fn current_seq_len ( & self ) -> usize {
698+ self . k
699+ . as_ref ( )
700+ . and_then ( |k| k. dims ( ) . get ( self . dim ) . copied ( ) )
701+ . unwrap_or ( 0 )
702+ }
703+
704+ /// Check if cache is empty
705+ pub fn is_empty ( & self ) -> bool {
706+ self . k . is_none ( )
707+ }
708+
709+ /// Get the concatenation dimension
710+ pub fn dim ( & self ) -> usize {
711+ self . dim
712+ }
713+
714+ /// Append key and value tensors to the cache
715+ ///
716+ /// This is the core operation that uses optimized concatenation kernels.
717+ ///
718+ /// # Arguments
719+ /// * `k` - Key tensor to append (shape: [..., seq_len, ...])
720+ /// * `v` - Value tensor to append (shape: [..., seq_len, ...])
721+ ///
722+ /// # Returns
723+ /// Tuple of `(full_k, full_v)` containing all cached keys and values,
724+ /// including the newly appended data.
725+ pub fn append ( & mut self , k : & Tensor , v : & Tensor ) -> Result < ( Tensor , Tensor ) > {
726+ // Ensure inputs are contiguous for optimal concatenation performance
727+ let k = k. contiguous ( ) ?;
728+ let v = v. contiguous ( ) ?;
729+ // Update K cache using concatenation
730+ self . k = Some ( match & self . k {
731+ None => k. clone ( ) ,
732+ Some ( k_cache) => {
733+ // Concatenate along the sequence dimension
734+ // GPU kernel for cat is highly optimized:
735+ // - Fused allocation + copy
736+ // - Coalesced memory access
737+ // - Single kernel launch
738+ Tensor :: cat ( & [ k_cache, & k] , self . dim ) ?
739+ }
740+ } ) ;
741+
742+ // Update V cache using concatenation
743+ self . v = Some ( match & self . v {
744+ None => v. clone ( ) ,
745+ Some ( v_cache) => Tensor :: cat ( & [ v_cache, & v] , self . dim ) ?,
746+ } ) ;
747+
748+ Ok ( (
749+ self . k . as_ref ( ) . unwrap ( ) . clone ( ) ,
750+ self . v . as_ref ( ) . unwrap ( ) . clone ( ) ,
751+ ) )
752+ }
753+
754+ /// Reset the cache (clear all stored keys and values)
755+ ///
756+ /// After calling this, `is_empty()` will return `true` and
757+ /// `current_seq_len()` will return 0.
758+ pub fn reset ( & mut self ) {
759+ self . k = None ;
760+ self . v = None ;
761+ }
762+
763+ /// Get reference to current K cache data
764+ ///
765+ /// Returns `None` if the cache is empty.
766+ pub fn k ( & self ) -> Option < & Tensor > {
767+ self . k . as_ref ( )
768+ }
769+
770+ /// Get reference to current V cache data
771+ ///
772+ /// Returns `None` if the cache is empty.
773+ pub fn v ( & self ) -> Option < & Tensor > {
774+ self . v . as_ref ( )
775+ }
776+
777+ /// Get mutable reference to K cache data
778+ ///
779+ /// Returns `None` if the cache is empty.
780+ pub fn k_mut ( & mut self ) -> Option < & mut Tensor > {
781+ self . k . as_mut ( )
782+ }
783+
784+ /// Get mutable reference to V cache data
785+ ///
786+ /// Returns `None` if the cache is empty.
787+ pub fn v_mut ( & mut self ) -> Option < & mut Tensor > {
788+ self . v . as_mut ( )
789+ }
790+
791+ /// Get owned K and V tensors, consuming the cache
792+ ///
793+ /// Returns `None` if the cache is empty.
794+ pub fn into_inner ( self ) -> Option < ( Tensor , Tensor ) > {
795+ match ( self . k , self . v ) {
796+ ( Some ( k) , Some ( v) ) => Some ( ( k, v) ) ,
797+ _ => None ,
798+ }
799+ }
800+ }
801+
634802#[ cfg( test) ]
635803mod tests {
636804 use super :: * ;
@@ -717,4 +885,102 @@ mod tests {
717885
718886 Ok ( ( ) )
719887 }
888+
889+ #[ test]
890+ fn test_concat_cache_basic ( ) -> Result < ( ) > {
891+ let device = Device :: Cpu ;
892+ let mut cache = ConcatKvCache :: new ( 2 ) ;
893+
894+ assert ! ( cache. is_empty( ) ) ;
895+ assert_eq ! ( cache. current_seq_len( ) , 0 ) ;
896+
897+ // First append
898+ let k1 = Tensor :: zeros ( ( 1 , 8 , 3 , 64 ) , DType :: F32 , & device) ?;
899+ let v1 = Tensor :: zeros ( ( 1 , 8 , 3 , 64 ) , DType :: F32 , & device) ?;
900+ let ( k, v) = cache. append ( & k1, & v1) ?;
901+
902+ assert_eq ! ( k. dims( ) , & [ 1 , 8 , 3 , 64 ] ) ;
903+ assert_eq ! ( v. dims( ) , & [ 1 , 8 , 3 , 64 ] ) ;
904+ assert_eq ! ( cache. current_seq_len( ) , 3 ) ;
905+ assert ! ( !cache. is_empty( ) ) ;
906+
907+ // Second append
908+ let k2 = Tensor :: zeros ( ( 1 , 8 , 2 , 64 ) , DType :: F32 , & device) ?;
909+ let v2 = Tensor :: zeros ( ( 1 , 8 , 2 , 64 ) , DType :: F32 , & device) ?;
910+ let ( k, v) = cache. append ( & k2, & v2) ?;
911+
912+ assert_eq ! ( k. dims( ) , & [ 1 , 8 , 5 , 64 ] ) ; // 3 + 2
913+ assert_eq ! ( v. dims( ) , & [ 1 , 8 , 5 , 64 ] ) ;
914+ assert_eq ! ( cache. current_seq_len( ) , 5 ) ;
915+
916+ Ok ( ( ) )
917+ }
918+
919+ #[ test]
920+ fn test_concat_cache_reset ( ) -> Result < ( ) > {
921+ let device = Device :: Cpu ;
922+ let mut cache = ConcatKvCache :: new ( 2 ) ;
923+
924+ let k = Tensor :: zeros ( ( 1 , 8 , 10 , 64 ) , DType :: F32 , & device) ?;
925+ let v = Tensor :: zeros ( ( 1 , 8 , 10 , 64 ) , DType :: F32 , & device) ?;
926+ cache. append ( & k, & v) ?;
927+
928+ assert_eq ! ( cache. current_seq_len( ) , 10 ) ;
929+
930+ cache. reset ( ) ;
931+
932+ assert ! ( cache. is_empty( ) ) ;
933+ assert_eq ! ( cache. current_seq_len( ) , 0 ) ;
934+ assert ! ( cache. k( ) . is_none( ) ) ;
935+ assert ! ( cache. v( ) . is_none( ) ) ;
936+
937+ Ok ( ( ) )
938+ }
939+
940+ #[ test]
941+ fn test_concat_cache_multiple_appends ( ) -> Result < ( ) > {
942+ let device = Device :: Cpu ;
943+ let mut cache = ConcatKvCache :: new ( 2 ) ;
944+
945+ // Simulate autoregressive generation
946+ let k_prefill = Tensor :: zeros ( ( 1 , 8 , 10 , 64 ) , DType :: F32 , & device) ?;
947+ let v_prefill = Tensor :: zeros ( ( 1 , 8 , 10 , 64 ) , DType :: F32 , & device) ?;
948+ cache. append ( & k_prefill, & v_prefill) ?;
949+
950+ assert_eq ! ( cache. current_seq_len( ) , 10 ) ;
951+
952+ // Decode phase: append one token at a time
953+ for i in 1 ..=5 {
954+ let k_token = Tensor :: zeros ( ( 1 , 8 , 1 , 64 ) , DType :: F32 , & device) ?;
955+ let v_token = Tensor :: zeros ( ( 1 , 8 , 1 , 64 ) , DType :: F32 , & device) ?;
956+ let ( k, v) = cache. append ( & k_token, & v_token) ?;
957+ assert_eq ! ( k. dims( ) [ 2 ] , 10 + i) ;
958+ assert_eq ! ( v. dims( ) [ 2 ] , 10 + i) ;
959+ }
960+
961+ assert_eq ! ( cache. current_seq_len( ) , 15 ) ;
962+
963+ Ok ( ( ) )
964+ }
965+
966+ #[ test]
967+ fn test_concat_cache_different_dim ( ) -> Result < ( ) > {
968+ let device = Device :: Cpu ;
969+ let mut cache = ConcatKvCache :: new ( 1 ) ; // Concatenate on dim 1 instead of 2
970+
971+ let k1 = Tensor :: zeros ( ( 1 , 3 , 8 , 64 ) , DType :: F32 , & device) ?;
972+ let v1 = Tensor :: zeros ( ( 1 , 3 , 8 , 64 ) , DType :: F32 , & device) ?;
973+ let ( k, _v) = cache. append ( & k1, & v1) ?;
974+
975+ assert_eq ! ( k. dims( ) , & [ 1 , 3 , 8 , 64 ] ) ;
976+
977+ let k2 = Tensor :: zeros ( ( 1 , 2 , 8 , 64 ) , DType :: F32 , & device) ?;
978+ let v2 = Tensor :: zeros ( ( 1 , 2 , 8 , 64 ) , DType :: F32 , & device) ?;
979+ let ( k, _v) = cache. append ( & k2, & v2) ?;
980+
981+ assert_eq ! ( k. dims( ) , & [ 1 , 5 , 8 , 64 ] ) ; // Concatenated on dim 1
982+ assert_eq ! ( cache. current_seq_len( ) , 5 ) ;
983+
984+ Ok ( ( ) )
985+ }
720986}
0 commit comments