Skip to content

Commit 298fc73

Browse files
Shuowei Lishuoweil
andauthored
feat: add Gemini-pro-1.5 to GeminiTextGenerator Tuning and Support score() method in Gemini-pro-1.5 (#1208)
* docs(bigquery): update minor parts in base.py * docs(bigquery): update minor changes for bigframes/ml/base.py * feat: Update lUpdate GeminiTextGenerator Tuning and Support score() method in Gemini-pro-1.5 \n Bug: b/381936588 and b/344891364 * feat: Update lUpdate GeminiTextGenerator Tuning and Support score() method in Gemini-pro-1.5 \n Bug: b/381936588 and b/344891364 * update testcase and docs for better clarification * update endpoint to corresponding endpoint for fine tuning. * docs(bigquery): update minor parts in base.py * fix syntax issue * Revert "docs(bigquery): update minor parts in base.py" This reverts commit 9de2c0e. * merge gemini_fine_tune_endpoints and gemini_score_endpoints together, since they are identical * merge genimi_fine_tune_endpoints and genimi_score_endpoints, since they are identical * Revert "merge genimi_fine_tune_endpoints and genimi_score_endpoints, since they are identical" This reverts commit 205e173. --------- Co-authored-by: Shuowei Li <[email protected]>
1 parent 200c9bb commit 298fc73

File tree

3 files changed

+51
-18
lines changed

3 files changed

+51
-18
lines changed

bigframes/ml/llm.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@
7979
_GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT,
8080
_GEMINI_2_FLASH_EXP_ENDPOINT,
8181
)
82+
_GEMINI_FINE_TUNE_SCORE_ENDPOINTS = (
83+
_GEMINI_PRO_ENDPOINT,
84+
_GEMINI_1P5_PRO_002_ENDPOINT,
85+
_GEMINI_1P5_FLASH_002_ENDPOINT,
86+
)
8287

