|
| 1 | +import time |
| 2 | +import torch |
| 3 | +import inspect |
| 4 | +import itertools |
| 5 | + |
| 6 | +from functorch import pointwise_operator |
| 7 | + |
| 8 | +torch.set_num_threads(1) |
| 9 | +torch._C._debug_set_fusion_group_inlining(False) |
| 10 | + |
| 11 | +def rand(*shape): |
| 12 | + return torch.rand(*shape).mul(16).add(1) |
| 13 | + |
| 14 | + |
| 15 | +# ------------------------------------------------------------------------------ |
| 16 | +# Shape test cases |
| 17 | +# ------------------------------------------------------------------------------ |
| 18 | +def scalar(): |
| 19 | + return (rand(1), rand(1)) |
| 20 | + |
| 21 | +def small(): |
| 22 | + return (rand(32), rand(32)) |
| 23 | + |
| 24 | +def small_2d(): |
| 25 | + return (rand(1, 32), rand(1, 32)) |
| 26 | + |
| 27 | +def small_broadcast(): |
| 28 | + return (rand(4, 32), rand(32)) |
| 29 | + |
| 30 | +def medium(): |
| 31 | + return (rand(32, 12, 64, 64), rand(32, 12, 64, 64)) |
| 32 | + |
| 33 | +def medium_sliced(): |
| 34 | + return (rand(32, 12, 64, 64)[..., ::2], |
| 35 | + rand(32, 12, 64, 64)[..., ::2]) |
| 36 | + |
| 37 | +def medium_transpose(): |
| 38 | + return (rand(32, 12, 64, 64).transpose(-1, -2), |
| 39 | + rand(32, 12, 64, 64).transpose(-1, -2)) |
| 40 | + |
| 41 | +def medium_channels_last(): |
| 42 | + return (rand(32, 3, 224, 224).to(memory_format=torch.channels_last), |
| 43 | + rand(32, 3, 224, 224).to(memory_format=torch.channels_last)) |
| 44 | + |
| 45 | +def medium_broadcast(): |
| 46 | + return (rand(32, 12, 64, 64), rand(64)) |
| 47 | + |
| 48 | +def medium_broadcast_channels_last(): |
| 49 | + return (rand(32, 3, 223, 223).to(memory_format=torch.channels_last), |
| 50 | + rand(3, 1, 1)) |
| 51 | + |
| 52 | +def large(): |
| 53 | + return (rand(8192, 8192), rand(8192, 8192)) |
| 54 | + |
| 55 | +def large_transpose(): |
| 56 | + return (rand(8192, 8192).transpose(0, 1), |
| 57 | + rand(8192, 8192).transpose(0, 1)) |
| 58 | + |
| 59 | +def pathological_broadcast(): |
| 60 | + return (rand(1, 32, 32, 2), rand(1024, 1, 1, 2)) |
| 61 | + |
| 62 | +# ------------------------------------------------------------------------------ |
| 63 | +# Operator test cases |
| 64 | +# ------------------------------------------------------------------------------ |
| 65 | +def add(a, b): |
| 66 | + return a + b |
| 67 | + |
| 68 | +def sub(a, b): |
| 69 | + return a - b |
| 70 | + |
| 71 | +def mul(a, b): |
| 72 | + return a * b |
| 73 | + |
| 74 | +def div(a, b): |
| 75 | + return a / b |
| 76 | + |
| 77 | +def relu(a): |
| 78 | + return a.relu() |
| 79 | + |
| 80 | +def sigmoid(a): |
| 81 | + return a.sigmoid() |
| 82 | + |
| 83 | +def tanh(a): |
| 84 | + return a.tanh() |
| 85 | + |
| 86 | +def log(a): |
| 87 | + return a.log() |
| 88 | + |
| 89 | +def exp(a): |
| 90 | + return a.exp() |
| 91 | + |
| 92 | +def pow(a): |
| 93 | + return a ** 2 |
| 94 | + |
| 95 | +def fma(a, b): |
| 96 | + return a * b + b |
| 97 | + |
| 98 | +def hardswish(a): |
| 99 | + return a * (a + 3).clamp(0, 6) / 6 |
| 100 | + |
| 101 | +def native_hardswish(a): |
| 102 | + return torch._C._nn.hardswish(a) |
| 103 | + |
| 104 | +def softplus(a): |
| 105 | + return (a * 1.0).exp().log1p() / 1.0 |
| 106 | + |
| 107 | +def mish(a): |
| 108 | + return a * ((a * 1.0).exp().log1p() / 1.0).tanh() |
| 109 | + |
| 110 | +shapes = [ |
| 111 | + scalar, |
| 112 | + small, |
| 113 | + small_2d, |
| 114 | + small_broadcast, |
| 115 | + medium, |
| 116 | + medium_sliced, |
| 117 | + medium_transpose, |
| 118 | + medium_channels_last, |
| 119 | + medium_broadcast, |
| 120 | + medium_broadcast_channels_last, |
| 121 | + large, |
| 122 | + large_transpose, |
| 123 | + pathological_broadcast, |
| 124 | +] |
| 125 | + |
| 126 | +operators = [ |
| 127 | + add, |
| 128 | + sub, |
| 129 | + mul, |
| 130 | + div, |
| 131 | + relu, |
| 132 | + sigmoid, |
| 133 | + tanh, |
| 134 | + log, |
| 135 | + exp, |
| 136 | + pow, |
| 137 | + fma, |
| 138 | + hardswish, |
| 139 | + native_hardswish, |
| 140 | +] |
| 141 | +#shapes = [large_transpose] |
| 142 | +#operators = [add] |
| 143 | +#shapes = [scalar] |
| 144 | +#operators = [add] |
| 145 | +nope = set() |
| 146 | +for shape, operator in itertools.product(shapes, operators): |
| 147 | + nargs = len(inspect.signature(operator).parameters) |
| 148 | + args = shape()[:nargs] |
| 149 | + #print(f"{operator.__name__} {shape.__name__}") |
| 150 | + |
| 151 | + try: |
| 152 | + if shape == medium_transpose: |
| 153 | + raise RuntimeError("pointwise_operator hangs on medium_transpose") |
| 154 | + pw_op = pointwise_operator(operator) |
| 155 | + torch.testing.assert_allclose(operator(*args), pw_op(*args)) |
| 156 | + except Exception: |
| 157 | + print(f"pointwise_operator failed on {operator.__name__}, {shape.__name__}") |
| 158 | + nope.add((operator, shape)) |
| 159 | + |
| 160 | + ts_op = torch.jit.script(operator) |
| 161 | + torch.testing.assert_allclose(operator(*args), ts_op(*args)) |
| 162 | + |
| 163 | +def time_cpu(fn, args, iters): |
| 164 | + s = time.perf_counter() |
| 165 | + for _ in range(iters): |
| 166 | + fn(*args) |
| 167 | + e = time.perf_counter() |
| 168 | + return e - s |
| 169 | + |
| 170 | +def time_cuda(fn, args, iters): |
| 171 | + start = torch.cuda.Event(enable_timing=True) |
| 172 | + end = torch.cuda.Event(enable_timing=True) |
| 173 | + start.record() |
| 174 | + for _ in range(iters): |
| 175 | + fn(*args) |
| 176 | + end.record() |
| 177 | + torch.cuda.synchronize() |
| 178 | + return start.elapsed_time(end) / 1e3 |
| 179 | + |
| 180 | +def benchmark_with_timer(fn, args, timer): |
| 181 | + timer(fn, args, 3) |
| 182 | + calibration = timer(fn, args, 1) |
| 183 | + iters = int(1.0 / calibration) |
| 184 | + return timer(fn, args, iters) / iters |
| 185 | + |
| 186 | +def benchmark(fn, args): |
| 187 | + timer = time_cpu if args[0].device.type == "cpu" else time_cuda |
| 188 | + return benchmark_with_timer(fn, args, timer) |
| 189 | + |
| 190 | +def micros(s): |
| 191 | + return f"{s * 1e6:.1f}" |
| 192 | + |
| 193 | +results = [] |
| 194 | +for shape, operator in itertools.product(shapes, operators): |
| 195 | + nargs = len(inspect.signature(operator).parameters) |
| 196 | + args = shape()[:nargs] |
| 197 | + |
| 198 | + result = benchmark(operator, args) |
| 199 | + print(",".join(["eager", args[0].device.type, operator.__name__, shape.__name__, micros(result)])) |
| 200 | + try: |
| 201 | + if shape == medium_transpose: |
| 202 | + raise RuntimeError("pointwise_operator hangs on medium_transpose") |
| 203 | + if (operator, shape) in nope: |
| 204 | + raise RuntimeError("pointwise_operator fails on medium_transpose") |
| 205 | + pw_op = pointwise_operator(operator) |
| 206 | + result = benchmark(pw_op, args) |
| 207 | + print(",".join(["pointwise", args[0].device.type, operator.__name__, shape.__name__, micros(result)])) |
| 208 | + except Exception: |
| 209 | + #print(f"pointwise_operator failed on {operator.__name__}, {shape.__name__}") |
| 210 | + #nope.add((operator, shape)) |
| 211 | + print(",".join(["pointwise", args[0].device.type, operator.__name__, shape.__name__, micros(float("nan"))])) |
| 212 | + |
| 213 | + ts_op = torch.jit.script(operator) |
| 214 | + result = benchmark(ts_op, args) |
| 215 | + print(",".join(["fuser", args[0].device.type, operator.__name__, shape.__name__, micros(result)])) |
| 216 | + |
| 217 | +# cpu |
| 218 | +# parallel cpu |
| 219 | +# cuda |
| 220 | + |
| 221 | +# casts |
| 222 | + |
| 223 | +# inplace? |
0 commit comments