Skip to content

Commit 2195892

Browse files
malfetamathewc
authored andcommitted
[MPS] Test bf16 perf of few unary and binary ops (pytorch#150382)
Pull Request resolved: pytorch#150382 Approved by: https://github.com/Skylion007
1 parent b309b30 commit 2195892

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

test/bench_mps_ops.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ def bench_binary(
7272

7373
def main() -> None:
7474
dtypes = [torch.float16, torch.float32]
75+
if torch.backends.mps.is_macos_or_newer(14, 0):
76+
dtypes.append(torch.bfloat16)
7577
# Profile unary ops
7678
rc = []
7779
for op, dtype in itertools.product([torch.sqrt, torch.sin], dtypes):
@@ -83,8 +85,8 @@ def main() -> None:
8385
ops = [torch.fmax, torch.add]
8486
for op, dtype in itertools.product(ops, dtypes):
8587
rc.extend(bench_binary(op, dt_a=dtype))
86-
for op in ops:
87-
rc.extend(bench_binary(op, dt_b=torch.float16))
88+
if dtype == torch.float32:
89+
rc.extend(bench_binary(op, dt_b=torch.float16))
8890
Compare(rc).print()
8991

9092

0 commit comments

Comments
 (0)