@@ -328,6 +328,7 @@ class BaseModel:
328328class EmbeddingModel (BaseModel ):
329329 max_input_tokens : int
330330 max_batch_tokens : int
331+ max_batch_size : int | None
331332 max_output_dimensions : int
332333 supports_shortening : bool
333334
@@ -339,6 +340,9 @@ class EmbeddingModel (BaseModel):
339340 max_batch_tokens_annotation : ClassVar [str ] = (
340341 "ext::ai::embedding_model_max_batch_tokens"
341342 )
343+ max_batch_size_annotation : ClassVar [str ] = (
344+ "ext::ai::embedding_model_max_batch_size"
345+ )
342346 max_output_dimensions_annotation : ClassVar [str ] = (
343347 "ext::ai::embedding_model_max_output_dimensions"
344348 )
@@ -846,6 +850,8 @@ async def _generate_embeddings_params(
846850 embeddings_params : list [EmbeddingsParams ] = []
847851
848852 for model_name , pending_entries in model_pending_entries .items ():
853+ embedding_model = embedding_models [model_name ]
854+
849855 groups = itertools .groupby (
850856 pending_entries , key = lambda e : e .target_dims_shortening
851857 )
@@ -856,8 +862,9 @@ async def _generate_embeddings_params(
856862 batches , excluded_indexes = batch_texts (
857863 part_texts ,
858864 get_model_tokenizer (provider_name , model_name ),
859- embedding_models [model_name ].max_input_tokens ,
860- embedding_models [model_name ].max_batch_tokens ,
865+ max_input_tokens = embedding_model .max_input_tokens ,
866+ max_batch_tokens = embedding_model .max_batch_tokens ,
867+ max_batch_size = embedding_model .max_batch_size ,
861868 )
862869
863870 if excluded_indexes :
@@ -908,8 +915,10 @@ class TextBatch:
908915def batch_texts (
909916 texts : list [tuple [str , bool ]],
910917 tokenizer : Optional [Tokenizer ],
918+ * ,
911919 max_input_tokens : int ,
912920 max_batch_tokens : int ,
921+ max_batch_size : int | None ,
913922) -> tuple [list [TextBatch ], list [int ]]:
914923 """Given a list of texts and whether each can be truncated, produce a list
915924 of valid texts to batch.
@@ -942,7 +951,7 @@ def batch_texts(
942951
943952 # Group the valid texts into batches based on token count
944953 batched_inputs = _batch_embeddings_inputs (
945- tokenizer , input_texts , max_batch_tokens
954+ tokenizer , input_texts , max_batch_tokens , max_batch_size
946955 )
947956
948957 # Gather results
@@ -960,6 +969,25 @@ def batch_texts(
960969 for batch_input_indexes , token_count in batched_inputs
961970 ]
962971
972+ elif max_batch_size :
973+ batch_count = (len (texts ) - 1 ) // max_batch_size + 1
974+ batches = [
975+ TextBatch (
976+ entries = [
977+ TextBatchEntry (
978+ input_index = index ,
979+ input_text = texts [index ][0 ],
980+ )
981+ for index in range (
982+ batch_index * max_batch_size ,
983+ min ((batch_index + 1 ) * max_batch_size , len (texts ))
984+ )
985+ ],
986+ token_count = 0 ,
987+ )
988+ for batch_index in range (batch_count )
989+ ]
990+
963991 else :
964992 batches = [
965993 TextBatch (
@@ -1099,6 +1127,7 @@ def _batch_embeddings_inputs(
10991127 tokenizer : Tokenizer ,
11001128 inputs : list [str ],
11011129 max_batch_tokens : int ,
1130+ max_batch_size : int | None ,
11021131) -> list [tuple [list [int ], int ]]:
11031132 """Create batches of embeddings inputs.
11041133
@@ -1140,9 +1169,15 @@ def unbatched_token_count(unbatched_index: int) -> int:
11401169
11411170 if batch_token_count < max_batch_tokens :
11421171 # Then add the smallest available input as long as long as the
1143- # max batch token count isn 't exceeded
1172+ # max batch token and input counts aren 't exceeded
11441173 unbatched_index = 0
1145- while unbatched_index < len (unbatched_input_indexes ):
1174+ while (
1175+ unbatched_index < len (unbatched_input_indexes )
1176+ and (
1177+ max_batch_size is None
1178+ or len (batch_input_indexes ) < max_batch_size
1179+ )
1180+ ):
11461181 if (
11471182 batch_token_count + unbatched_token_count (unbatched_index )
11481183 <= max_batch_tokens
@@ -3124,6 +3159,7 @@ async def _get_embedding_models(
31243159 EmbeddingModel .provider_annotation ,
31253160 EmbeddingModel .max_model_input_tokens_annotation ,
31263161 EmbeddingModel .max_batch_tokens_annotation ,
3162+ EmbeddingModel .max_batch_size_annotation ,
31273163 EmbeddingModel .max_output_dimensions_annotation ,
31283164 EmbeddingModel .supports_shortening_annotation ,
31293165 ],
@@ -3143,6 +3179,20 @@ def _get_ann(
31433179 )
31443180 return val
31453181
3182+ def _get_bool_ann (
3183+ model : str ,
3184+ anns : dict [str , str | None ],
3185+ name : str ,
3186+ ) -> bool :
3187+ val = _get_ann (model , anns , name )
3188+ try :
3189+ return bool (val )
3190+ except ValueError :
3191+ raise InternalError (
3192+ f"Model '{ model } ' annotation '{ name } ' "
3193+ f"has non boolean value { val } "
3194+ )
3195+
31463196 def _get_int_ann (
31473197 model : str ,
31483198 anns : dict [str , str | None ],
@@ -3157,18 +3207,20 @@ def _get_int_ann(
31573207 f"has non integer value { val } "
31583208 )
31593209
3160- def _get_bool_ann (
3210+ def _get_int_or_none_ann (
31613211 model : str ,
31623212 anns : dict [str , str | None ],
31633213 name : str ,
3164- ) -> bool :
3214+ ) -> int | None :
31653215 val = _get_ann (model , anns , name )
3216+ if val == "<optional>" :
3217+ return None
31663218 try :
3167- return bool (val )
3219+ return int (val )
31683220 except ValueError :
31693221 raise InternalError (
31703222 f"Model '{ model } ' annotation '{ name } ' "
3171- f"has non boolean value { val } "
3223+ f"has non integer value { val } "
31723224 )
31733225
31743226 result : dict [str , EmbeddingModel ] = {}
@@ -3182,6 +3234,9 @@ def _get_bool_ann(
31823234 max_batch_tokens = _get_int_ann (
31833235 model , anns , EmbeddingModel .max_batch_tokens_annotation
31843236 ),
3237+ max_batch_size = _get_int_or_none_ann (
3238+ model , anns , EmbeddingModel .max_batch_size_annotation
3239+ ),
31853240 max_output_dimensions = _get_int_ann (
31863241 model , anns , EmbeddingModel .max_output_dimensions_annotation
31873242 ),
@@ -3427,8 +3482,6 @@ async def generate_embeddings_for_texts(
34273482 embedding_model = embedding_models [model_name ]
34283483
34293484 tokenizer = get_model_tokenizer (provider , model_name )
3430- max_input_tokens = embedding_model .max_input_tokens
3431- max_batch_tokens = embedding_model .max_batch_tokens
34323485
34333486 texts = [
34343487 (
@@ -3441,8 +3494,9 @@ async def generate_embeddings_for_texts(
34413494 text_batches , excluded_indexes = batch_texts (
34423495 texts ,
34433496 tokenizer ,
3444- max_input_tokens ,
3445- max_batch_tokens ,
3497+ max_input_tokens = embedding_model .max_input_tokens ,
3498+ max_batch_tokens = embedding_model .max_batch_tokens ,
3499+ max_batch_size = embedding_model .max_batch_size ,
34463500 )
34473501
34483502 if excluded_indexes or too_long :
0 commit comments