|
22 | 22 | """ |
23 | 23 |
|
24 | 24 | import abc |
25 | | -from typing import Callable, cast, Mapping, Optional, TypeVar |
| 25 | +from typing import Callable, cast, Mapping, Optional, TypeVar, Union |
26 | 26 | import warnings |
27 | 27 |
|
28 | 28 | import bigframes_vendored.sklearn.base |
@@ -259,38 +259,29 @@ def _predict_and_retry( |
259 | 259 | ) -> bpd.DataFrame: |
260 | 260 | assert self._bqml_model is not None |
261 | 261 |
|
262 | | - df_result = bpd.DataFrame(session=self._bqml_model.session) # placeholder |
263 | | - df_fail = X |
264 | | - for _ in range(max_retries + 1): |
| 262 | + df_result: Union[bpd.DataFrame, None] = None # placeholder |
| 263 | + df_succ = df_fail = X |
| 264 | + for i in range(max_retries + 1): |
| 265 | + if i > 0 and df_fail.empty: |
| 266 | + break |
| 267 | + if i > 0 and df_succ.empty: |
| 268 | + msg = bfe.format_message("Can't make any progress, stop retrying.") |
| 269 | + warnings.warn(msg, category=RuntimeWarning) |
| 270 | + break |
| 271 | + |
265 | 272 | df = self._predict_func(df_fail, options) |
266 | 273 |
|
267 | 274 | success = df[self._status_col].str.len() == 0 |
268 | 275 | df_succ = df[success] |
269 | 276 | df_fail = df[~success] |
270 | 277 |
|
271 | | - if df_succ.empty: |
272 | | - if max_retries > 0: |
273 | | - msg = bfe.format_message("Can't make any progress, stop retrying.") |
274 | | - warnings.warn(msg, category=RuntimeWarning) |
275 | | - break |
276 | | - |
277 | 278 | df_result = ( |
278 | | - bpd.concat([df_result, df_succ]) if not df_result.empty else df_succ |
279 | | - ) |
280 | | - |
281 | | - if df_fail.empty: |
282 | | - break |
283 | | - |
284 | | - if not df_fail.empty: |
285 | | - msg = bfe.format_message( |
286 | | - f"Some predictions failed. Check column {self._status_col} for detailed " |
287 | | - "status. You may want to filter the failed rows and retry." |
| 279 | + bpd.concat([df_result, df_succ]) if df_result is not None else df_succ |
288 | 280 | ) |
289 | | - warnings.warn(msg, category=RuntimeWarning) |
290 | 281 |
|
291 | 282 | df_result = cast( |
292 | 283 | bpd.DataFrame, |
293 | | - bpd.concat([df_result, df_fail]) if not df_result.empty else df_fail, |
| 284 | + bpd.concat([df_result, df_fail]) if df_result is not None else df_fail, |
294 | 285 | ) |
295 | 286 | return df_result |
296 | 287 |
|
|
0 commit comments