Skip to content

Commit a0d2551

Browse files
authored
fix: Use CLI tqdm if ipywidgets not installed (#167)
1 parent 89d77af commit a0d2551

File tree

1 file changed

+25
-5
lines changed

1 file changed

+25
-5
lines changed

google/cloud/dataproc_spark_connect/session.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import atexit
1616
import datetime
17+
import functools
1718
import json
1819
import logging
1920
import os
@@ -25,8 +26,6 @@
2526
import uuid
2627
import tqdm
2728
from packaging import version
28-
from tqdm import tqdm as cli_tqdm
29-
from tqdm.notebook import tqdm as notebook_tqdm
3029
from types import MethodType
3130
from typing import Any, cast, ClassVar, Dict, Iterable, Optional, Union
3231

@@ -991,6 +990,28 @@ def clearProgressHandlers_wrapper_method(_, *args, **kwargs):
991990
clearProgressHandlers_wrapper_method, self
992991
)
993992

993+
@staticmethod
994+
@functools.lru_cache(maxsize=1)
995+
def get_tqdm_bar():
996+
"""
997+
Return a tqdm implementation that works in the current environment.
998+
999+
- Uses CLI tqdm for interactive terminals.
1000+
- Uses the notebook tqdm if available, otherwise falls back to CLI tqdm.
1001+
"""
1002+
from tqdm import tqdm as cli_tqdm
1003+
1004+
if environment.is_interactive_terminal():
1005+
return cli_tqdm
1006+
1007+
try:
1008+
import ipywidgets
1009+
from tqdm.notebook import tqdm as notebook_tqdm
1010+
1011+
return notebook_tqdm
1012+
except ImportError:
1013+
return cli_tqdm
1014+
9941015
def _register_progress_execution_handler(self):
9951016
from pyspark.sql.connect.shell.progress import StageInfo
9961017

@@ -1019,9 +1040,8 @@ def handler(
10191040
if total_tasks == 0:
10201041
return
10211042

1022-
tqdm_pbar = notebook_tqdm
1023-
if environment.is_interactive_terminal():
1024-
tqdm_pbar = cli_tqdm
1043+
# Get correct tqdm (notebook or CLI)
1044+
tqdm_pbar = self.get_tqdm_bar()
10251045

10261046
# Use a lock to ensure only one thread can access and modify
10271047
# the shared dictionaries at a time.

0 commit comments

Comments
 (0)