Skip to content

Commit 2ab68be

Browse files
authored
Keep vocab probs input precision for UnigramTokenizer. This fixes ARM plugin (#600)
1 parent 01c913b commit 2ab68be

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

python/openvino_tokenizers/tokenizer_pipeline.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -820,10 +820,14 @@ def from_hf_json(cls, tokenizer_json: dict[str, Any]) -> "UnigramModelStep":
820820
)
821821

822822
def get_ov_subgraph(self, input_nodes: list[Output]) -> list[Output]:
823+
# Keep precision and not compress to f16 on ARM devices.
824+
const_vocab_logprobs_node = make_constant_node(np.array(self.vocab_logprobs, dtype=np.float32), Type.f32)
825+
const_vocab_logprobs_node.get_rt_info()["precise_0"] = ""
826+
823827
input_nodes.extend(
824828
(
825829
*create_string_constant_node(self.vocab),
826-
make_constant_node(np.array(self.vocab_logprobs, dtype=np.float32), Type.f32),
830+
const_vocab_logprobs_node,
827831
)
828832
)
829833
return (

0 commit comments

Comments
 (0)