33import json
44from typing import Any , Union , List
55import os
6+ import sys
67import time
78import shlex
89import warnings
910import tempfile
1011
12+ # prevent sum from pyskaprk.sql.functions from shadowing the builtin sum
13+ builtinSum = sys .modules ['builtins' ].sum
1114
1215def logError (function_name : str , e : Union [str , Exception ]):
1316 import sys
@@ -403,14 +406,20 @@ def init_ui(self):
403406 def update_ticks (
404407 self ,
405408 stages ,
406- inflight_tasks : int
409+ inflight_tasks : int ,
410+ done : bool
407411 ) -> None :
408- total_tasks = sum (map (lambda x : x .num_tasks , stages ))
409- completed_tasks = sum (map (lambda x : x .num_completed_tasks , stages ))
412+ total_tasks = builtinSum (map (lambda x : x .num_tasks , stages ))
413+ completed_tasks = builtinSum (map (lambda x : x .num_completed_tasks , stages ))
410414 if total_tasks > 0 :
411415 self ._ticks = total_tasks
412416 self ._tick = completed_tasks
413- self ._bytes_read = sum (map (lambda x : x .num_bytes_read , stages ))
417+ self ._bytes_read = builtinSum (map (lambda x : x .num_bytes_read , stages ))
418+
419+ if done :
420+ self ._tick = self ._ticks
421+ self ._running = 0
422+
414423 if self ._tick is not None and self ._tick >= 0 :
415424 self .output ()
416425 self ._running = inflight_tasks
@@ -432,7 +441,6 @@ def _bytes_to_string(size: int) -> str:
432441 i += 1
433442 result = float (size ) / Progress .SI_BYTE_SIZES [i ]
434443 return f"{ result :.1f} { Progress .SI_BYTE_SUFFIXES [i ]} "
435-
436444
437445 class ProgressHandler :
438446 def __init__ (self ):
@@ -454,7 +462,7 @@ def __call__(self,
454462 self .op_id = operation_id
455463 self .reset ()
456464
457- self .p .update_ticks (stages , inflight_tasks )
465+ self .p .update_ticks (stages , inflight_tasks , done )
458466
459467 spark .clearProgressHandlers ()
460468 spark .registerProgressHandler (ProgressHandler ())
0 commit comments