Skip to content

Commit 8c00879

Browse files
tamarPaltamarPal
authored andcommitted
fix: increase graph nodes for Megrez-MoE warmup
Megrez-MoE creates many intermediate tensors during MoE FFN construction: - sigmoid, add, reshape (3x), get_rows, sum_rows, div, view_2d, mul_mat operations - ggml_top_k internally calls ggml_argsort + ggml_view_4d (2 more tensors per layer) - Each of 30 MoE layers creates ~35 intermediate tensors during graph construction During warmup, the graph is built 3 times with different batch sizes, requiring sufficient memory pool space for all intermediate tensors. Add 4096 node overhead for LLM_ARCH_MEGREZ_MOE to accommodate these intermediate tensors (30 layers × 35 tensors/layer ≈ 1050 nodes, doubled for safety margin). This fixes the 'not enough space in the context's memory pool' error during warmup, allowing Megrez-MoE to work without the --no-warmup flag. Tested: - All 39 tests pass - Megrez-MoE works with warmup enabled (no crashes) - Other models (e.g., Gemma-2) are unaffected - Verified with outputs up to 100 tokens
1 parent 256414a commit 8c00879

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

src/llama-context.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1362,7 +1362,21 @@ void llama_context::output_reorder() {
13621362
//
13631363

13641364
uint32_t llama_context::graph_max_nodes() const {
1365-
return std::max<uint32_t>(1024u, 8u*model.n_tensors());
1365+
uint32_t base_nodes = std::max<uint32_t>(1024u, 8u*model.n_tensors());
1366+
1367+
// Megrez-MoE creates many intermediate tensors in build_mergez_moe_ffn for each layer:
1368+
// - sigmoid, add (bias), reshape (3x), get_rows, sum_rows, div, view_2d, mul_mat (per expert)
1369+
// - ggml_top_k internally calls ggml_argsort + ggml_view_4d (2 more tensors per layer)
1370+
// Each MoE layer needs ~30-40 intermediate tensors during graph construction
1371+
// With 30 MoE layers, this adds significant overhead to the graph (30 layers * 35 tensors = ~1050)
1372+
// During warmup, the graph is built 3 times with different batch sizes
1373+
if (model.arch == LLM_ARCH_MEGREZ_MOE) {
1374+
// Add substantial overhead: ~35 intermediate tensors per MoE layer * 30 layers = ~1050 nodes
1375+
// Double it to 4096 for safety margin during warmup's triple graph construction
1376+
base_nodes += 4096;
1377+
}
1378+
1379+
return base_nodes;
13661380
}
13671381

13681382
llm_graph_result * llama_context::get_gf_res_reserve() const {

0 commit comments

Comments
 (0)