Skip to content

Commit e5197a4

Browse files
enhance pooling and loss related cases
1 parent d8b3966 commit e5197a4

17 files changed

+984
-670
lines changed

test/microbench/im2col.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@ def run_profile(shape, dtype, backward, device, num_iter):
2323
ProfilerActivity.XPU if device == 'xpu' else ProfilerActivity.CUDA],
2424
record_shapes=True,
2525
) as prof:
26-
for _ in range(num_iter):
26+
for i in range(num_iter):
2727
Im2col(shape, dtype, backward, device)
2828
print(prof.key_averages().table(sort_by="{}_time_total".format(device)))
2929

3030
def run_e2e(shape, dtype, backward, device, num_iter):
3131
if device in ['xpu', 'cuda']:
3232
torch.xpu.synchronize() if device == 'xpu' else torch.cuda.synchronize()
3333
t1 = time.time()
34-
for _ in range(num_iter):
34+
for i in range(num_iter):
3535
Im2col(shape, dtype, backward, device)
3636
if device in ['xpu', 'cuda']:
3737
torch.xpu.synchronize() if device == 'xpu' else torch.cuda.synchronize()
Lines changed: 71 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,11 @@
11
import time
2-
2+
import argparse
33
import torch
44
import torch.nn as nn
55
from torch.profiler import profile, ProfilerActivity
66

7-
device = "xpu"
8-
backward = True
9-
num_iter = 20
107
shape_list = [(8733, 8733), (8733, 513), (513, 8733), (8192, 8192)]
11-
12-
cache_r = torch.randn(1024 * 1024 * 1024, device=device)
13-
cache_w = torch.randn(1024 * 1024 * 1024, device=device)
8+
backward = True
149

1510

1611
def _do_test(loss, input, target, dtype, device):
@@ -20,51 +15,76 @@ def _do_test(loss, input, target, dtype, device):
2015

2116
return output, grad_inputs
2217

23-
24-
for shape in shape_list:
25-
for dtype in [torch.bfloat16, torch.float16, torch.float32]:
26-
M, N = shape[0], shape[1]
27-
input = torch.randn((M, N), requires_grad=True)
28-
target = torch.empty((M, N)).random_(2)
29-
for reduce in ["none", "mean", "sum"]:
30-
loss = nn.BCELoss(reduce=reduce)
31-
m = nn.Sigmoid()
32-
input = m(input).to(dtype=dtype, device=device)
33-
target = target.to(dtype=dtype, device=device)
34-
# warm up
18+
def run_profile(loss, input, target, dtype, backward, cache_r, cache_w, device, num_iter):
19+
with profile(
20+
activities=[ProfilerActivity.CPU,
21+
ProfilerActivity.XPU if device == 'xpu' else ProfilerActivity.CUDA],
22+
record_shapes=True,
23+
) as prof:
24+
for _ in range(num_iter):
25+
cache_r = cache_w + 1
3526
_do_test(loss, input, target, dtype, device)
27+
print(prof.key_averages().table(sort_by="{}_time_total".format(device)))
3628

37-
# go
38-
print(
39-
"shape:",
40-
(M, N),
41-
"; datatype:",
42-
dtype,
43-
"; reduce:",
44-
reduce,
45-
"; backward:",
46-
backward,
47-
)
48-
with profile(
49-
activities=[ProfilerActivity.CPU, ProfilerActivity.XPU],
50-
record_shapes=True,
51-
) as prof:
52-
for i in range(num_iter):
53-
cache_r = cache_w + 1
54-
output_xpu, grad_input_xpu = _do_test(
55-
loss, input, target, dtype, device
56-
)
57-
print(prof.key_averages().table(sort_by="xpu_time_total"))
29+
def run_e2e(loss, input, target, dtype, backward, cache_r, cache_w, device, num_iter):
30+
if device in ['xpu', 'cuda']:
31+
torch.xpu.synchronize() if device == 'xpu' else torch.cuda.synchronize()
32+
t1 = time.time()
33+
for _ in range(num_iter):
34+
cache_r = cache_w + 1
35+
_do_test(loss, input, target, dtype, device)
36+
if device in ['xpu', 'cuda']:
37+
torch.xpu.synchronize() if device == 'xpu' else torch.cuda.synchronize()
38+
t2 = time.time()
39+
e2e_time = (t2 - t1) / num_iter
40+
print("E2E total time:", f"{float(e2e_time):.20f}")
5841

