Skip to content

Commit 4e7733a

Browse files
align the lint check
1 parent 1192cb6 commit 4e7733a

File tree

2 files changed

+60
-25
lines changed

2 files changed

+60
-25
lines changed

test/microbench/adaptive_avg_pool2d.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import time
21
import argparse
2+
import time
33

44
import torch
55
from torch.profiler import profile, ProfilerActivity
@@ -46,17 +46,20 @@ def Adaptive_AVGPool2d(shape, dtype, channels_last, backward, device):
4646
if backward:
4747
output[0].backward(grad)
4848

49+
4950
def run_profile(shape, dtype, channels_last, backward, device, num_iter):
5051
with profile(
51-
activities=[ProfilerActivity.CPU,
52-
ProfilerActivity.XPU if device == "xpu" else ProfilerActivity.CUDA,
52+
activities=[
53+
ProfilerActivity.CPU,
54+
ProfilerActivity.XPU if device == "xpu" else ProfilerActivity.CUDA,
5355
],
5456
record_shapes=True,
5557
) as prof:
5658
for i in range(num_iter):
5759
Adaptive_AVGPool2d(shape, dtype, channels_last, backward, device)
5860
print(prof.key_averages().table(sort_by=f"{device}_time_total"))
5961

62+
6063
def run_e2e(shape, dtype, channels_last, backward, device, num_iter):
6164
if device in ["xpu", "cuda"]:
6265
torch.xpu.synchronize() if device == "xpu" else torch.cuda.synchronize()
@@ -69,6 +72,7 @@ def run_e2e(shape, dtype, channels_last, backward, device, num_iter):
6972
e2e_time = (t2 - t1) / num_iter
7073
print("E2E total time:", f"{float(e2e_time):.20f}")
7174

75+
7276
def benchmark(args):
7377
for shape in shape_list:
7478
for dtype in [torch.bfloat16, torch.float16, torch.float32]:
@@ -109,13 +113,14 @@ def benchmark(args):
109113
args.num_iter,
110114
)
111115

116+
112117
def parse_args():
113118
parser = argparse.ArgumentParser(description="OP Benchmark")
114119
parser.add_argument(
115120
"--device",
116121
type=str,
117-
default='xpu',
118-
help='Device to run on (e.g., "cpu", "cuda", "xpu")'
122+
default="xpu",
123+
help='Device to run on (e.g., "cpu", "cuda", "xpu")',
119124
)
120125
group = parser.add_mutually_exclusive_group()
121126
group.add_argument(
@@ -125,6 +130,7 @@ def parse_args():
125130
parser.add_argument("--num-iter", type=int, default=20, help="Number of iterations")
126131
return parser.parse_args()
127132

133+
128134
if __name__ == "__main__":
129135
args = parse_args()
130136
benchmark(args)
Lines changed: 49 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import time
21
import argparse
2+
import time
3+
34
import torch
45
from torch.profiler import profile, ProfilerActivity
56

@@ -13,36 +14,43 @@ def Index_fill(input, indices, dim, device):
1314
else:
1415
output = input.index_fill(dim, indices, 2)
1516

17+
1618
def run_profile(input, indices, dim, cache_r, cache_w, device, num_iter):
1719
with profile(
18-
activities=[ProfilerActivity.CPU,
19-
ProfilerActivity.XPU if device == 'xpu' else ProfilerActivity.CUDA],
20+
activities=[
21+
ProfilerActivity.CPU,
22+
ProfilerActivity.XPU if device == "xpu" else ProfilerActivity.CUDA,
23+
],
2024
record_shapes=True,
2125
) as prof:
2226
for i in range(num_iter):
2327
cache_r = cache_w * i
2428
Index_fill(input, indices, dim, device)
25-
print(prof.key_averages().table(sort_by="{}_time_total".format(device)))
29+
print(prof.key_averages().table(sort_by=f"{device}_time_total"))
30+
2631

2732
def run_e2e(input, indices, dim, cache_r, cache_w, device, num_iter):
28-
if device in ['xpu', 'cuda']:
29-
torch.xpu.synchronize() if device == 'xpu' else torch.cuda.synchronize()
33+
if device in ["xpu", "cuda"]:
34+
torch.xpu.synchronize() if device == "xpu" else torch.cuda.synchronize()
3035
t1 = time.time()
3136
for i in range(num_iter):
3237
cache_r = cache_w * i
3338
Index_fill(input, indices, dim, device)
34-
if device in ['xpu', 'cuda']:
35-
torch.xpu.synchronize() if device == 'xpu' else torch.cuda.synchronize()
39+
if device in ["xpu", "cuda"]:
40+
torch.xpu.synchronize() if device == "xpu" else torch.cuda.synchronize()
3641
t2 = time.time()
3742
e2e_time = (t2 - t1) / num_iter
3843
print("E2E total time:", f"{float(e2e_time):.20f}")
3944

45+
4046
def benchmark(args):
4147
for shape in shape_list:
4248
for dtype in [torch.bfloat16, torch.float16, torch.float32]:
4349
for dim in [0, 1]:
4450
input = torch.zeros(shape, dtype=dtype, device=args.device)
45-
indices = torch.linspace(0, 1022, steps=512, device=args.device).to(torch.long)
51+
indices = torch.linspace(0, 1022, steps=512, device=args.device).to(
52+
torch.long
53+
)
4654
y_0 = torch.ones((512, 1024), dtype=dtype, device=args.device)
4755
y_1 = torch.randn((1024, 512), dtype=dtype, device=args.device)
4856
cache_r = torch.randn((1024 * 1024 * 1024), device=args.device)
@@ -62,24 +70,45 @@ def benchmark(args):
6270
backward,
6371
)
6472
if not args.e2e_only:
65-
run_profile(input, indices, dim, cache_r, cache_w, args.device, args.num_iter)
73+
run_profile(
74+
input,
75+
indices,
76+
dim,
77+
cache_r,
78+
cache_w,
79+
args.device,
80+
args.num_iter,
81+
)
6682

6783
if not args.profile_only:
68-
run_e2e(input, indices, dim, cache_r, cache_w, args.device, args.num_iter)
84+
run_e2e(
85+
input,
86+
indices,
87+
dim,
88+
cache_r,
89+
cache_w,
90+
args.device,
91+
args.num_iter,
92+
)
93+
6994

7095
def parse_args():
71-
parser = argparse.ArgumentParser(description='OP Benchmark')
72-
parser.add_argument('--device', type=str, default='xpu',
73-
help='Device to run on (e.g., "cpu", "cuda", "xpu")')
96+
parser = argparse.ArgumentParser(description="OP Benchmark")
97+
parser.add_argument(
98+
"--device",
99+
type=str,
100+
default="xpu",
101+
help='Device to run on (e.g., "cpu", "cuda", "xpu")',
102+
)
74103
group = parser.add_mutually_exclusive_group()
75-
group.add_argument('--profile-only', action='store_true',
76-
help='Only Run profile timing')
77-
group.add_argument('--e2e-only', action='store_true',
78-
help='Only Run E2E timing')
79-
parser.add_argument('--num-iter', type=int, default=20,
80-
help='Number of iterations')
104+
group.add_argument(
105+
"--profile-only", action="store_true", help="Only Run profile timing"
106+
)
107+
group.add_argument("--e2e-only", action="store_true", help="Only Run E2E timing")
108+
parser.add_argument("--num-iter", type=int, default=20, help="Number of iterations")
81109
return parser.parse_args()
82110

111+
83112
if __name__ == "__main__":
84113
args = parse_args()
85114
benchmark(args)

0 commit comments

Comments
 (0)