Skip to content

Commit 8313575

Browse files
committed
Improve UX for memray attach and detach commands
Users reported confusion when using memray attach with the --duration flag because the command would exit immediately without any indication that tracking had started successfully or was still running. This made it difficult to know if the attach operation worked, and if interrupted, there was no clear feedback about what would happen to the background tracking. Additionally, users would waste time going through the entire attach process only to discover at the end that their output file already existed, requiring them to restart with the --force flag. This change addresses these issues by checking the output file existence upfront before any injection work begins, and by showing detailed progress through each phase of the attach and detach operations. Users can now see exactly which step is in progress, making it easier to diagnose slow or stuck operations. For users who want the command to wait until tracking completes, a new --wait flag is provided that displays a live progress bar showing time elapsed and remaining. This makes the behavior more predictable while maintaining backward compatibility for users who expect the command to return immediately. Fixes #831 Fixes #701 Signed-off-by: Pablo Galindo <[email protected]>
1 parent 95fcf68 commit 8313575

File tree

1 file changed

+254
-31
lines changed

1 file changed

+254
-31
lines changed

src/memray/commands/attach.py

Lines changed: 254 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@
1414
import tempfile
1515
import textwrap
1616
import threading
17+
import time
18+
19+
from rich.console import Console
20+
from rich.progress import BarColumn
21+
from rich.progress import Progress
22+
from rich.progress import SpinnerColumn
23+
from rich.progress import TextColumn
24+
from rich.progress import TimeElapsedColumn
25+
from rich.progress import TimeRemainingColumn
1726

1827
import memray
1928
from memray._errors import MemrayCommandError
@@ -331,6 +340,50 @@ def recvall(sock: socket.socket) -> str:
331340
return b"".join(iter(lambda: sock.recv(4096), b"")).decode("utf-8")
332341

333342

343+
def show_progress_with_duration(duration: int, pid: int) -> None:
344+
"""Show a progress indicator while waiting for the specified duration.
345+
346+
Args:
347+
duration: Duration in seconds to wait
348+
pid: Process ID being tracked
349+
"""
350+
console = Console()
351+
352+
with Progress(
353+
SpinnerColumn(),
354+
TextColumn("[bold blue]{task.description}"),
355+
BarColumn(),
356+
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
357+
TimeElapsedColumn(),
358+
TimeRemainingColumn(),
359+
console=console,
360+
) as progress:
361+
task = progress.add_task(
362+
f"Tracking process {pid}", total=duration * 10 # 10 updates per second
363+
)
364+
365+
try:
366+
start_time = time.time()
367+
while not progress.finished:
368+
elapsed = time.time() - start_time
369+
if elapsed >= duration:
370+
progress.update(task, completed=duration * 10)
371+
break
372+
373+
progress.update(task, completed=int(elapsed * 10))
374+
time.sleep(0.1)
375+
376+
except KeyboardInterrupt:
377+
console.print()
378+
console.print(
379+
"[yellow]⚠ Interrupted! Tracking is still running in the background.[/yellow]"
380+
)
381+
console.print(
382+
f"[yellow] Use 'memray detach {pid}' to stop tracking immediately.[/yellow]"
383+
)
384+
raise
385+
386+
334387
class ErrorReaderThread(threading.Thread):
335388
def __init__(self, sock: socket.socket) -> None:
336389
self._sock = sock
@@ -488,6 +541,13 @@ def prepare_parser(self, parser: argparse.ArgumentParser) -> None:
488541
"--duration", type=int, help="Duration to track for (in seconds)"
489542
)
490543

544+
parser.add_argument(
545+
"--wait",
546+
help="Wait for tracking to complete before exiting (use with --duration)",
547+
action="store_true",
548+
default=False,
549+
)
550+
491551
super().prepare_parser(parser)
492552

493553
def run(self, args: argparse.Namespace, parser: argparse.ArgumentParser) -> None:
@@ -503,9 +563,18 @@ def run(self, args: argparse.Namespace, parser: argparse.ArgumentParser) -> None
503563

