Skip to content

Commit b483932

Browse files
authored
chore: fix wordings of Gemini max_retries (#1244)
1 parent dd4fd2e commit b483932

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

bigframes/ml/llm.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -986,9 +986,8 @@ def predict(
986986
The default is `False`.
987987
988988
max_retries (int, default 0):
989-
Max number of retry rounds if any rows failed in the prediction. Each round need to make progress (has succeeded rows) to continue the next retry round.
990-
Each round will append newly succeeded rows. When the max retry rounds is reached, the remaining failed rows will be appended to the end of the result.
991-
989+
Max number of retries if the prediction for any rows failed. Each try needs to make progress (i.e. has successfully predicted rows) to continue the retry.
990+
Each retry will append newly succeeded rows. When the max retries are reached, the remaining rows (the ones without successful predictions) will be appended to the end of the result.
992991
Returns:
993992
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
994993
"""
@@ -1034,11 +1033,15 @@ def predict(
10341033
for _ in range(max_retries + 1):
10351034
df = self._bqml_model.generate_text(df_fail, options)
10361035

1037-
df_succ = df[df[_ML_GENERATE_TEXT_STATUS].str.len() == 0]
1038-
df_fail = df[df[_ML_GENERATE_TEXT_STATUS].str.len() > 0]
1036+
success = df[_ML_GENERATE_TEXT_STATUS].str.len() == 0
1037+
df_succ = df[success]
1038+
df_fail = df[~success]
10391039

10401040
if df_succ.empty:
1041-
warnings.warn("Can't make any progress, stop retrying.", RuntimeWarning)
1041+
if max_retries > 0:
1042+
warnings.warn(
1043+
"Can't make any progress, stop retrying.", RuntimeWarning
1044+
)
10421045
break
10431046

10441047
df_result = (

0 commit comments

Comments
 (0)