diff --git a/src/aiconfigurator/sdk/perf_database.py b/src/aiconfigurator/sdk/perf_database.py index ea08e455..e3445e46 100755 --- a/src/aiconfigurator/sdk/perf_database.py +++ b/src/aiconfigurator/sdk/perf_database.py @@ -5700,12 +5700,12 @@ def get_sol( Get the SOL time for All2All communication. All2All transfers token data between GPUs: - - dispatch: each GPU sends (num_tokens * topk / ep_size) tokens to other GPUs - - combine: reverse direction - - prepare: lightweight metadata exchange - - The total data transferred per GPU is proportional to - num_tokens * topk * hidden_size * (1 - 1/ep_size), since each GPU keeps 1/ep_size locally. + - prepare: lightweight metadata exchange (topk * 4 bytes per token) + - dispatch: each token sent once per unique remote rank (deduplication). + remote_ranks = min(topk, num_experts, ep_size - 1), bytes use quant_mode precision. + - combine: standard returns results in bfloat16 (2 B/elem); + low-precision variant returns results in fp4 (0.5 B/elem). + remote_ranks = min(topk, num_experts, ep_size - 1). """ is_inter_node = node_num > 1 @@ -5714,23 +5714,20 @@ def get_sol( else: bw = self.system_spec["node"]["intra_node_bw"] + remote_ranks = min(topk, num_experts, moe_ep_size - 1) + if op_name == "alltoall_prepare": data_bytes = num_tokens * topk * 4 # token routing indices, ~4 bytes per entry + elif "combine" in op_name: + bytes_per_element = 0.5 if "low_precision" in op_name else 2 + data_bytes = num_tokens * remote_ranks * hidden_size * bytes_per_element else: - # dispatch/combine: transfer token activations - data_bytes = ( - num_tokens - * topk - * hidden_size - * quant_mode.value.memory - * (1.0 - 1.0 / moe_ep_size) # fraction sent to remote GPUs - ) + # dispatch: per-rank deduplication, use quant_mode precision + data_bytes = num_tokens * remote_ranks * hidden_size * quant_mode.value.memory sol_comm = data_bytes / bw * 1000 # ms - p2p_latency_ms = self.system_spec["node"]["p2p_latency"] * 1000 - sol_overhead = p2p_latency_ms - sol_time = sol_comm + sol_overhead - return sol_time, sol_comm, sol_overhead + sol_time = sol_comm + return sol_time, sol_comm, 0.0 def get_empirical_from_sol( num_tokens: int, diff --git a/src/aiconfigurator/systems/data/gb200/trtllm/1.2.0rc6/trtllm_alltoall_perf.txt b/src/aiconfigurator/systems/data/gb200/trtllm/1.2.0rc6/trtllm_alltoall_perf.txt index 4e4d7921..4666744e 100644 --- a/src/aiconfigurator/systems/data/gb200/trtllm/1.2.0rc6/trtllm_alltoall_perf.txt +++ b/src/aiconfigurator/systems/data/gb200/trtllm/1.2.0rc6/trtllm_alltoall_perf.txt @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f51fcf2451066814d6452f2347415b4891fd40baebc94fb160a637c41d179a1b -size 221416 +oid sha256:6915f42dff1b1e1e22625c11bdd89c575fe2e79333f0f9d1726393019864e933 +size 241498 diff --git a/tools/sanity_check/validate_database.ipynb b/tools/sanity_check/validate_database.ipynb index d14bc198..78ca9f15 100644 --- a/tools/sanity_check/validate_database.ipynb +++ b/tools/sanity_check/validate_database.ipynb @@ -52,16 +52,12 @@ " sol_time, sol_math, sol_mem = database.query_gemm(\n", " m=m, n=n, k=k, quant_mode=quant_mode, database_mode=DatabaseMode.SOL_FULL\n", " )\n", - " db_time = database.query_gemm(\n", - " m=m, n=n, k=k, quant_mode=quant_mode, database_mode=DatabaseMode.SILICON\n", - " )\n", + " db_time = database.query_gemm(m=m, n=n, k=k, quant_mode=quant_mode, database_mode=DatabaseMode.SILICON)\n", " percentage_of_math = sol_math / db_time\n", " percentage_of_mem = sol_mem / db_time\n", " sol_math_list.append(percentage_of_math)\n", " sol_mem_list.append(percentage_of_mem)\n", - " ax[0, i].plot(\n", - " m_list, sol_math_list, color=color_list[color_id], label=f\"{quant_mode} math\"\n", - " )\n", + " ax[0, i].plot(m_list, sol_math_list, color=color_list[color_id], label=f\"{quant_mode} math\")\n", " ax[1, i].plot(\n", " m_list,\n", " sol_mem_list,\n", @@ -206,7 +202,7 @@ " prefix_len = int(s * prefix_factor)\n", " sol_time, sol_math, sol_mem = database.query_context_attention(\n", " b=b,\n", - " s=s-prefix_len,\n", + " s=s - prefix_len,\n", " n=n,\n", " n_kv=n_kv,\n", " kvcache_quant_mode=kvcache_quant_mode,\n", @@ -216,7 +212,7 @@ " )\n", " db_time = database.query_context_attention(\n", " b=b,\n", - " s=s-prefix_len,\n", + " s=s - prefix_len,\n", " n=n,\n", " n_kv=n_kv,\n", " kvcache_quant_mode=kvcache_quant_mode,\n", @@ -393,9 +389,7 @@ " sol_math_list.append(percentage_of_math)\n", " sol_mem_list.append(percentage_of_mem)\n", " ax[0, i].plot(s_list, sol_math_list, color=color_list[color_id], label=f\"b{b} math\")\n", - " ax[1, i].plot(\n", - " s_list, sol_mem_list, color=color_list[color_id], linestyle=\"--\", label=f\"b{b} mem\"\n", - " )\n", + " ax[1, i].plot(s_list, sol_mem_list, color=color_list[color_id], linestyle=\"--\", label=f\"b{b} mem\")\n", " ax[0, i].set_title(f\"kvcache={kvcache_quant_mode} n={n}, n_kv={n_kv}\")\n", " ax[0, i].set_xlabel(\"s\")\n", " ax[0, i].set_ylabel(\"math sol %\")\n", @@ -447,7 +441,7 @@ " prefix_len = int(s * prefix_scale)\n", " sol_time, sol_math, sol_mem = database.query_context_mla(\n", " b=b,\n", - " s=s-prefix_len,\n", + " s=s - prefix_len,\n", " num_heads=n_q,\n", " kvcache_quant_mode=kvcache_quant_mode,\n", " fmha_quant_mode=quant_mode,\n", @@ -456,7 +450,7 @@ " )\n", " db_time = database.query_context_mla(\n", " b=b,\n", - " s=s-prefix_len,\n", + " s=s - prefix_len,\n", " num_heads=n_q,\n", " kvcache_quant_mode=kvcache_quant_mode,\n", " fmha_quant_mode=quant_mode,\n", @@ -626,9 +620,7 @@ " sol_math_list.append(percentage_of_math)\n", " sol_mem_list.append(percentage_of_mem)\n", " ax[0, i].plot(s_list, sol_math_list, color=color_list[color_id], label=f\"b{b} math\")\n", - " ax[1, i].plot(\n", - " s_list, sol_mem_list, color=color_list[color_id], linestyle=\"--\", label=f\"b{b} mem\"\n", - " )\n", + " ax[1, i].plot(s_list, sol_mem_list, color=color_list[color_id], linestyle=\"--\", label=f\"b{b} mem\")\n", " ax[0, i].set_title(f\"kvcache={kvcache_quant_mode} n_q_per_gpu={num_q}\")\n", " ax[0, i].set_xlabel(\"s\")\n", " ax[0, i].set_ylabel(\"math sol %\")\n", @@ -660,8 +652,8 @@ " kv_cache_dtype = common.KVCacheQuantMode.float16\n", " fmha_quant_mode = common.FMHAQuantMode.float16\n", "\n", - " context_s_list = [128, 256, 512, 1024, 2048, 4096, 8192, 16384,65536]\n", - " generation_s_list = [128, 512, 1024, 2048, 4096, 8192, 16384, 32768,65536*4]\n", + " context_s_list = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 65536]\n", + " generation_s_list = [128, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536 * 4]\n", "\n", " fig, ax = plt.subplots(2, 2, figsize=(14, 8))\n", "\n", @@ -757,7 +749,7 @@ " plt.show()\n", "\n", "\n", - "#visualize_dsa_module(database)" + "# visualize_dsa_module(database)" ] }, { @@ -792,7 +784,11 @@ " \"olive\",\n", " \"cyan\",\n", " ]\n", - " fig, ax = plt.subplots(2*len(workload_distributions), len(tp_ep_list), figsize=(5 * len(tp_ep_list), 5 * 2*len(workload_distributions)))\n", + " fig, ax = plt.subplots(\n", + " 2 * len(workload_distributions),\n", + " len(tp_ep_list),\n", + " figsize=(5 * len(tp_ep_list), 5 * 2 * len(workload_distributions)),\n", + " )\n", " for workload_distribution_id, workload_distribution in enumerate(workload_distributions):\n", " for i, (tp, ep) in enumerate(tp_ep_list):\n", " for color_id, quant_mode in enumerate(database._moe_data.keys()):\n", @@ -843,10 +839,10 @@ " sol_math_list.append(percentage_of_math)\n", " sol_mem_list.append(percentage_of_mem)\n", "\n", - " ax[workload_distribution_id*2, i].plot(\n", + " ax[workload_distribution_id * 2, i].plot(\n", " m_list, sol_math_list, color=color_list[color_id], label=f\"{quant_mode} math\"\n", " )\n", - " ax[workload_distribution_id*2+1, i].plot(\n", + " ax[workload_distribution_id * 2 + 1, i].plot(\n", " m_list,\n", " sol_mem_list,\n", " color=color_list[color_id],\n", @@ -858,15 +854,17 @@ " else:\n", " workload_distribution_title = workload_distribution\n", "\n", - " ax[workload_distribution_id*2, i].set_title(f\"{workload_distribution_title} \\ntopk={topk} e={num_experts} tp={tp} ep={ep}\")\n", - " ax[workload_distribution_id*2, i].set_xlabel(\"s\")\n", - " ax[workload_distribution_id*2, i].set_ylabel(\"math sol %\")\n", + " ax[workload_distribution_id * 2, i].set_title(\n", + " f\"{workload_distribution_title} \\ntopk={topk} e={num_experts} tp={tp} ep={ep}\"\n", + " )\n", + " ax[workload_distribution_id * 2, i].set_xlabel(\"s\")\n", + " ax[workload_distribution_id * 2, i].set_ylabel(\"math sol %\")\n", " # ax[0,i].set_ylim(0,1)\n", - " ax[workload_distribution_id*2, i].legend()\n", - " ax[workload_distribution_id*2+1, i].set_xlabel(\"s\")\n", - " ax[workload_distribution_id*2+1, i].set_ylabel(\"mem sol %\")\n", + " ax[workload_distribution_id * 2, i].legend()\n", + " ax[workload_distribution_id * 2 + 1, i].set_xlabel(\"s\")\n", + " ax[workload_distribution_id * 2 + 1, i].set_ylabel(\"mem sol %\")\n", " # ax[1,i].set_ylim(0,1)\n", - " ax[workload_distribution_id*2+1, i].legend()\n", + " ax[workload_distribution_id * 2 + 1, i].legend()\n", " plt.tight_layout()\n", " plt.show()\n", "\n", @@ -990,9 +988,7 @@ " sol_time, sol_math, sol_mem = database.query_nccl(\n", " quant_mode, num_gpu, operation, m, database_mode=DatabaseMode.SOL_FULL\n", " )\n", - " db_time = database.query_nccl(\n", - " quant_mode, num_gpu, operation, m, database_mode=DatabaseMode.SILICON\n", - " )\n", + " db_time = database.query_nccl(quant_mode, num_gpu, operation, m, database_mode=DatabaseMode.SILICON)\n", " percentage_of_sol = sol_time / db_time\n", " sol_list.append(percentage_of_sol)\n", " ax[i].plot(m_list, sol_list, color=color_list[i], label=f\"{num_gpu}\")\n", @@ -1010,6 +1006,191 @@ "for op in op_list:\n", " visualize_nccl(database, operation=op)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def visualize_trtllm_alltoall(database):\n", + " \"\"\"Visualize TRT-LLM AlltoAll communication latency and sol%.\n", + "\n", + " Layout: rows = op phases (prepare/dispatch/combine/combine_lp) with latency and sol%,\n", + " cols = ep_sizes, lines = moe_dtype (quant mode).\n", + "\n", + " \"\"\"\n", + " alltoall_data = database._trtllm_alltoall_data\n", + " if not alltoall_data.loaded:\n", + " print(\"No trtllm alltoall data available (data not loaded)\")\n", + " return\n", + " if not alltoall_data:\n", + " print(\"No trtllm alltoall data available (empty)\")\n", + " return\n", + "\n", + " color_list = [\n", + " \"red\",\n", + " \"blue\",\n", + " \"green\",\n", + " \"orange\",\n", + " \"purple\",\n", + " \"brown\",\n", + " \"pink\",\n", + " \"gray\",\n", + " \"olive\",\n", + " \"cyan\",\n", + " ]\n", + "\n", + " # Map kernel_source -> moe_backend for query API\n", + " kernel_to_backend = {\n", + " \"NVLinkTwoSided\": \"wideep\",\n", + " \"NVLinkOneSided\": None,\n", + " }\n", + "\n", + " for kernel_source in alltoall_data:\n", + " if kernel_source not in kernel_to_backend:\n", + " continue\n", + " moe_backend = kernel_to_backend[kernel_source]\n", + " kernel_data = alltoall_data[kernel_source]\n", + "\n", + " op_names_set = set()\n", + " ep_sizes_set = set()\n", + " quant_modes_set = set()\n", + "\n", + " for op_name in kernel_data:\n", + " op_names_set.add(op_name)\n", + " for qm in kernel_data[op_name]:\n", + " quant_modes_set.add(qm)\n", + " for nn in kernel_data[op_name][qm]:\n", + " for hs in kernel_data[op_name][qm][nn]:\n", + " for tk in kernel_data[op_name][qm][nn][hs]:\n", + " for ne in kernel_data[op_name][qm][nn][hs][tk]:\n", + " for ep in kernel_data[op_name][qm][nn][hs][tk][ne]:\n", + " ep_sizes_set.add(ep)\n", + "\n", + " op_order = [\n", + " \"alltoall_prepare\",\n", + " \"alltoall_dispatch\",\n", + " \"alltoall_combine\",\n", + " \"alltoall_combine_low_precision\",\n", + " ]\n", + " op_names = [op for op in op_order if op in op_names_set]\n", + " ep_sizes = sorted(ep_sizes_set)\n", + " quant_modes = sorted(quant_modes_set, key=str)\n", + "\n", + " if not op_names or not ep_sizes:\n", + " continue\n", + "\n", + " sol_ops = {\"alltoall_dispatch\", \"alltoall_combine\", \"alltoall_combine_low_precision\"}\n", + " rows = []\n", + " for op in op_names:\n", + " rows.append((op, \"latency\"))\n", + " if op in sol_ops:\n", + " rows.append((op, \"sol %\"))\n", + "\n", + " n_rows = len(rows)\n", + " n_cols = max(len(ep_sizes), 1)\n", + " fig, axes = plt.subplots(\n", + " n_rows,\n", + " n_cols,\n", + " figsize=(5 * n_cols, 4 * n_rows),\n", + " squeeze=False,\n", + " )\n", + " fig.suptitle(\n", + " f\"{database.system.upper()} - {database.backend.upper()} {database.version}\"\n", + " f\" - TRT-LLM AlltoAll Sanity Chart ({kernel_source})\",\n", + " fontsize=14,\n", + " fontweight=\"bold\",\n", + " )\n", + "\n", + " for ri, (op_name, metric) in enumerate(rows):\n", + " for ci, target_ep in enumerate(ep_sizes):\n", + " ax = axes[ri][ci]\n", + " cid = 0\n", + " for qm in quant_modes:\n", + " # Collect token sets keyed by shape tuple so\n", + " # each plotted line uses its own exact lookup.\n", + " shape_map = {} # (nn, h, t, ne) -> token_set\n", + " if qm in kernel_data.get(op_name, {}):\n", + " for nn in kernel_data[op_name][qm]:\n", + " for h in kernel_data[op_name][qm][nn]:\n", + " for t in kernel_data[op_name][qm][nn][h]:\n", + " for ne in kernel_data[op_name][qm][nn][h][t]:\n", + " if target_ep not in kernel_data[op_name][qm][nn][h][t][ne]:\n", + " continue\n", + " key = (nn, h, t, ne)\n", + " if key not in shape_map:\n", + " shape_map[key] = set()\n", + " shape_map[key].update(kernel_data[op_name][qm][nn][h][t][ne][target_ep].keys())\n", + "\n", + " if not shape_map:\n", + " cid += 1\n", + " continue\n", + "\n", + " for (nn_val, hs, tk, ne_val), token_set in shape_map.items():\n", + " tokens = sorted(token_set)\n", + " base_label = qm.name if hasattr(qm, \"name\") else str(qm)\n", + " label = f\"{base_label} hs={hs} tk={tk} ne={ne_val}\" if len(shape_map) > 1 else base_label\n", + "\n", + " query_kwargs = dict(\n", + " op_name=op_name,\n", + " hidden_size=hs,\n", + " topk=tk,\n", + " num_experts=ne_val,\n", + " moe_ep_size=target_ep,\n", + " quant_mode=qm,\n", + " node_num=nn_val,\n", + " moe_backend=moe_backend,\n", + " )\n", + "\n", + " if metric == \"latency\":\n", + " vals = [\n", + " float(\n", + " database.query_trtllm_alltoall(\n", + " num_tokens=nt,\n", + " database_mode=DatabaseMode.SILICON,\n", + " **query_kwargs,\n", + " )\n", + " )\n", + " for nt in tokens\n", + " ]\n", + " else:\n", + " vals = []\n", + " for nt in tokens:\n", + " sol_time = database.query_trtllm_alltoall(\n", + " num_tokens=nt,\n", + " database_mode=DatabaseMode.SOL_FULL,\n", + " **query_kwargs,\n", + " )[0]\n", + " db_time = database.query_trtllm_alltoall(\n", + " num_tokens=nt,\n", + " database_mode=DatabaseMode.SILICON,\n", + " **query_kwargs,\n", + " )\n", + " vals.append(sol_time / db_time if db_time > 0 else 0)\n", + "\n", + " ax.plot(\n", + " tokens,\n", + " vals,\n", + " color=color_list[cid % len(color_list)],\n", + " label=label,\n", + " marker=\".\",\n", + " markersize=3,\n", + " )\n", + " cid += 1\n", + "\n", + " ax.set_title(f\"{op_name}\\nep_size={target_ep}\")\n", + " ax.set_xlabel(\"num_tokens\")\n", + " ax.set_ylabel(metric)\n", + " ax.legend(fontsize=\"small\")\n", + "\n", + " plt.tight_layout(rect=[0, 0, 1, 0.96])\n", + " plt.show()\n", + "\n", + "\n", + "alltoall_database = get_database(system=\"gb200\", backend=\"trtllm\", version=\"1.2.0rc6\")\n", + "visualize_trtllm_alltoall(alltoall_database)" + ] } ], "metadata": { @@ -1028,7 +1209,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.12" + "version": "3.12.3" } }, "nbformat": 4,