504564
destination: memray.Destination
505565
if args.output:
566+
# Check if output file exists before doing any work
567+
output_path = pathlib.Path(args.output).resolve()
568+
if output_path.exists() and not args.force:
569+
raise MemrayCommandError(
570+
f"Output file already exists: {output_path}\n"
571+
f"Use --force to overwrite it.",
572+
exit_code=1,
573+
)
574+
506575
live_port = None
507576
destination = memray.FileDestination(
508-
path=os.path.abspath(args.output),
577+
path=str(output_path),
509578
overwrite=args.force,
510579
compress_on_exit=not args.no_compress,
511580
)
@@ -530,23 +599,131 @@ def run(self, args: argparse.Namespace, parser: argparse.ArgumentParser) -> None
530599
f"{file_format})"
531600
)
532601

533-
client = self.inject_control_channel(args.method, args.pid, verbose=verbose)
534-
client.sendall(
535-
PAYLOAD.format(
536-
tracker_call=tracker_call,
537-
mode=mode,
538-
duration=duration,
539-
).encode("utf-8")
540-
)
541-
client.shutdown(socket.SHUT_WR)
602+
console = Console()
603+
604+
# Show detailed attaching progress with steps
605+
with Progress(
606+
SpinnerColumn(),
607+
TextColumn("[bold blue]{task.description}"),
608+
console=console,
609+
transient=True, # Clear progress when done
610+
) as progress:
611+
# Step 1: Resolve debugger method
612+
task1 = progress.add_task("Resolving injection method...", total=None)
613+
resolved_method = self.resolve_debugger(args.method, verbose=verbose)
614+
progress.update(
615+
task1,
616+
description=f"[green]✓[/green] Using {resolved_method} for injection",
617+
)
618+
progress.stop_task(task1)
542619

