|
3 | 3 |
|
4 | 4 | import itertools
|
5 | 5 | from abc import abstractmethod
|
6 |
| -from typing import Any, Literal, Optional, Union |
| 6 | +from typing import Any, Optional, Union |
7 | 7 |
|
8 | 8 | import torch
|
9 |
| -import torch.nn as nn |
10 | 9 | from torch.nn.parameter import Parameter, UninitializedParameter
|
11 | 10 |
|
12 | 11 | from vllm.distributed import (
|
@@ -1440,237 +1439,3 @@ def extra_repr(self) -> str:
|
1440 | 1439 | s += f", tp_size={self.tp_size}"
|
1441 | 1440 | s += f", reduce_results={self.reduce_results}"
|
1442 | 1441 | return s
|
1443 |
| - |
1444 |
| - |
1445 |
| -@CustomOp.register("qkv_cross_parallel_linear") |
1446 |
| -class QKVCrossParallelLinear(LinearBase): |
1447 |
| - """Linear layers for efficient cross-attention's QKV transformation. |
1448 |
| -
|
1449 |
| - Args: |
1450 |
| - hidden_size: input hidden state size of the transformer. |
1451 |
| - head_size: size of each attention head. |
1452 |
| - total_num_heads: total number of attention query heads. |
1453 |
| - total_num_kv_heads: total number of attention key/value heads. If |
1454 |
| - None, assume total_num_kv_heads = total_num_heads. |
1455 |
| - bias: If true, add bias. |
1456 |
| - skip_bias_add: This was added to enable performance optimizations where |
1457 |
| - bias can be fused with other element-wise operations. we |
1458 |
| - skip adding bias but instead return it. |
1459 |
| - params_dtype: Data type for the parameters. |
1460 |
| - quant_config: Quantization configure. |
1461 |
| - prefix: The name of the layer in the state dict, including all parents |
1462 |
| - (e.g. model.layers.0.qkv_proj) |
1463 |
| - """ |
1464 |
| - |
1465 |
| - def __init__( |
1466 |
| - self, |
1467 |
| - hidden_size: int, |
1468 |
| - head_size: int, |
1469 |
| - total_num_heads: int, |
1470 |
| - total_num_kv_heads: Optional[int] = None, |
1471 |
| - bias: bool = True, |
1472 |
| - skip_bias_add: bool = False, |
1473 |
| - params_dtype: Optional[torch.dtype] = None, |
1474 |
| - quant_config: Optional[QuantizationConfig] = None, |
1475 |
| - prefix: str = "", |
1476 |
| - ): |
1477 |
| - # input_size and output_size are not used, just for alignment |
1478 |
| - input_size = hidden_size |
1479 |
| - output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size |
1480 |
| - super().__init__( |
1481 |
| - input_size=input_size, |
1482 |
| - output_size=output_size, |
1483 |
| - skip_bias_add=skip_bias_add, |
1484 |
| - params_dtype=params_dtype, |
1485 |
| - quant_config=quant_config, |
1486 |
| - prefix=prefix, |
1487 |
| - ) |
1488 |
| - |
1489 |
| - self.quant_config = quant_config |
1490 |
| - |
1491 |
| - # Empty placeholders for loading as a single module. |
1492 |
| - placeholder_size = 0 |
1493 |
| - assert self.quant_method is not None |
1494 |
| - self.quant_method.create_weights( |
1495 |
| - self, |
1496 |
| - placeholder_size, |
1497 |
| - [placeholder_size], |
1498 |
| - placeholder_size, |
1499 |
| - placeholder_size, |
1500 |
| - self.params_dtype, |
1501 |
| - weight_loader=self.weight_loader, |
1502 |
| - ) |
1503 |
| - |
1504 |
| - # Use a dictionary to avoid submodules parameters auto-registration: |
1505 |
| - # drop-in replacement for a `QKVParallelLinear` module. |
1506 |
| - self.proj = dict() |
1507 |
| - self.proj["q_proj_decoder"] = ColumnParallelLinear( |
1508 |
| - input_size=hidden_size, |
1509 |
| - output_size=total_num_heads * head_size, |
1510 |
| - bias=bias, |
1511 |
| - quant_config=quant_config, |
1512 |
| - skip_bias_add=skip_bias_add, |
1513 |
| - params_dtype=params_dtype, |
1514 |
| - prefix=f"{prefix}.q_proj_decoder", |
1515 |
| - ) |
1516 |
| - |
1517 |
| - self.proj["kv_proj_encoder"] = QKVParallelLinear( |
1518 |
| - hidden_size=hidden_size, |
1519 |
| - head_size=head_size, |
1520 |
| - total_num_heads=0, |
1521 |
| - total_num_kv_heads=total_num_kv_heads, |
1522 |
| - bias=bias, |
1523 |
| - quant_config=quant_config, |
1524 |
| - skip_bias_add=skip_bias_add, |
1525 |
| - params_dtype=params_dtype, |
1526 |
| - prefix=f"{prefix}.kv_proj_encoder", |
1527 |
| - ) |
1528 |
| - |
1529 |
| - # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1. |
1530 |
| - self.q_size = self.q_proj_decoder.output_size_per_partition |
1531 |
| - self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size |
1532 |
| - |
1533 |
| - if bias: |
1534 |
| - self.bias = torch.nn.Parameter() |
1535 |
| - set_weight_attrs( |
1536 |
| - self.bias, |
1537 |
| - { |
1538 |
| - "output_dim": 0, |
1539 |
| - "weight_loader": self.weight_loader_v1, |
1540 |
| - }, |
1541 |
| - ) |
1542 |
| - else: |
1543 |
| - self.bias = None |
1544 |
| - |
1545 |
| - def process_weights_after_loading(self): |
1546 |
| - for layer in self.proj.values(): |
1547 |
| - if self.quant_method is not None: |
1548 |
| - self.quant_method.process_weights_after_loading(layer) |
1549 |
| - |
1550 |
| - @property |
1551 |
| - def q_proj_decoder(self) -> ColumnParallelLinear: |
1552 |
| - layer = self.proj["q_proj_decoder"] |
1553 |
| - for name, param in self.named_parameters(): |
1554 |
| - target_param = getattr(layer, name, None) |
1555 |
| - if target_param is not None: |
1556 |
| - self.sync_weight_attrs(param, target_param, mode="q_proj_decoder") |
1557 |
| - return layer |
1558 |
| - |
1559 |
| - @property |
1560 |
| - def kv_proj_encoder(self) -> QKVParallelLinear: |
1561 |
| - layer = self.proj["kv_proj_encoder"] |
1562 |
| - for name, param in self.named_parameters(): |
1563 |
| - target_param = getattr(layer, name, None) |
1564 |
| - if target_param is not None: |
1565 |
| - self.sync_weight_attrs(param, target_param, mode="kv_proj_encoder") |
1566 |
| - return layer |
1567 |
| - |
1568 |
| - def sync_weight_attrs( |
1569 |
| - self, |
1570 |
| - src_param: nn.Parameter, |
1571 |
| - tgt_param: nn.Parameter, |
1572 |
| - mode: Literal["q_proj_decoder", "kv_proj_encoder"], |
1573 |
| - ): |
1574 |
| - missing_attrs_dict = { |
1575 |
| - k: getattr(src_param, k) |
1576 |
| - for k in (set(vars(src_param).keys()) - set(vars(tgt_param).keys())) |
1577 |
| - } |
1578 |
| - # TODO(Isotr0py): handle bitsandbytes 8bit |
1579 |
| - use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit", False) |
1580 |
| - if missing_attrs_dict and use_bitsandbytes_4bit: |
1581 |
| - q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard( |
1582 |
| - missing_attrs_dict |
1583 |
| - ) |
1584 |
| - if mode == "q_proj_decoder": |
1585 |
| - set_weight_attrs(tgt_param, q_proj_attrs) |
1586 |
| - elif mode == "kv_proj_encoder": |
1587 |
| - set_weight_attrs(tgt_param, kv_proj_attrs) |
1588 |
| - else: |
1589 |
| - set_weight_attrs(tgt_param, missing_attrs_dict) |
1590 |
| - |
1591 |
| - def _is_same_param( |
1592 |
| - self, |
1593 |
| - src_param: torch.nn.Parameter, |
1594 |
| - map_param: torch.nn.Parameter, |
1595 |
| - ) -> bool: |
1596 |
| - """Check if two parameters are exactly pointing to same things.""" |
1597 |
| - # ignore weight_loader because it's always different |
1598 |
| - key_to_ignore = ["weight_loader", "_weight_loader"] |
1599 |
| - has_same_type_name = type(src_param) is type(map_param) |
1600 |
| - src_param_attrs = { |
1601 |
| - k: v for k, v in src_param.__dict__.items() if k not in key_to_ignore |
1602 |
| - } |
1603 |
| - map_param_attrs = { |
1604 |
| - k: v for k, v in map_param.__dict__.items() if k not in key_to_ignore |
1605 |
| - } |
1606 |
| - has_same_attrs = src_param_attrs == map_param_attrs |
1607 |
| - return has_same_type_name and has_same_attrs |
1608 |
| - |
1609 |
| - def select_proj_params( |
1610 |
| - self, |
1611 |
| - layer: nn.Module, |
1612 |
| - param: nn.Parameter, |
1613 |
| - ) -> nn.Parameter: |
1614 |
| - """ |
1615 |
| - Given the placeholder param, |
1616 |
| - return the corresponding param in the proj layers. |
1617 |
| - """ |
1618 |
| - target_param_list = [ |
1619 |
| - v for _, v in layer.named_parameters() if self._is_same_param(param, v) |
1620 |
| - ] |
1621 |
| - assert len(target_param_list) == 1 |
1622 |
| - target_param = target_param_list[0] |
1623 |
| - return target_param |
1624 |
| - |
1625 |
| - def forward( # type: ignore[override] |
1626 |
| - self, |
1627 |
| - decoder_hidden_states: torch.Tensor, |
1628 |
| - encoder_hidden_states: torch.Tensor, |
1629 |
| - ) -> tuple[torch.Tensor, ...]: |
1630 |
| - q, _ = self.q_proj_decoder(decoder_hidden_states) |
1631 |
| - if encoder_hidden_states is None: |
1632 |
| - # Encoder KV already cached. |
1633 |
| - k = None |
1634 |
| - v = None |
1635 |
| - else: |
1636 |
| - # Prefill phase, encoder KV cached here. |
1637 |
| - kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states) |
1638 |
| - # Split kv in half |
1639 |
| - k, v = kv_enc.split(self.kv_size, dim=-1) |
1640 |
| - return q, k, v |
1641 |
| - |
1642 |
| - def weight_loader_v1( |
1643 |
| - self, |
1644 |
| - param: torch.nn.Parameter, |
1645 |
| - loaded_weight: torch.Tensor, |
1646 |
| - loaded_shard_id: Optional[str] = None, |
1647 |
| - ): |
1648 |
| - # just like all other parameters, does not yet |
1649 |
| - # support loading bias with weight_loader_v2 |
1650 |
| - layer = self.q_proj_decoder if loaded_shard_id == "q" else self.kv_proj_encoder |
1651 |
| - target_param = self.select_proj_params(layer, param) |
1652 |
| - shard_id_args = (loaded_shard_id,) if loaded_shard_id != "q" else () |
1653 |
| - layer.weight_loader(target_param, loaded_weight, *shard_id_args) |
1654 |
| - |
1655 |
| - def weight_loader( |
1656 |
| - self, |
1657 |
| - param: torch.nn.Parameter, |
1658 |
| - loaded_weight: torch.Tensor, |
1659 |
| - loaded_shard_id: Optional[str] = None, |
1660 |
| - ): |
1661 |
| - layer = self.q_proj_decoder if loaded_shard_id == "q" else self.kv_proj_encoder |
1662 |
| - target_param = self.select_proj_params(layer, param) |
1663 |
| - shard_id_args = (loaded_shard_id,) if loaded_shard_id != "q" else () |
1664 |
| - if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED: |
1665 |
| - layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args) |
1666 |
| - else: |
1667 |
| - layer.weight_loader(target_param, loaded_weight, *shard_id_args) |
1668 |
| - |
1669 |
| - def extra_repr(self) -> str: |
1670 |
| - s = f"in_features={self.input_size}" |
1671 |
| - s += f", q_size={self.q_size}" |
1672 |
| - s += f", kv_size={self.kv_size}" |
1673 |
| - s += f", bias={self.bias is not None}" |
1674 |
| - s += f", tp_size={get_tensor_model_parallel_world_size()}" |
1675 |
| - s += ", gather_output=False" |
1676 |
| - return s |
0 commit comments