Skip to content
This repository was archived by the owner on May 24, 2021. It is now read-only.

Commit b45f937

Browse files
authored
Merge pull request #8 from awwong1/issue-7
Sum over all occurrences per layer, display occurrences per layer
2 parents ebca8a1 + 64d3d63 commit b45f937

File tree

2 files changed

+93
-88
lines changed

2 files changed

+93
-88
lines changed

README.md

Lines changed: 78 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -24,53 +24,52 @@ with torchprof.Profile(model, use_cuda=True) as prof:
2424
print(prof.display(show_events=False)) # equivalent to `print(prof)` and `print(prof.display())`
2525
```
2626
```text
27-
Module | Self CPU total | CPU total | CUDA total
28-
---------------|----------------|-----------|-----------
29-
AlexNet | | |
30-
├── features | | |
31-
│├── 0 | 1.956ms | 7.714ms | 7.787ms
32-
│├── 1 | 68.880us | 68.880us | 69.632us
33-
│├── 2 | 85.639us | 155.948us | 155.648us
34-
│├── 3 | 253.419us | 970.386us | 1.747ms
35-
│├── 4 | 18.919us | 18.919us | 19.584us
36-
│├── 5 | 30.910us | 54.900us | 55.296us
37-
│├── 6 | 132.839us | 492.367us | 652.192us
38-
│├── 7 | 17.990us | 17.990us | 18.432us
39-
│├── 8 | 87.219us | 310.776us | 552.544us
40-
│├── 9 | 17.620us | 17.620us | 17.536us
41-
│├── 10 | 85.690us | 303.120us | 437.248us
42-
│├── 11 | 17.910us | 17.910us | 18.400us
43-
│└── 12 | 29.239us | 51.488us | 52.288us
44-
├── avgpool | 49.230us | 85.740us | 88.960us
45-
└── classifier | | |
46-
├── 0 | 626.236us | 1.239ms | 1.362ms
47-
├── 1 | 235.669us | 235.669us | 635.008us
48-
├── 2 | 17.990us | 17.990us | 18.432us
49-
├── 3 | 31.890us | 56.770us | 57.344us
50-
├── 4 | 39.280us | 39.280us | 212.128us
51-
├── 5 | 16.800us | 16.800us | 17.600us
52-
└── 6 | 38.459us | 38.459us | 79.872us
27+
Module | Self CPU total | CPU total | CUDA total | Occurrences
28+
---------------|----------------|-----------|------------|------------
29+
AlexNet | | | |
30+
├── features | | | |
31+
│├── 0 | 1.671ms | 6.589ms | 6.701ms | 1
32+
│├── 1 | 62.430us | 62.430us | 63.264us | 1
33+
│├── 2 | 62.909us | 109.948us | 112.640us | 1
34+
│├── 3 | 225.389us | 858.376us | 1.814ms | 1
35+
│├── 4 | 18.999us | 18.999us | 19.456us | 1
36+
│├── 5 | 29.560us | 52.720us | 54.272us | 1
37+
│├── 6 | 136.959us | 511.216us | 707.360us | 1
38+
│├── 7 | 18.480us | 18.480us | 18.624us | 1
39+
│├── 8 | 84.380us | 300.700us | 590.688us | 1
40+
│├── 9 | 18.249us | 18.249us | 17.632us | 1
41+
│├── 10 | 81.289us | 289.946us | 470.016us | 1
42+
│├── 11 | 17.850us | 17.850us | 18.432us | 1
43+
│└── 12 | 29.350us | 52.260us | 52.288us | 1
44+
├── avgpool | 41.840us | 70.840us | 76.832us | 1
45+
└── classifier | | | |
46+
├── 0 | 66.400us | 122.110us | 125.920us | 1
47+
├── 1 | 293.658us | 293.658us | 664.704us | 1
48+
├── 2 | 17.600us | 17.600us | 18.432us | 1
49+
├── 3 | 27.920us | 49.030us | 51.168us | 1
50+
├── 4 | 40.590us | 40.590us | 208.672us | 1
51+
├── 5 | 17.570us | 17.570us | 18.432us | 1
52+
└── 6 | 40.489us | 40.489us | 81.920us | 1
5353
```
5454

5555
To see the low level operations that occur within each layer, print the contents of `prof.display(show_events=True)`.
5656

5757
```text
58-
Module | Self CPU total | CPU total | CUDA total
59-
------------------------------|----------------|-----------|-----------
60-
AlexNet | | |
61-
├── features | | |
62-
│├── 0 | | |
63-
││├── conv2d | 15.740us | 1.956ms | 1.972ms
64-
││├── convolution | 12.000us | 1.940ms | 1.957ms
65-
││├── _convolution | 36.590us | 1.928ms | 1.946ms
66-
││├── contiguous | 6.600us | 6.600us | 6.464us
67-
││└── cudnn_convolution | 1.885ms | 1.885ms | 1.906ms
68-
│├── 1 | | |
69-
││└── relu_ | 68.880us | 68.880us | 69.632us
70-
│├── 2 | | |
71-
││├── max_pool2d | 15.330us | 85.639us | 84.992us
72-
││└── max_pool2d_with_indices | 70.309us | 70.309us | 70.656us
73-
│├── 3 | | |
58+
Module | Self CPU total | CPU total | CUDA total | Occurrences
59+
------------------------------|----------------|-----------|------------|------------
60+
AlexNet | | | |
61+
├── features | | | |
62+
│├── 0 | | | |
63+
││├── conv2d | 13.370us | 1.671ms | 1.698ms | 1
64+
││├── convolution | 12.730us | 1.658ms | 1.685ms | 1
65+
││├── _convolution | 30.660us | 1.645ms | 1.673ms | 1
66+
││├── contiguous | 6.970us | 6.970us | 7.136us | 1
67+
││└── cudnn_convolution | 1.608ms | 1.608ms | 1.638ms | 1
68+
│├── 1 | | | |
69+
││└── relu_ | 62.430us | 62.430us | 63.264us | 1
70+
│├── 2 | | | |
71+
││├── max_pool2d | 15.870us | 62.909us | 63.488us | 1
72+
││└── max_pool2d_with_indices | 47.039us | 47.039us | 49.152us | 1
7473
...
7574
```
7675

@@ -85,17 +84,17 @@ print(trace[2])
8584
print(event_lists_dict[trace[2].path][0])
8685
```
8786
```text
88-
--------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- ---------------
89-
Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg CUDA total % CUDA total CUDA time avg Number of Calls
90-
--------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- ---------------
91-
conv2d 0.80% 15.740us 100.00% 1.956ms 1.956ms 25.32% 1.972ms 1.972ms 1
92-
convolution 0.61% 12.000us 99.20% 1.940ms 1.940ms 25.14% 1.957ms 1.957ms 1
93-
_convolution 1.87% 36.590us 98.58% 1.928ms 1.928ms 24.99% 1.946ms 1.946ms 1
94-
contiguous 0.34% 6.600us 0.34% 6.600us 6.600us 0.08% 6.464us 6.464us 1
95-
cudnn_convolution 96.37% 1.885ms 96.37% 1.885ms 1.885ms 24.47% 1.906ms 1.906ms 1
96-
--------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- ---------------
97-
Self CPU time total: 1.956ms
98-
CUDA time total: 7.787ms
87+
--------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- -----------------------------------
88+
Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg CUDA total % CUDA total CUDA time avg Number of Calls Input Shapes
89+
--------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- -----------------------------------
90+
conv2d 0.80% 13.370us 100.00% 1.671ms 1.671ms 25.34% 1.698ms 1.698ms 1 []
91+
convolution 0.76% 12.730us 99.20% 1.658ms 1.658ms 25.15% 1.685ms 1.685ms 1 []
92+
_convolution 1.83% 30.660us 98.44% 1.645ms 1.645ms 24.97% 1.673ms 1.673ms 1 []
93+
contiguous 0.42% 6.970us 0.42% 6.970us 6.970us 0.11% 7.136us 7.136us 1 []
94+
cudnn_convolution 96.19% 1.608ms 96.19% 1.608ms 1.608ms 24.44% 1.638ms 1.638ms 1 []
95+
--------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- -----------------------------------
96+
Self CPU time total: 1.671ms
97+
CUDA time total: 6.701ms
9998
10099
```
101100

