Skip to content

Commit ba913b8

Browse files
authored
Add support for max batch size to AI embedding models. (#9040)
Adds annotation `ext::ai::embedding_model_max_batch_size` which limits the number of batches which can be part of a single batched embedding request.
1 parent ec67872 commit ba913b8

File tree

5 files changed

+111
-21
lines changed

5 files changed

+111
-21
lines changed

docs/reference/ai/extai.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,7 @@ instantiated as a :eql:type:`ext::ai::TextGenerationModel`
409409

410410
* ``embedding_model_max_input_tokens`` - Maximum tokens per input
411411
* ``embedding_model_max_batch_tokens`` - Maximum tokens per batch. Default: ``'8191'``.
412+
* ``embedding_model_max_batch_size`` - Maximum inputs per batch. Optional.
412413
* ``embedding_model_max_output_dimensions`` - Maximum embedding dimensions
413414
* ``embedding_model_supports_shortening`` - Input shortening support flag
414415

edb/buildmeta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
# The merge conflict there is a nice reminder that you probably need
5858
# to write a patch in edb/pgsql/patches.py, and then you should preserve
5959
# the old value.
60-
EDGEDB_CATALOG_VERSION = 2025_09_09_00_00
60+
EDGEDB_CATALOG_VERSION = 2025_09_23_00_00
6161
EDGEDB_MAJOR_VERSION = 8
6262

6363

edb/lib/ext/ai.edgeql

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,9 @@ CREATE EXTENSION PACKAGE ai VERSION '1.0' {
198198
create abstract inheritable annotation
199199
ext::ai::embedding_model_max_batch_tokens;
200200

201+
create abstract inheritable annotation
202+
ext::ai::embedding_model_max_batch_size;
203+
201204
create abstract inheritable annotation
202205
ext::ai::embedding_model_max_output_dimensions;
203206

@@ -212,6 +215,8 @@ CREATE EXTENSION PACKAGE ai VERSION '1.0' {
212215
# for now, use the openai batch limit as the default.
213216
create annotation
214217
ext::ai::embedding_model_max_batch_tokens := "8191";
218+
create annotation
219+
ext::ai::embedding_model_max_batch_size := "<optional>";
215220
create annotation
216221
ext::ai::embedding_model_max_output_dimensions := "<must override>";
217222
create annotation

edb/server/protocol/ai_ext.py

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ class BaseModel:
328328
class EmbeddingModel (BaseModel):
329329
max_input_tokens: int
330330
max_batch_tokens: int
331+
max_batch_size: int | None
331332
max_output_dimensions: int
332333
supports_shortening: bool
333334

@@ -339,6 +340,9 @@ class EmbeddingModel (BaseModel):
339340
max_batch_tokens_annotation: ClassVar[str] = (
340341
"ext::ai::embedding_model_max_batch_tokens"
341342
)
343+
max_batch_size_annotation: ClassVar[str] = (
344+
"ext::ai::embedding_model_max_batch_size"
345+
)
342346
max_output_dimensions_annotation: ClassVar[str] = (
343347
"ext::ai::embedding_model_max_output_dimensions"
344348
)
@@ -846,6 +850,8 @@ async def _generate_embeddings_params(
846850
embeddings_params: list[EmbeddingsParams] = []
847851

848852
for model_name, pending_entries in model_pending_entries.items():
853+
embedding_model = embedding_models[model_name]
854+
849855
groups = itertools.groupby(
850856
pending_entries, key=lambda e: e.target_dims_shortening
851857
)
@@ -856,8 +862,9 @@ async def _generate_embeddings_params(
856862
batches, excluded_indexes = batch_texts(
857863
part_texts,
858864
get_model_tokenizer(provider_name, model_name),
859-
embedding_models[model_name].max_input_tokens,
860-
embedding_models[model_name].max_batch_tokens,
865+
max_input_tokens=embedding_model.max_input_tokens,
866+
max_batch_tokens=embedding_model.max_batch_tokens,
867+
max_batch_size=embedding_model.max_batch_size,
861868
)
862869

863870
if excluded_indexes:
@@ -908,8 +915,10 @@ class TextBatch:
908915
def batch_texts(
909916
texts: list[tuple[str, bool]],
910917
tokenizer: Optional[Tokenizer],
918+
*,
911919
max_input_tokens: int,
912920
max_batch_tokens: int,
921+
max_batch_size: int | None,
913922
) -> tuple[list[TextBatch], list[int]]:
914923
"""Given a list of texts and whether each can be truncated, produce a list
915924
of valid texts to batch.
@@ -942,7 +951,7 @@ def batch_texts(
942951

943952
# Group the valid texts into batches based on token count
944953
batched_inputs = _batch_embeddings_inputs(
945-
tokenizer, input_texts, max_batch_tokens
954+
tokenizer, input_texts, max_batch_tokens, max_batch_size
946955
)
947956

948957
# Gather results
@@ -960,6 +969,25 @@ def batch_texts(
960969
for batch_input_indexes, token_count in batched_inputs
961970
]
962971

972+
elif max_batch_size:
973+
batch_count = (len(texts) - 1) // max_batch_size + 1
974+
batches = [
975+
TextBatch(
976+
entries=[
977+
TextBatchEntry(
978+
input_index=index,
979+
input_text=texts[index][0],
980+
)
981+
for index in range(
982+
batch_index * max_batch_size,
983+
min((batch_index + 1) * max_batch_size, len(texts))
984+
)
985+
],
986+
token_count=0,
987+
)
988+
for batch_index in range(batch_count)
989+
]
990+
963991
else:
964992
batches = [
965993
TextBatch(
@@ -1099,6 +1127,7 @@ def _batch_embeddings_inputs(
10991127
tokenizer: Tokenizer,
11001128
inputs: list[str],
11011129
max_batch_tokens: int,
1130+
max_batch_size: int | None,
11021131
) -> list[tuple[list[int], int]]:
11031132
"""Create batches of embeddings inputs.
11041133
@@ -1140,9 +1169,15 @@ def unbatched_token_count(unbatched_index: int) -> int:
11401169

11411170
if batch_token_count < max_batch_tokens:
11421171
# Then add the smallest available input as long as long as the
1143-
# max batch token count isn't exceeded
1172+
# max batch token and input counts aren't exceeded
11441173
unbatched_index = 0
1145-
while unbatched_index < len(unbatched_input_indexes):
1174+
while (
1175+
unbatched_index < len(unbatched_input_indexes)
1176+
and (
1177+
max_batch_size is None
1178+
or len(batch_input_indexes) < max_batch_size
1179+
)
1180+
):
11461181
if (
11471182
batch_token_count + unbatched_token_count(unbatched_index)
11481183
<= max_batch_tokens
@@ -3124,6 +3159,7 @@ async def _get_embedding_models(
31243159
EmbeddingModel.provider_annotation,
31253160
EmbeddingModel.max_model_input_tokens_annotation,
31263161
EmbeddingModel.max_batch_tokens_annotation,
3162+
EmbeddingModel.max_batch_size_annotation,
31273163
EmbeddingModel.max_output_dimensions_annotation,
31283164
EmbeddingModel.supports_shortening_annotation,
31293165
],
@@ -3143,6 +3179,20 @@ def _get_ann(
31433179
)
31443180
return val
31453181

3182+
def _get_bool_ann(
3183+
model: str,
3184+
anns: dict[str, str | None],
3185+
name: str,
3186+
) -> bool:
3187+
val = _get_ann(model, anns, name)
3188+
try:
3189+
return bool(val)
3190+
except ValueError:
3191+
raise InternalError(
3192+
f"Model '{model}' annotation '{name}' "
3193+
f"has non boolean value {val}"
3194+
)
3195+
31463196
def _get_int_ann(
31473197
model: str,
31483198
anns: dict[str, str | None],
@@ -3157,18 +3207,20 @@ def _get_int_ann(
31573207
f"has non integer value {val}"
31583208
)
31593209

3160-
def _get_bool_ann(
3210+
def _get_int_or_none_ann(
31613211
model: str,
31623212
anns: dict[str, str | None],
31633213
name: str,
3164-
) -> bool:
3214+
) -> int | None:
31653215
val = _get_ann(model, anns, name)
3216+
if val == "<optional>":
3217+
return None
31663218
try:
3167-
return bool(val)
3219+
return int(val)
31683220
except ValueError:
31693221
raise InternalError(
31703222
f"Model '{model}' annotation '{name}' "
3171-
f"has non boolean value {val}"
3223+
f"has non integer value {val}"
31723224
)
31733225

31743226
result: dict[str, EmbeddingModel] = {}
@@ -3182,6 +3234,9 @@ def _get_bool_ann(
31823234
max_batch_tokens=_get_int_ann(
31833235
model, anns, EmbeddingModel.max_batch_tokens_annotation
31843236
),
3237+
max_batch_size=_get_int_or_none_ann(
3238+
model, anns, EmbeddingModel.max_batch_size_annotation
3239+
),
31853240
max_output_dimensions=_get_int_ann(
31863241
model, anns, EmbeddingModel.max_output_dimensions_annotation
31873242
),
@@ -3427,8 +3482,6 @@ async def generate_embeddings_for_texts(
34273482
embedding_model = embedding_models[model_name]
34283483

34293484
tokenizer = get_model_tokenizer(provider, model_name)
3430-
max_input_tokens = embedding_model.max_input_tokens
3431-
max_batch_tokens = embedding_model.max_batch_tokens
34323485

34333486
texts = [
34343487
(
@@ -3441,8 +3494,9 @@ async def generate_embeddings_for_texts(
34413494
text_batches, excluded_indexes = batch_texts(
34423495
texts,
34433496
tokenizer,
3444-
max_input_tokens,
3445-
max_batch_tokens,
3497+
max_input_tokens=embedding_model.max_input_tokens,
3498+
max_batch_tokens=embedding_model.max_batch_tokens,
3499+
max_batch_size=embedding_model.max_batch_size,
34463500
)
34473501

34483502
if excluded_indexes or too_long:

tests/test_ext_ai.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,23 +1355,26 @@ def test_batch_embeddings_inputs_01(self):
13551355
ai_ext._batch_embeddings_inputs(
13561356
CharacterTokenizer(),
13571357
[],
1358-
10
1358+
10,
1359+
None,
13591360
),
13601361
[],
13611362
)
13621363
self.assertEqual(
13631364
ai_ext._batch_embeddings_inputs(
13641365
CharacterTokenizer(),
13651366
['1', '22', '333', '4444'],
1366-
10
1367+
10,
1368+
None,
13671369
),
13681370
[([3, 0, 1, 2], 10)],
13691371
)
13701372
self.assertEqual(
13711373
ai_ext._batch_embeddings_inputs(
13721374
CharacterTokenizer(),
13731375
['1', '22', '333', '4444', '55555'],
1374-
10
1376+
10,
1377+
None,
13751378
),
13761379
[
13771380
([4, 0, 1], 8),
@@ -1382,7 +1385,8 @@ def test_batch_embeddings_inputs_01(self):
13821385
ai_ext._batch_embeddings_inputs(
13831386
CharacterTokenizer(),
13841387
['1', '22', '333', '4444', '55555', '666666'],
1385-
10
1388+
10,
1389+
None,
13861390
),
13871391
[
13881392
([5, 0, 1], 9),
@@ -1394,7 +1398,8 @@ def test_batch_embeddings_inputs_01(self):
13941398
ai_ext._batch_embeddings_inputs(
13951399
CharacterTokenizer(),
13961400
['1', '22', '333', '4444', '55555', '666666'],
1397-
10
1401+
10,
1402+
None,
13981403
),
13991404
[
14001405
([5, 0, 1], 9),
@@ -1406,19 +1411,44 @@ def test_batch_embeddings_inputs_01(self):
14061411
ai_ext._batch_embeddings_inputs(
14071412
CharacterTokenizer(),
14081413
['1', '22', '333', '4444', '55555', '121212121212'],
1409-
10
1414+
10,
1415+
None,
14101416
),
14111417
[
14121418
([4, 0, 1], 8),
14131419
([3, 2], 7),
14141420
],
14151421
)
1422+
self.assertEqual(
1423+
ai_ext._batch_embeddings_inputs(
1424+
CharacterTokenizer(),
1425+
[
1426+
'1',
1427+
'22',
1428+
'333',
1429+
'4444',
1430+
'55555',
1431+
'666666',
1432+
'7777777',
1433+
'88888888',
1434+
],
1435+
12,
1436+
3,
1437+
),
1438+
[
1439+
([7, 0, 1], 11),
1440+
([6, 2], 10),
1441+
([5, 3], 10),
1442+
([4], 5),
1443+
],
1444+
)
14161445
# Text is alphabetically ordered to ensure consistent batching
14171446
self.assertEqual(
14181447
ai_ext._batch_embeddings_inputs(
14191448
CharacterTokenizer(),
14201449
['AAA', 'CCC', 'EEE', 'BBB', 'DDD'],
1421-
10
1450+
10,
1451+
None,
14221452
),
14231453
[
14241454
([2, 0, 3], 9),

0 commit comments

Comments
 (0)