59-
# E2E time
60-
torch.xpu.synchronize()
61-
t1 = time.time()
62-
for i in range(num_iter):
63-
cache_r = cache_w + 1
64-
output_xpu, grad_input_xpu = _do_test(
65-
loss, input, target, dtype, device
42+
def benchmark(args):
43+
for shape in shape_list:
44+
for dtype in [torch.bfloat16, torch.float16, torch.float32]:
45+
M, N = shape[0], shape[1]
46+
input = torch.randn((M, N), requires_grad=True)
47+
target = torch.empty((M, N)).random_(2)
48+
cache_r = torch.randn(1024 * 1024 * 1024, device=args.device)
49+
cache_w = torch.randn(1024 * 1024 * 1024, device=args.device)
50+
for reduce in ["none", "mean", "sum"]:
51+
loss = nn.BCELoss(reduce=reduce)
52+
m = nn.Sigmoid()
53+
input = m(input).to(dtype=dtype, device=args.device)
54+
target = target.to(dtype=dtype, device=args.device)
55+
# warm up
56+
_do_test(loss, input, target, dtype, args.device)
57+
58+
# go
59+
print(
60+
"shape:",
61+
(M, N),
62+
"; datatype:",
63+
dtype,
64+
"; reduce:",
65+
reduce,
66+
"; backward:",
67+
backward,
6668
)
67-
torch.xpu.synchronize()
68-
t2 = time.time()
69-
e2e_time = (t2 - t1) / num_iter
70-
print("E2E total time:", f"{float(e2e_time):.20f}")
69+
if not args.e2e_only:
70+
run_profile(loss, input, target, dtype, backward, cache_r, cache_w, args.device, args.num_iter)
71+
72+
if not args.profile_only:
73+
run_e2e(loss, input, target, dtype, backward, cache_r, cache_w, args.device, args.num_iter)
74+
75+
def parse_args():
76+
parser = argparse.ArgumentParser(description='OP Benchmark')
77+
parser.add_argument('--device', type=str, default='xpu',
78+
help='Device to run on (e.g., "cpu", "cuda", "xpu")')
79+
group = parser.add_mutually_exclusive_group()
80+
group.add_argument('--profile-only', action='store_true',
81+
help='Only Run profile timing')
82+
group.add_argument('--e2e-only', action='store_true',
83+
help='Only Run E2E timing')
84+
parser.add_argument('--num-iter', type=int, default=20,
85+
help='Number of iterations')
86+
return parser.parse_args()
87+
88+
if __name__ == "__main__":
89+
args = parse_args()
90+
benchmark(args)

test/microbench/loss.ctc_loss.py

Lines changed: 72 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,92 @@
11
import time
2-
2+
import argparse
33
import torch
44
from torch.profiler import profile, ProfilerActivity
55

6-
device = "xpu"
7-
backward = True
8-
num_iter = 20
96
# T,N,C,S
107
shape_list = [(32, 32, 32, 16), (128, 128, 128, 128), (8, 8, 4, 8)]
8+
backward = True
119

1210

13-
def _test_loss_ctc(log_probs, targets, input_lengths, target_lengths, dtype):
14-
log_probs_dpcpp = log_probs.to("xpu")
15-
log_probs_dpcpp.requires_grad_(True)
16-
targets_dpcpp = targets.to("xpu")
17-
input_lengths_dpcpp = input_lengths.to("xpu")
18-
target_lengths_dpcpp = target_lengths.to("xpu")
19-
20-
# warm up
11+
def _test_loss_ctc(log_probs, targets, input_lengths, target_lengths, backward):
2112
loss_dpcpp = torch.nn.functional.ctc_loss(
22-
log_probs_dpcpp, targets_dpcpp, input_lengths_dpcpp, target_lengths_dpcpp
13+
log_probs, targets, input_lengths, target_lengths
2314
)
24-
loss_dpcpp.backward()
15+
if backward:
16+
loss_dpcpp.backward()
2517

26-
# go
27-
print(
28-
"shape:",
29-
(shape[0], shape[1], shape[2], shape[3]),
30-
"; datatype:",
31-
dtype,
32-
"; backward:",
33-
backward,
34-
)
18+
def run_profile(log_probs, targets, input_lengths, target_lengths, backward, device, num_iter):
3519
with profile(
36-
activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], record_shapes=True
20+
activities=[ProfilerActivity.CPU,
21+
ProfilerActivity.XPU if device == 'xpu' else ProfilerActivity.CUDA],
22+
record_shapes=True,
3723
) as prof:
38-
for i in range(num_iter):
39-
loss_dpcpp = torch.nn.functional.ctc_loss(
40-
log_probs_dpcpp,
41-
targets_dpcpp,
42-
input_lengths_dpcpp,
43-
target_lengths_dpcpp,
44-
)
45-
loss_dpcpp.backward()
46-
print(prof.key_averages().table(sort_by="xpu_time_total"))
24+
for _ in range(num_iter):
25+
_test_loss_ctc(log_probs, targets, input_lengths, target_lengths, backward)
26+
print(prof.key_averages().table(sort_by="{}_time_total".format(device)))
4727

