diff --git a/google/cloud/dataproc_spark_connect/session.py b/google/cloud/dataproc_spark_connect/session.py index 8d8d9f2..7e2701c 100644 --- a/google/cloud/dataproc_spark_connect/session.py +++ b/google/cloud/dataproc_spark_connect/session.py @@ -14,6 +14,7 @@ import atexit import datetime +import functools import json import logging import os @@ -25,8 +26,6 @@ import uuid import tqdm from packaging import version -from tqdm import tqdm as cli_tqdm -from tqdm.notebook import tqdm as notebook_tqdm from types import MethodType from typing import Any, cast, ClassVar, Dict, Iterable, Optional, Union @@ -991,6 +990,28 @@ def clearProgressHandlers_wrapper_method(_, *args, **kwargs): clearProgressHandlers_wrapper_method, self ) + @staticmethod + @functools.lru_cache(maxsize=1) + def get_tqdm_bar(): + """ + Return a tqdm implementation that works in the current environment. + + - Uses CLI tqdm for interactive terminals. + - Uses the notebook tqdm if available, otherwise falls back to CLI tqdm. + """ + from tqdm import tqdm as cli_tqdm + + if environment.is_interactive_terminal(): + return cli_tqdm + + try: + import ipywidgets + from tqdm.notebook import tqdm as notebook_tqdm + + return notebook_tqdm + except ImportError: + return cli_tqdm + def _register_progress_execution_handler(self): from pyspark.sql.connect.shell.progress import StageInfo @@ -1019,9 +1040,8 @@ def handler( if total_tasks == 0: return - tqdm_pbar = notebook_tqdm - if environment.is_interactive_terminal(): - tqdm_pbar = cli_tqdm + # Get correct tqdm (notebook or CLI) + tqdm_pbar = self.get_tqdm_bar() # Use a lock to ensure only one thread can access and modify # the shared dictionaries at a time.