Skip to content

Commit 4a5ed8b

Browse files
authored
feat(engine): TP on mac by RDMA over Thunderbolt 5 (#378)
1 parent d32b885 commit 4a5ed8b

File tree

16 files changed

+564
-230
lines changed

16 files changed

+564
-230
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ build/
1111
*.key
1212
.cache
1313
.vscode/
14+
hosts.json

scripts/generate.py

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
"""
2+
Simple offline inference script
3+
4+
Example command:
5+
6+
single node:
7+
python scripts/generate.py
8+
9+
tensor parallel:
10+
https://ml-explore.github.io/mlx/build/html/usage/distributed.html#enabling-rdma
11+
12+
mlx.distributed_config --verbose \
13+
--hosts macmini1,macmini2 \
14+
--over thunderbolt --backend jaccl \
15+
--auto-setup --output hosts.json
16+
17+
mlx.launch \
18+
--backend jaccl \
19+
--env MLX_METAL_FAST_SYNCH=1 \
20+
--hostfile hosts.json \
21+
scripts/generate.py
22+
"""
23+
24+
import argparse
25+
import time
26+
27+
import mlx.core as mx
28+
29+
from parallax.server.cache_manager import CacheManager
30+
from parallax.server.request import InitialRequest
31+
from parallax.server.sampling.sampler import SamplingBatchInfo
32+
from parallax.server.sampling.sampling_params import SamplingParams
33+
from parallax.server.shard_loader import MLXModelLoader
34+
from parallax.utils.utils import create_causal_mask, get_layer_types
35+
36+
tp_size = 1
37+
tp_rank = 0
38+
39+
40+
def print_rank(message):
41+
if tp_size == 1:
42+
print(message)
43+
else:
44+
print(f"[Rank {tp_rank}] {message}")
45+
46+
47+
def main():
48+
parser = argparse.ArgumentParser(description="Simple offline inference script")
49+
parser.add_argument(
50+
"--model", type=str, default="Qwen/Qwen3-32B-MLX-4bit", help="Model path or HF repo"
51+
)
52+
parser.add_argument("--prompt", type=str, default="Hi", help="Prompt for inference")
53+
parser.add_argument(
54+
"--max-tokens", type=int, default=1024, help="Maximum number of tokens to generate"
55+
)
56+
parser.add_argument("--topk", type=int, default=1, help="Top-k sampling parameter")
57+
parser.add_argument("--temp", type=float, default=1.0, help="Temperature for sampling")
58+
args = parser.parse_args()
59+
60+
# TP Initialization
61+
global tp_size, tp_rank
62+
group = mx.distributed.init()
63+
tp_rank = group.rank()
64+
tp_size = group.size()
65+
66+
mx.set_wired_limit(mx.metal.device_info()["max_recommended_working_set_size"])
67+
68+
# 1. Load Model
69+
print_rank(f"Loading model from {args.model}...")
70+
71+
loader = MLXModelLoader(
72+
args.model,
73+
)
74+
model, config, tokenizer = loader.load()
75+
76+
# 2. Initialize CacheManager
77+
num_layers = config.get("num_hidden_layers")
78+
num_kv_heads = config.get("num_key_value_heads")
79+
head_dim = config.get("head_dim") or config.get("hidden_size") // config.get(
80+
"num_attention_heads"
81+
)
82+
83+
# Check for DeepSeek style head dims
84+
qk_nope_head_dim = config.get("qk_nope_head_dim")
85+
qk_rope_head_dim = config.get("qk_rope_head_dim")
86+
if qk_nope_head_dim is not None and qk_rope_head_dim is not None:
87+
head_dim = qk_nope_head_dim + qk_rope_head_dim
88+
89+
v_head_dim = config.get("v_head_dim")
90+
layer_types = get_layer_types(config, 0, num_layers)
91+
92+
cache_manager = CacheManager(
93+
num_layers=num_layers,
94+
num_kv_heads=num_kv_heads // tp_size, # Shard heads
95+
head_dim=head_dim,
96+
dtype=model.dtype,
97+
block_size=32,
98+
cache_memory_fraction=0.1,
99+
head_dim_v=v_head_dim,
100+
layer_types=layer_types,
101+
)
102+
103+
# 3. Tokenize and Create Request
104+
messages = [{"role": "user", "content": args.prompt}]
105+
106+
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
107+
full_prompt = tokenizer.apply_chat_template(
108+
messages, tokenize=False, add_generation_prompt=True
109+
)
110+
else:
111+
full_prompt = args.prompt
112+
113+
prompt_tokens = tokenizer.encode(full_prompt)
114+
sampling_params = SamplingParams(temperature=args.temp, top_k=args.topk)
115+
request = InitialRequest(
116+
prompt=full_prompt,
117+
input_ids=prompt_tokens,
118+
sampling_params=sampling_params,
119+
max_new_tokens=args.max_tokens,
120+
)
121+
122+
eos_token_ids = []
123+
if tokenizer.eos_token_id is not None:
124+
if isinstance(tokenizer.eos_token_id, list):
125+
eos_token_ids.extend(tokenizer.eos_token_id)
126+
else:
127+
eos_token_ids.append(tokenizer.eos_token_id)
128+
config_eos = config.get("eos_token_id")
129+
if config_eos is not None:
130+
if isinstance(config_eos, list):
131+
for e in config_eos:
132+
if e not in eos_token_ids:
133+
eos_token_ids.append(e)
134+
elif config_eos not in eos_token_ids:
135+
eos_token_ids.append(config_eos)
136+
137+
eos_token_ids = set(eos_token_ids)
138+
139+
# 4. Prefill
140+
print_rank(f"Full prompt:\n {full_prompt}")
141+
142+
if tp_size > 1:
143+
mx.eval(mx.distributed.all_sum(mx.ones(1)))
144+
print_rank("Forced sync before prefill")
145+
146+
success, _ = cache_manager.allocate_request(request.request_id, request.prompt_len)
147+
if not success:
148+
print_rank("Failed to allocate cache")
149+
return
150+
151+
input_ids = mx.array([request.input_ids])
152+
block_table = mx.array([cache_manager.get_block_table(request.request_id)], dtype=mx.int32)
153+
context_lengths = mx.array([request.prompt_len], dtype=mx.int32)
154+
155+
block_size = cache_manager.block_size
156+
slot_mapping = []
157+
for i in range(request.prompt_len):
158+
block_idx = i // block_size
159+
block_offset = i % block_size
160+
physical_block = cache_manager.get_block_table(request.request_id)[block_idx]
161+
slot_mapping.append(physical_block * block_size + block_offset)
162+
slot_mapping = mx.array(slot_mapping, dtype=mx.int64)
163+
164+
mask = create_causal_mask(request.prompt_len, request.prompt_len, model.dtype)
165+
166+
prefill_start = time.perf_counter()
167+
168+
logits = model(
169+
input_ids,
170+
cache=cache_manager.get_caches(),
171+
mask=mask,
172+
block_tables=block_table,
173+
context_lengths=context_lengths,
174+
slot_mapping=slot_mapping,
175+
)
176+
177+
sampling_info = SamplingBatchInfo.from_reqs([request])
178+
179+
next_token_id = model.logits_to_tokens(logits, context_lengths, sampling_info)
180+
181+
token_id = int(next_token_id[0])
182+
request.commit_new_token(token_id)
183+
184+
prefill_time = time.perf_counter() - prefill_start
185+
print_rank(f"Token 1 (Prefill) time: {prefill_time * 1000:.2f} ms")
186+
187+
# 5. Decode Loop
188+
total_decode_time = 0
189+
for i in range(args.max_tokens - 1):
190+
decode_step_start = time.perf_counter()
191+
192+
success = cache_manager.append_slot(request.request_id)
193+
if not success:
194+
print_rank("\nOOM during decoding")
195+
break
196+
197+
block_table = mx.array([cache_manager.get_block_table(request.request_id)], dtype=mx.int32)
198+
context_lengths = mx.array(
199+
[cache_manager.get_context_length(request.request_id)], dtype=mx.int32
200+
)
201+
logits = model(
202+
mx.expand_dims(next_token_id, axis=0),
203+
cache=cache_manager.get_caches(),
204+
mask=None,
205+
block_tables=block_table,
206+
context_lengths=context_lengths,
207+
)
208+
209+
next_token_id = model.logits_to_tokens(logits, mx.array([1]), sampling_info)
210+
211+
token_id = int(next_token_id[0])
212+
if token_id in eos_token_ids:
213+
break
214+
request.commit_new_token(token_id)
215+
216+
decode_step_time = time.perf_counter() - decode_step_start
217+
total_decode_time += decode_step_time
218+
print_rank(f"Token {i + 2} time: {decode_step_time * 1000:.2f} ms")
219+
220+
print_rank("\nGenerated Content:")
221+
print_rank(tokenizer.decode(request.output_ids))
222+
223+
# Summary Statistics
224+
prompt_tps = request.prompt_len / prefill_time
225+
generation_tps = len(request.output_ids) / total_decode_time if total_decode_time > 0 else 0
226+
peak_mem = mx.get_peak_memory() / 1024**3
227+
228+
print_rank("-" * 20)
229+
print_rank(f"Prompt: {request.prompt_len} tokens, {prompt_tps:.3f} tokens-per-sec")
230+
print_rank(f"Generation: {len(request.output_ids)} tokens, {generation_tps:.3f} tokens-per-sec")
231+
print_rank(f"Peak memory: {peak_mem:.3f} GB")
232+
cache_manager.free_request(request.request_id)
233+
234+
235+
if __name__ == "__main__":
236+
main()