48-
# E2E time
49-
torch.xpu.synchronize()
28+
def run_e2e(log_probs, targets, input_lengths, target_lengths, backward, device, num_iter):
29+
if device in ['xpu', 'cuda']:
30+
torch.xpu.synchronize() if device == 'xpu' else torch.cuda.synchronize()
5031
t1 = time.time()
51-
for i in range(num_iter):
52-
loss_dpcpp = torch.nn.functional.ctc_loss(
53-
log_probs_dpcpp,
54-
targets_dpcpp,
55-
input_lengths_dpcpp,
56-
target_lengths_dpcpp,
57-
)
58-
loss_dpcpp.backward()
59-
torch.xpu.synchronize()
32+
for _ in range(num_iter):
33+
_test_loss_ctc(log_probs, targets, input_lengths, target_lengths, backward)
34+
if device in ['xpu', 'cuda']:
35+
torch.xpu.synchronize() if device == 'xpu' else torch.cuda.synchronize()
6036
t2 = time.time()
6137
e2e_time = (t2 - t1) / num_iter
6238
print("E2E total time:", f"{float(e2e_time):.20f}")
6339

40+
def benchmark(args):
41+
for shape in shape_list:
42+
for dtype in [torch.float32]:
43+
T, N, C, S = shape[0], shape[1], shape[2], shape[3]
44+
g_cpu = torch.Generator()
45+
g_cpu.manual_seed(15)
46+
torch.manual_seed(15)
47+
log_probs = (
48+
torch.randn(T, N, C, dtype=dtype, device=args.device).log_softmax(2).detach().requires_grad_()
49+
)
50+
targets = torch.randint(1, N, (N, S), dtype=torch.long, device=args.device)
51+
input_lengths = torch.full((N,), T, dtype=torch.long, device=args.device)
52+
target_lengths = torch.randint(1, S, (N,), dtype=torch.long, device=args.device)
53+
54+
if backward:
55+
log_probs.requires_grad_(True)
56+
57+
# warm up
58+
_test_loss_ctc(log_probs, targets, input_lengths, target_lengths, backward)
59+
# go
60+
print(
61+
"shape:",
62+
(shape[0], shape[1], shape[2], shape[3]),
63+
"; datatype:",
64+
dtype,
65+
"; backward:",
66+
backward,
67+
)
68+
if not args.e2e_only:
69+
run_profile(log_probs, targets, input_lengths, target_lengths, backward, args.device, args.num_iter)
70+
71+
if not args.profile_only:
72+
run_e2e(log_probs, targets, input_lengths, target_lengths, backward, args.device, args.num_iter)
73+
g_cpu = torch.Generator()
74+
g_cpu.manual_seed(15)
75+
torch.manual_seed(15)
76+
77+
def parse_args():
78+
parser = argparse.ArgumentParser(description='OP Benchmark')
79+
parser.add_argument('--device', type=str, default='xpu',
80+
help='Device to run on (e.g., "cpu", "cuda", "xpu")')
81+
group = parser.add_mutually_exclusive_group()
82+
group.add_argument('--profile-only', action='store_true',
83+
help='Only Run profile timing')
84+
group.add_argument('--e2e-only', action='store_true',
85+
help='Only Run E2E timing')
86+
parser.add_argument('--num-iter', type=int, default=20,
87+
help='Number of iterations')
88+
return parser.parse_args()
6489

65-
for shape in shape_list:
66-
for dtype in [torch.float32]:
67-
T, N, C, S = shape[0], shape[1], shape[2], shape[3]
68-
g_cpu = torch.Generator()
69-
g_cpu.manual_seed(15)
70-
torch.manual_seed(15)
71-
log_probs = (
72-
torch.randn(T, N, C, dtype=dtype).log_softmax(2).detach().requires_grad_()
73-
)
74-
targets = torch.randint(1, N, (N, S), dtype=torch.long, generator=g_cpu)
75-
input_lengths = torch.full((N,), T, dtype=torch.long)
76-
target_lengths = torch.randint(1, S, (N,), dtype=torch.long, generator=g_cpu)
77-
_test_loss_ctc(log_probs, targets, input_lengths, target_lengths, dtype)
78-
g_cpu = torch.Generator()
79-
g_cpu.manual_seed(15)
80-
torch.manual_seed(15)
90+
if __name__ == "__main__":
91+
args = parse_args()
92+
benchmark(args)

0 commit comments

Comments
 (0)