Skip to content

Commit 3e53432

Browse files
authored
[backends] add mojo matmul (meta-pytorch#461)
1 parent b39e28e commit 3e53432

File tree

1 file changed

+120
-0
lines changed

1 file changed

+120
-0
lines changed

benchmarks/mojo_matmul/run.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
"""
2+
Benchmark mojo_matmul with modular nightly.
3+
To install modular nightly:
4+
pip install --pre modular --index-url https://dl.modular.com/public/nightly/python/simple/
5+
"""
6+
7+
import argparse
8+
import json
9+
import logging
10+
import os
11+
import sys
12+
13+
from os.path import abspath, exists
14+
from typing import Dict, List
15+
16+
17+
def setup_tritonbench_cwd():
18+
original_dir = abspath(os.getcwd())
19+
20+
for tritonbench_dir in (
21+
".",
22+
"../../../tritonbench",
23+
):
24+
if exists(tritonbench_dir):
25+
break
26+
27+
if exists(tritonbench_dir):
28+
tritonbench_dir = abspath(tritonbench_dir)
29+
os.chdir(tritonbench_dir)
30+
sys.path.append(tritonbench_dir)
31+
return original_dir
32+
33+
setup_tritonbench_cwd()
34+
35+
import torch
36+
import max.graph as mg
37+
38+
from max import engine, driver
39+
from max.graph import TensorValue, ops, DeviceRef, TensorType, Graph
40+
from max.graph.type import Shape, ShapeLike, DType
41+
42+
from tritonbench.operators import load_opbench_by_name
43+
from tritonbench.utils.triton_op import register_benchmark
44+
from tritonbench.utils.parser import get_parser
45+
46+
from typing import Callable
47+
48+
def promote_mojo_tensor_to_fp32(mojo_tensor, dtype):
49+
input_type = TensorType(dtype=dtype, shape=mojo_tensor.shape, device=DeviceRef.GPU())
50+
with mg.Graph("mojo_to_fp32", input_types=(input_type, )) as g:
51+
(inp, ) = g.inputs
52+
out = ops.cast(inp, dtype=DType.float32)
53+
g.output(out)
54+
session = engine.InferenceSession(devices=[driver.Accelerator()])
55+
model = session.load(g)
56+
output = model.execute(mojo_tensor)
57+
return output
58+
59+
def demote_numpy_to_mojo_tensor_dtype(numpy_array, dtype):
60+
with mg.Graph("mojo_to_dtype") as g:
61+
inp = ops.constant(numpy_array, dtype=DType.float32, device=DeviceRef.GPU())
62+
out = ops.cast(inp, dtype=dtype)
63+
g.output(out)
64+
session = engine.InferenceSession(devices=[driver.Accelerator()])
65+
model = session.load(g)
66+
output = model.execute()
67+
return output[0]
68+
69+
MOJO_DTYPE_MAPPING = {
70+
"bf16": DType.bfloat16,
71+
"fp32": DType.float32,
72+
"fp16": DType.float16,
73+
}
74+
MOJO_DEVICE_MAPPING = {
75+
"cuda": DeviceRef.GPU,
76+
"cpu": DeviceRef.CPU,
77+
}
78+
MOJO_DRIVER_DEVICE_MAPPING = {
79+
"cuda": driver.Accelerator,
80+
"cpu": driver.CPU,
81+
}
82+
83+
def mojo_matmul(operator, a, b, bias) -> Callable:
84+
precision = operator.precision
85+
device = operator.device
86+
mojo_dtype = MOJO_DTYPE_MAPPING[precision]
87+
mojo_device = MOJO_DEVICE_MAPPING[device]
88+
mojo_driver_device = MOJO_DRIVER_DEVICE_MAPPING[device]
89+
a_numpy = a.cpu().float().numpy()
90+
b_numpy = b.T.cpu().float().numpy()
91+
a_mojo_cuda = driver.Tensor.from_numpy(a_numpy).to(mojo_driver_device())
92+
b_mojo_cuda = driver.Tensor.from_numpy(b_numpy).to(mojo_driver_device())
93+
a_mojo_bf16 = demote_numpy_to_mojo_tensor_dtype(a_numpy, mojo_dtype)
94+
b_mojo_bf16 = demote_numpy_to_mojo_tensor_dtype(b_numpy, mojo_dtype)
95+
input_types = (
96+
TensorType(dtype=mojo_dtype, shape=a_numpy.shape, device=mojo_device()),
97+
TensorType(dtype=mojo_dtype, shape=b_numpy.shape, device=mojo_device()),
98+
)
99+
with mg.Graph("mojo_matmul", input_types=input_types) as g:
100+
a_val, b_val = g.inputs
101+
c_val = ops.matmul(a_val, b_val.T)
102+
g.output(c_val)
103+
session = engine.InferenceSession(devices=[driver.Accelerator()])
104+
model = session.load(g)
105+
outputs = model.execute(a_mojo_bf16, b_mojo_bf16)
106+
output_func = lambda: model.execute(a_mojo_bf16, b_mojo_bf16)
107+
return output_func
108+
109+
if __name__ == "__main__":
110+
args = ["--op", "gemm", "--only", "aten_matmul,mojo_matmul", "--precision", "bf16", "--m", "512", "--n", "8192", "--k", "5376"] + sys.argv[1:]
111+
gemm_opbench_cls = load_opbench_by_name("gemm")
112+
parser = get_parser(args)
113+
tb_args, extra_args = parser.parse_known_args(args)
114+
gemm_opbench = gemm_opbench_cls(tb_args, extra_args)
115+
gemm_opbench.add_benchmark(bm_func_name="mojo_matmul", bm_callable=mojo_matmul)
116+
gemm_opbench.run()
117+
metrics = gemm_opbench.output
118+
print(metrics)
119+
# TODO: promote the output to fp32 for numerics check
120+
# y_torch = torch.from_numpy(promote_mojo_tensor_to_fp32(outputs[0], dtype=DType.bfloat16)[0].to_numpy())

0 commit comments

Comments
 (0)