|
| 1 | +""" |
| 2 | +Simple offline inference script |
| 3 | +
|
| 4 | +Example command: |
| 5 | +
|
| 6 | +single node: |
| 7 | + python scripts/generate.py |
| 8 | +
|
| 9 | +tensor parallel: |
| 10 | + https://ml-explore.github.io/mlx/build/html/usage/distributed.html#enabling-rdma |
| 11 | +
|
| 12 | + mlx.distributed_config --verbose \ |
| 13 | + --hosts macmini1,macmini2 \ |
| 14 | + --over thunderbolt --backend jaccl \ |
| 15 | + --auto-setup --output hosts.json |
| 16 | +
|
| 17 | + mlx.launch \ |
| 18 | + --backend jaccl \ |
| 19 | + --env MLX_METAL_FAST_SYNCH=1 \ |
| 20 | + --hostfile hosts.json \ |
| 21 | + scripts/generate.py |
| 22 | +""" |
| 23 | + |
| 24 | +import argparse |
| 25 | +import time |
| 26 | + |
| 27 | +import mlx.core as mx |
| 28 | + |
| 29 | +from parallax.server.cache_manager import CacheManager |
| 30 | +from parallax.server.request import InitialRequest |
| 31 | +from parallax.server.sampling.sampler import SamplingBatchInfo |
| 32 | +from parallax.server.sampling.sampling_params import SamplingParams |
| 33 | +from parallax.server.shard_loader import MLXModelLoader |
| 34 | +from parallax.utils.utils import create_causal_mask, get_layer_types |
| 35 | + |
| 36 | +tp_size = 1 |
| 37 | +tp_rank = 0 |
| 38 | + |
| 39 | + |
| 40 | +def print_rank(message): |
| 41 | + if tp_size == 1: |
| 42 | + print(message) |
| 43 | + else: |
| 44 | + print(f"[Rank {tp_rank}] {message}") |
| 45 | + |
| 46 | + |
| 47 | +def main(): |
| 48 | + parser = argparse.ArgumentParser(description="Simple offline inference script") |
| 49 | + parser.add_argument( |
| 50 | + "--model", type=str, default="Qwen/Qwen3-32B-MLX-4bit", help="Model path or HF repo" |
| 51 | + ) |
| 52 | + parser.add_argument("--prompt", type=str, default="Hi", help="Prompt for inference") |
| 53 | + parser.add_argument( |
| 54 | + "--max-tokens", type=int, default=1024, help="Maximum number of tokens to generate" |
| 55 | + ) |
| 56 | + parser.add_argument("--topk", type=int, default=1, help="Top-k sampling parameter") |
| 57 | + parser.add_argument("--temp", type=float, default=1.0, help="Temperature for sampling") |
| 58 | + args = parser.parse_args() |
| 59 | + |
| 60 | + # TP Initialization |
| 61 | + global tp_size, tp_rank |
| 62 | + group = mx.distributed.init() |
| 63 | + tp_rank = group.rank() |
| 64 | + tp_size = group.size() |
| 65 | + |
| 66 | + mx.set_wired_limit(mx.metal.device_info()["max_recommended_working_set_size"]) |
| 67 | + |
| 68 | + # 1. Load Model |
| 69 | + print_rank(f"Loading model from {args.model}...") |
| 70 | + |
| 71 | + loader = MLXModelLoader( |
| 72 | + args.model, |
| 73 | + ) |
| 74 | + model, config, tokenizer = loader.load() |
| 75 | + |
| 76 | + # 2. Initialize CacheManager |
| 77 | + num_layers = config.get("num_hidden_layers") |
| 78 | + num_kv_heads = config.get("num_key_value_heads") |
| 79 | + head_dim = config.get("head_dim") or config.get("hidden_size") // config.get( |
| 80 | + "num_attention_heads" |
| 81 | + ) |
| 82 | + |
| 83 | + # Check for DeepSeek style head dims |
| 84 | + qk_nope_head_dim = config.get("qk_nope_head_dim") |
| 85 | + qk_rope_head_dim = config.get("qk_rope_head_dim") |
| 86 | + if qk_nope_head_dim is not None and qk_rope_head_dim is not None: |
| 87 | + head_dim = qk_nope_head_dim + qk_rope_head_dim |
| 88 | + |
| 89 | + v_head_dim = config.get("v_head_dim") |
| 90 | + layer_types = get_layer_types(config, 0, num_layers) |
| 91 | + |
| 92 | + cache_manager = CacheManager( |
| 93 | + num_layers=num_layers, |
| 94 | + num_kv_heads=num_kv_heads // tp_size, # Shard heads |
| 95 | + head_dim=head_dim, |
| 96 | + dtype=model.dtype, |
| 97 | + block_size=32, |
| 98 | + cache_memory_fraction=0.1, |
| 99 | + head_dim_v=v_head_dim, |
| 100 | + layer_types=layer_types, |
| 101 | + ) |
| 102 | + |
| 103 | + # 3. Tokenize and Create Request |
| 104 | + messages = [{"role": "user", "content": args.prompt}] |
| 105 | + |
| 106 | + if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None: |
| 107 | + full_prompt = tokenizer.apply_chat_template( |
| 108 | + messages, tokenize=False, add_generation_prompt=True |
| 109 | + ) |
| 110 | + else: |
| 111 | + full_prompt = args.prompt |
| 112 | + |
| 113 | + prompt_tokens = tokenizer.encode(full_prompt) |
| 114 | + sampling_params = SamplingParams(temperature=args.temp, top_k=args.topk) |
| 115 | + request = InitialRequest( |
| 116 | + prompt=full_prompt, |
| 117 | + input_ids=prompt_tokens, |
| 118 | + sampling_params=sampling_params, |
| 119 | + max_new_tokens=args.max_tokens, |
| 120 | + ) |
| 121 | + |
| 122 | + eos_token_ids = [] |
| 123 | + if tokenizer.eos_token_id is not None: |
| 124 | + if isinstance(tokenizer.eos_token_id, list): |
| 125 | + eos_token_ids.extend(tokenizer.eos_token_id) |
| 126 | + else: |
| 127 | + eos_token_ids.append(tokenizer.eos_token_id) |
| 128 | + config_eos = config.get("eos_token_id") |
| 129 | + if config_eos is not None: |
| 130 | + if isinstance(config_eos, list): |
| 131 | + for e in config_eos: |
| 132 | + if e not in eos_token_ids: |
| 133 | + eos_token_ids.append(e) |
| 134 | + elif config_eos not in eos_token_ids: |
| 135 | + eos_token_ids.append(config_eos) |
| 136 | + |
| 137 | + eos_token_ids = set(eos_token_ids) |
| 138 | + |
| 139 | + # 4. Prefill |
| 140 | + print_rank(f"Full prompt:\n {full_prompt}") |
| 141 | + |
| 142 | + if tp_size > 1: |
| 143 | + mx.eval(mx.distributed.all_sum(mx.ones(1))) |
| 144 | + print_rank("Forced sync before prefill") |
| 145 | + |
| 146 | + success, _ = cache_manager.allocate_request(request.request_id, request.prompt_len) |
| 147 | + if not success: |
| 148 | + print_rank("Failed to allocate cache") |
| 149 | + return |
| 150 | + |
| 151 | + input_ids = mx.array([request.input_ids]) |
| 152 | + block_table = mx.array([cache_manager.get_block_table(request.request_id)], dtype=mx.int32) |
| 153 | + context_lengths = mx.array([request.prompt_len], dtype=mx.int32) |
| 154 | + |
| 155 | + block_size = cache_manager.block_size |
| 156 | + slot_mapping = [] |
| 157 | + for i in range(request.prompt_len): |
| 158 | + block_idx = i // block_size |
| 159 | + block_offset = i % block_size |
| 160 | + physical_block = cache_manager.get_block_table(request.request_id)[block_idx] |
| 161 | + slot_mapping.append(physical_block * block_size + block_offset) |
| 162 | + slot_mapping = mx.array(slot_mapping, dtype=mx.int64) |
| 163 | + |
| 164 | + mask = create_causal_mask(request.prompt_len, request.prompt_len, model.dtype) |
| 165 | + |
| 166 | + prefill_start = time.perf_counter() |
| 167 | + |
| 168 | + logits = model( |
| 169 | + input_ids, |
| 170 | + cache=cache_manager.get_caches(), |
| 171 | + mask=mask, |
| 172 | + block_tables=block_table, |
| 173 | + context_lengths=context_lengths, |
| 174 | + slot_mapping=slot_mapping, |
| 175 | + ) |
| 176 | + |
| 177 | + sampling_info = SamplingBatchInfo.from_reqs([request]) |
| 178 | + |
| 179 | + next_token_id = model.logits_to_tokens(logits, context_lengths, sampling_info) |
| 180 | + |
| 181 | + token_id = int(next_token_id[0]) |
| 182 | + request.commit_new_token(token_id) |
| 183 | + |
| 184 | + prefill_time = time.perf_counter() - prefill_start |
| 185 | + print_rank(f"Token 1 (Prefill) time: {prefill_time * 1000:.2f} ms") |
| 186 | + |
| 187 | + # 5. Decode Loop |
| 188 | + total_decode_time = 0 |
| 189 | + for i in range(args.max_tokens - 1): |
| 190 | + decode_step_start = time.perf_counter() |
| 191 | + |
| 192 | + success = cache_manager.append_slot(request.request_id) |
| 193 | + if not success: |
| 194 | + print_rank("\nOOM during decoding") |
| 195 | + break |
| 196 | + |
| 197 | + block_table = mx.array([cache_manager.get_block_table(request.request_id)], dtype=mx.int32) |
| 198 | + context_lengths = mx.array( |
| 199 | + [cache_manager.get_context_length(request.request_id)], dtype=mx.int32 |
| 200 | + ) |
| 201 | + logits = model( |
| 202 | + mx.expand_dims(next_token_id, axis=0), |
| 203 | + cache=cache_manager.get_caches(), |
| 204 | + mask=None, |
| 205 | + block_tables=block_table, |
| 206 | + context_lengths=context_lengths, |
| 207 | + ) |
| 208 | + |
| 209 | + next_token_id = model.logits_to_tokens(logits, mx.array([1]), sampling_info) |
| 210 | + |
| 211 | + token_id = int(next_token_id[0]) |
| 212 | + if token_id in eos_token_ids: |
| 213 | + break |
| 214 | + request.commit_new_token(token_id) |
| 215 | + |
| 216 | + decode_step_time = time.perf_counter() - decode_step_start |
| 217 | + total_decode_time += decode_step_time |
| 218 | + print_rank(f"Token {i + 2} time: {decode_step_time * 1000:.2f} ms") |
| 219 | + |
| 220 | + print_rank("\nGenerated Content:") |
| 221 | + print_rank(tokenizer.decode(request.output_ids)) |
| 222 | + |
| 223 | + # Summary Statistics |
| 224 | + prompt_tps = request.prompt_len / prefill_time |
| 225 | + generation_tps = len(request.output_ids) / total_decode_time if total_decode_time > 0 else 0 |
| 226 | + peak_mem = mx.get_peak_memory() / 1024**3 |
| 227 | + |
| 228 | + print_rank("-" * 20) |
| 229 | + print_rank(f"Prompt: {request.prompt_len} tokens, {prompt_tps:.3f} tokens-per-sec") |
| 230 | + print_rank(f"Generation: {len(request.output_ids)} tokens, {generation_tps:.3f} tokens-per-sec") |
| 231 | + print_rank(f"Peak memory: {peak_mem:.3f} GB") |
| 232 | + cache_manager.free_request(request.request_id) |
| 233 | + |
| 234 | + |
| 235 | +if __name__ == "__main__": |
| 236 | + main() |
0 commit comments