Skip to content

Commit 6bb49a4

Browse files
committed
Add sampling profiler
1 parent 3224bb2 commit 6bb49a4

File tree

1 file changed

+169
-0
lines changed

1 file changed

+169
-0
lines changed

Lib/profile/sample.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import collections
2+
import marshal
3+
import pstats
4+
import time
5+
import _remote_debugging
6+
import argparse
7+
8+
9+
class SampleProfile:
10+
def __init__(self, pid, sample_interval_usec, all_threads):
11+
self.pid = pid
12+
self.sample_interval_usec = sample_interval_usec
13+
self.all_threads = all_threads
14+
self.unwinder = _remote_debugging.RemoteUnwinder(
15+
self.pid, all_threads=self.all_threads
16+
)
17+
self.stats = {}
18+
self.callers = collections.defaultdict(
19+
lambda: collections.defaultdict(int)
20+
)
21+
22+
def sample(self, duration_sec=10):
23+
result = collections.defaultdict(
24+
lambda: dict(total_calls=0, total_rec_calls=0, inline_calls=0)
25+
)
26+
sample_interval_sec = self.sample_interval_usec / 1_000_000
27+
28+
running_time = 0
29+
num_samples = 0
30+
errors = 0
31+
start_time = next_time = time.perf_counter()
32+
while running_time < duration_sec:
33+
if next_time < time.perf_counter():
34+
try:
35+
stack_frames = self.unwinder.get_stack_trace()
36+
self.aggregate_stack_frames(result, stack_frames)
37+
except RuntimeError, UnicodeDecodeError, OSError:
38+
errors += 1
39+
40+
num_samples += 1
41+
next_time += sample_interval_sec
42+
43+
running_time = time.perf_counter() - start_time
44+
45+
print(f"Captured {num_samples} samples in {running_time:.2f} seconds")
46+
print(f"Sample rate: {num_samples / running_time:.2f} samples/sec")
47+
print(f"Error rate: {(errors / num_samples) * 100:.2f}%")
48+
49+
expected_samples = int(duration_sec / sample_interval_sec)
50+
if num_samples < expected_samples:
51+
print(
52+
f"Warning: missed {expected_samples - num_samples} samples "
53+
f"from the expected total of {expected_samples} "
54+
f"({(expected_samples - num_samples) / expected_samples * 100:.2f}%)"
55+
)
56+
57+
self.stats = self.convert_to_pstats(result)
58+
59+
def print_stats(self, sort=-1):
60+
if not isinstance(sort, tuple):
61+
sort = (sort,)
62+
pstats.Stats(self).strip_dirs().sort_stats(*sort).print_stats()
63+
64+
def dump_stats(self, file):
65+
with open(file, "wb") as f:
66+
marshal.dump(self.stats, f)
67+
68+
# Needed for compatibility with pstats.Stats
69+
def create_stats(self):
70+
pass
71+
72+
def convert_to_pstats(self, raw_results):
73+
sample_interval_sec = self.sample_interval_usec / 1_000_000
74+
pstats = {}
75+
callers = {}
76+
for fname, call_counts in raw_results.items():
77+
total = call_counts["inline_calls"] * sample_interval_sec
78+
cumulative = call_counts["total_calls"] * sample_interval_sec
79+
callers = dict(self.callers.get(fname, {}))
80+
pstats[fname] = (
81+
call_counts["total_calls"],
82+
call_counts["total_rec_calls"]
83+
if call_counts["total_rec_calls"]
84+
else call_counts["total_calls"],
85+
total,
86+
cumulative,
87+
callers,
88+
)
89+
90+
return pstats
91+
92+
def aggregate_stack_frames(self, result, stack_frames):
93+
for thread_id, frames in stack_frames:
94+
if not frames:
95+
continue
96+
top_location = frames[0]
97+
result[top_location]["inline_calls"] += 1
98+
result[top_location]["total_calls"] += 1
99+
100+
for i in range(1, len(frames)):
101+
callee = frames[i - 1]
102+
caller = frames[i]
103+
self.callers[callee][caller] += 1
104+
105+
if len(frames) <= 1:
106+
continue
107+
108+
for location in frames[1:]:
109+
result[location]["total_calls"] += 1
110+
if top_location == location:
111+
result[location]["total_rec_calls"] += 1
112+
113+
114+
def sample(
115+
pid,
116+
*,
117+
sort=-1,
118+
sample_interval_usec=100,
119+
duration_sec=10,
120+
filename=None,
121+
all_threads=False,
122+
):
123+
profile = SampleProfile(pid, sample_interval_usec, all_threads=False)
124+
profile.sample(duration_sec)
125+
if filename:
126+
profile.dump_stats(filename)
127+
else:
128+
profile.print_stats(sort)
129+
130+
131+
def main():
132+
parser = argparse.ArgumentParser(
133+
description="Sample a process's stack frames.", color=True
134+
)
135+
parser.add_argument("pid", type=int, help="Process ID to sample.")
136+
parser.add_argument(
137+
"-i",
138+
"--interval",
139+
type=int,
140+
default=10,
141+
help="Sampling interval in microseconds (default: 10 usec)",
142+
)
143+
parser.add_argument(
144+
"-d",
145+
"--duration",
146+
type=int,
147+
default=10,
148+
help="Sampling duration in seconds (default: 10 seconds)",
149+
)
150+
parser.add_argument(
151+
"-a",
152+
"--all-threads",
153+
action="store_true",
154+
help="Sample all threads in the process",
155+
)
156+
parser.add_argument("-o", "--outfile", help="Save stats to <outfile>")
157+
args = parser.parse_args()
158+
159+
sample(
160+
args.pid,
161+
sample_interval_usec=args.interval,
162+
duration_sec=args.duration,
163+
filename=args.outfile,
164+
all_threads=args.all_threads,
165+
)
166+
167+
168+
if __name__ == "__main__":
169+
main()

0 commit comments

Comments
 (0)