1414import tempfile
1515import textwrap
1616import threading
17+ import time
18+
19+ from rich .console import Console
20+ from rich .progress import BarColumn
21+ from rich .progress import Progress
22+ from rich .progress import SpinnerColumn
23+ from rich .progress import TextColumn
24+ from rich .progress import TimeElapsedColumn
25+ from rich .progress import TimeRemainingColumn
1726
1827import memray
1928from memray ._errors import MemrayCommandError
@@ -331,6 +340,50 @@ def recvall(sock: socket.socket) -> str:
331340 return b"" .join (iter (lambda : sock .recv (4096 ), b"" )).decode ("utf-8" )
332341
333342
343+ def show_progress_with_duration (duration : int , pid : int ) -> None :
344+ """Show a progress indicator while waiting for the specified duration.
345+
346+ Args:
347+ duration: Duration in seconds to wait
348+ pid: Process ID being tracked
349+ """
350+ console = Console ()
351+
352+ with Progress (
353+ SpinnerColumn (),
354+ TextColumn ("[bold blue]{task.description}" ),
355+ BarColumn (),
356+ TextColumn ("[progress.percentage]{task.percentage:>3.0f}%" ),
357+ TimeElapsedColumn (),
358+ TimeRemainingColumn (),
359+ console = console ,
360+ ) as progress :
361+ task = progress .add_task (
362+ f"Tracking process { pid } " , total = duration * 10 # 10 updates per second
363+ )
364+
365+ try :
366+ start_time = time .time ()
367+ while not progress .finished :
368+ elapsed = time .time () - start_time
369+ if elapsed >= duration :
370+ progress .update (task , completed = duration * 10 )
371+ break
372+
373+ progress .update (task , completed = int (elapsed * 10 ))
374+ time .sleep (0.1 )
375+
376+ except KeyboardInterrupt :
377+ console .print ()
378+ console .print (
379+ "[yellow]⚠ Interrupted! Tracking is still running in the background.[/yellow]"
380+ )
381+ console .print (
382+ f"[yellow] Use 'memray detach { pid } ' to stop tracking immediately.[/yellow]"
383+ )
384+ raise
385+
386+
334387class ErrorReaderThread (threading .Thread ):
335388 def __init__ (self , sock : socket .socket ) -> None :
336389 self ._sock = sock
@@ -488,6 +541,13 @@ def prepare_parser(self, parser: argparse.ArgumentParser) -> None:
488541 "--duration" , type = int , help = "Duration to track for (in seconds)"
489542 )
490543
544+ parser .add_argument (
545+ "--wait" ,
546+ help = "Wait for tracking to complete before exiting (use with --duration)" ,
547+ action = "store_true" ,
548+ default = False ,
549+ )
550+
491551 super ().prepare_parser (parser )
492552
493553 def run (self , args : argparse .Namespace , parser : argparse .ArgumentParser ) -> None :
@@ -503,9 +563,18 @@ def run(self, args: argparse.Namespace, parser: argparse.ArgumentParser) -> None
503563
504564 destination : memray .Destination
505565 if args .output :
566+ # Check if output file exists before doing any work
567+ output_path = pathlib .Path (args .output ).resolve ()
568+ if output_path .exists () and not args .force :
569+ raise MemrayCommandError (
570+ f"Output file already exists: { output_path } \n "
571+ f"Use --force to overwrite it." ,
572+ exit_code = 1 ,
573+ )
574+
506575 live_port = None
507576 destination = memray .FileDestination (
508- path = os . path . abspath ( args . output ),
577+ path = str ( output_path ),
509578 overwrite = args .force ,
510579 compress_on_exit = not args .no_compress ,
511580 )
@@ -530,23 +599,131 @@ def run(self, args: argparse.Namespace, parser: argparse.ArgumentParser) -> None
530599 f"{ file_format } )"
531600 )
532601
533- client = self .inject_control_channel (args .method , args .pid , verbose = verbose )
534- client .sendall (
535- PAYLOAD .format (
536- tracker_call = tracker_call ,
537- mode = mode ,
538- duration = duration ,
539- ).encode ("utf-8" )
540- )
541- client .shutdown (socket .SHUT_WR )
602+ console = Console ()
603+
604+ # Show detailed attaching progress with steps
605+ with Progress (
606+ SpinnerColumn (),
607+ TextColumn ("[bold blue]{task.description}" ),
608+ console = console ,
609+ transient = True , # Clear progress when done
610+ ) as progress :
611+ # Step 1: Resolve debugger method
612+ task1 = progress .add_task ("Resolving injection method..." , total = None )
613+ resolved_method = self .resolve_debugger (args .method , verbose = verbose )
614+ progress .update (
615+ task1 ,
616+ description = f"[green]✓[/green] Using { resolved_method } for injection" ,
617+ )
618+ progress .stop_task (task1 )
542619
620+ # Step 2: Inject control channel
621+ task2 = progress .add_task (
622+ f"Injecting into process { args .pid } using { resolved_method } ..." ,
623+ total = None ,
624+ )
625+ client = self .inject_control_channel (
626+ resolved_method , args .pid , verbose = verbose
627+ )
628+ progress .update (
629+ task2 , description = "[green]✓[/green] Control channel established"
630+ )
631+ progress .stop_task (task2 )
632+
633+ # Step 3: Send tracking payload
634+ task3 = progress .add_task ("Sending tracking configuration..." , total = None )
635+ client .sendall (
636+ PAYLOAD .format (
637+ tracker_call = tracker_call ,
638+ mode = mode ,
639+ duration = duration ,
640+ ).encode ("utf-8" )
641+ )
642+ client .shutdown (socket .SHUT_WR )
643+ progress .update (task3 , description = "[green]✓[/green] Configuration sent" )
644+ progress .stop_task (task3 )
645+
646+ # Step 4: Wait for confirmation
647+ if not live_port :
648+ task4 = progress .add_task (
649+ "Waiting for confirmation from process..." , total = None
650+ )
651+ err = recvall (client )
652+ if err :
653+ raise MemrayCommandError (
654+ f"Failed to start tracking in remote process: { err } " ,
655+ exit_code = 1 ,
656+ )
657+ progress .update (
658+ task4 ,
659+ description = "[green]✓[/green] Tracking activated in remote process" ,
660+ )
661+ progress .stop_task (task4 )
662+
663+ # Only show confirmation after attach succeeded
543664 if not live_port :
544- err = recvall (client )
545- if err :
546- raise MemrayCommandError (
547- f"Failed to start tracking in remote process: { err } " ,
548- exit_code = 1 ,
665+ console .print (
666+ f"[green]✓[/green] Successfully attached to process [bold]{ args .pid } [/bold]"
667+ )
668+ console .print (f" Output file: [cyan]{ args .output } [/cyan]" )
669+
670+ # If duration and --wait are specified, wait and show progress
671+ if duration and args .wait :
672+ console .print (f" Tracking for [bold]{ duration } [/bold] seconds..." )
673+ console .print () # Add blank line before progress bar
674+ try :
675+ show_progress_with_duration (duration , args .pid )
676+ console .print () # Add blank line after completion
677+ console .print (
678+ f"[green]✓[/green] Tracking complete. "
679+ f"Results saved to: [cyan]{ args .output } [/cyan]"
680+ )
681+ except KeyboardInterrupt :
682+ console .print (
683+ f"\n [yellow]⚠ Note: Tracking will continue in process "
684+ f"{ args .pid } until the duration expires.[/yellow]"
685+ )
686+ console .print (
687+ f"[yellow] Use 'memray detach { args .pid } ' "
688+ f"to stop tracking immediately.[/yellow]"
689+ )
690+ raise MemrayCommandError ("Interrupted by user" , exit_code = 130 )
691+ elif duration :
692+ # Duration specified but not waiting - show prominent info message
693+ console .print () # Blank line for emphasis
694+ console .print (
695+ "[blue]ℹ[/blue] This command will exit immediately, "
696+ "but tracking continues in the background."
549697 )
698+ console .print (
699+ f" The process will be tracked for [bold]{ duration } [/bold] "
700+ f"seconds and results will be saved to [cyan]{ args .output } [/cyan]."
701+ )
702+ console .print () # Blank line
703+ console .print (
704+ f" To stop tracking early: "
705+ f"[bold]memray detach { args .pid } [/bold]"
706+ )
707+ console .print (
708+ " To wait and see progress: "
709+ "Use the [bold]--wait[/bold] flag next time"
710+ )
711+ else :
712+ # No duration - indefinite tracking
713+ console .print () # Blank line for emphasis
714+ console .print (
715+ "[blue]ℹ[/blue] This command will exit immediately, "
716+ "but tracking continues indefinitely."
717+ )
718+ console .print (
719+ f" Results will be saved to [cyan]{ args .output } [/cyan] "
720+ f"when tracking stops."
721+ )
722+ console .print () # Blank line
723+ console .print (
724+ f" To stop tracking: " f"[bold]memray detach { args .pid } [/bold]"
725+ )
726+
550727 return
551728
552729 # If an error prevents the tracked process from binding a server to
@@ -585,21 +762,67 @@ class DetachCommand(_DebuggerCommand):
585762 def run (self , args : argparse .Namespace , parser : argparse .ArgumentParser ) -> None :
586763 verbose = args .verbose
587764 mode : TrackingMode = "DEACTIVATE"
588- args .method = self .resolve_debugger (args .method , verbose = verbose )
589- client = self .inject_control_channel (args .method , args .pid , verbose = verbose )
590-
591- client .sendall (
592- PAYLOAD .format (
593- tracker_call = None ,
594- mode = mode ,
595- duration = None ,
596- ).encode ("utf-8" )
597- )
598- client .shutdown (socket .SHUT_WR )
765+ console = Console ()
766+
767+ # Show detailed detaching progress with steps
768+ with Progress (
769+ SpinnerColumn (),
770+ TextColumn ("[bold blue]{task.description}" ),
771+ console = console ,
772+ transient = True , # Clear progress when done
773+ ) as progress :
774+ # Step 1: Resolve debugger method
775+ task1 = progress .add_task ("Resolving injection method..." , total = None )
776+ resolved_method = self .resolve_debugger (args .method , verbose = verbose )
777+ progress .update (
778+ task1 ,
779+ description = f"[green]✓[/green] Using { resolved_method } for injection" ,
780+ )
781+ progress .stop_task (task1 )
599782
600- err = recvall (client )
601- if err :
602- raise MemrayCommandError (
603- f"Failed to stop tracking in remote process: { err } " ,
604- exit_code = 1 ,
783+ # Step 2: Inject control channel
784+ task2 = progress .add_task (
785+ f"Connecting to process { args .pid } using { resolved_method } ..." ,
786+ total = None ,
787+ )
788+ client = self .inject_control_channel (
789+ resolved_method , args .pid , verbose = verbose
790+ )
791+ progress .update (
792+ task2 , description = "[green]✓[/green] Control channel established"
605793 )
794+ progress .stop_task (task2 )
795+
796+ # Step 3: Send detach command
797+ task3 = progress .add_task ("Sending stop tracking command..." , total = None )
798+ client .sendall (
799+ PAYLOAD .format (
800+ tracker_call = None ,
801+ mode = mode ,
802+ duration = None ,
803+ ).encode ("utf-8" )
804+ )
805+ client .shutdown (socket .SHUT_WR )
806+ progress .update (task3 , description = "[green]✓[/green] Stop command sent" )
807+ progress .stop_task (task3 )
808+
809+ # Step 4: Wait for confirmation
810+ task4 = progress .add_task (
811+ "Waiting for confirmation from process..." , total = None
812+ )
813+ err = recvall (client )
814+ if err :
815+ raise MemrayCommandError (
816+ f"Failed to stop tracking in remote process: { err } " ,
817+ exit_code = 1 ,
818+ )
819+ progress .update (
820+ task4 , description = "[green]✓[/green] Tracking stopped in remote process"
821+ )
822+ progress .stop_task (task4 )
823+
824+ # Show final confirmation
825+ console .print (
826+ f"[green]✓[/green] Successfully stopped tracking in process "
827+ f"[bold]{ args .pid } [/bold]"
828+ )
0 commit comments