Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions src/aiconfigurator/sdk/perf_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -5700,12 +5700,11 @@ def get_sol(
Get the SOL time for All2All communication.

All2All transfers token data between GPUs:
- dispatch: each GPU sends (num_tokens * topk / ep_size) tokens to other GPUs
- combine: reverse direction
- prepare: lightweight metadata exchange

The total data transferred per GPU is proportional to
num_tokens * topk * hidden_size * (1 - 1/ep_size), since each GPU keeps 1/ep_size locally.
- prepare: lightweight metadata exchange (topk * 4 bytes per token)
- dispatch: each token sent once per unique remote rank (deduplication).
remote_ranks = min(topk, ep_size) - 1, bytes use quant_mode precision.
- combine: each remote expert returns one result in bfloat16.
remote_ranks = min(topk, ep_size) - 1, bytes always use 2 (bf16).
"""
is_inter_node = node_num > 1

Expand All @@ -5714,23 +5713,25 @@ def get_sol(
else:
bw = self.system_spec["node"]["intra_node_bw"]

remote_ranks = min(topk, moe_ep_size) - 1

if op_name == "alltoall_prepare":
data_bytes = num_tokens * topk * 4 # token routing indices, ~4 bytes per entry
elif "combine" in op_name:
# combine: results returned in bfloat16 regardless of quant mode
data_bytes = num_tokens * remote_ranks * hidden_size * 2
else:
# dispatch/combine: transfer token activations
# dispatch: per-rank deduplication, use quant_mode precision
data_bytes = (
num_tokens
* topk
* remote_ranks
* hidden_size
* quant_mode.value.memory
* (1.0 - 1.0 / moe_ep_size) # fraction sent to remote GPUs
)

sol_comm = data_bytes / bw * 1000 # ms
p2p_latency_ms = self.system_spec["node"]["p2p_latency"] * 1000
sol_overhead = p2p_latency_ms
sol_time = sol_comm + sol_overhead
return sol_time, sol_comm, sol_overhead
sol_time = sol_comm
return sol_time, sol_comm, 0.0

def get_empirical_from_sol(
num_tokens: int,
Expand Down
Git LFS file not shown
165 changes: 164 additions & 1 deletion tools/sanity_check/validate_database.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,169 @@
"for op in op_list:\n",
" visualize_nccl(database, operation=op)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def visualize_trtllm_alltoall(database):\n",
" \"\"\"Visualize TRT-LLM AlltoAll communication latency and sol%.\n",
"\n",
" Layout: rows = op phases (prepare/dispatch/combine/combine_lp) with latency and sol%,\n",
" cols = ep_sizes, lines = moe_dtype (quant mode).\n",
"\n",
" \"\"\"\n",
" alltoall_data = database._trtllm_alltoall_data\n",
" if not alltoall_data.loaded:\n",
" print(\"No trtllm alltoall data available (data not loaded)\")\n",
" return\n",
" if not alltoall_data:\n",
" print(\"No trtllm alltoall data available (empty)\")\n",
" return\n",
"\n",
" color_list = [\n",
" \"red\",\n",
" \"blue\",\n",
" \"green\",\n",
" \"orange\",\n",
" \"purple\",\n",
" \"brown\",\n",
" \"pink\",\n",
" \"gray\",\n",
" \"olive\",\n",
" \"cyan\",\n",
" ]\n",
"\n",
" # Map kernel_source -> moe_backend for query API\n",
" kernel_to_backend = {\n",
" \"NVLinkTwoSided\": \"wideep\",\n",
" \"NVLinkOneSided\": None,\n",
" }\n",
"\n",
" for kernel_source in alltoall_data:\n",
" if kernel_source not in kernel_to_backend:\n",
" continue\n",
" moe_backend = kernel_to_backend[kernel_source]\n",
" kernel_data = alltoall_data[kernel_source]\n",
"\n",
" op_names_set = set()\n",
" ep_sizes_set = set()\n",
" quant_modes_set = set()\n",
"\n",
" for op_name in kernel_data:\n",
" op_names_set.add(op_name)\n",
" for qm in kernel_data[op_name]:\n",
" quant_modes_set.add(qm)\n",
" for nn in kernel_data[op_name][qm]:\n",
" for hs in kernel_data[op_name][qm][nn]:\n",
" for tk in kernel_data[op_name][qm][nn][hs]:\n",
" for ne in kernel_data[op_name][qm][nn][hs][tk]:\n",
" for ep in kernel_data[op_name][qm][nn][hs][tk][ne]:\n",
" ep_sizes_set.add(ep)\n",
"\n",
" op_order = [\n",
" \"alltoall_prepare\",\n",
" \"alltoall_dispatch\",\n",
" \"alltoall_combine\",\n",
" \"alltoall_combine_low_precision\",\n",
" ]\n",
" op_names = [op for op in op_order if op in op_names_set]\n",
" ep_sizes = sorted(ep_sizes_set)\n",
" quant_modes = sorted(quant_modes_set, key=str)\n",
"\n",
" if not op_names or not ep_sizes:\n",
" continue\n",
"\n",
" sol_ops = {\"alltoall_dispatch\", \"alltoall_combine\", \"alltoall_combine_low_precision\"}\n",
" rows = []\n",
" for op in op_names:\n",
" rows.append((op, \"latency\"))\n",
" if op in sol_ops:\n",
" rows.append((op, \"sol %\"))\n",
"\n",
" n_rows = len(rows)\n",
" n_cols = max(len(ep_sizes), 1)\n",
" fig, axes = plt.subplots(\n",
" n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows), squeeze=False,\n",
" )\n",
" fig.suptitle(\n",
" f\"{database.system.upper()} - {database.backend.upper()} {database.version}\"\n",
" f\" - TRT-LLM AlltoAll Sanity Chart ({kernel_source})\",\n",
" fontsize=14,\n",
" fontweight=\"bold\",\n",
" )\n",
"\n",
" for ri, (op_name, metric) in enumerate(rows):\n",
" for ci, target_ep in enumerate(ep_sizes):\n",
" ax = axes[ri][ci]\n",
" cid = 0\n",
" for qm in quant_modes:\n",
" # Discover available num_tokens for this (qm, op, ep) combo\n",
" token_set, hs, tk, ne_val, nn_val = set(), 0, 0, 0, 0\n",
" if qm in kernel_data.get(op_name, {}):\n",
" for nn in kernel_data[op_name][qm]:\n",
" for h in kernel_data[op_name][qm][nn]:\n",
" for t in kernel_data[op_name][qm][nn][h]:\n",
" for ne in kernel_data[op_name][qm][nn][h][t]:\n",
" if target_ep not in kernel_data[op_name][qm][nn][h][t][ne]:\n",
" continue\n",
" hs, tk, ne_val, nn_val = h, t, ne, nn\n",
" token_set.update(kernel_data[op_name][qm][nn][h][t][ne][target_ep].keys())\n",
"\n",
" if not token_set:\n",
" cid += 1\n",
" continue\n",
"\n",
" tokens = sorted(token_set)\n",
" label = qm.name if hasattr(qm, \"name\") else str(qm)\n",
"\n",
" query_kwargs = dict(\n",
" op_name=op_name, hidden_size=hs, topk=tk,\n",
" num_experts=ne_val, moe_ep_size=target_ep,\n",
" quant_mode=qm, node_num=nn_val, moe_backend=moe_backend,\n",
" )\n",
"\n",
" if metric == \"latency\":\n",
" vals = [\n",
" float(database.query_trtllm_alltoall(\n",
" num_tokens=nt, database_mode=DatabaseMode.SILICON, **query_kwargs,\n",
" ))\n",
" for nt in tokens\n",
" ]\n",
" else:\n",
" vals = []\n",
" for nt in tokens:\n",
" sol_time = database.query_trtllm_alltoall(\n",
" num_tokens=nt, database_mode=DatabaseMode.SOL_FULL, **query_kwargs,\n",
" )[0]\n",
" db_time = database.query_trtllm_alltoall(\n",
" num_tokens=nt, database_mode=DatabaseMode.SILICON, **query_kwargs,\n",
" )\n",
" vals.append(sol_time / db_time if db_time > 0 else 0)\n",
"\n",
" ax.plot(\n",
" tokens,\n",
" vals,\n",
" color=color_list[cid % len(color_list)],\n",
" label=label,\n",
" marker=\".\",\n",
" markersize=3,\n",
" )\n",
" cid += 1\n",
"\n",
" ax.set_title(f\"{op_name}\\nep_size={target_ep}\")\n",
" ax.set_xlabel(\"num_tokens\")\n",
" ax.set_ylabel(metric)\n",
" ax.legend(fontsize=\"small\")\n",
"\n",
" plt.tight_layout(rect=[0, 0, 1, 0.96])\n",
" plt.show()\n",
"\n",
"\n",
"visualize_trtllm_alltoall(database)"
]
}
],
"metadata": {
Expand All @@ -1028,7 +1191,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.12"
"version": "3.12.3"
}
},
"nbformat": 4,
Expand Down
Loading