Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions google/cloud/dataproc_spark_connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import atexit
import datetime
import functools
import json
import logging
import os
Expand All @@ -25,8 +26,6 @@
import uuid
import tqdm
from packaging import version
from tqdm import tqdm as cli_tqdm
from tqdm.notebook import tqdm as notebook_tqdm
from types import MethodType
from typing import Any, cast, ClassVar, Dict, Iterable, Optional, Union

Expand Down Expand Up @@ -991,6 +990,28 @@ def clearProgressHandlers_wrapper_method(_, *args, **kwargs):
clearProgressHandlers_wrapper_method, self
)

@staticmethod
@functools.lru_cache(maxsize=1)
def get_tqdm_bar():
"""
Return a tqdm implementation that works in the current environment.

- Uses CLI tqdm for interactive terminals.
- Uses the notebook tqdm if available, otherwise falls back to CLI tqdm.
"""
from tqdm import tqdm as cli_tqdm

if environment.is_interactive_terminal():
return cli_tqdm

try:
import ipywidgets
from tqdm.notebook import tqdm as notebook_tqdm

return notebook_tqdm
except ImportError:
return cli_tqdm

def _register_progress_execution_handler(self):
from pyspark.sql.connect.shell.progress import StageInfo

Expand Down Expand Up @@ -1019,9 +1040,8 @@ def handler(
if total_tasks == 0:
return

tqdm_pbar = notebook_tqdm
if environment.is_interactive_terminal():
tqdm_pbar = cli_tqdm
# Get correct tqdm (notebook or CLI)
tqdm_pbar = self.get_tqdm_bar()

# Use a lock to ensure only one thread can access and modify
# the shared dictionaries at a time.
Expand Down