@@ -115,32 +114,32 @@ print(prof)
115114
```
116115

117116
```text
118-
Module | Self CPU total | CPU total | CUDA total
119-
---------------|----------------|-----------|-----------
120-
AlexNet | | |
121-
├── features | | |
122-
│├── 0 | | |
123-
│├── 1 | | |
124-
│├── 2 | | |
125-
│├── 3 | 2.846ms | 11.368ms | 0.000us
126-
│├── 4 | | |
127-
│├── 5 | | |
128-
│├── 6 | | |
129-
│├── 7 | | |
130-
│├── 8 | | |
131-
│├── 9 | | |
132-
│├── 10 | | |
133-
│├── 11 | | |
134-
│└── 12 | | |
135-
├── avgpool | | |
136-
└── classifier | 12.016ms | 12.206ms | 0.000us
137-
├── 0 | | |
138-
├── 1 | | |
139-
├── 2 | | |
140-
├── 3 | | |
141-
├── 4 | | |
142-
├── 5 | | |
143-
└── 6 | | |
117+
Module | Self CPU total | CPU total | CUDA total | Occurrences
118+
---------------|----------------|-----------|------------|------------
119+
AlexNet | | | |
120+
├── features | | | |
121+
│├── 0 | | | |
122+
│├── 1 | | | |
123+
│├── 2 | | | |
124+
│├── 3 | 3.189ms | 12.717ms | 0.000us | 1
125+
│├── 4 | | | |
126+
│├── 5 | | | |
127+
│├── 6 | | | |
128+
│├── 7 | | | |
129+
│├── 8 | | | |
130+
│├── 9 | | | |
131+
│├── 10 | | | |
132+
│├── 11 | | | |
133+
│└── 12 | | | |
134+
├── avgpool | | | |
135+
└── classifier | 13.403ms | 14.011ms | 0.000us | 1
136+
├── 0 | | | |
137+
├── 1 | | | |
138+
├── 2 | | | |
139+
├── 3 | | | |
140+
├── 4 | | | |
141+
├── 5 | | | |
142+
└── 6 | | | |
144143
145144
```
146145

torchprof/profile.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections import namedtuple, defaultdict, OrderedDict
44

55
Trace = namedtuple("Trace", ["path", "leaf", "module"])
6-
Measure = namedtuple("Measure", ["self_cpu_total", "cpu_total", "cuda_total"])
6+
Measure = namedtuple("Measure", ["self_cpu_total", "cpu_total", "cuda_total", "occurrences"])
77

88

99
def walk_modules(module, name="", path=()):
@@ -133,16 +133,18 @@ def traces_to_display(traces, trace_events, show_events=False, paths=None):
133133
for event in events:
134134
current_tree[name][event.name] = {
135135
None: Measure(
136-
event.self_cpu_time_total,
137-
event.cpu_time_total,
138-
event.cuda_time_total,
136+
sum([e.self_cpu_time_total for e in events if e.name == event.name]),
137+
sum([e.cpu_time_total for e in events if e.name == event.name]),
138+
sum([e.cuda_time_total for e in events if e.name == event.name]),
139+
len([e for e in events if e.name == event.name])
139140
)
140141
}
141142
else:
142143
current_tree[name][None] = Measure(
143144
sum([e.self_cpu_time_total for e in events]),
144145
sum([e.cpu_time_total for e in events]),
145146
sum([e.cuda_time_total for e in events]),
147+
len(trace_events[path])
146148
)
147149
current_tree = current_tree[name]
148150
tree_lines = flatten_tree(tree)
@@ -155,10 +157,12 @@ def traces_to_display(traces, trace_events, show_events=False, paths=None):
155157
self_cpu_time = ""
156158
cpu_time = ""
157159
cuda_time = ""
160+
occurrences = ""
158161
if measures:
159162
self_cpu_time = tprofiler.format_time(measures.self_cpu_total)
160163
cpu_time = tprofiler.format_time(measures.cpu_total)
161164
cuda_time = tprofiler.format_time(measures.cuda_total)
165+
occurrences = str(measures.occurrences)
162166
pre = ""
163167
next_depths = [pl[0] for pl in tree_lines[idx + 1 :]]
164168
current = True
@@ -175,22 +179,24 @@ def traces_to_display(traces, trace_events, show_events=False, paths=None):
175179
pre = dt[3] + pre
176180
depth -= 1
177181
current = False
178-
format_lines.append([pre + name, self_cpu_time, cpu_time, cuda_time])
182+
format_lines.append([pre + name, self_cpu_time, cpu_time, cuda_time, occurrences])
179183

180184
# construct the table
181-
heading = ("Module", "Self CPU total", "CPU total", "CUDA total")
185+
heading = ("Module", "Self CPU total", "CPU total", "CUDA total", "Occurrences")
182186
max_lens = [max(map(len, col)) for col in zip(*([heading] + format_lines))]
183187
# create the heading
184188
disp = "{:<{}s}".format(heading[0], max_lens[0]) + " | "
185189
disp += "{:>{}s}".format(heading[1], max_lens[1]) + " | "
186190
disp += "{:>{}s}".format(heading[2], max_lens[2]) + " | "
187-
disp += "{:>{}s}".format(heading[3], max_lens[3]) + "\n"
191+
disp += "{:>{}s}".format(heading[3], max_lens[3]) + " | "
192+
disp += "{:>{}s}".format(heading[4], max_lens[4]) + "\n"
188193
disp += "-|-".join(["-" * mlen for mlen in max_lens]) + "\n"
189194
for line in format_lines:
190-
label, self_cpu_time, cpu_time, cuda_time = line
195+
label, self_cpu_time, cpu_time, cuda_time, occurrences = line
191196
disp += "{:<{}s}".format(label, max_lens[0]) + " | "
192197
disp += "{:>{}s}".format(self_cpu_time, max_lens[1]) + " | "
193198
disp += "{:>{}s}".format(cpu_time, max_lens[2]) + " | "
194-
disp += "{:>{}s}".format(cuda_time, max_lens[3]) + "\n"
199+
disp += "{:>{}s}".format(cuda_time, max_lens[3]) + " | "
200+
disp += "{:>{}s}".format(occurrences, max_lens[4]) + "\n"
195201

196202
return disp

0 commit comments

Comments
 (0)