Skip to content

Commit e51e62b

Browse files
authored
DB Connect Progress: Make sure we always end up at 100% (#1363)
## Changes DB Connect Progress: Make sure we always end up at 100% ## Tests <!-- How is this tested? -->
1 parent a9f79d2 commit e51e62b

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

packages/databricks-vscode/resources/python/00-databricks-init.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@
33
import json
44
from typing import Any, Union, List
55
import os
6+
import sys
67
import time
78
import shlex
89
import warnings
910
import tempfile
1011

12+
# prevent sum from pyskaprk.sql.functions from shadowing the builtin sum
13+
builtinSum = sys.modules['builtins'].sum
1114

1215
def logError(function_name: str, e: Union[str, Exception]):
1316
import sys
@@ -403,14 +406,20 @@ def init_ui(self):
403406
def update_ticks(
404407
self,
405408
stages,
406-
inflight_tasks: int
409+
inflight_tasks: int,
410+
done: bool
407411
) -> None:
408-
total_tasks = sum(map(lambda x: x.num_tasks, stages))
409-
completed_tasks = sum(map(lambda x: x.num_completed_tasks, stages))
412+
total_tasks = builtinSum(map(lambda x: x.num_tasks, stages))
413+
completed_tasks = builtinSum(map(lambda x: x.num_completed_tasks, stages))
410414
if total_tasks > 0:
411415
self._ticks = total_tasks
412416
self._tick = completed_tasks
413-
self._bytes_read = sum(map(lambda x: x.num_bytes_read, stages))
417+
self._bytes_read = builtinSum(map(lambda x: x.num_bytes_read, stages))
418+
419+
if done:
420+
self._tick = self._ticks
421+
self._running = 0
422+
414423
if self._tick is not None and self._tick >= 0:
415424
self.output()
416425
self._running = inflight_tasks
@@ -432,7 +441,6 @@ def _bytes_to_string(size: int) -> str:
432441
i += 1
433442
result = float(size) / Progress.SI_BYTE_SIZES[i]
434443
return f"{result:.1f} {Progress.SI_BYTE_SUFFIXES[i]}"
435-
436444

437445
class ProgressHandler:
438446
def __init__(self):
@@ -454,7 +462,7 @@ def __call__(self,
454462
self.op_id = operation_id
455463
self.reset()
456464

457-
self.p.update_ticks(stages, inflight_tasks)
465+
self.p.update_ticks(stages, inflight_tasks, done)
458466

459467
spark.clearProgressHandlers()
460468
spark.registerProgressHandler(ProgressHandler())

0 commit comments

Comments
 (0)