diff --git a/examples/llama_disco_example.py b/examples/llama_disco_example.py new file mode 100644 index 000000000000..997bcaaec47a --- /dev/null +++ b/examples/llama_disco_example.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Example script demonstrating DISCO algorithm usage with LLaMA model in PaddleNLP. + +DISCO (Dynamic Score-based Cache Optimization) provides adaptive KV cache management +for efficient LLM inference. +""" + +import paddle +from paddlenlp.transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer + +def run_disco_example(): + """Run a simple example with DISCO-enabled LLaMA model.""" + + # Configuration with DISCO enabled + config = LlamaConfig( + # Model parameters + vocab_size=32000, + hidden_size=1024, # Smaller size for example + num_hidden_layers=12, + num_attention_heads=16, + intermediate_size=2816, + + # DISCO parameters + use_disco=True, + disco_cache_size=512, # Total cache budget + disco_window_size=32, # Recent token window + disco_gamma=0.1, # Variance weight factor + disco_score_func_path=None, # Optional: path to learned scoring function + disco_layer_budget=None, # Optional: per-layer budget allocation + ) + + # Initialize model with DISCO + model = LlamaForCausalLM(config) + model.eval() + + # Mock tokenizer for demonstration + class MockTokenizer: + def __init__(self): + self.pad_token_id = 0 + self.eos_token_id = 1 + + def __call__(self, text, return_tensors="pd"): + # Simple mock tokenization + tokens = list(range(2, 102)) # 100 tokens + return {"input_ids": paddle.to_tensor([tokens])} + + def decode(self, token_ids): + return f"Generated {len(token_ids)} tokens" + + tokenizer = MockTokenizer() + + # Example input + input_text = "This is a long document that will test the DISCO cache management..." + inputs = tokenizer(input_text, return_tensors="pd") + + print("Running LLaMA with DISCO cache management...") + print(f"- Cache size: {config.disco_cache_size}") + print(f"- Window size: {config.disco_window_size}") + print(f"- Gamma: {config.disco_gamma}") + + # Generate with DISCO-managed cache + with paddle.no_grad(): + outputs = model.generate( + input_ids=inputs["input_ids"], + max_length=200, + num_beams=1, + do_sample=False, + use_cache=True, # DISCO will manage the cache + ) + + # Decode output + generated_text = tokenizer.decode(outputs[0]) + print(f"\nGenerated: {generated_text}") + + # Check cache usage + if hasattr(model.llama, 'disco_cache') and model.llama.disco_cache is not None: + disco_cache = model.llama.disco_cache + print("\nDISCO Cache Statistics:") + for layer_idx in range(config.num_hidden_layers): + seq_len = disco_cache.get_seq_length(layer_idx) + budget = disco_cache.layer_budget[layer_idx] + print(f" Layer {layer_idx}: {seq_len}/{budget} tokens cached") + + +def compare_with_standard_cache(): + """Compare DISCO with standard KV cache.""" + + print("\n" + "="*60) + print("Comparing DISCO vs Standard Cache") + print("="*60) + + # Standard configuration + standard_config = LlamaConfig( + vocab_size=32000, + hidden_size=1024, + num_hidden_layers=12, + num_attention_heads=16, + intermediate_size=2816, + use_disco=False, # Standard cache + ) + + # DISCO configuration + disco_config = LlamaConfig( + vocab_size=32000, + hidden_size=1024, + num_hidden_layers=12, + num_attention_heads=16, + intermediate_size=2816, + use_disco=True, + disco_cache_size=256, # Smaller cache for comparison + disco_window_size=16, + disco_gamma=0.15, + ) + + # Initialize models + standard_model = LlamaForCausalLM(standard_config) + disco_model = LlamaForCausalLM(disco_config) + + print("\nStandard Cache: Full KV storage") + print(f"DISCO Cache: Adaptive storage with {disco_config.disco_cache_size} token budget") + + # You can add actual performance comparisons here + print("\nDISCO advantages:") + print("- Reduced memory usage (3.2% of original in paper)") + print("- Adaptive layer-wise allocation") + print("- Maintains model quality with intelligent eviction") + + +if __name__ == "__main__": + print("DISCO (Dynamic Score-based Cache Optimization) Example") + print("======================================================\n") + + # Run basic example + run_disco_example() + + # Compare with standard cache + compare_with_standard_cache() + + print("\n✅ DISCO example completed successfully!") \ No newline at end of file diff --git a/paddlenlp/transformers/llama/configuration.py b/paddlenlp/transformers/llama/configuration.py index d2095049c064..4e7010e2c624 100644 --- a/paddlenlp/transformers/llama/configuration.py +++ b/paddlenlp/transformers/llama/configuration.py @@ -159,6 +159,13 @@ def __init__( use_last_token_for_generation=False, immediate_clear_past_key_value=False, dpo_config=None, + # DISCO algorithm configuration + use_disco=False, + disco_window_size=32, + disco_gamma=0.1, + disco_cache_size=1024, + disco_score_func_path=None, + disco_layer_budget=None, **kwargs, ): self.vocab_size = vocab_size @@ -198,6 +205,14 @@ def __init__( self.immediate_clear_past_key_value = immediate_clear_past_key_value self.dpo_config = dpo_config + # DISCO algorithm parameters + self.use_disco = use_disco + self.disco_window_size = disco_window_size + self.disco_gamma = disco_gamma + self.disco_cache_size = disco_cache_size + self.disco_score_func_path = disco_score_func_path + self.disco_layer_budget = disco_layer_budget + super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, diff --git a/paddlenlp/transformers/llama/disco_cache.py b/paddlenlp/transformers/llama/disco_cache.py new file mode 100644 index 000000000000..626f77ac842b --- /dev/null +++ b/paddlenlp/transformers/llama/disco_cache.py @@ -0,0 +1,239 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +DISCO (Dynamic Score-based Cache Optimization) implementation for PaddleNLP. +Based on the CAKE algorithm, this provides adaptive KV cache eviction strategies. +""" + +# Avoid circular import issues +import logging +from typing import List, Optional, Tuple + +import paddle +import paddle.nn.functional as F +from paddle import Tensor + +logger = logging.getLogger(__name__) + + +class DISCOCache: + """ + DISCO Cache implementation for efficient KV cache management. + This cache adaptively allocates memory across layers based on attention patterns. + """ + + def __init__( + self, + num_layers: int, + cache_size: int, + window_size: int = 32, + gamma: float = 0.1, + score_func_path: Optional[str] = None, + layer_budget: Optional[List[int]] = None, + ): + self.num_layers = num_layers + self.cache_size = cache_size + self.window_size = window_size + self.gamma = gamma + + # Layer-wise cache allocation + if layer_budget is None: + # Default: equal allocation across layers + self.layer_budget = [cache_size // num_layers] * num_layers + else: + self.layer_budget = layer_budget + + # Initialize score function + self.score_func = None + if score_func_path: + try: + coeffs = np.load(score_func_path) + from numpy.polynomial import Polynomial + + self.score_func = Polynomial(coeffs) + except Exception as e: + logger.warning(f"Failed to load score function from {score_func_path}: {e}") + + # Cache storage + self.key_cache = {} + self.value_cache = {} + self.attention_scores = {} + self.dispersion_scores = {} + + # Layer-wise eviction managers + self.layer_eviction_managers = {} + for i in range(num_layers): + self.layer_eviction_managers[i] = LayerwiseEvictionManager( + cache_size=self.layer_budget[i], window_size=window_size, gamma=gamma + ) + + def update( + self, + key_states: Tensor, + value_states: Tensor, + layer_idx: int, + attention_weights: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + """Update cache with new key-value pairs.""" + + if layer_idx not in self.key_cache: + # First update for this layer - don't evict on initialization + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + # Concatenate new states + self.key_cache[layer_idx] = paddle.concat([self.key_cache[layer_idx], key_states], axis=2) + self.value_cache[layer_idx] = paddle.concat([self.value_cache[layer_idx], value_states], axis=2) + + # Apply eviction if needed (only on subsequent updates) + if self.key_cache[layer_idx].shape[2] > self.layer_budget[layer_idx]: + self._evict_tokens(layer_idx, attention_weights) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def _evict_tokens(self, layer_idx: int, attention_weights: Optional[Tensor] = None): + """Evict tokens based on DISCO scoring strategy.""" + eviction_manager = self.layer_eviction_managers[layer_idx] + + # Get current cache + key_cache = self.key_cache[layer_idx] + value_cache = self.value_cache[layer_idx] + + # Compute eviction scores + if attention_weights is not None: + scores = eviction_manager.compute_scores(key_cache, value_cache, attention_weights) + else: + # Fallback to distance-based scoring + scores = eviction_manager.compute_distance_scores(key_cache, value_cache) + + # Keep top-k tokens + k = self.layer_budget[layer_idx] + + # Get the sequence dimension + seq_len = key_cache.shape[2] + + # Average scores across batch and heads for token selection + if scores.ndim > 1: + # scores shape: [batch, heads, seq] or [batch, seq] + if scores.ndim == 3: + # Average across batch and heads + scores = scores.mean(axis=[0, 1]) + elif scores.ndim == 2: + # Average across batch + scores = scores.mean(axis=0) + + # Now scores is 1D with shape [seq_len] + # Get top-k indices + k = min(k, seq_len) + _, keep_indices = paddle.topk(scores, k=k) + + # Sort indices to maintain order + keep_indices = paddle.sort(keep_indices) + + # Update cache + self.key_cache[layer_idx] = paddle.index_select(key_cache, keep_indices, axis=2) + self.value_cache[layer_idx] = paddle.index_select(value_cache, keep_indices, axis=2) + + def get_seq_length(self, layer_idx: Optional[int] = None) -> int: + """Get sequence length for a specific layer or maximum across all layers.""" + if layer_idx is not None: + return self.key_cache.get(layer_idx, paddle.zeros([1, 1, 0, 1])).shape[2] + else: + if not self.key_cache: + return 0 + return max(cache.shape[2] for cache in self.key_cache.values()) + + def get_cache(self, layer_idx: int) -> Tuple[Tensor, Tensor]: + """Get key-value cache for a specific layer.""" + if layer_idx not in self.key_cache: + return None, None + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + +class LayerwiseEvictionManager: + """Manages token eviction for a single layer.""" + + def __init__(self, cache_size: int, window_size: int, gamma: float): + self.cache_size = cache_size + self.window_size = window_size + self.gamma = gamma + + def compute_scores(self, key_states: Tensor, value_states: Tensor, attention_weights: Tensor) -> Tensor: + """Compute eviction scores based on attention patterns.""" + + # Extract attention scores for recent window + window_attn = attention_weights[..., -self.window_size :, :] + + # Compute mean and variance of attention + attn_mean = window_attn.mean(axis=-2) + attn_var = window_attn.var(axis=-2) + + # Combined score: mean + gamma * variance + scores = attn_mean + self.gamma * attn_var + + # Apply smoothing + if scores.shape[-1] > 5: + # Simple moving average + kernel_size = min(5, scores.shape[-1]) + # Reshape for avg_pool1d: [batch*heads, 1, seq_len] + original_shape = scores.shape + scores_reshaped = scores.reshape([-1, 1, scores.shape[-1]]) + scores = F.avg_pool1d(scores_reshaped, kernel_size=kernel_size, stride=1, padding=kernel_size // 2) + # Reshape back + scores = scores.squeeze(1).reshape(original_shape) + + return scores + + def compute_distance_scores(self, key_states: Tensor, value_states: Tensor) -> Tensor: + """Compute scores based on key-value distances (fallback method).""" + + # Compute L2 norm of keys and values + key_norm = paddle.norm(key_states, p=2, axis=-1) + value_norm = paddle.norm(value_states, p=2, axis=-1) + + # Combined score + scores = key_norm + value_norm + + # Average across heads + scores = scores.mean(axis=1) + + return scores + + +def get_scores_with_kv_fusion(key_states: Tensor, value_states: Tensor, window_size: int = 32) -> Tensor: + """ + Compute fusion scores based on key-value similarity patterns. + This is used for dispersion-based scoring in DISCO. + """ + + # Extract recent window + if key_states.shape[2] > window_size: + recent_keys = key_states[:, :, -window_size:] + recent_values = value_states[:, :, -window_size:] + else: + recent_keys = key_states + recent_values = value_states + + # Compute key-value similarity + key_norm = F.normalize(recent_keys, p=2, axis=-1) + value_norm = F.normalize(recent_values, p=2, axis=-1) + + # Cosine similarity between keys and values + similarity = paddle.sum(key_norm * value_norm, axis=-1) + + # Aggregate across window + scores = similarity.mean(axis=-1) + + return scores diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 84c9bcc7ff4e..a19b2fae4a1a 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -97,6 +97,7 @@ def swiglu(x, y=None): except: flash_attention = None from . import fusion_ops +from .disco_cache import DISCOCache, get_scores_with_kv_fusion rms_norm_fused = fusion_ops.rms_norm_fused @@ -714,6 +715,11 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, skip_ self.fuse_attention_qkv = config.fuse_attention_qkv + # DISCO algorithm attributes + self.use_disco = getattr(config, 'use_disco', False) + self.disco_cache = None + self.layer_idx = None # Will be set by the model + self.kv_indices = None # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True # Enable_recompute defaults to False and is controlled by Trainer @@ -1093,15 +1099,30 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) # [bs, seq_len, num_head, head_dim] - if past_key_value is not None: - # reuse k, v, self_attention - key_states = paddle.concat([past_key_value[0], key_states], axis=1) - value_states = paddle.concat([past_key_value[1], value_states], axis=1) - if self.config.immediate_clear_past_key_value: - past_key_value[0]._clear_data() - past_key_value[1]._clear_data() - - past_key_value = (key_states, value_states) if use_cache else None + if self.use_disco and use_cache: + # Use DISCO cache management + if self.disco_cache is not None and self.layer_idx is not None: + # Get attention weights for scoring (if available) + attention_weights = None + if self.config.use_disco and hasattr(self, '_last_attention_weights'): + attention_weights = self._last_attention_weights + + # Update DISCO cache + key_states, value_states = self.disco_cache.update( + key_states, value_states, self.layer_idx, attention_weights + ) + past_key_value = (key_states, value_states) + else: + # Standard KV cache handling + if past_key_value is not None: + # reuse k, v, self_attention + key_states = paddle.concat([past_key_value[0], key_states], axis=1) + value_states = paddle.concat([past_key_value[1], value_states], axis=1) + if self.config.immediate_clear_past_key_value: + past_key_value[0]._clear_data() + past_key_value[1]._clear_data() + + past_key_value = (key_states, value_states) if use_cache else None if self.kv_indices is not None: key_states = paddle.index_select(key_states, self.kv_indices, axis=2) value_states = paddle.index_select(value_states, self.kv_indices, axis=2) @@ -1572,6 +1593,23 @@ def __init__(self, config: LlamaConfig): self.norm = LlamaRMSNorm(config) self.gradient_checkpointing = False + + # Initialize DISCO cache if enabled + self.disco_cache = None + if hasattr(config, 'use_disco') and config.use_disco: + self.disco_cache = DISCOCache( + num_layers=config.num_hidden_layers, + cache_size=config.disco_cache_size if hasattr(config, 'disco_cache_size') else 1024, + window_size=config.disco_window_size if hasattr(config, 'disco_window_size') else 32, + gamma=config.disco_gamma if hasattr(config, 'disco_gamma') else 0.1, + score_func_path=config.disco_score_func_path if hasattr(config, 'disco_score_func_path') else None, + layer_budget=config.disco_layer_budget if hasattr(config, 'disco_layer_budget') else None, + ) + + # Set layer indices and disco_cache for each attention layer + for idx, layer in enumerate(self.layers): + layer.self_attn.layer_idx = idx + layer.self_attn.disco_cache = self.disco_cache def get_input_embeddings(self): return self.embed_tokens diff --git a/test_disco_fix.py b/test_disco_fix.py new file mode 100644 index 000000000000..d0808ac470b0 --- /dev/null +++ b/test_disco_fix.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +"""Test DISCO cache eviction fix""" + +import sys +import os +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +import paddle +import numpy as np + +# Direct import to avoid circular dependency +exec(open('paddlenlp/transformers/llama/disco_cache.py').read()) + +def test_eviction(): + """Test the eviction mechanism fix.""" + print("Testing DISCO eviction fix...") + + # Create cache with small budget to trigger eviction + cache = DISCOCache( + num_layers=4, + cache_size=64, # Total budget + window_size=8, + gamma=0.1, + ) + + # Set smaller per-layer budget to force eviction + cache.layer_budget = [16, 16, 16, 16] # 16 tokens per layer + + batch_size = 2 + num_heads = 8 + head_dim = 64 + layer_idx = 0 + + # First update with 32 tokens (exceeds budget of 16) + seq_length = 32 + key_states = paddle.randn([batch_size, num_heads, seq_length, head_dim]) + value_states = paddle.randn([batch_size, num_heads, seq_length, head_dim]) + + print(f" Initial update with {seq_length} tokens (budget: {cache.layer_budget[layer_idx]})") + updated_keys, updated_values = cache.update(key_states, value_states, layer_idx) + + # Check that eviction happened + actual_seq_len = updated_keys.shape[2] + expected_max = cache.layer_budget[layer_idx] + + print(f" After eviction: {actual_seq_len} tokens (expected <= {expected_max})") + + if actual_seq_len <= expected_max: + print(" ✅ Eviction working correctly!") + else: + print(f" ❌ Eviction failed: {actual_seq_len} > {expected_max}") + return False + + # Add more tokens to test continuous eviction + new_tokens = 10 + new_keys = paddle.randn([batch_size, num_heads, new_tokens, head_dim]) + new_values = paddle.randn([batch_size, num_heads, new_tokens, head_dim]) + + print(f" Adding {new_tokens} more tokens...") + updated_keys, updated_values = cache.update(new_keys, new_values, layer_idx) + + actual_seq_len = updated_keys.shape[2] + print(f" After second update: {actual_seq_len} tokens (expected <= {expected_max})") + + if actual_seq_len <= expected_max: + print(" ✅ Continuous eviction working correctly!") + return True + else: + print(f" ❌ Continuous eviction failed: {actual_seq_len} > {expected_max}") + return False + +if __name__ == "__main__": + print("DISCO Eviction Fix Test") + print("=" * 60) + + try: + success = test_eviction() + print("\n" + "=" * 60) + if success: + print("✅ All tests passed!") + else: + print("❌ Tests failed!") + except Exception as e: + print(f"\n❌ Error during test: {e}") + import traceback + traceback.print_exc() \ No newline at end of file diff --git a/tests/transformers/llama/test_disco.py b/tests/transformers/llama/test_disco.py new file mode 100644 index 000000000000..6d127775d0c8 --- /dev/null +++ b/tests/transformers/llama/test_disco.py @@ -0,0 +1,260 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for DISCO (Dynamic Score-based Cache Optimization) algorithm.""" + +import tempfile +import unittest + +import numpy as np +import paddle + +from paddlenlp.transformers import LlamaConfig, LlamaForCausalLM, LlamaModel +from paddlenlp.transformers.llama.disco_cache import DISCOCache, LayerwiseEvictionManager, get_scores_with_kv_fusion +import paddle.nn.functional as F + +class DISCOCacheTest(unittest.TestCase): + """Test cases for DISCO cache implementation.""" + + def setUp(self): + """Set up test fixtures.""" + self.num_layers = 4 + self.cache_size = 128 + self.window_size = 16 + self.gamma = 0.1 + self.batch_size = 2 + self.num_heads = 8 + self.head_dim = 64 + self.seq_length = 32 + + def test_disco_cache_initialization(self): + """Test DISCO cache initialization with different configurations.""" + # Test default initialization + cache = DISCOCache( + num_layers=self.num_layers, + cache_size=self.cache_size, + window_size=self.window_size, + gamma=self.gamma, + ) + + self.assertEqual(len(cache.layer_budget), self.num_layers) + self.assertEqual(sum(cache.layer_budget), self.cache_size) + self.assertEqual(cache.window_size, self.window_size) + self.assertEqual(cache.gamma, self.gamma) + + # Test with custom layer budget + custom_budget = [40, 30, 30, 28] + cache_custom = DISCOCache( + num_layers=self.num_layers, + cache_size=self.cache_size, + window_size=self.window_size, + gamma=self.gamma, + layer_budget=custom_budget, + ) + self.assertEqual(cache_custom.layer_budget, custom_budget) + + def test_cache_update_and_eviction(self): + """Test cache update and eviction mechanism.""" + cache = DISCOCache( + num_layers=self.num_layers, + cache_size=64, # Small cache for testing eviction + window_size=8, + gamma=self.gamma, + ) + + # Create dummy key and value states + key_states = paddle.randn([self.batch_size, self.num_heads, self.seq_length, self.head_dim]) + value_states = paddle.randn([self.batch_size, self.num_heads, self.seq_length, self.head_dim]) + + # Test update for layer 0 + updated_keys, updated_values = cache.update(key_states, value_states, layer_idx=0) + + # Check shapes + self.assertEqual(updated_keys.shape, key_states.shape) + self.assertEqual(updated_values.shape, value_states.shape) + + # Test multiple updates to trigger eviction + for i in range(3): + new_keys = paddle.randn([self.batch_size, self.num_heads, 16, self.head_dim]) + new_values = paddle.randn([self.batch_size, self.num_heads, 16, self.head_dim]) + updated_keys, updated_values = cache.update(new_keys, new_values, layer_idx=0) + + # Check that cache size is within budget + actual_size = updated_keys.shape[2] + budget = cache.layer_budget[0] + self.assertLessEqual(actual_size, budget) + + def test_layerwise_eviction_manager(self): + """Test LayerwiseEvictionManager scoring functions.""" + manager = LayerwiseEvictionManager( + cache_size=32, + window_size=8, + gamma=0.1 + ) + + # Create dummy inputs + key_states = paddle.randn([self.batch_size, self.num_heads, self.seq_length, self.head_dim]) + value_states = paddle.randn([self.batch_size, self.num_heads, self.seq_length, self.head_dim]) + attention_weights = F.softmax( + paddle.randn([self.batch_size, self.num_heads, self.window_size, self.seq_length]), + axis=-1 + ) + + # Test attention-based scoring + scores = manager.compute_scores(key_states, value_states, attention_weights) + self.assertEqual(scores.shape[0], self.batch_size) + self.assertEqual(scores.shape[1], self.num_heads) + + # Test distance-based scoring + distance_scores = manager.compute_distance_scores(key_states, value_states) + self.assertEqual(distance_scores.shape[0], self.batch_size) + self.assertEqual(distance_scores.shape[-1], self.seq_length) + + def test_kv_fusion_scoring(self): + """Test key-value fusion scoring function.""" + key_states = paddle.randn([self.batch_size, self.num_heads, self.seq_length, self.head_dim]) + value_states = paddle.randn([self.batch_size, self.num_heads, self.seq_length, self.head_dim]) + + scores = get_scores_with_kv_fusion(key_states, value_states, window_size=16) + + self.assertEqual(scores.shape[0], self.batch_size) + self.assertEqual(scores.shape[1], self.num_heads) + + def test_llama_with_disco_config(self): + """Test LLaMA model initialization with DISCO configuration.""" + config = LlamaConfig( + vocab_size=1000, # Small vocab for testing + hidden_size=256, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=512, + # DISCO parameters + use_disco=True, + disco_cache_size=64, + disco_window_size=8, + disco_gamma=0.15, + ) + + # Test model creation doesn't fail + model = LlamaModel(config) + + # Verify DISCO cache is initialized + self.assertIsNotNone(model.disco_cache) + self.assertEqual(model.disco_cache.num_layers, config.num_hidden_layers) + self.assertEqual(model.disco_cache.cache_size, config.disco_cache_size) + + # Verify attention layers have DISCO attributes + for idx, layer in enumerate(model.layers): + self.assertTrue(layer.self_attn.use_disco) + self.assertEqual(layer.self_attn.layer_idx, idx) + self.assertEqual(layer.self_attn.disco_cache, model.disco_cache) + + def test_llama_forward_with_disco(self): + """Test LLaMA forward pass with DISCO enabled.""" + config = LlamaConfig( + vocab_size=1000, + hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=256, + use_disco=True, + disco_cache_size=32, + disco_window_size=8, + ) + + model = LlamaModel(config) + model.eval() + + # Create input + input_ids = paddle.randint(0, config.vocab_size, shape=[self.batch_size, 16]) + + # Forward pass with DISCO + with paddle.no_grad(): + outputs = model(input_ids, use_cache=True) + + # Check outputs + self.assertEqual(outputs[0].shape[0], self.batch_size) + self.assertEqual(outputs[0].shape[1], 16) + self.assertEqual(outputs[0].shape[2], config.hidden_size) + + def test_disco_disabled_compatibility(self): + """Test that DISCO can be disabled without affecting normal operation.""" + # Config with DISCO disabled + config = LlamaConfig( + vocab_size=1000, + hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=256, + use_disco=False, # DISCO disabled + ) + + model = LlamaModel(config) + model.eval() + + # Verify DISCO cache is not initialized + self.assertIsNone(model.disco_cache) + + # Verify model still works normally + input_ids = paddle.randint(0, config.vocab_size, shape=[self.batch_size, 16]) + with paddle.no_grad(): + outputs = model(input_ids, use_cache=True) + + self.assertEqual(outputs[0].shape[0], self.batch_size) + self.assertEqual(outputs[0].shape[1], 16) + self.assertEqual(outputs[0].shape[2], config.hidden_size) + + def test_score_function_loading(self): + """Test loading score function from file.""" + # Create a temporary score function file + with tempfile.NamedTemporaryFile(suffix='.npy', delete=False) as f: + coeffs = np.array([0.1, 0.2, 0.3, 0.4]) + np.save(f.name, coeffs) + + cache = DISCOCache( + num_layers=self.num_layers, + cache_size=self.cache_size, + score_func_path=f.name, + ) + + self.assertIsNotNone(cache.score_func) + # Test score function evaluation + score = cache.score_func(1) + self.assertIsInstance(score, (float, np.floating)) + + def test_causal_lm_with_disco(self): + """Test LlamaForCausalLM with DISCO enabled.""" + config = LlamaConfig( + vocab_size=1000, + hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=256, + use_disco=True, + disco_cache_size=32, + ) + + model = LlamaForCausalLM(config) + model.eval() + + # Test generation doesn't fail + input_ids = paddle.randint(0, config.vocab_size, shape=[1, 8]) + with paddle.no_grad(): + outputs = model(input_ids, use_cache=True) + + self.assertEqual(outputs[0].shape[-1], config.vocab_size) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file