Skip to content

Commit 9426cd6

Browse files
committed
add some more tokenizer max length checks
Signed-off-by: HenryL27 <[email protected]>
1 parent 077e7e8 commit 9426cd6

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

opensearch_py_ml/ml_models/crossencodermodel.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,17 @@ def zip_model(self, framework: str = "pt", zip_fname: str = "model.zip") -> Path
120120
# save tokenizer file
121121
tk_path = Path(f"/tmp/{mname}-tokenizer")
122122
tk.save_pretrained(tk_path)
123+
if tk.model_max_length > model.get_max_length():
124+
model_config = AutoConfig.from_pretrained(self._hf_model_id)
125+
if hasattr(model_config, "max_position_embeddings"):
126+
tk.model_max_length = model_config.max_position_embeddings
127+
elif hasattr(model_config, "n_positions"):
128+
tk.model_max_length = model_config.n_positions
129+
else:
130+
tk.model_max_length = 2**15 # =32768. Set to something big I guess
131+
print(
132+
f"The model_max_length is not properly defined in tokenizer_config.json. Setting it to be {tk.model_max_length}"
133+
)
123134
_fix_tokenizer(tk.model_max_length, tk_path)
124135

125136
# get apache license

0 commit comments

Comments
 (0)