620+
# Step 2: Inject control channel
621+
task2 = progress.add_task(
622+
f"Injecting into process {args.pid} using {resolved_method}...",
623+
total=None,
624+
)
625+
client = self.inject_control_channel(
626+
resolved_method, args.pid, verbose=verbose
627+
)
628+
progress.update(
629+
task2, description="[green]✓[/green] Control channel established"
630+
)
631+
progress.stop_task(task2)
632+
633+
# Step 3: Send tracking payload
634+
task3 = progress.add_task("Sending tracking configuration...", total=None)
635+
client.sendall(
636+
PAYLOAD.format(
637+
tracker_call=tracker_call,
638+
mode=mode,
639+
duration=duration,
640+
).encode("utf-8")
641+
)
642+
client.shutdown(socket.SHUT_WR)
643+
progress.update(task3, description="[green]✓[/green] Configuration sent")
644+
progress.stop_task(task3)
645+
646+
# Step 4: Wait for confirmation
647+
if not live_port:
648+
task4 = progress.add_task(
649+
"Waiting for confirmation from process...", total=None
650+
)
651+
err = recvall(client)
652+
if err:
653+
raise MemrayCommandError(
654+
f"Failed to start tracking in remote process: {err}",
655+
exit_code=1,
656+
)
657+
progress.update(
658+
task4,
659+
description="[green]✓[/green] Tracking activated in remote process",
660+
)
661+
progress.stop_task(task4)
662+
663+
# Only show confirmation after attach succeeded
543664
if not live_port:
544-
err = recvall(client)
545-
if err:
546-
raise MemrayCommandError(
547-
f"Failed to start tracking in remote process: {err}",
548-
exit_code=1,
665+
console.print(
666+
f"[green]✓[/green] Successfully attached to process [bold]{args.pid}[/bold]"
667+
)
668+
console.print(f" Output file: [cyan]{args.output}[/cyan]")
669+
670+
# If duration and --wait are specified, wait and show progress
671+
if duration and args.wait:
672+
console.print(f" Tracking for [bold]{duration}[/bold] seconds...")
673+
console.print() # Add blank line before progress bar
674+
try:
675+
show_progress_with_duration(duration, args.pid)
676+
console.print() # Add blank line after completion
677+
console.print(
678+
f"[green]✓[/green] Tracking complete. "
679+
f"Results saved to: [cyan]{args.output}[/cyan]"
680+
)
681+
except KeyboardInterrupt:
682+
console.print(
683+
f"\n[yellow]⚠ Note: Tracking will continue in process "
684+
f"{args.pid} until the duration expires.[/yellow]"
685+
)
686+
console.print(
687+
f"[yellow] Use 'memray detach {args.pid}' "
688+
f"to stop tracking immediately.[/yellow]"
689+
)
690+
raise MemrayCommandError("Interrupted by user", exit_code=130)
691+
elif duration:
692+
# Duration specified but not waiting - show prominent info message
693+
console.print() # Blank line for emphasis
694+
console.print(
695+
"[blue]ℹ[/blue] This command will exit immediately, "
696+
"but tracking continues in the background."
549697
)
698+
console.print(
699+
f" The process will be tracked for [bold]{duration}[/bold] "
700+
f"seconds and results will be saved to [cyan]{args.output}[/cyan]."
701+
)
702+
console.print() # Blank line
703+
console.print(
704+
f" To stop tracking early: "
705+
f"[bold]memray detach {args.pid}[/bold]"
706+
)
707+
console.print(
708+
" To wait and see progress: "
709+
"Use the [bold]--wait[/bold] flag next time"
710+
)
711+
else:
712+
# No duration - indefinite tracking
713+
console.print() # Blank line for emphasis
714+
console.print(
715+
"[blue]ℹ[/blue] This command will exit immediately, "
716+
"but tracking continues indefinitely."
717+
)
718+
console.print(
719+
f" Results will be saved to [cyan]{args.output}[/cyan] "
720+
f"when tracking stops."
721+
)
722+
console.print() # Blank line
723+
console.print(
724+
f" To stop tracking: " f"[bold]memray detach {args.pid}[/bold]"
725+
)
726+
550727
return
551728

552729
# If an error prevents the tracked process from binding a server to
@@ -585,21 +762,67 @@ class DetachCommand(_DebuggerCommand):
585762
def run(self, args: argparse.Namespace, parser: argparse.ArgumentParser) -> None:
586763
verbose = args.verbose
587764
mode: TrackingMode = "DEACTIVATE"
588-
args.method = self.resolve_debugger(args.method, verbose=verbose)
589-
client = self.inject_control_channel(args.method, args.pid, verbose=verbose)
590-
591-
client.sendall(
592-
PAYLOAD.format(
593-
tracker_call=None,
594-
mode=mode,
595-
duration=None,
596-
).encode("utf-8")
597-
)
598-
client.shutdown(socket.SHUT_WR)
765+
console = Console()
766+
767+
# Show detailed detaching progress with steps
768+
with Progress(
769+
SpinnerColumn(),
770+
TextColumn("[bold blue]{task.description}"),
771+
console=console,
772+
transient=True, # Clear progress when done
773+
) as progress:
774+
# Step 1: Resolve debugger method
775+
task1 = progress.add_task("Resolving injection method...", total=None)
776+
resolved_method = self.resolve_debugger(args.method, verbose=verbose)
777+
progress.update(
778+
task1,
779+
description=f"[green]✓[/green] Using {resolved_method} for injection",
780+
)
781+
progress.stop_task(task1)
599782

600-
err = recvall(client)
601-
if err:
602-
raise MemrayCommandError(
603-
f"Failed to stop tracking in remote process: {err}",
604-
exit_code=1,
783+
# Step 2: Inject control channel
784+
task2 = progress.add_task(
785+
f"Connecting to process {args.pid} using {resolved_method}...",
786+
total=None,
787+
)
788+
client = self.inject_control_channel(
789+
resolved_method, args.pid, verbose=verbose
790+
)
791+
progress.update(
792+
task2, description="[green]✓[/green] Control channel established"
605793
)
794+
progress.stop_task(task2)
795+
796+
# Step 3: Send detach command
797+
task3 = progress.add_task("Sending stop tracking command...", total=None)
798+
client.sendall(
799+
PAYLOAD.format(
800+
tracker_call=None,
801+
mode=mode,
802+
duration=None,
803+
).encode("utf-8")
804+
)
805+
client.shutdown(socket.SHUT_WR)
806+
progress.update(task3, description="[green]✓[/green] Stop command sent")
807+
progress.stop_task(task3)
808+
809+
# Step 4: Wait for confirmation
810+
task4 = progress.add_task(
811+
"Waiting for confirmation from process...", total=None
812+
)
813+
err = recvall(client)
814+
if err:
815+
raise MemrayCommandError(
816+
f"Failed to stop tracking in remote process: {err}",
817+
exit_code=1,
818+
)
819+
progress.update(
820+
task4, description="[green]✓[/green] Tracking stopped in remote process"
821+
)
822+
progress.stop_task(task4)
823+
824+
# Show final confirmation
825+
console.print(
826+
f"[green]✓[/green] Successfully stopped tracking in process "
827+
f"[bold]{args.pid}[/bold]"
828+
)

0 commit comments

Comments
 (0)