Skip to content

Commit 7bd8ac4

Browse files
committed
fix: corrected error associated with eval_llm_1GPU failing with granite-3-models
Signed-off-by: omobayode.fagbohungbe <[email protected]>
1 parent 7777b49 commit 7bd8ac4

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

fms_mo/quant/ptq.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2140,14 +2140,22 @@ def get_blocks(model, model_type=None):
21402140
None,
21412141
"lm_head",
21422142
),
2143-
"granite": (
2143+
"granite_old": (
21442144
"transformer.h",
21452145
"transformer.wte",
21462146
"transformer.wpe",
21472147
None,
21482148
"transformer.ln_f",
21492149
"lm_head",
21502150
),
2151+
"granite": (
2152+
"model.layers",
2153+
"model.embed_tokens",
2154+
"model.rotary_emb",
2155+
None,
2156+
"model.norm",
2157+
"lm_head",
2158+
),
21512159
"llama": (
21522160
"model.layers",
21532161
"model.embed_tokens",

fms_mo/utils/eval_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ def eval_llm_1GPU(qcfg, model, test_dataset, pre_cache_func=None, **kwargs): #
107107
lm_head.to(dev)
108108
lm_logits = lm_head(hidden_states)
109109

110+
if model.config.model_type == "granite":
111+
lm_logits /= model.config.logits_scaling
112+
110113
# Shift so that tokens < n predict n
111114
shift_logits = lm_logits[:, :-1, :].contiguous().float()
112115
shift_labels = test_dataset.input_ids[:, (i * seq_len) : ((i + 1) * seq_len)][

0 commit comments

Comments
 (0)