Skip to content

Commit 372f5c5

Browse files
authored
fix jamba accuracy test failure (#3466)
1 parent 77a98ab commit 372f5c5

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

examples/cpu/llm/inference/distributed/run_accuracy_with_deepspeed.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,8 @@ def _model_call(
730730
example_dict["output_router_logits"] = torch.tensor(
731731
model_inputs["output_router_logits"]
732732
)
733+
if self.config.architectures[0] == "JambaForCausalLM":
734+
example_dict["num_logits_to_keep"] = torch.tensor(0)
733735

734736
with torch.inference_mode(), torch.no_grad(), torch.cpu.amp.autocast(
735737
enabled=True if args.quant_with_amp or self._dtype == "bfloat16" else False,

examples/cpu/llm/inference/single_instance/run_accuracy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,8 @@ def _model_call(
431431
example_dict["output_router_logits"] = torch.tensor(
432432
model_inputs["output_router_logits"]
433433
)
434+
if self.config.architectures[0] == "JambaForCausalLM":
435+
example_dict["num_logits_to_keep"] = torch.tensor(0)
434436

435437
with torch.inference_mode(), torch.no_grad(), torch.cpu.amp.autocast(
436438
enabled=True if args.quant_with_amp or self._dtype == "bfloat16" else False,

0 commit comments

Comments
 (0)