Skip to content

Commit 6c1da4c

Browse files
committed
fix: updated ways of getting obtaining logits_scaling
Signed-off-by: omobayode.fagbohungbe <[email protected]>
1 parent bf53238 commit 6c1da4c

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

fms_mo/utils/eval_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def eval_llm_1GPU(qcfg, model, test_dataset, pre_cache_func=None, **kwargs): #
9898
logger.info("All blocks are computed for evaluation")
9999

100100
nlls = []
101+
logits_scaling = getattr(model.config, "logits_scaling", 1)
101102
# for i, data_mb in enumerate(dloader): #if using dloader.
102103
for i in tqdm(range(qcfg["n_samples"]), desc="Final Evaluating..."):
103104
hidden_states = qcfg["cached_input"][i].to(dev)
@@ -106,9 +107,7 @@ def eval_llm_1GPU(qcfg, model, test_dataset, pre_cache_func=None, **kwargs): #
106107
hidden_states = ln_f(hidden_states)
107108
lm_head.to(dev)
108109
lm_logits = lm_head(hidden_states)
109-
110-
if model.config.model_type == "granite":
111-
lm_logits /= model.config.logits_scaling
110+
lm_logits /= logits_scaling
112111

113112
# Shift so that tokens < n predict n
114113
shift_logits = lm_logits[:, :-1, :].contiguous().float()

0 commit comments

Comments
 (0)