Skip to content

Commit c1cde19

Browse files
authored
feat: add Gemini 1.5 stable models support (#945)
* feat: add Gemini 1.5 stable models support * add to loader
1 parent 8e8279d commit c1cde19

File tree

3 files changed

+34
-6
lines changed

3 files changed

+34
-6
lines changed

bigframes/ml/llm.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,14 @@
5555
_GEMINI_PRO_ENDPOINT = "gemini-pro"
5656
_GEMINI_1P5_PRO_PREVIEW_ENDPOINT = "gemini-1.5-pro-preview-0514"
5757
_GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT = "gemini-1.5-flash-preview-0514"
58+
_GEMINI_1P5_PRO_001_ENDPOINT = "gemini-1.5-pro-001"
59+
_GEMINI_1P5_FLASH_001_ENDPOINT = "gemini-1.5-flash-001"
5860
_GEMINI_ENDPOINTS = (
5961
_GEMINI_PRO_ENDPOINT,
6062
_GEMINI_1P5_PRO_PREVIEW_ENDPOINT,
6163
_GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT,
64+
_GEMINI_1P5_PRO_001_ENDPOINT,
65+
_GEMINI_1P5_FLASH_001_ENDPOINT,
6266
)
6367

6468
_CLAUDE_3_SONNET_ENDPOINT = "claude-3-sonnet"
@@ -728,7 +732,7 @@ class GeminiTextGenerator(base.BaseEstimator):
728732
729733
Args:
730734
model_name (str, Default to "gemini-pro"):
731-
The model for natural language tasks. Accepted values are "gemini-pro", "gemini-1.5-pro-preview-0514" and "gemini-1.5-flash-preview-0514". Default to "gemini-pro".
735+
The model for natural language tasks. Accepted values are "gemini-pro", "gemini-1.5-pro-preview-0514", "gemini-1.5-flash-preview-0514", "gemini-1.5-pro-001" and "gemini-1.5-flash-001". Default to "gemini-pro".
732736
733737
.. note::
734738
"gemini-1.5-pro-preview-0514" and "gemini-1.5-flash-preview-0514" is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the
@@ -750,7 +754,11 @@ def __init__(
750754
self,
751755
*,
752756
model_name: Literal[
753-
"gemini-pro", "gemini-1.5-pro-preview-0514", "gemini-1.5-flash-preview-0514"
757+
"gemini-pro",
758+
"gemini-1.5-pro-preview-0514",
759+
"gemini-1.5-flash-preview-0514",
760+
"gemini-1.5-pro-001",
761+
"gemini-1.5-flash-001",
754762
] = "gemini-pro",
755763
session: Optional[bigframes.Session] = None,
756764
connection_name: Optional[str] = None,

bigframes/ml/loader.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@
6363
llm._GEMINI_PRO_ENDPOINT: llm.GeminiTextGenerator,
6464
llm._GEMINI_1P5_PRO_PREVIEW_ENDPOINT: llm.GeminiTextGenerator,
6565
llm._GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT: llm.GeminiTextGenerator,
66+
llm._GEMINI_1P5_PRO_001_ENDPOINT: llm.GeminiTextGenerator,
67+
llm._GEMINI_1P5_FLASH_001_ENDPOINT: llm.GeminiTextGenerator,
6668
llm._CLAUDE_3_HAIKU_ENDPOINT: llm.Claude3TextGenerator,
6769
llm._CLAUDE_3_SONNET_ENDPOINT: llm.Claude3TextGenerator,
6870
llm._CLAUDE_3_5_SONNET_ENDPOINT: llm.Claude3TextGenerator,

tests/system/small/ml/test_llm.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def test_create_load_text_embedding_generator_model(
324324
("text-embedding-004", "text-multilingual-embedding-002"),
325325
)
326326
@pytest.mark.flaky(retries=2)
327-
def test_gemini_text_embedding_generator_predict_default_params_success(
327+
def test_text_embedding_generator_predict_default_params_success(
328328
llm_text_df, model_name, session, bq_connection
329329
):
330330
text_embedding_model = llm.TextEmbeddingGenerator(
@@ -340,7 +340,13 @@ def test_gemini_text_embedding_generator_predict_default_params_success(
340340

341341
@pytest.mark.parametrize(
342342
"model_name",
343-
("gemini-pro", "gemini-1.5-pro-preview-0514", "gemini-1.5-flash-preview-0514"),
343+
(
344+
"gemini-pro",
345+
"gemini-1.5-pro-preview-0514",
346+
"gemini-1.5-flash-preview-0514",
347+
"gemini-1.5-pro-001",
348+
"gemini-1.5-flash-001",
349+
),
344350
)
345351
def test_create_load_gemini_text_generator_model(
346352
dataset_id, model_name, session, bq_connection
@@ -362,7 +368,13 @@ def test_create_load_gemini_text_generator_model(
362368

363369
@pytest.mark.parametrize(
364370
"model_name",
365-
("gemini-pro", "gemini-1.5-pro-preview-0514", "gemini-1.5-flash-preview-0514"),
371+
(
372+
"gemini-pro",
373+
"gemini-1.5-pro-preview-0514",
374+
"gemini-1.5-flash-preview-0514",
375+
"gemini-1.5-pro-001",
376+
"gemini-1.5-flash-001",
377+
),
366378
)
367379
@pytest.mark.flaky(retries=2)
368380
def test_gemini_text_generator_predict_default_params_success(
@@ -379,7 +391,13 @@ def test_gemini_text_generator_predict_default_params_success(
379391

380392
@pytest.mark.parametrize(
381393
"model_name",
382-
("gemini-pro", "gemini-1.5-pro-preview-0514", "gemini-1.5-flash-preview-0514"),
394+
(
395+
"gemini-pro",
396+
"gemini-1.5-pro-preview-0514",
397+
"gemini-1.5-flash-preview-0514",
398+
"gemini-1.5-pro-001",
399+
"gemini-1.5-flash-001",
400+
),
383401
)
384402
@pytest.mark.flaky(retries=2)
385403
def test_gemini_text_generator_predict_with_params_success(

0 commit comments

Comments
 (0)