|
16 | 16 | from unittest import mock |
17 | 17 |
|
18 | 18 | import pandas as pd |
| 19 | +import pyarrow as pa |
19 | 20 | import pytest |
20 | 21 |
|
21 | 22 | import bigframes |
@@ -253,22 +254,27 @@ def test_gemini_text_generator_predict_output_schema_success( |
253 | 254 | "int_output": "int64", |
254 | 255 | "float_output": "float64", |
255 | 256 | "str_output": "string", |
| 257 | + "array_output": "array<int64>", |
| 258 | + "struct_output": "struct<number int64>", |
256 | 259 | } |
257 | | - df = gemini_text_generator_model.predict( |
258 | | - llm_text_df, output_schema=output_schema |
259 | | - ).to_pandas() |
| 260 | + df = gemini_text_generator_model.predict(llm_text_df, output_schema=output_schema) |
| 261 | + assert df["bool_output"].dtype == pd.BooleanDtype() |
| 262 | + assert df["int_output"].dtype == pd.Int64Dtype() |
| 263 | + assert df["float_output"].dtype == pd.Float64Dtype() |
| 264 | + assert df["str_output"].dtype == pd.StringDtype(storage="pyarrow") |
| 265 | + assert df["array_output"].dtype == pd.ArrowDtype(pa.list_(pa.int64())) |
| 266 | + assert df["struct_output"].dtype == pd.ArrowDtype( |
| 267 | + pa.struct([("number", pa.int64())]) |
| 268 | + ) |
| 269 | + |
| 270 | + pd_df = df.to_pandas() |
260 | 271 | utils.check_pandas_df_schema_and_index( |
261 | | - df, |
| 272 | + pd_df, |
262 | 273 | columns=list(output_schema.keys()) + ["prompt", "full_response", "status"], |
263 | 274 | index=3, |
264 | 275 | col_exact=False, |
265 | 276 | ) |
266 | 277 |
|
267 | | - assert df["bool_output"].dtype == pd.BooleanDtype() |
268 | | - assert df["int_output"].dtype == pd.Int64Dtype() |
269 | | - assert df["float_output"].dtype == pd.Float64Dtype() |
270 | | - assert df["str_output"].dtype == pd.StringDtype(storage="pyarrow") |
271 | | - |
272 | 278 |
|
273 | 279 | # Overrides __eq__ function for comparing as mock.call parameter |
274 | 280 | class EqCmpAllDataFrame(bpd.DataFrame): |
@@ -305,8 +311,7 @@ def test_text_generator_retry_success( |
305 | 311 | session, |
306 | 312 | model_class, |
307 | 313 | options, |
308 | | - bqml_gemini_text_generator: llm.GeminiTextGenerator, |
309 | | - bqml_claude3_text_generator: llm.Claude3TextGenerator, |
| 314 | + bq_connection, |
310 | 315 | ): |
311 | 316 | # Requests. |
312 | 317 | df0 = EqCmpAllDataFrame( |
@@ -387,11 +392,7 @@ def test_text_generator_retry_success( |
387 | 392 | ), |
388 | 393 | ] |
389 | 394 |
|
390 | | - text_generator_model = ( |
391 | | - bqml_gemini_text_generator |
392 | | - if (model_class == llm.GeminiTextGenerator) |
393 | | - else bqml_claude3_text_generator |
394 | | - ) |
| 395 | + text_generator_model = model_class(connection_name=bq_connection, session=session) |
395 | 396 | text_generator_model._bqml_model = mock_bqml_model |
396 | 397 |
|
397 | 398 | with mock.patch.object(core.BqmlModel, "generate_text_tvf", generate_text_tvf): |
@@ -448,13 +449,7 @@ def test_text_generator_retry_success( |
448 | 449 | ), |
449 | 450 | ], |
450 | 451 | ) |
451 | | -def test_text_generator_retry_no_progress( |
452 | | - session, |
453 | | - model_class, |
454 | | - options, |
455 | | - bqml_gemini_text_generator: llm.GeminiTextGenerator, |
456 | | - bqml_claude3_text_generator: llm.Claude3TextGenerator, |
457 | | -): |
| 452 | +def test_text_generator_retry_no_progress(session, model_class, options, bq_connection): |
458 | 453 | # Requests. |
459 | 454 | df0 = EqCmpAllDataFrame( |
460 | 455 | { |
@@ -514,11 +509,7 @@ def test_text_generator_retry_no_progress( |
514 | 509 | ), |
515 | 510 | ] |
516 | 511 |
|
517 | | - text_generator_model = ( |
518 | | - bqml_gemini_text_generator |
519 | | - if (model_class == llm.GeminiTextGenerator) |
520 | | - else bqml_claude3_text_generator |
521 | | - ) |
| 512 | + text_generator_model = model_class(connection_name=bq_connection, session=session) |
522 | 513 | text_generator_model._bqml_model = mock_bqml_model |
523 | 514 |
|
524 | 515 | with mock.patch.object(core.BqmlModel, "generate_text_tvf", generate_text_tvf): |
|
0 commit comments