|
4 | 4 |
|
5 | 5 | import pytest |
6 | 6 |
|
7 | | -from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig |
| 7 | +from vllm.attention.layer import Attention |
| 8 | +from vllm.config import (CacheConfig, ModelConfig, SchedulerConfig, VllmConfig, |
| 9 | + set_current_vllm_config) |
8 | 10 | from vllm.sampling_params import SamplingParams |
| 11 | +from vllm.utils import GiB_bytes |
| 12 | +from vllm.v1.core.kv_cache_utils import (estimate_max_model_len, |
| 13 | + get_kv_cache_config) |
9 | 14 | from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, |
10 | 15 | SchedulerOutput) |
11 | 16 | from vllm.v1.worker.tpu_model_runner import ( |
@@ -363,3 +368,223 @@ def test_get_req_paddings(): |
363 | 368 | assert _get_req_paddings(1, 32) == [8, 16, 32] |
364 | 369 | assert _get_req_paddings(8, 32) == [8, 16, 32] |
365 | 370 | assert _get_req_paddings(8, 36) == [8, 16, 32, 36] |
| 371 | + |
| 372 | + |
| 373 | +def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(): |
| 374 | + layer_0 = "model.layers.0.self_attn.attn" |
| 375 | + layer_1 = "model.layers.1.self_attn.attn" |
| 376 | + error_msg = f"{layer_1} must come before the current layer" |
| 377 | + with pytest.raises(ValueError, match=error_msg): |
| 378 | + fwd_context = { |
| 379 | + # initialization below will fail because target layer is invalid; |
| 380 | + # the target layer needs to come before layer 1 |
| 381 | + layer_0: |
| 382 | + Attention( |
| 383 | + num_heads=8, |
| 384 | + head_size=64, |
| 385 | + scale=1.0, |
| 386 | + prefix=layer_0, |
| 387 | + kv_sharing_target_layer_name=layer_1, |
| 388 | + ), |
| 389 | + layer_1: |
| 390 | + Attention( |
| 391 | + num_heads=8, |
| 392 | + head_size=64, |
| 393 | + scale=1.0, |
| 394 | + prefix=layer_1, |
| 395 | + ) |
| 396 | + } |
| 397 | + # suppress var not used error |
| 398 | + assert fwd_context is not None |
| 399 | + |
| 400 | + |
| 401 | +def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(): |
| 402 | + layer_0 = "model.layers.0.self_attn.attn" |
| 403 | + layer_1 = "model.layers.1.self_attn.attn" |
| 404 | + invalid_layer = "model.layers.0.cross_attn.attn" |
| 405 | + error_msg = f"{invalid_layer} is not a valid Attention layer in the model" |
| 406 | + with pytest.raises(ValueError, match=error_msg): |
| 407 | + fwd_context = { |
| 408 | + layer_0: |
| 409 | + Attention( |
| 410 | + num_heads=8, |
| 411 | + head_size=64, |
| 412 | + scale=1.0, |
| 413 | + prefix=layer_0, |
| 414 | + ), |
| 415 | + layer_1: |
| 416 | + Attention( |
| 417 | + num_heads=8, |
| 418 | + head_size=64, |
| 419 | + scale=1.0, |
| 420 | + prefix=layer_1, |
| 421 | + # invalid layer: cross_attn.atn doesn't exist! |
| 422 | + kv_sharing_target_layer_name=invalid_layer, |
| 423 | + ) |
| 424 | + } |
| 425 | + # suppress var not used error |
| 426 | + assert fwd_context is not None |
| 427 | + |
| 428 | + |
| 429 | +def test_init_kv_cache_with_kv_sharing_target_same_as_current(): |
| 430 | + layer_0 = "model.layers.0.self_attn.attn" |
| 431 | + layer_1 = "model.layers.1.self_attn.attn" |
| 432 | + error_msg = f"{layer_1} cannot be the same as the current layer" |
| 433 | + with pytest.raises(ValueError, match=error_msg): |
| 434 | + fwd_context = { |
| 435 | + # initialization below will fail because target layer is invalid; |
| 436 | + # the target layer needs to come before layer 1 |
| 437 | + layer_0: |
| 438 | + Attention( |
| 439 | + num_heads=8, |
| 440 | + head_size=64, |
| 441 | + scale=1.0, |
| 442 | + prefix=layer_0, |
| 443 | + ), |
| 444 | + layer_1: |
| 445 | + Attention( |
| 446 | + num_heads=8, |
| 447 | + head_size=64, |
| 448 | + scale=1.0, |
| 449 | + prefix=layer_1, |
| 450 | + kv_sharing_target_layer_name=layer_1, |
| 451 | + ) |
| 452 | + } |
| 453 | + # suppress var not used error |
| 454 | + assert fwd_context is not None |
| 455 | + |
| 456 | + |
| 457 | +def test_init_kv_cache_without_kv_sharing(model_runner): |
| 458 | + layer_0 = "model.layers.0.self_attn.attn" |
| 459 | + layer_1 = "model.layers.1.self_attn.attn" |
| 460 | + vllm_config = model_runner.vllm_config |
| 461 | + with set_current_vllm_config(vllm_config): |
| 462 | + fwd_context = { |
| 463 | + layer_0: |
| 464 | + Attention( |
| 465 | + num_heads=8, |
| 466 | + head_size=64, |
| 467 | + scale=1.0, |
| 468 | + prefix=layer_0, |
| 469 | + ), |
| 470 | + layer_1: |
| 471 | + Attention( |
| 472 | + num_heads=8, |
| 473 | + head_size=64, |
| 474 | + scale=1.0, |
| 475 | + prefix=layer_1, |
| 476 | + ) |
| 477 | + } |
| 478 | + # suppress var not used error |
| 479 | + assert fwd_context is not None |
| 480 | + # Set high context length to test max context length estimation |
| 481 | + vllm_config.model_config.max_model_len = 3_000_000 |
| 482 | + vllm_ctx = vllm_config.compilation_config.static_forward_context |
| 483 | + kv_cache_spec = model_runner.get_kv_cache_spec() |
| 484 | + assert len(kv_cache_spec) == 2 |
| 485 | + assert len(model_runner.shared_kv_cache_layers) == 0 |
| 486 | + |
| 487 | + available_memory = 20 * GiB_bytes |
| 488 | + # page size for layer 0's kv_cache_spec is 32KB |
| 489 | + num_expected_blocks = 327680 # 20GB / 32KB / 2 (num layers) |
| 490 | + kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, |
| 491 | + available_memory) |
| 492 | + assert kv_cache_config.num_blocks == num_expected_blocks |
| 493 | + assert len(kv_cache_config.tensors) == 2 |
| 494 | + assert kv_cache_config.tensors[layer_0].size == available_memory // 2 |
| 495 | + assert kv_cache_config.tensors[layer_1].size == available_memory // 2 |
| 496 | + |
| 497 | + max_context_len =\ |
| 498 | + estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) |
| 499 | + # max context len with KV sharing should be 2x as large as without |
| 500 | + assert max_context_len == 1310720 |
| 501 | + |
| 502 | + # important: override tensor size to prevent large mem alloc during test |
| 503 | + # this will only allocate 2 block worth of memory (2 * 32kb) |
| 504 | + kv_cache_config.num_blocks = 1 |
| 505 | + for layer in kv_cache_config.tensors: |
| 506 | + kv_cache_config.tensors[layer].size =\ |
| 507 | + kv_cache_spec[layer].page_size_bytes |
| 508 | + |
| 509 | + model_runner.initialize_kv_cache(kv_cache_config) |
| 510 | + |
| 511 | + layer_0_kv = vllm_ctx[layer_0].kv_cache[0] |
| 512 | + layer_1_kv = vllm_ctx[layer_1].kv_cache[0] |
| 513 | + # check layer 1 kv cache does NOT share memory with layer 0 |
| 514 | + assert id(layer_1_kv) != id(layer_0_kv) |
| 515 | + |
| 516 | + # check layer 1 added to kv cache group's layer names |
| 517 | + assert len(kv_cache_config.kv_cache_groups) == 1 |
| 518 | + assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2 |
| 519 | + assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0 |
| 520 | + assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1 |
| 521 | + |
| 522 | + |
| 523 | +def test_init_kv_cache_with_kv_sharing_valid(model_runner): |
| 524 | + layer_0 = "model.layers.0.self_attn.attn" |
| 525 | + layer_1 = "model.layers.1.self_attn.attn" |
| 526 | + vllm_config = model_runner.vllm_config |
| 527 | + with set_current_vllm_config(vllm_config): |
| 528 | + fwd_context = { |
| 529 | + layer_0: |
| 530 | + Attention( |
| 531 | + num_heads=8, |
| 532 | + head_size=64, |
| 533 | + scale=1.0, |
| 534 | + prefix=layer_0, |
| 535 | + ), |
| 536 | + layer_1: |
| 537 | + Attention( |
| 538 | + num_heads=8, |
| 539 | + head_size=64, |
| 540 | + scale=1.0, |
| 541 | + prefix=layer_1, |
| 542 | + kv_sharing_target_layer_name="model.layers.0.self_attn.attn", |
| 543 | + ) |
| 544 | + } |
| 545 | + # suppress var not used error |
| 546 | + assert fwd_context is not None |
| 547 | + # Set high context length to test max context length estimation |
| 548 | + vllm_config.model_config.max_model_len = 3_000_000 |
| 549 | + vllm_ctx = vllm_config.compilation_config.static_forward_context |
| 550 | + kv_cache_spec = model_runner.get_kv_cache_spec() |
| 551 | + assert len(kv_cache_spec) == 1 |
| 552 | + assert layer_0 in kv_cache_spec |
| 553 | + assert model_runner.shared_kv_cache_layers[layer_1] == layer_0 |
| 554 | + |
| 555 | + available_memory = 20 * GiB_bytes |
| 556 | + # page size for layer 0's kv_cache_spec is 32KB |
| 557 | + # with KV sharing, we can allocate (available_mem//page_size//1) blocks |
| 558 | + # which is twice as many as without KV sharing |
| 559 | + num_expected_blocks = 655360 # 20GB / 32KB |
| 560 | + kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, |
| 561 | + available_memory) |
| 562 | + assert kv_cache_config.num_blocks == num_expected_blocks |
| 563 | + assert len(kv_cache_config.tensors) == 1 |
| 564 | + # Each layer now has twice the available memory for KV cache |
| 565 | + # compared to no KV sharing |
| 566 | + assert kv_cache_config.tensors[layer_0].size == available_memory |
| 567 | + |
| 568 | + max_context_len =\ |
| 569 | + estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) |
| 570 | + # max context len with KV sharing should be 2x as large as without |
| 571 | + assert max_context_len == 2 * 1310720 |
| 572 | + |
| 573 | + # important: override tensor size to prevent large mem alloc during test |
| 574 | + # this will only allocate 1 block worth of memory (32kb) |
| 575 | + kv_cache_config.num_blocks = 1 |
| 576 | + kv_cache_config.tensors[layer_0].size =\ |
| 577 | + kv_cache_spec[layer_0].page_size_bytes |
| 578 | + |
| 579 | + model_runner.initialize_kv_cache(kv_cache_config) |
| 580 | + |
| 581 | + layer_0_kv = vllm_ctx[layer_0].kv_cache[0] |
| 582 | + layer_1_kv = vllm_ctx[layer_1].kv_cache[0] |
| 583 | + # check layer 1 kv cache shares memory with layer 0 |
| 584 | + assert id(layer_1_kv) == id(layer_0_kv) |
| 585 | + |
| 586 | + # check layer 1 added to kv cache group's layer names |
| 587 | + assert len(kv_cache_config.kv_cache_groups) == 1 |
| 588 | + assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2 |
| 589 | + assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0 |
| 590 | + assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1 |
0 commit comments