Skip to content

Commit b1b612e

Browse files
authored
chore: add experimental multimodal model tests (#1486)
1 parent 958b537 commit b1b612e

File tree

2 files changed

+80
-0
lines changed

2 files changed

+80
-0
lines changed

tests/system/small/ml/test_llm.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import pandas as pd
1818
import pytest
1919

20+
import bigframes
2021
from bigframes import exceptions
2122
from bigframes.ml import core, llm
2223
import bigframes.pandas as bpd
@@ -260,6 +261,44 @@ def test_text_embedding_generator_multi_cols_predict_success(
260261
assert len(pd_df["ml_generate_embedding_result"][0]) == 768
261262

262263

264+
def test_create_load_multimodal_embedding_generator_model(
265+
dataset_id, session, bq_connection
266+
):
267+
bigframes.options.experiments.blob = True
268+
269+
mm_embedding_model = llm.MultimodalEmbeddingGenerator(
270+
connection_name=bq_connection, session=session
271+
)
272+
assert mm_embedding_model is not None
273+
assert mm_embedding_model._bqml_model is not None
274+
275+
# save, load to ensure configuration was kept
276+
reloaded_model = mm_embedding_model.to_gbq(
277+
f"{dataset_id}.temp_mm_model", replace=True
278+
)
279+
assert f"{dataset_id}.temp_mm_model" == reloaded_model._bqml_model.model_name
280+
assert reloaded_model.connection_name == bq_connection
281+
282+
283+
@pytest.mark.flaky(retries=2)
284+
def test_multimodal_embedding_generator_predict_default_params_success(
285+
images_mm_df, session, bq_connection
286+
):
287+
bigframes.options.experiments.blob = True
288+
289+
text_embedding_model = llm.MultimodalEmbeddingGenerator(
290+
connection_name=bq_connection, session=session
291+
)
292+
df = text_embedding_model.predict(images_mm_df).to_pandas()
293+
utils.check_pandas_df_schema_and_index(
294+
df,
295+
columns=utils.ML_MULTIMODAL_GENERATE_EMBEDDING_OUTPUT,
296+
index=2,
297+
col_exact=False,
298+
)
299+
assert len(df["ml_generate_embedding_result"][0]) == 1408
300+
301+
263302
@pytest.mark.parametrize(
264303
"model_name",
265304
(
@@ -273,6 +312,9 @@ def test_text_embedding_generator_multi_cols_predict_success(
273312
"gemini-2.0-flash-exp",
274313
),
275314
)
315+
@pytest.mark.flaky(
316+
retries=2
317+
) # usually create model shouldn't be flaky, but this one due to the limited quota of gemini-2.0-flash-exp.
276318
def test_create_load_gemini_text_generator_model(
277319
dataset_id, model_name, session, bq_connection
278320
):
@@ -375,6 +417,36 @@ def test_gemini_text_generator_multi_cols_predict_success(
375417
)
376418

377419

420+
@pytest.mark.parametrize(
421+
"model_name",
422+
(
423+
"gemini-1.5-pro-001",
424+
"gemini-1.5-pro-002",
425+
"gemini-1.5-flash-001",
426+
"gemini-1.5-flash-002",
427+
"gemini-2.0-flash-exp",
428+
),
429+
)
430+
@pytest.mark.flaky(retries=2)
431+
def test_gemini_text_generator_multimodal_input(
432+
images_mm_df: bpd.DataFrame, model_name, session, bq_connection
433+
):
434+
bigframes.options.experiments.blob = True
435+
436+
gemini_text_generator_model = llm.GeminiTextGenerator(
437+
model_name=model_name, connection_name=bq_connection, session=session
438+
)
439+
pd_df = gemini_text_generator_model.predict(
440+
images_mm_df, prompt=["Describe", images_mm_df["blob_col"]]
441+
).to_pandas()
442+
utils.check_pandas_df_schema_and_index(
443+
pd_df,
444+
columns=utils.ML_GENERATE_TEXT_OUTPUT + ["blob_col"],
445+
index=2,
446+
col_exact=False,
447+
)
448+
449+
378450
# Overrides __eq__ function for comparing as mock.call parameter
379451
class EqCmpAllDataFrame(bpd.DataFrame):
380452
def __eq__(self, other):

tests/system/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@
5656
"ml_generate_embedding_status",
5757
"content",
5858
]
59+
ML_MULTIMODAL_GENERATE_EMBEDDING_OUTPUT = [
60+
"ml_generate_embedding_result",
61+
"ml_generate_embedding_status",
62+
# start and end sec depend on input format. Images and videos input will contain these 2.
63+
"ml_generate_embedding_start_sec",
64+
"ml_generate_embedding_end_sec",
65+
"content",
66+
]
5967

6068

6169
def skip_legacy_pandas(test):

0 commit comments

Comments
 (0)