Skip to content

Commit 2d78aed

Browse files
Merge pull request #74 from IBM/main
[pull] main from IBM:main
2 parents 102f77d + deb99f6 commit 2d78aed

File tree

5 files changed

+29
-16
lines changed

5 files changed

+29
-16
lines changed

server/poetry.lock

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

server/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ onnxruntime-gpu = { version = "^1.17.1", optional = true }
3030
onnx = { version = "^1.16.0", optional = true }
3131
einops = "^0.7.0"
3232
ibm-fms = { version = "^0.0", optional = true }
33-
fms-extras = { git = "https://github.com/foundation-model-stack/fms-extras", rev = "a010516ff2c938c206b9b342b16bd747ef07d43c", optional = true }
33+
fms-extras = { git = "https://github.com/foundation-model-stack/fms-extras", rev = "d41f8a34c9841aa3c4c59f17b5e7f3cb365f49de", optional = true }
3434

3535
# Explicitly install some transitive dependencies to avoid CVEs
3636
jinja2 = ">=3.1.3"

server/text_generation_server/inference_engine/tgis_native.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,11 @@ def __init__(
101101
model_class = FlashRWForCausalLM
102102

103103
elif model_type == "llama":
104+
# See: https://github.com/ibm-granite/vllm_granite/blob/main/vllm/model_executor/models/llama.py#L353-L354
105+
if self._config.tie_word_embeddings:
106+
aliases = {
107+
"lm_head.weight": ["model.embed_tokens.weight"]
108+
}
104109
if PAGED_ATTENTION:
105110
from text_generation_server.models.custom_modeling.paged_llama_modeling import PagedLlamaForCausalLM
106111
model_class = PagedLlamaForCausalLM

server/text_generation_server/models/custom_modeling/flash_llama_modeling.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def __init__(
6464
tie_word_embeddings=False,
6565
rope_scaling=None,
6666
rope_theta=10000.0,
67+
attention_bias=False,
68+
mlp_bias=False,
6769
**kwargs,
6870
):
6971
self.vocab_size = vocab_size
@@ -85,6 +87,8 @@ def __init__(
8587
self.use_cache = use_cache
8688
self.rope_scaling = rope_scaling
8789
self.rope_theta = rope_theta
90+
self.attention_bias = attention_bias
91+
self.mlp_bias = mlp_bias
8892

8993
super().__init__(
9094
pad_token_id=pad_token_id,
@@ -169,7 +173,7 @@ def _load_gqa(config, prefix: str, weights):
169173
config.hidden_size,
170174
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
171175

172-
return TensorParallelColumnLinear(get_linear(weight, bias=None, quantize=config.quantize))
176+
return TensorParallelColumnLinear(get_linear(weight, bias=config.attention_bias, quantize=config.quantize))
173177

174178

175179
class FlashLlamaAttention(torch.nn.Module):
@@ -220,13 +224,13 @@ def __init__(
220224
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
221225
dim=0,
222226
weights=weights,
223-
bias=False,
227+
bias=config.attention_bias,
224228
)
225229
self.o_proj = TensorParallelRowLinear.load(
226230
config,
227231
prefix=f"{prefix}.o_proj",
228232
weights=weights,
229-
bias=False,
233+
bias=config.attention_bias,
230234
)
231235

232236
def forward(
@@ -309,13 +313,13 @@ def __init__(self, prefix, config, weights):
309313
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
310314
weights=weights,
311315
dim=0,
312-
bias=False,
316+
bias=config.mlp_bias,
313317
)
314318
self.down_proj = TensorParallelRowLinear.load(
315319
config,
316320
prefix=f"{prefix}.down_proj",
317321
weights=weights,
318-
bias=False,
322+
bias=config.mlp_bias,
319323
)
320324
self.intermediate_size = (
321325
config.intermediate_size // weights.process_group.size()

server/text_generation_server/models/custom_modeling/paged_llama_modeling.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def __init__(
6464
tie_word_embeddings=False,
6565
rope_scaling=None,
6666
rope_theta=10000.0,
67+
attention_bias=False,
68+
mlp_bias=False,
6769
**kwargs,
6870
):
6971
self.vocab_size = vocab_size
@@ -85,6 +87,8 @@ def __init__(
8587
self.use_cache = use_cache
8688
self.rope_scaling = rope_scaling
8789
self.rope_theta = rope_theta
90+
self.attention_bias = attention_bias
91+
self.mlp_bias = mlp_bias
8892

8993
super().__init__(
9094
pad_token_id=pad_token_id,
@@ -169,7 +173,7 @@ def _load_gqa(config, prefix: str, weights):
169173
config.hidden_size,
170174
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
171175

172-
return TensorParallelColumnLinear(get_linear(weight, bias=None, quantize=config.quantize))
176+
return TensorParallelColumnLinear(get_linear(weight, bias=config.attention_bias, quantize=config.quantize))
173177

174178

175179
class PagedLlamaAttention(torch.nn.Module):
@@ -207,13 +211,13 @@ def __init__(
207211
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
208212
dim=0,
209213
weights=weights,
210-
bias=False,
214+
bias=config.attention_bias,
211215
)
212216
self.o_proj = TensorParallelRowLinear.load(
213217
config,
214218
prefix=f"{prefix}.o_proj",
215219
weights=weights,
216-
bias=False,
220+
bias=config.attention_bias,
217221
)
218222

219223
def forward(
@@ -280,13 +284,13 @@ def __init__(self, prefix, config, weights):
280284
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
281285
weights=weights,
282286
dim=0,
283-
bias=False,
287+
bias=config.mlp_bias,
284288
)
285289
self.down_proj = TensorParallelRowLinear.load(
286290
config,
287291
prefix=f"{prefix}.down_proj",
288292
weights=weights,
289-
bias=False,
293+
bias=config.mlp_bias,
290294
)
291295
self.intermediate_size = (
292296
config.intermediate_size // weights.process_group.size()

0 commit comments

Comments
 (0)