Skip to content

Commit d1ddf34

Browse files
authored
[V0 deprecation] Remove QKVCrossParallelLinear implementation (vllm-project#26475)
Signed-off-by: Isotr0py <[email protected]>
1 parent ec10fd0 commit d1ddf34

File tree

4 files changed

+2
-255
lines changed

4 files changed

+2
-255
lines changed

vllm/lora/layers/qkv_x_parallel_linear.py

Lines changed: 0 additions & 8 deletions
This file was deleted.

vllm/model_executor/layers/linear.py

Lines changed: 1 addition & 236 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33

44
import itertools
55
from abc import abstractmethod
6-
from typing import Any, Literal, Optional, Union
6+
from typing import Any, Optional, Union
77

88
import torch
9-
import torch.nn as nn
109
from torch.nn.parameter import Parameter, UninitializedParameter
1110

1211
from vllm.distributed import (
@@ -1440,237 +1439,3 @@ def extra_repr(self) -> str:
14401439
s += f", tp_size={self.tp_size}"
14411440
s += f", reduce_results={self.reduce_results}"
14421441
return s
1443-
1444-
1445-
@CustomOp.register("qkv_cross_parallel_linear")
1446-
class QKVCrossParallelLinear(LinearBase):
1447-
"""Linear layers for efficient cross-attention's QKV transformation.
1448-
1449-
Args:
1450-
hidden_size: input hidden state size of the transformer.
1451-
head_size: size of each attention head.
1452-
total_num_heads: total number of attention query heads.
1453-
total_num_kv_heads: total number of attention key/value heads. If
1454-
None, assume total_num_kv_heads = total_num_heads.
1455-
bias: If true, add bias.
1456-
skip_bias_add: This was added to enable performance optimizations where
1457-
bias can be fused with other element-wise operations. we
1458-
skip adding bias but instead return it.
1459-
params_dtype: Data type for the parameters.
1460-
quant_config: Quantization configure.
1461-
prefix: The name of the layer in the state dict, including all parents
1462-
(e.g. model.layers.0.qkv_proj)
1463-
"""
1464-
1465-
def __init__(
1466-
self,
1467-
hidden_size: int,
1468-
head_size: int,
1469-
total_num_heads: int,
1470-
total_num_kv_heads: Optional[int] = None,
1471-
bias: bool = True,
1472-
skip_bias_add: bool = False,
1473-
params_dtype: Optional[torch.dtype] = None,
1474-
quant_config: Optional[QuantizationConfig] = None,
1475-
prefix: str = "",
1476-
):
1477-
# input_size and output_size are not used, just for alignment
1478-
input_size = hidden_size
1479-
output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size
1480-
super().__init__(
1481-
input_size=input_size,
1482-
output_size=output_size,
1483-
skip_bias_add=skip_bias_add,
1484-
params_dtype=params_dtype,
1485-
quant_config=quant_config,
1486-
prefix=prefix,
1487-
)
1488-
1489-
self.quant_config = quant_config
1490-
1491-
# Empty placeholders for loading as a single module.
1492-
placeholder_size = 0
1493-
assert self.quant_method is not None
1494-
self.quant_method.create_weights(
1495-
self,
1496-
placeholder_size,
1497-
[placeholder_size],
1498-
placeholder_size,
1499-
placeholder_size,
1500-
self.params_dtype,
1501-
weight_loader=self.weight_loader,
1502-
)
1503-
1504-
# Use a dictionary to avoid submodules parameters auto-registration:
1505-
# drop-in replacement for a `QKVParallelLinear` module.
1506-
self.proj = dict()
1507-
self.proj["q_proj_decoder"] = ColumnParallelLinear(
1508-
input_size=hidden_size,
1509-
output_size=total_num_heads * head_size,
1510-
bias=bias,
1511-
quant_config=quant_config,
1512-
skip_bias_add=skip_bias_add,
1513-
params_dtype=params_dtype,
1514-
prefix=f"{prefix}.q_proj_decoder",
1515-
)
1516-
1517-
self.proj["kv_proj_encoder"] = QKVParallelLinear(
1518-
hidden_size=hidden_size,
1519-
head_size=head_size,
1520-
total_num_heads=0,
1521-
total_num_kv_heads=total_num_kv_heads,
1522-
bias=bias,
1523-
quant_config=quant_config,
1524-
skip_bias_add=skip_bias_add,
1525-
params_dtype=params_dtype,
1526-
prefix=f"{prefix}.kv_proj_encoder",
1527-
)
1528-
1529-
# `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
1530-
self.q_size = self.q_proj_decoder.output_size_per_partition
1531-
self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size
1532-
1533-
if bias:
1534-
self.bias = torch.nn.Parameter()
1535-
set_weight_attrs(
1536-
self.bias,
1537-
{
1538-
"output_dim": 0,
1539-
"weight_loader": self.weight_loader_v1,
1540-
},
1541-
)
1542-
else:
1543-
self.bias = None
1544-
1545-
def process_weights_after_loading(self):
1546-
for layer in self.proj.values():
1547-
if self.quant_method is not None:
1548-
self.quant_method.process_weights_after_loading(layer)
1549-
1550-
@property
1551-
def q_proj_decoder(self) -> ColumnParallelLinear:
1552-
layer = self.proj["q_proj_decoder"]
1553-
for name, param in self.named_parameters():
1554-
target_param = getattr(layer, name, None)
1555-
if target_param is not None:
1556-
self.sync_weight_attrs(param, target_param, mode="q_proj_decoder")
1557-
return layer
1558-
1559-
@property
1560-
def kv_proj_encoder(self) -> QKVParallelLinear:
1561-
layer = self.proj["kv_proj_encoder"]
1562-
for name, param in self.named_parameters():
1563-
target_param = getattr(layer, name, None)
1564-
if target_param is not None:
1565-
self.sync_weight_attrs(param, target_param, mode="kv_proj_encoder")
1566-
return layer
1567-
1568-
def sync_weight_attrs(
1569-
self,
1570-
src_param: nn.Parameter,
1571-
tgt_param: nn.Parameter,
1572-
mode: Literal["q_proj_decoder", "kv_proj_encoder"],
1573-
):
1574-
missing_attrs_dict = {
1575-
k: getattr(src_param, k)
1576-
for k in (set(vars(src_param).keys()) - set(vars(tgt_param).keys()))
1577-
}
1578-
# TODO(Isotr0py): handle bitsandbytes 8bit
1579-
use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit", False)
1580-
if missing_attrs_dict and use_bitsandbytes_4bit:
1581-
q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard(
1582-
missing_attrs_dict
1583-
)
1584-
if mode == "q_proj_decoder":
1585-
set_weight_attrs(tgt_param, q_proj_attrs)
1586-
elif mode == "kv_proj_encoder":
1587-
set_weight_attrs(tgt_param, kv_proj_attrs)
1588-
else:
1589-
set_weight_attrs(tgt_param, missing_attrs_dict)
1590-
1591-
def _is_same_param(
1592-
self,
1593-
src_param: torch.nn.Parameter,
1594-
map_param: torch.nn.Parameter,
1595-
) -> bool:
1596-
"""Check if two parameters are exactly pointing to same things."""
1597-
# ignore weight_loader because it's always different
1598-
key_to_ignore = ["weight_loader", "_weight_loader"]
1599-
has_same_type_name = type(src_param) is type(map_param)
1600-
src_param_attrs = {
1601-
k: v for k, v in src_param.__dict__.items() if k not in key_to_ignore
1602-
}
1603-
map_param_attrs = {
1604-
k: v for k, v in map_param.__dict__.items() if k not in key_to_ignore
1605-
}
1606-
has_same_attrs = src_param_attrs == map_param_attrs
1607-
return has_same_type_name and has_same_attrs
1608-
1609-
def select_proj_params(
1610-
self,
1611-
layer: nn.Module,
1612-
param: nn.Parameter,
1613-
) -> nn.Parameter:
1614-
"""
1615-
Given the placeholder param,
1616-
return the corresponding param in the proj layers.
1617-
"""
1618-
target_param_list = [
1619-
v for _, v in layer.named_parameters() if self._is_same_param(param, v)
1620-
]
1621-
assert len(target_param_list) == 1
1622-
target_param = target_param_list[0]
1623-
return target_param
1624-
1625-
def forward( # type: ignore[override]
1626-
self,
1627-
decoder_hidden_states: torch.Tensor,
1628-
encoder_hidden_states: torch.Tensor,
1629-
) -> tuple[torch.Tensor, ...]:
1630-
q, _ = self.q_proj_decoder(decoder_hidden_states)
1631-
if encoder_hidden_states is None:
1632-
# Encoder KV already cached.
1633-
k = None
1634-
v = None
1635-
else:
1636-
# Prefill phase, encoder KV cached here.
1637-
kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states)
1638-
# Split kv in half
1639-
k, v = kv_enc.split(self.kv_size, dim=-1)
1640-
return q, k, v
1641-
1642-
def weight_loader_v1(
1643-
self,
1644-
param: torch.nn.Parameter,
1645-
loaded_weight: torch.Tensor,
1646-
loaded_shard_id: Optional[str] = None,
1647-
):
1648-
# just like all other parameters, does not yet
1649-
# support loading bias with weight_loader_v2
1650-
layer = self.q_proj_decoder if loaded_shard_id == "q" else self.kv_proj_encoder
1651-
target_param = self.select_proj_params(layer, param)
1652-
shard_id_args = (loaded_shard_id,) if loaded_shard_id != "q" else ()
1653-
layer.weight_loader(target_param, loaded_weight, *shard_id_args)
1654-
1655-
def weight_loader(
1656-
self,
1657-
param: torch.nn.Parameter,
1658-
loaded_weight: torch.Tensor,
1659-
loaded_shard_id: Optional[str] = None,
1660-
):
1661-
layer = self.q_proj_decoder if loaded_shard_id == "q" else self.kv_proj_encoder
1662-
target_param = self.select_proj_params(layer, param)
1663-
shard_id_args = (loaded_shard_id,) if loaded_shard_id != "q" else ()
1664-
if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED:
1665-
layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args)
1666-
else:
1667-
layer.weight_loader(target_param, loaded_weight, *shard_id_args)
1668-
1669-
def extra_repr(self) -> str:
1670-
s = f"in_features={self.input_size}"
1671-
s += f", q_size={self.q_size}"
1672-
s += f", kv_size={self.kv_size}"
1673-
s += f", bias={self.bias is not None}"
1674-
s += f", tp_size={get_tensor_model_parallel_world_size()}"
1675-
s += ", gather_output=False"
1676-
return s

vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from vllm.model_executor.layers.linear import (
1717
WEIGHT_LOADER_V2_SUPPORTED,
1818
LinearMethodBase,
19-
QKVCrossParallelLinear,
2019
)
2120
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
2221
CompressedTensorsScheme,
@@ -89,10 +88,7 @@ def create_weights(
8988
# hack around this by getting weight loader v1 so ULM can load correctly
9089
quant_method_name = self.quant_method.__class__.__name__
9190
if quant_method_name not in WEIGHT_LOADER_V2_SUPPORTED:
92-
if isinstance(layer, QKVCrossParallelLinear):
93-
weight_loader_v1 = layer.weight_loader_v1
94-
else:
95-
weight_loader_v1 = layer.weight_loader
91+
weight_loader_v1 = layer.weight_loader
9692
extra_weight_attrs["weight_loader"] = weight_loader_v1
9793

9894
self.quant_method.create_weights(

vllm/model_executor/model_loader/utils.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from vllm.attention.layer import MLAAttention
1818
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
1919
from vllm.logger import init_logger
20-
from vllm.model_executor.layers.linear import QKVCrossParallelLinear
2120
from vllm.model_executor.layers.quantization.base_config import (
2221
QuantizationConfig,
2322
QuantizeMethodBase,
@@ -108,11 +107,6 @@ def process_weights_after_loading(
108107
maybe_save_metadata_and_attributes_for_weight_reloading(model, model_config)
109108

110109
for _, module in model.named_modules():
111-
if isinstance(module, QKVCrossParallelLinear):
112-
# NOTE(Isotr0py): special case for cross QKV layer because
113-
# q and kv proj aren't registered as submodules intentionally
114-
module.process_weights_after_loading()
115-
continue
116110
quant_method = getattr(module, "quant_method", None)
117111
if isinstance(quant_method, QuantizeMethodBase):
118112
# When quant methods need to process weights after loading

0 commit comments

Comments
 (0)