src/backend/benchmark/benchmark_serving.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def calculate_metrics(
435435
total_output=sum(actual_output_lens),
436436
request_throughput=completed / dur_s,
437437
request_goodput=good_completed / dur_s,
438-
output_throughput=sum(actual_output_lens) / dur_s,
438+
output_throughput=np.mean([1.0 / x for x in tpots]),
439439
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
440440
mean_ttft_ms=np.mean(ttfts or 0)
441441
* 1000, # ttfts is empty if streaming is not supported by backend
@@ -484,6 +484,7 @@ async def benchmark(
484484
ignore_eos: bool,
485485
goodput_config_dict: Dict[str, float],
486486
max_concurrency: Optional[int],
487+
skip_test: bool = False,
487488
time_serving: Optional[float] = None,
488489
report_interval: float = 30.0,
489490
):
@@ -492,31 +493,32 @@ async def benchmark(
492493
else:
493494
raise ValueError(f"Unknown backend: {backend}")
494495

495-
print("Starting initial single prompt test run...")
496-
test_prompt, test_prompt_len, test_output_len, test_mm_content = input_requests[0]
497-
if backend != "openai-chat" and test_mm_content is not None:
498-
# multi-modal benchmark is only available on OpenAI Chat backend.
499-
raise ValueError("Multi-modal content is only supported on 'openai-chat' backend.")
500-
test_input = RequestFuncInput(
501-
model=model_id,
502-
model_name=model_name,
503-
prompt=test_prompt,
504-
api_url=api_url,
505-
prompt_len=test_prompt_len,
506-
output_len=test_output_len,
507-
logprobs=logprobs,
508-
best_of=best_of,
509-
multi_modal_content=test_mm_content,
510-
ignore_eos=ignore_eos,
511-
)
512-
test_output = await request_func(request_func_input=test_input)
513-
if not test_output.success:
514-
raise ValueError(
515-
"Initial test run failed - Please make sure benchmark arguments "
516-
f"are correctly specified. Error: {test_output.error}"
496+
if not skip_test:
497+
print("Starting initial single prompt test run...")
498+
test_prompt, test_prompt_len, test_output_len, test_mm_content = input_requests[0]
499+
if backend != "openai-chat" and test_mm_content is not None:
500+
# multi-modal benchmark is only available on OpenAI Chat backend.
501+
raise ValueError("Multi-modal content is only supported on 'openai-chat' backend.")
502+
test_input = RequestFuncInput(
503+
model=model_id,
504+
model_name=model_name,
505+
prompt=test_prompt,
506+
api_url=api_url,
507+
prompt_len=test_prompt_len,
508+
output_len=test_output_len,
509+
logprobs=logprobs,
510+
best_of=best_of,
511+
multi_modal_content=test_mm_content,
512+
ignore_eos=ignore_eos,
517513
)
518-
else:
519-
print("Initial test run completed. Starting main benchmark run...")
514+
test_output = await request_func(request_func_input=test_input)
515+
if not test_output.success:
516+
raise ValueError(
517+
"Initial test run failed - Please make sure benchmark arguments "
518+
f"are correctly specified. Error: {test_output.error}"
519+
)
520+
else:
521+
print("Initial test run completed. Starting main benchmark run...")
520522

521523
if profile:
522524
print("Starting profiler...")
@@ -1011,6 +1013,7 @@ def main(args: argparse.Namespace):
10111013
ignore_eos=args.ignore_eos,
10121014
goodput_config_dict=goodput_config_dict,
10131015
max_concurrency=args.max_concurrency,
1016+
skip_test=args.skip_test,
10141017
time_serving=args.time_serving,
10151018
report_interval=args.report_interval,
10161019
)
@@ -1192,6 +1195,11 @@ def main(args: argparse.Namespace):
11921195
help="Use Torch Profiler. The endpoint must be launched with "
11931196
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
11941197
)
1198+
parser.add_argument(
1199+
"--skip-test",
1200+
action="store_true",
1201+
help="Skip the initial single prompt test run.",
1202+
)
11951203
parser.add_argument(
11961204
"--save-result",
11971205
action="store_true",

src/parallax/launch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def _wait_executors_check_layer_change(shared_state: SharedState, executor_subpr
140140
max_sequence_length=args.max_sequence_length,
141141
param_mem_ratio=args.param_mem_ratio,
142142
kvcache_mem_ratio=args.kvcache_mem_ratio,
143-
shared_state=shared_state.dict, # Pass dict to subprocess
143+
shared_state=shared_state.dict,
144144
log_level=args.log_level,
145145
conn=conn1,
146146
)

src/parallax/metal/paged_attention/kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,8 @@ def mk_int(val):
190190
verbose=False,
191191
)
192192

193-
mx.eval(outputs)
194-
return key_cache, value_cache
193+
mx.async_eval(outputs)
194+
return
195195

196196

197197
def paged_attention(

src/parallax/models/gpt_oss.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any, List, Optional
66

77
import mlx.core as mx
8+
from mlx.nn.layers.distributed import shard_inplace, shard_linear
89
from mlx_lm.models.base import create_causal_mask, scaled_dot_product_attention
910
from mlx_lm.models.gpt_oss import AttentionBlock as MLXGPTOSSAttention
1011
from mlx_lm.models.gpt_oss import ModelArgs
@@ -158,6 +159,30 @@ def __call__(
158159
out = h + r
159160
return out
160161

162+
def shard(self):
163+
group = mx.distributed.init()
164+
N = group.size()
165+
r = group.rank()
166+
# Shard the self attention
167+
self.self_attn.q_proj = shard_linear(self.self_attn.q_proj, "all-to-sharded", group=group)
168+
self.self_attn.k_proj = shard_linear(self.self_attn.k_proj, "all-to-sharded", group=group)
169+
self.self_attn.v_proj = shard_linear(self.self_attn.v_proj, "all-to-sharded", group=group)
170+
self.self_attn.o_proj = shard_linear(self.self_attn.o_proj, "sharded-to-all", group=group)
171+
num_attention_heads = self.self_attn.num_attention_heads // N
172+
self.self_attn.sinks = self.self_attn.sinks[
173+
num_attention_heads * r : num_attention_heads * (r + 1)
174+
]
175+
self.self_attn.num_attention_heads = num_attention_heads
176+
self.self_attn.num_key_value_heads = self.self_attn.num_key_value_heads // N
177+
178+
# Shard the MLP
179+
shard_inplace(self.mlp.experts.gate_proj, "all-to-sharded", group=group)
180+
shard_inplace(self.mlp.experts.up_proj, "all-to-sharded", group=group)
181+
shard_inplace(self.mlp.experts.down_proj, "sharded-to-all", group=group)
182+
if r > 0:
183+
# set the bias to 0 for the down proj on the non-zero ranks so that bias only be added once.
184+
self.mlp.experts.down_proj.bias = mx.zeros_like(self.mlp.experts.down_proj.bias)
185+
161186
@classmethod
162187
def get_architecture(cls):
163188
"""Get the architecture name for the block."""

0 commit comments

Comments
 (0)