Skip to content

Commit 807a33e

Browse files
committed
Add simple examples for torch and mlx optimization
1 parent 9644c61 commit 807a33e

File tree

5 files changed

+737
-0
lines changed

5 files changed

+737
-0
lines changed

examples/simple-mlx/evaluate.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import time
2+
import sys
3+
import pathlib
4+
import importlib
5+
import traceback
6+
import mlx.core as mx
7+
import mlx.nn as nn
8+
9+
10+
########################################################
11+
# Baseline
12+
########################################################
13+
class Model(nn.Module):
14+
"""
15+
Model that performs a matrix multiplication, division, summation, and scaling.
16+
"""
17+
18+
def __init__(self, input_size, hidden_size, scaling_factor):
19+
super(Model, self).__init__()
20+
self.weight = mx.random.normal(shape=(hidden_size, input_size))
21+
self.scaling_factor = scaling_factor
22+
23+
def __call__(self, x):
24+
"""
25+
Args:
26+
x (mx.array): Input tensor of shape (batch_size, input_size).
27+
Returns:
28+
mx.array: Output tensor of shape (batch_size, hidden_size).
29+
"""
30+
x = mx.matmul(x, mx.transpose(self.weight)) # Gemm
31+
x = x / 2 # Divide
32+
x = mx.sum(x, axis=1, keepdims=True) # Sum
33+
x = x * self.scaling_factor # Scaling
34+
return x
35+
36+
37+
########################################################
38+
# Weco Solution
39+
########################################################
40+
def load_module_from_path(module_path: str, add_to_sys_modules: bool = False):
41+
# Clean out all old compiled extensions to prevent namespace collisions during build
42+
module_path = pathlib.Path(module_path)
43+
name = module_path.stem
44+
spec = importlib.util.spec_from_file_location(name, module_path)
45+
mod = importlib.util.module_from_spec(spec) # type: ignore
46+
if add_to_sys_modules:
47+
sys.modules[name] = mod
48+
spec.loader.exec_module(mod) # type: ignore
49+
return mod
50+
51+
52+
########################################################
53+
# Benchmark
54+
########################################################
55+
def get_inputs(B, N):
56+
# MLX doesn't use device parameter like PyTorch, as it automatically uses Metal
57+
return mx.random.normal(shape=(B, N), dtype=mx.float32)
58+
59+
60+
def bench(f, inputs, n_warmup, n_rep):
61+
# Warm up
62+
for _ in range(n_warmup):
63+
result = f(inputs)
64+
mx.eval(result) # Force computation due to lazy evaluation
65+
66+
t_avg = 0.0
67+
for _ in range(n_rep):
68+
# Clear cache before timing
69+
mx.metal.clear_cache()
70+
71+
start_time = time.time()
72+
result = f(inputs)
73+
mx.eval(result) # Force computation
74+
mx.synchronize() # Wait for all computations to complete
75+
t_avg += time.time() - start_time
76+
77+
t_avg /= n_rep * 1e-3
78+
return t_avg
79+
80+
81+
if __name__ == "__main__":
82+
import argparse
83+
84+
parser = argparse.ArgumentParser()
85+
parser.add_argument("--solution-path", type=str, required=True)
86+
args = parser.parse_args()
87+
88+
# init and input parameters
89+
B, N, H, S = 128, 10, 20, 1.5
90+
91+
# Set the default device to 0
92+
mx.set_default_device(mx.gpu)
93+
94+
# load solution module
95+
try:
96+
mx.random.seed(0)
97+
solution_module = load_module_from_path(args.solution_path, add_to_sys_modules=False)
98+
solution_model = solution_module.Model(N, H, S)
99+
assert hasattr(solution_model, "__call__")
100+
except Exception:
101+
print(f"Candidate module initialization failed: {traceback.format_exc()}")
102+
exit(1)
103+
104+
mx.random.seed(0)
105+
baseline_model = Model(N, H, S)
106+
107+
# measure correctness
108+
n_correctness_trials = 10
109+
max_diff_avg = 0
110+
for _ in range(n_correctness_trials):
111+
inputs = get_inputs(B, N)
112+
baseline_output = baseline_model(inputs)
113+
optimized_output = solution_model(inputs)
114+
max_diff = mx.max(mx.abs(optimized_output - baseline_output))
115+
mx.eval(max_diff) # Force computation
116+
max_diff_avg += max_diff.item() # Convert to Python scalar
117+
max_diff_avg /= n_correctness_trials
118+
print(f"max float diff between values of baseline and optimized model: {max_diff_avg}")
119+
120+
# measure performance
121+
inputs = get_inputs(B, N)
122+
n_warmup = 100
123+
n_rep = 500
124+
125+
# baseline
126+
t_avg_baseline = bench(baseline_model, inputs, n_warmup, n_rep)
127+
print(f"baseline time: {t_avg_baseline:.2f}ms")
128+
129+
# optimized
130+
t_avg_optimized = bench(solution_model, inputs, n_warmup, n_rep)
131+
print(f"optimized time: {t_avg_optimized:.2f}ms")
132+
133+
print(f"speedup: {t_avg_baseline / t_avg_optimized:.2f}x")

0 commit comments

Comments
 (0)