Skip to content

Commit 340b93d

Browse files
authored
test: add tests for gemini structured output array and struct types (#1670)
* test: add tests for gemini structured output array and struct types * fix test in python 3.9
1 parent 4c5dee5 commit 340b93d

File tree

4 files changed

+74
-47
lines changed

4 files changed

+74
-47
lines changed

bigframes/ml/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ def predict(
617617
It creates a struct column of the items of the iterable, and use the concatenated result as the input prompt. No-op if set to None.
618618
output_schema (Mapping[str, str] or None, default None):
619619
The schema used to generate structured output as a bigframes DataFrame. The schema is a string key-value pair of <column_name>:<type>.
620-
Supported types are int64, float64, bool and string. If None, output text result.
620+
Supported types are int64, float64, bool, string, array<type> and struct<column type>. If None, output text result.
621621
Returns:
622622
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
623623
"""

tests/system/small/ml/conftest.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
globals,
3030
imported,
3131
linear_model,
32-
llm,
3332
remote,
3433
)
3534

@@ -339,20 +338,3 @@ def imported_xgboost_model(
339338
output={"predicted_label": "float64"},
340339
model_path=imported_xgboost_array_model_path,
341340
)
342-
343-
344-
@pytest.fixture(scope="session")
345-
def bqml_gemini_text_generator(bq_connection, session) -> llm.GeminiTextGenerator:
346-
return llm.GeminiTextGenerator(
347-
model_name="gemini-1.5-flash-002",
348-
connection_name=bq_connection,
349-
session=session,
350-
)
351-
352-
353-
@pytest.fixture(scope="session")
354-
def bqml_claude3_text_generator(bq_connection, session) -> llm.Claude3TextGenerator:
355-
return llm.Claude3TextGenerator(
356-
connection_name=bq_connection,
357-
session=session,
358-
)

tests/system/small/ml/test_llm.py

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from unittest import mock
1717

1818
import pandas as pd
19+
import pyarrow as pa
1920
import pytest
2021

2122
import bigframes
@@ -253,22 +254,27 @@ def test_gemini_text_generator_predict_output_schema_success(
253254
"int_output": "int64",
254255
"float_output": "float64",
255256
"str_output": "string",
257+
"array_output": "array<int64>",
258+
"struct_output": "struct<number int64>",
256259
}
257-
df = gemini_text_generator_model.predict(
258-
llm_text_df, output_schema=output_schema
259-
).to_pandas()
260+
df = gemini_text_generator_model.predict(llm_text_df, output_schema=output_schema)
261+
assert df["bool_output"].dtype == pd.BooleanDtype()
262+
assert df["int_output"].dtype == pd.Int64Dtype()
263+
assert df["float_output"].dtype == pd.Float64Dtype()
264+
assert df["str_output"].dtype == pd.StringDtype(storage="pyarrow")
265+
assert df["array_output"].dtype == pd.ArrowDtype(pa.list_(pa.int64()))
266+
assert df["struct_output"].dtype == pd.ArrowDtype(
267+
pa.struct([("number", pa.int64())])
268+
)
269+
270+
pd_df = df.to_pandas()
260271
utils.check_pandas_df_schema_and_index(
261-
df,
272+
pd_df,
262273
columns=list(output_schema.keys()) + ["prompt", "full_response", "status"],
263274
index=3,
264275
col_exact=False,
265276
)
266277

267-
assert df["bool_output"].dtype == pd.BooleanDtype()
268-
assert df["int_output"].dtype == pd.Int64Dtype()
269-
assert df["float_output"].dtype == pd.Float64Dtype()
270-
assert df["str_output"].dtype == pd.StringDtype(storage="pyarrow")
271-
272278

273279
# Overrides __eq__ function for comparing as mock.call parameter
274280
class EqCmpAllDataFrame(bpd.DataFrame):
@@ -305,8 +311,7 @@ def test_text_generator_retry_success(
305311
session,
306312
model_class,
307313
options,
308-
bqml_gemini_text_generator: llm.GeminiTextGenerator,
309-
bqml_claude3_text_generator: llm.Claude3TextGenerator,
314+
bq_connection,
310315
):
311316
# Requests.
312317
df0 = EqCmpAllDataFrame(
@@ -387,11 +392,7 @@ def test_text_generator_retry_success(
387392
),
388393
]
389394

390-
text_generator_model = (
391-
bqml_gemini_text_generator
392-
if (model_class == llm.GeminiTextGenerator)
393-
else bqml_claude3_text_generator
394-
)
395+
text_generator_model = model_class(connection_name=bq_connection, session=session)
395396
text_generator_model._bqml_model = mock_bqml_model
396397

397398
with mock.patch.object(core.BqmlModel, "generate_text_tvf", generate_text_tvf):
@@ -448,13 +449,7 @@ def test_text_generator_retry_success(
448449
),
449450
],
450451
)
451-
def test_text_generator_retry_no_progress(
452-
session,
453-
model_class,
454-
options,
455-
bqml_gemini_text_generator: llm.GeminiTextGenerator,
456-
bqml_claude3_text_generator: llm.Claude3TextGenerator,
457-
):
452+
def test_text_generator_retry_no_progress(session, model_class, options, bq_connection):
458453
# Requests.
459454
df0 = EqCmpAllDataFrame(
460455
{
@@ -514,11 +509,7 @@ def test_text_generator_retry_no_progress(
514509
),
515510
]
516511

517-
text_generator_model = (
518-
bqml_gemini_text_generator
519-
if (model_class == llm.GeminiTextGenerator)
520-
else bqml_claude3_text_generator
521-
)
512+
text_generator_model = model_class(connection_name=bq_connection, session=session)
522513
text_generator_model._bqml_model = mock_bqml_model
523514

524515
with mock.patch.object(core.BqmlModel, "generate_text_tvf", generate_text_tvf):

tests/system/small/ml/test_multimodal_llm.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import pandas as pd
16+
import pyarrow as pa
1517
import pytest
1618

1719
import bigframes
@@ -68,3 +70,55 @@ def test_gemini_text_generator_multimodal_input(
6870
index=2,
6971
col_exact=False,
7072
)
73+
74+
75+
@pytest.mark.parametrize(
76+
"model_name",
77+
(
78+
"gemini-1.5-pro-001",
79+
# "gemini-1.5-pro-002",
80+
"gemini-1.5-flash-001",
81+
"gemini-1.5-flash-002",
82+
"gemini-2.0-flash-exp",
83+
"gemini-2.0-flash-001",
84+
),
85+
)
86+
@pytest.mark.flaky(retries=2)
87+
def test_gemini_text_generator_multimodal_structured_output(
88+
images_mm_df: bpd.DataFrame, model_name, test_session, bq_connection
89+
):
90+
bigframes.options.experiments.blob = True
91+
92+
gemini_text_generator_model = llm.GeminiTextGenerator(
93+
model_name=model_name, connection_name=bq_connection, session=test_session
94+
)
95+
output_schema = {
96+
"bool_output": "bool",
97+
"int_output": "int64",
98+
"float_output": "float64",
99+
"str_output": "string",
100+
"array_output": "array<int64>",
101+
"struct_output": "struct<number int64>",
102+
}
103+
df = gemini_text_generator_model.predict(
104+
images_mm_df,
105+
prompt=["Describe", images_mm_df["blob_col"]],
106+
output_schema=output_schema,
107+
)
108+
assert df["bool_output"].dtype == pd.BooleanDtype()
109+
assert df["int_output"].dtype == pd.Int64Dtype()
110+
assert df["float_output"].dtype == pd.Float64Dtype()
111+
assert df["str_output"].dtype == pd.StringDtype(storage="pyarrow")
112+
assert df["array_output"].dtype == pd.ArrowDtype(pa.list_(pa.int64()))
113+
assert df["struct_output"].dtype == pd.ArrowDtype(
114+
pa.struct([("number", pa.int64())])
115+
)
116+
117+
pd_df = df.to_pandas()
118+
utils.check_pandas_df_schema_and_index(
119+
pd_df,
120+
columns=list(output_schema.keys())
121+
+ ["blob_col", "prompt", "full_response", "status"],
122+
index=2,
123+
col_exact=False,
124+
)

0 commit comments

Comments
 (0)