Skip to content

Commit 2c454e4

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents 2d78aed + e87d462 commit 2c454e4

File tree

3 files changed

+19
-8
lines changed

3 files changed

+19
-8
lines changed

server/text_generation_server/models/custom_modeling/flash_llama_modeling.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,9 @@ def _load_gqa(config, prefix: str, weights):
156156
assert config.hidden_size % config.num_attention_heads == 0
157157
assert config.num_attention_heads % weights.process_group.size() == 0
158158

159+
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
159160
weight = weights.get_multi_weights_col(
160-
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
161-
quantize=config.quantize,
162-
dim=0
161+
prefixes=prefixes, quantize=config.quantize, dim=0
163162
)
164163

165164
if config.quantize != "gptq":
@@ -173,7 +172,12 @@ def _load_gqa(config, prefix: str, weights):
173172
config.hidden_size,
174173
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
175174

176-
return TensorParallelColumnLinear(get_linear(weight, bias=config.attention_bias, quantize=config.quantize))
175+
if config.attention_bias:
176+
bias = torch.cat([weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes], dim=0)
177+
else:
178+
bias = None
179+
180+
return TensorParallelColumnLinear(get_linear(weight, bias=bias, quantize=config.quantize))
177181

178182

179183
class FlashLlamaAttention(torch.nn.Module):

server/text_generation_server/models/custom_modeling/paged_llama_modeling.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,9 @@ def _load_gqa(config, prefix: str, weights):
156156
assert config.hidden_size % config.num_attention_heads == 0
157157
assert config.num_attention_heads % weights.process_group.size() == 0
158158

159+
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
159160
weight = weights.get_multi_weights_col(
160-
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
161-
quantize=config.quantize,
162-
dim=0
161+
prefixes=prefixes, quantize=config.quantize, dim=0
163162
)
164163

165164
if config.quantize != "gptq":
@@ -173,7 +172,12 @@ def _load_gqa(config, prefix: str, weights):
173172
config.hidden_size,
174173
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
175174

176-
return TensorParallelColumnLinear(get_linear(weight, bias=config.attention_bias, quantize=config.quantize))
175+
if config.attention_bias:
176+
bias = torch.cat([weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes], dim=0)
177+
else:
178+
bias = None
179+
180+
return TensorParallelColumnLinear(get_linear(weight, bias=bias, quantize=config.quantize))
177181

178182

179183
class PagedLlamaAttention(torch.nn.Module):

server/text_generation_server/models/paged_causal_lm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,9 @@ def __init__(
333333
total_num_gpu_blocks=total_num_gpu_blocks,
334334
)
335335

336+
# log number of free blocks at init
337+
print("[PagedKVCacheManager] number of free blocks: %d" % (len(self.kv_cache_manager.free_blocks)))
338+
336339
@property
337340
def batch_type(self) -> Type[PagedCausalLMBatch]:
338341
return self._batch_type

0 commit comments

Comments
 (0)