|
1 | 1 | import time
|
2 |
| - |
| 2 | +import argparse |
3 | 3 | import torch
|
4 | 4 | from torch.profiler import profile, ProfilerActivity
|
5 | 5 |
|
6 |
| -device = "xpu" |
7 |
| -backward = True |
8 |
| -num_iter = 20 |
9 | 6 | # T,N,C,S
|
10 | 7 | shape_list = [(32, 32, 32, 16), (128, 128, 128, 128), (8, 8, 4, 8)]
|
| 8 | +backward = True |
11 | 9 |
|
12 | 10 |
|
13 |
| -def _test_loss_ctc(log_probs, targets, input_lengths, target_lengths, dtype): |
14 |
| - log_probs_dpcpp = log_probs.to("xpu") |
15 |
| - log_probs_dpcpp.requires_grad_(True) |
16 |
| - targets_dpcpp = targets.to("xpu") |
17 |
| - input_lengths_dpcpp = input_lengths.to("xpu") |
18 |
| - target_lengths_dpcpp = target_lengths.to("xpu") |
19 |
| - |
20 |
| - # warm up |
| 11 | +def _test_loss_ctc(log_probs, targets, input_lengths, target_lengths, backward): |
21 | 12 | loss_dpcpp = torch.nn.functional.ctc_loss(
|
22 |
| - log_probs_dpcpp, targets_dpcpp, input_lengths_dpcpp, target_lengths_dpcpp |
| 13 | + log_probs, targets, input_lengths, target_lengths |
23 | 14 | )
|
24 |
| - loss_dpcpp.backward() |
| 15 | + if backward: |
| 16 | + loss_dpcpp.backward() |
25 | 17 |
|
26 |
| - # go |
27 |
| - print( |
28 |
| - "shape:", |
29 |
| - (shape[0], shape[1], shape[2], shape[3]), |
30 |
| - "; datatype:", |
31 |
| - dtype, |
32 |
| - "; backward:", |
33 |
| - backward, |
34 |
| - ) |
| 18 | +def run_profile(log_probs, targets, input_lengths, target_lengths, backward, device, num_iter): |
35 | 19 | with profile(
|
36 |
| - activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], record_shapes=True |
| 20 | + activities=[ProfilerActivity.CPU, |
| 21 | + ProfilerActivity.XPU if device == 'xpu' else ProfilerActivity.CUDA], |
| 22 | + record_shapes=True, |
37 | 23 | ) as prof:
|
38 |
| - for i in range(num_iter): |
39 |
| - loss_dpcpp = torch.nn.functional.ctc_loss( |
40 |
| - log_probs_dpcpp, |
41 |
| - targets_dpcpp, |
42 |
| - input_lengths_dpcpp, |
43 |
| - target_lengths_dpcpp, |
44 |
| - ) |
45 |
| - loss_dpcpp.backward() |
46 |
| - print(prof.key_averages().table(sort_by="xpu_time_total")) |
| 24 | + for _ in range(num_iter): |
| 25 | + _test_loss_ctc(log_probs, targets, input_lengths, target_lengths, backward) |
| 26 | + print(prof.key_averages().table(sort_by="{}_time_total".format(device))) |
47 | 27 |
|
48 |
| - # E2E time |
49 |
| - torch.xpu.synchronize() |
| 28 | +def run_e2e(log_probs, targets, input_lengths, target_lengths, backward, device, num_iter): |
| 29 | + if device in ['xpu', 'cuda']: |
| 30 | + torch.xpu.synchronize() if device == 'xpu' else torch.cuda.synchronize() |
50 | 31 | t1 = time.time()
|
51 |
| - for i in range(num_iter): |
52 |
| - loss_dpcpp = torch.nn.functional.ctc_loss( |
53 |
| - log_probs_dpcpp, |
54 |
| - targets_dpcpp, |
55 |
| - input_lengths_dpcpp, |
56 |
| - target_lengths_dpcpp, |
57 |
| - ) |
58 |
| - loss_dpcpp.backward() |
59 |
| - torch.xpu.synchronize() |
| 32 | + for _ in range(num_iter): |
| 33 | + _test_loss_ctc(log_probs, targets, input_lengths, target_lengths, backward) |
| 34 | + if device in ['xpu', 'cuda']: |
| 35 | + torch.xpu.synchronize() if device == 'xpu' else torch.cuda.synchronize() |
60 | 36 | t2 = time.time()
|
61 | 37 | e2e_time = (t2 - t1) / num_iter
|
62 | 38 | print("E2E total time:", f"{float(e2e_time):.20f}")
|
63 | 39 |
|
| 40 | +def benchmark(args): |
| 41 | + for shape in shape_list: |
| 42 | + for dtype in [torch.float32]: |
| 43 | + T, N, C, S = shape[0], shape[1], shape[2], shape[3] |
| 44 | + g_cpu = torch.Generator() |
| 45 | + g_cpu.manual_seed(15) |
| 46 | + torch.manual_seed(15) |
| 47 | + log_probs = ( |
| 48 | + torch.randn(T, N, C, dtype=dtype, device=args.device).log_softmax(2).detach().requires_grad_() |
| 49 | + ) |
| 50 | + targets = torch.randint(1, N, (N, S), dtype=torch.long, device=args.device) |
| 51 | + input_lengths = torch.full((N,), T, dtype=torch.long, device=args.device) |
| 52 | + target_lengths = torch.randint(1, S, (N,), dtype=torch.long, device=args.device) |
| 53 | + |
| 54 | + if backward: |
| 55 | + log_probs.requires_grad_(True) |
| 56 | + |
| 57 | + # warm up |
| 58 | + _test_loss_ctc(log_probs, targets, input_lengths, target_lengths, backward) |
| 59 | + # go |
| 60 | + print( |
| 61 | + "shape:", |
| 62 | + (shape[0], shape[1], shape[2], shape[3]), |
| 63 | + "; datatype:", |
| 64 | + dtype, |
| 65 | + "; backward:", |
| 66 | + backward, |
| 67 | + ) |
| 68 | + if not args.e2e_only: |
| 69 | + run_profile(log_probs, targets, input_lengths, target_lengths, backward, args.device, args.num_iter) |
| 70 | + |
| 71 | + if not args.profile_only: |
| 72 | + run_e2e(log_probs, targets, input_lengths, target_lengths, backward, args.device, args.num_iter) |
| 73 | + g_cpu = torch.Generator() |
| 74 | + g_cpu.manual_seed(15) |
| 75 | + torch.manual_seed(15) |
| 76 | + |
| 77 | +def parse_args(): |
| 78 | + parser = argparse.ArgumentParser(description='OP Benchmark') |
| 79 | + parser.add_argument('--device', type=str, default='xpu', |
| 80 | + help='Device to run on (e.g., "cpu", "cuda", "xpu")') |
| 81 | + group = parser.add_mutually_exclusive_group() |
| 82 | + group.add_argument('--profile-only', action='store_true', |
| 83 | + help='Only Run profile timing') |
| 84 | + group.add_argument('--e2e-only', action='store_true', |
| 85 | + help='Only Run E2E timing') |
| 86 | + parser.add_argument('--num-iter', type=int, default=20, |
| 87 | + help='Number of iterations') |
| 88 | + return parser.parse_args() |
64 | 89 |
|
65 |
| -for shape in shape_list: |
66 |
| - for dtype in [torch.float32]: |
67 |
| - T, N, C, S = shape[0], shape[1], shape[2], shape[3] |
68 |
| - g_cpu = torch.Generator() |
69 |
| - g_cpu.manual_seed(15) |
70 |
| - torch.manual_seed(15) |
71 |
| - log_probs = ( |
72 |
| - torch.randn(T, N, C, dtype=dtype).log_softmax(2).detach().requires_grad_() |
73 |
| - ) |
74 |
| - targets = torch.randint(1, N, (N, S), dtype=torch.long, generator=g_cpu) |
75 |
| - input_lengths = torch.full((N,), T, dtype=torch.long) |
76 |
| - target_lengths = torch.randint(1, S, (N,), dtype=torch.long, generator=g_cpu) |
77 |
| - _test_loss_ctc(log_probs, targets, input_lengths, target_lengths, dtype) |
78 |
| - g_cpu = torch.Generator() |
79 |
| - g_cpu.manual_seed(15) |
80 |
| - torch.manual_seed(15) |
| 90 | +if __name__ == "__main__": |
| 91 | + args = parse_args() |
| 92 | + benchmark(args) |
0 commit comments