|
| 1 | +import time |
| 2 | +import sys |
| 3 | +import pathlib |
| 4 | +import importlib |
| 5 | +import traceback |
| 6 | +import mlx.core as mx |
| 7 | +import mlx.nn as nn |
| 8 | + |
| 9 | + |
| 10 | +######################################################## |
| 11 | +# Baseline |
| 12 | +######################################################## |
| 13 | +class Model(nn.Module): |
| 14 | + """ |
| 15 | + Model that performs a matrix multiplication, division, summation, and scaling. |
| 16 | + """ |
| 17 | + |
| 18 | + def __init__(self, input_size, hidden_size, scaling_factor): |
| 19 | + super(Model, self).__init__() |
| 20 | + self.weight = mx.random.normal(shape=(hidden_size, input_size)) |
| 21 | + self.scaling_factor = scaling_factor |
| 22 | + |
| 23 | + def __call__(self, x): |
| 24 | + """ |
| 25 | + Args: |
| 26 | + x (mx.array): Input tensor of shape (batch_size, input_size). |
| 27 | + Returns: |
| 28 | + mx.array: Output tensor of shape (batch_size, hidden_size). |
| 29 | + """ |
| 30 | + x = mx.matmul(x, mx.transpose(self.weight)) # Gemm |
| 31 | + x = x / 2 # Divide |
| 32 | + x = mx.sum(x, axis=1, keepdims=True) # Sum |
| 33 | + x = x * self.scaling_factor # Scaling |
| 34 | + return x |
| 35 | + |
| 36 | + |
| 37 | +######################################################## |
| 38 | +# Weco Solution |
| 39 | +######################################################## |
| 40 | +def load_module_from_path(module_path: str, add_to_sys_modules: bool = False): |
| 41 | + # Clean out all old compiled extensions to prevent namespace collisions during build |
| 42 | + module_path = pathlib.Path(module_path) |
| 43 | + name = module_path.stem |
| 44 | + spec = importlib.util.spec_from_file_location(name, module_path) |
| 45 | + mod = importlib.util.module_from_spec(spec) # type: ignore |
| 46 | + if add_to_sys_modules: |
| 47 | + sys.modules[name] = mod |
| 48 | + spec.loader.exec_module(mod) # type: ignore |
| 49 | + return mod |
| 50 | + |
| 51 | + |
| 52 | +######################################################## |
| 53 | +# Benchmark |
| 54 | +######################################################## |
| 55 | +def get_inputs(B, N): |
| 56 | + # MLX doesn't use device parameter like PyTorch, as it automatically uses Metal |
| 57 | + return mx.random.normal(shape=(B, N), dtype=mx.float32) |
| 58 | + |
| 59 | + |
| 60 | +def bench(f, inputs, n_warmup, n_rep): |
| 61 | + # Warm up |
| 62 | + for _ in range(n_warmup): |
| 63 | + result = f(inputs) |
| 64 | + mx.eval(result) # Force computation due to lazy evaluation |
| 65 | + |
| 66 | + t_avg = 0.0 |
| 67 | + for _ in range(n_rep): |
| 68 | + # Clear cache before timing |
| 69 | + mx.metal.clear_cache() |
| 70 | + |
| 71 | + start_time = time.time() |
| 72 | + result = f(inputs) |
| 73 | + mx.eval(result) # Force computation |
| 74 | + mx.synchronize() # Wait for all computations to complete |
| 75 | + t_avg += time.time() - start_time |
| 76 | + |
| 77 | + t_avg /= n_rep * 1e-3 |
| 78 | + return t_avg |
| 79 | + |
| 80 | + |
| 81 | +if __name__ == "__main__": |
| 82 | + import argparse |
| 83 | + |
| 84 | + parser = argparse.ArgumentParser() |
| 85 | + parser.add_argument("--solution-path", type=str, required=True) |
| 86 | + args = parser.parse_args() |
| 87 | + |
| 88 | + # init and input parameters |
| 89 | + B, N, H, S = 128, 10, 20, 1.5 |
| 90 | + |
| 91 | + # Set the default device to 0 |
| 92 | + mx.set_default_device(mx.gpu) |
| 93 | + |
| 94 | + # load solution module |
| 95 | + try: |
| 96 | + mx.random.seed(0) |
| 97 | + solution_module = load_module_from_path(args.solution_path, add_to_sys_modules=False) |
| 98 | + solution_model = solution_module.Model(N, H, S) |
| 99 | + assert hasattr(solution_model, "__call__") |
| 100 | + except Exception: |
| 101 | + print(f"Candidate module initialization failed: {traceback.format_exc()}") |
| 102 | + exit(1) |
| 103 | + |
| 104 | + mx.random.seed(0) |
| 105 | + baseline_model = Model(N, H, S) |
| 106 | + |
| 107 | + # measure correctness |
| 108 | + n_correctness_trials = 10 |
| 109 | + max_diff_avg = 0 |
| 110 | + for _ in range(n_correctness_trials): |
| 111 | + inputs = get_inputs(B, N) |
| 112 | + baseline_output = baseline_model(inputs) |
| 113 | + optimized_output = solution_model(inputs) |
| 114 | + max_diff = mx.max(mx.abs(optimized_output - baseline_output)) |
| 115 | + mx.eval(max_diff) # Force computation |
| 116 | + max_diff_avg += max_diff.item() # Convert to Python scalar |
| 117 | + max_diff_avg /= n_correctness_trials |
| 118 | + print(f"max float diff between values of baseline and optimized model: {max_diff_avg}") |
| 119 | + |
| 120 | + # measure performance |
| 121 | + inputs = get_inputs(B, N) |
| 122 | + n_warmup = 100 |
| 123 | + n_rep = 500 |
| 124 | + |
| 125 | + # baseline |
| 126 | + t_avg_baseline = bench(baseline_model, inputs, n_warmup, n_rep) |
| 127 | + print(f"baseline time: {t_avg_baseline:.2f}ms") |
| 128 | + |
| 129 | + # optimized |
| 130 | + t_avg_optimized = bench(solution_model, inputs, n_warmup, n_rep) |
| 131 | + print(f"optimized time: {t_avg_optimized:.2f}ms") |
| 132 | + |
| 133 | + print(f"speedup: {t_avg_baseline / t_avg_optimized:.2f}x") |
0 commit comments