8388
_CLAUDE_3_SONNET_ENDPOINT = "claude-3-sonnet"
8489
_CLAUDE_3_HAIKU_ENDPOINT = "claude-3-haiku"
@@ -890,7 +895,8 @@ def fit(
890895
X: utils.ArrayType,
891896
y: utils.ArrayType,
892897
) -> GeminiTextGenerator:
893-
"""Fine tune GeminiTextGenerator model. Only support "gemini-pro" model for now.
898+
"""Fine tune GeminiTextGenerator model. Only support "gemini-pro", "gemini-1.5-pro-002",
899+
"gemini-1.5-flash-002" models for now.
894900
895901
.. note::
896902
@@ -908,13 +914,18 @@ def fit(
908914
Returns:
909915
GeminiTextGenerator: Fitted estimator.
910916
"""
911-
if self._bqml_model.model_name.startswith("gemini-1.5"):
912-
raise NotImplementedError("Fit is not supported for gemini-1.5 model.")
917+
if self.model_name not in _GEMINI_FINE_TUNE_SCORE_ENDPOINTS:
918+
raise NotImplementedError(
919+
"fit() only supports gemini-pro, \
920+
gemini-1.5-pro-002, or gemini-1.5-flash-002 model."
921+
)
913922

914923
X, y = utils.batch_convert_to_dataframe(X, y)
915924

916925
options = self._bqml_options
917-
options["endpoint"] = "gemini-1.0-pro-002"
926+
options["endpoint"] = (
927+
"gemini-1.0-pro-002" if self.model_name == "gemini-pro" else self.model_name
928+
)
918929
options["prompt_col"] = X.columns.tolist()[0]
919930

920931
self._bqml_model = self._bqml_model_factory.create_llm_remote_model(
@@ -1025,7 +1036,7 @@ def score(
10251036
"text_generation", "classification", "summarization", "question_answering"
10261037
] = "text_generation",
10271038
) -> bpd.DataFrame:
1028-
"""Calculate evaluation metrics of the model. Only "gemini-pro" model is supported for now.
1039+
"""Calculate evaluation metrics of the model. Only support "gemini-pro" and "gemini-1.5-pro-002", and "gemini-1.5-flash-002".
10291040
10301041
.. note::
10311042
@@ -1057,9 +1068,11 @@ def score(
10571068
if not self._bqml_model:
10581069
raise RuntimeError("A model must be fitted before score")
10591070

1060-
# TODO(ashleyxu): Support gemini-1.5 when the rollout is ready. b/344891364.
1061-
if self._bqml_model.model_name.startswith("gemini-1.5"):
1062-
raise NotImplementedError("Score is not supported for gemini-1.5 model.")
1071+
if self.model_name not in _GEMINI_FINE_TUNE_SCORE_ENDPOINTS:
1072+
raise NotImplementedError(
1073+
"score() only supports gemini-pro \
1074+
, gemini-1.5-pro-002, and gemini-1.5-flash-2 model."
1075+
)
10631076

10641077
X, y = utils.batch_convert_to_dataframe(X, y, session=self._bqml_model.session)
10651078

tests/system/load/test_llm.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,19 @@ def llm_remote_text_df(session, llm_remote_text_pandas_df):
3838
return session.read_pandas(llm_remote_text_pandas_df)
3939

4040

41-
@pytest.mark.flaky(retries=2)
41+
@pytest.mark.parametrize(
42+
"model_name",
43+
(
44+
"gemini-pro",
45+
"gemini-1.5-pro-002",
46+
"gemini-1.5-flash-002",
47+
),
48+
)
4249
def test_llm_gemini_configure_fit(
43-
session, llm_fine_tune_df_default_index, llm_remote_text_df
50+
session, model_name, llm_fine_tune_df_default_index, llm_remote_text_df
4451
):
4552
model = llm.GeminiTextGenerator(
46-
session=session, model_name="gemini-pro", max_iterations=1
53+
session=session, model_name=model_name, max_iterations=1
4754
)
4855

4956
X_train = llm_fine_tune_df_default_index[["prompt"]]
@@ -69,7 +76,6 @@ def test_llm_gemini_configure_fit(
6976
],
7077
index=3,
7178
)
72-
# TODO(ashleyxu b/335492787): After bqml rolled out version control: save, load, check parameters to ensure configuration was kept
7379

7480

7581
@pytest.mark.flaky(retries=2)

tests/system/small/ml/test_llm.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -417,9 +417,16 @@ def test_llm_palm_score_params(llm_fine_tune_df_default_index):
417417
)
418418

419419

420-
@pytest.mark.flaky(retries=2)
421-
def test_llm_gemini_pro_score(llm_fine_tune_df_default_index):
422-
model = llm.GeminiTextGenerator(model_name="gemini-pro")
420+
@pytest.mark.parametrize(
421+
"model_name",
422+
(
423+
"gemini-pro",
424+
"gemini-1.5-pro-002",
425+
"gemini-1.5-flash-002",
426+
),
427+
)
428+
def test_llm_gemini_score(llm_fine_tune_df_default_index, model_name):
429+
model = llm.GeminiTextGenerator(model_name=model_name)
423430

424431
# Check score to ensure the model was fitted
425432
score_result = model.score(
@@ -439,9 +446,16 @@ def test_llm_gemini_pro_score(llm_fine_tune_df_default_index):
439446
)
440447

441448

442-
@pytest.mark.flaky(retries=2)
443-
def test_llm_gemini_pro_score_params(llm_fine_tune_df_default_index):
444-
model = llm.GeminiTextGenerator(model_name="gemini-pro")
449+
@pytest.mark.parametrize(
450+
"model_name",
451+
(
452+
"gemini-pro",
453+
"gemini-1.5-pro-002",
454+
"gemini-1.5-flash-002",
455+
),
456+
)
457+
def test_llm_gemini_pro_score_params(llm_fine_tune_df_default_index, model_name):
458+
model = llm.GeminiTextGenerator(model_name=model_name)
445459

446460
# Check score to ensure the model was fitted
447461
score_result = model.score(

0 commit comments

Comments
 (0)