17
17
import pandas as pd
18
18
import pytest
19
19
20
+ import bigframes
20
21
from bigframes import exceptions
21
22
from bigframes .ml import core , llm
22
23
import bigframes .pandas as bpd
@@ -260,6 +261,44 @@ def test_text_embedding_generator_multi_cols_predict_success(
260
261
assert len (pd_df ["ml_generate_embedding_result" ][0 ]) == 768
261
262
262
263
264
+ def test_create_load_multimodal_embedding_generator_model (
265
+ dataset_id , session , bq_connection
266
+ ):
267
+ bigframes .options .experiments .blob = True
268
+
269
+ mm_embedding_model = llm .MultimodalEmbeddingGenerator (
270
+ connection_name = bq_connection , session = session
271
+ )
272
+ assert mm_embedding_model is not None
273
+ assert mm_embedding_model ._bqml_model is not None
274
+
275
+ # save, load to ensure configuration was kept
276
+ reloaded_model = mm_embedding_model .to_gbq (
277
+ f"{ dataset_id } .temp_mm_model" , replace = True
278
+ )
279
+ assert f"{ dataset_id } .temp_mm_model" == reloaded_model ._bqml_model .model_name
280
+ assert reloaded_model .connection_name == bq_connection
281
+
282
+
283
+ @pytest .mark .flaky (retries = 2 )
284
+ def test_multimodal_embedding_generator_predict_default_params_success (
285
+ images_mm_df , session , bq_connection
286
+ ):
287
+ bigframes .options .experiments .blob = True
288
+
289
+ text_embedding_model = llm .MultimodalEmbeddingGenerator (
290
+ connection_name = bq_connection , session = session
291
+ )
292
+ df = text_embedding_model .predict (images_mm_df ).to_pandas ()
293
+ utils .check_pandas_df_schema_and_index (
294
+ df ,
295
+ columns = utils .ML_MULTIMODAL_GENERATE_EMBEDDING_OUTPUT ,
296
+ index = 2 ,
297
+ col_exact = False ,
298
+ )
299
+ assert len (df ["ml_generate_embedding_result" ][0 ]) == 1408
300
+
301
+
263
302
@pytest .mark .parametrize (
264
303
"model_name" ,
265
304
(
@@ -273,6 +312,9 @@ def test_text_embedding_generator_multi_cols_predict_success(
273
312
"gemini-2.0-flash-exp" ,
274
313
),
275
314
)
315
+ @pytest .mark .flaky (
316
+ retries = 2
317
+ ) # usually create model shouldn't be flaky, but this one due to the limited quota of gemini-2.0-flash-exp.
276
318
def test_create_load_gemini_text_generator_model (
277
319
dataset_id , model_name , session , bq_connection
278
320
):
@@ -375,6 +417,36 @@ def test_gemini_text_generator_multi_cols_predict_success(
375
417
)
376
418
377
419
420
+ @pytest .mark .parametrize (
421
+ "model_name" ,
422
+ (
423
+ "gemini-1.5-pro-001" ,
424
+ "gemini-1.5-pro-002" ,
425
+ "gemini-1.5-flash-001" ,
426
+ "gemini-1.5-flash-002" ,
427
+ "gemini-2.0-flash-exp" ,
428
+ ),
429
+ )
430
+ @pytest .mark .flaky (retries = 2 )
431
+ def test_gemini_text_generator_multimodal_input (
432
+ images_mm_df : bpd .DataFrame , model_name , session , bq_connection
433
+ ):
434
+ bigframes .options .experiments .blob = True
435
+
436
+ gemini_text_generator_model = llm .GeminiTextGenerator (
437
+ model_name = model_name , connection_name = bq_connection , session = session
438
+ )
439
+ pd_df = gemini_text_generator_model .predict (
440
+ images_mm_df , prompt = ["Describe" , images_mm_df ["blob_col" ]]
441
+ ).to_pandas ()
442
+ utils .check_pandas_df_schema_and_index (
443
+ pd_df ,
444
+ columns = utils .ML_GENERATE_TEXT_OUTPUT + ["blob_col" ],
445
+ index = 2 ,
446
+ col_exact = False ,
447
+ )
448
+
449
+
378
450
# Overrides __eq__ function for comparing as mock.call parameter
379
451
class EqCmpAllDataFrame (bpd .DataFrame ):
380
452
def __eq__ (self , other ):
0 commit comments