|  | 
| 10 | 10 | 
 | 
| 11 | 11 | import torch | 
| 12 | 12 | import torch.nn as nn | 
| 13 |  | -from executorch.examples.models.llama.attention import KVCache | 
|  | 13 | +from executorch.examples.models.llama.attention import ( | 
|  | 14 | +    _create_causal_mask_for_ring_buffer, | 
|  | 15 | +    CachePositionsManager, | 
|  | 16 | +    KVCache, | 
|  | 17 | +    RingKVCache, | 
|  | 18 | +) | 
| 14 | 19 | 
 | 
| 15 | 20 | from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib  # noqa: F401 | 
| 16 | 21 | 
 | 
| @@ -75,6 +80,7 @@ def __init__( | 
| 75 | 80 |             self.register_buffer( | 
| 76 | 81 |                 "v_cache_zero_points", torch.ones(scale_shape, dtype=torch.int8) | 
| 77 | 82 |             ) | 
|  | 83 | +        self.cache_type = cache_type | 
| 78 | 84 | 
 | 
| 79 | 85 |     def _quantize(self, value): | 
| 80 | 86 |         ( | 
| @@ -181,6 +187,7 @@ def update(self, input_pos, k_val, v_val, indices=None): | 
| 181 | 187 |         However the storage is [B, S, H, D] so we incur transpose in, transpose out | 
| 182 | 188 |         This shall be removed by subsequent post-export graph pass | 
| 183 | 189 |         """ | 
|  | 190 | + | 
| 184 | 191 |         k_val = k_val.transpose(1, 2) | 
| 185 | 192 |         v_val = v_val.transpose(1, 2) | 
| 186 | 193 | 
 | 
| @@ -346,3 +353,185 @@ def _replace_kv_cache_with_custom_kv_cache(module): | 
| 346 | 353 |         else: | 
| 347 | 354 |             _replace_kv_cache_with_custom_kv_cache(child) | 
| 348 | 355 |     return module | 
|  | 356 | + | 
|  | 357 | + | 
|  | 358 | +class QuantizedRingKVCache(QuantizedKVCache): | 
|  | 359 | +    def __init__( | 
|  | 360 | +        self, | 
|  | 361 | +        max_batch_size, | 
|  | 362 | +        max_context_length, | 
|  | 363 | +        n_heads, | 
|  | 364 | +        head_dim, | 
|  | 365 | +        cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric, | 
|  | 366 | +        use_custom_update_cache_op: bool = False, | 
|  | 367 | +    ): | 
|  | 368 | +        # Look at attention.py for explanation on why max_context_length * 2 | 
|  | 369 | +        super().__init__( | 
|  | 370 | +            max_batch_size, | 
|  | 371 | +            max_context_length * 2, | 
|  | 372 | +            n_heads, | 
|  | 373 | +            head_dim, | 
|  | 374 | +            cache_type, | 
|  | 375 | +            use_custom_update_cache_op, | 
|  | 376 | +        ) | 
|  | 377 | +        self.cache_positions_manager = CachePositionsManager(self.max_context_length) | 
|  | 378 | +        self.is_ring_buffer = True | 
|  | 379 | +        self.window_size = max_context_length | 
|  | 380 | + | 
|  | 381 | +    def create_causal_mask_for_ring_buffer(self, start_pos, seq_len): | 
|  | 382 | +        cache_positions = self.cache_positions_manager.cache_positions | 
|  | 383 | +        return _create_causal_mask_for_ring_buffer( | 
|  | 384 | +            cache_positions, self.window_size, start_pos, seq_len | 
|  | 385 | +        ) | 
|  | 386 | + | 
|  | 387 | +    def update(self, input_pos, k_val, v_val): | 
|  | 388 | +        """ | 
|  | 389 | +        k_val, v_val: [B, H, S, D] | 
|  | 390 | +        return: [B, H, S, D] | 
|  | 391 | +        However the storage is [B, S, H, D] so we incur transpose in, transpose out | 
|  | 392 | +        This shall be removed by subsequent post-export graph pass | 
|  | 393 | +        """ | 
|  | 394 | +        # Need to transpose for two reasons | 
|  | 395 | +        # 1. kv cache is stored as [B, S, H, D] | 
|  | 396 | +        # 2. If seq_len = k_val.size(2), we wont be able be able to optimize | 
|  | 397 | +        #    away transpose at the output of k, v projection | 
|  | 398 | +        seq_len = k_val.transpose(1, 2).size(1) | 
|  | 399 | +        assert seq_len <= self.k_cache.size( | 
|  | 400 | +            1 | 
|  | 401 | +        ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" | 
|  | 402 | +        indices = self.cache_positions_manager.calculate_positions_and_update_indices( | 
|  | 403 | +            input_pos, seq_len | 
|  | 404 | +        ) | 
|  | 405 | +        indices = indices.unsqueeze(0) | 
|  | 406 | + | 
|  | 407 | +        return super().update(input_pos, k_val, v_val, indices) | 
|  | 408 | + | 
|  | 409 | +    @classmethod | 
|  | 410 | +    def from_quantized_kv_cache( | 
|  | 411 | +        cls, | 
|  | 412 | +        kv_cache, | 
|  | 413 | +        sliding_window_size, | 
|  | 414 | +    ): | 
|  | 415 | +        assert isinstance( | 
|  | 416 | +            kv_cache, QuantizedKVCache | 
|  | 417 | +        ), "For QuantizedRingKVCache expect QuantizedKVCache as input kv_cache" | 
|  | 418 | +        max_batch_size, _, n_heads, head_dim = kv_cache.k_cache.shape | 
|  | 419 | +        return cls( | 
|  | 420 | +            max_batch_size, | 
|  | 421 | +            sliding_window_size, | 
|  | 422 | +            n_heads, | 
|  | 423 | +            head_dim, | 
|  | 424 | +            kv_cache.cache_type, | 
|  | 425 | +            kv_cache.use_custom_update_cache_op, | 
|  | 426 | +        ) | 
|  | 427 | + | 
|  | 428 | + | 
|  | 429 | +class CustomRingKVCache(CustomKVCache): | 
|  | 430 | +    def __init__( | 
|  | 431 | +        self, | 
|  | 432 | +        max_batch_size, | 
|  | 433 | +        max_context_length, | 
|  | 434 | +        n_heads, | 
|  | 435 | +        head_dim, | 
|  | 436 | +        dtype=torch.float32, | 
|  | 437 | +    ): | 
|  | 438 | +        # Look at attention.py for explanation on why max_context_length * 2 | 
|  | 439 | +        super().__init__( | 
|  | 440 | +            max_batch_size, max_context_length * 2, n_heads, head_dim, dtype | 
|  | 441 | +        ) | 
|  | 442 | +        self.cache_positions_manager = CachePositionsManager(self.max_context_length) | 
|  | 443 | +        self.is_ring_buffer = True | 
|  | 444 | +        self.window_size = max_context_length | 
|  | 445 | + | 
|  | 446 | +    def create_causal_mask_for_ring_buffer(self, start_pos, seq_len): | 
|  | 447 | +        cache_positions = self.cache_positions_manager.cache_positions | 
|  | 448 | +        return _create_causal_mask_for_ring_buffer( | 
|  | 449 | +            cache_positions, self.window_size, start_pos, seq_len | 
|  | 450 | +        ) | 
|  | 451 | + | 
|  | 452 | +    def update(self, input_pos, k_val, v_val): | 
|  | 453 | +        """ | 
|  | 454 | +        k_val, v_val: [B, H, S, D] | 
|  | 455 | +        return: [B, H, S, D] | 
|  | 456 | +        However the storage is [B, S, H, D] so we incur transpose in, transpose out | 
|  | 457 | +        This shall be removed by subsequent post-export graph pass | 
|  | 458 | +        """ | 
|  | 459 | +        # Need to transpose for two reasons | 
|  | 460 | +        # 1. kv cache is stored as [B, S, H, D] | 
|  | 461 | +        # 2. If seq_len = k_val.size(2), we wont be able be able to optimize | 
|  | 462 | +        #    away transpose at the output of k, v projection | 
|  | 463 | +        seq_len = k_val.transpose(1, 2).size(1) | 
|  | 464 | +        assert seq_len <= self.k_cache.size( | 
|  | 465 | +            1 | 
|  | 466 | +        ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" | 
|  | 467 | +        indices = self.cache_positions_manager.calculate_positions_and_update_indices( | 
|  | 468 | +            input_pos, seq_len | 
|  | 469 | +        ) | 
|  | 470 | +        indices = indices.unsqueeze(0) | 
|  | 471 | + | 
|  | 472 | +        return super().update(input_pos, k_val, v_val, indices) | 
|  | 473 | + | 
|  | 474 | +    @classmethod | 
|  | 475 | +    def from_custom_kv_cache( | 
|  | 476 | +        cls, | 
|  | 477 | +        kv_cache, | 
|  | 478 | +        sliding_window_size, | 
|  | 479 | +    ): | 
|  | 480 | +        max_batch_size, n_heads, _, head_dim = kv_cache.k_cache.shape | 
|  | 481 | +        if isinstance(kv_cache, CustomKVCache): | 
|  | 482 | +            # If replacing custom kv cache, then the shape is [B, S, H, D] | 
|  | 483 | +            max_batch_size, _, n_heads, head_dim = kv_cache.k_cache.shape | 
|  | 484 | +        return cls( | 
|  | 485 | +            max_batch_size, | 
|  | 486 | +            sliding_window_size, | 
|  | 487 | +            n_heads, | 
|  | 488 | +            head_dim, | 
|  | 489 | +            dtype=kv_cache.k_cache.dtype, | 
|  | 490 | +        ) | 
|  | 491 | + | 
|  | 492 | + | 
|  | 493 | +def _replace_kv_cache_with_ring_kv_cache(attention, layer_size): | 
|  | 494 | +    sliding_window_size = layer_size | 
|  | 495 | +    assert ( | 
|  | 496 | +        getattr(attention, "kv_cache", None) is not None | 
|  | 497 | +    ), "Attention module must have kv_cache module" | 
|  | 498 | +    kv_cache = attention.kv_cache | 
|  | 499 | +    if isinstance(kv_cache, KVCache): | 
|  | 500 | +        attention.kv_cache = RingKVCache( | 
|  | 501 | +            kv_cache.max_batch_size, | 
|  | 502 | +            sliding_window_size, | 
|  | 503 | +            kv_cache.n_heads, | 
|  | 504 | +            kv_cache.head_dim, | 
|  | 505 | +            kv_cache.enable_dynamic_shape, | 
|  | 506 | +            kv_cache.k_cache.dtype, | 
|  | 507 | +        ) | 
|  | 508 | +    elif isinstance(kv_cache, CustomKVCache): | 
|  | 509 | +        attention.kv_cache = CustomRingKVCache.from_custom_kv_cache( | 
|  | 510 | +            kv_cache, layer_size | 
|  | 511 | +        ) | 
|  | 512 | +    elif isinstance(kv_cache, QuantizedKVCache): | 
|  | 513 | +        attention.kv_cache = QuantizedRingKVCache.from_quantized_kv_cache( | 
|  | 514 | +            kv_cache, layer_size | 
|  | 515 | +        ) | 
|  | 516 | + | 
|  | 517 | + | 
|  | 518 | +def replace_kv_cache_with_ring_kv_cache(module, layer_sizes): | 
|  | 519 | +    # This is needed to ensure that custom ops are registered | 
|  | 520 | +    from executorch.extension.llm.custom_ops import custom_ops  # noqa: F401 | 
|  | 521 | + | 
|  | 522 | +    logging.info( | 
|  | 523 | +        "Replacing kv cache with ring kv cache. This modifies the model in place." | 
|  | 524 | +    ) | 
|  | 525 | +    assert len(layer_sizes) == len( | 
|  | 526 | +        module.layers | 
|  | 527 | +    ), f"Length of layer sizes {len(layer_sizes)} must match the number of layers in the module {len(module.layers)}." | 
|  | 528 | +    for i, transformer_block in enumerate(module.layers): | 
|  | 529 | +        sliding_window_size = layer_sizes[i] | 
|  | 530 | +        if sliding_window_size == 0: | 
|  | 531 | +            continue | 
|  | 532 | +        assert ( | 
|  | 533 | +            getattr(transformer_block, "attention", None) is not None | 
|  | 534 | +        ), f"Transfomer block must have attention module. Transformer block {transformer_block}" | 
|  | 535 | +        attention = transformer_block.attention | 
|  | 536 | +        _replace_kv_cache_with_ring_kv_cache(attention, sliding_window_size) | 
|  | 537 | +    return module | 
0 commit comments