From 750eaeb72d448768fc15973d8adf65737f06e597 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sun, 24 Aug 2025 16:15:55 +0100 Subject: [PATCH 1/2] add dependence vector code --- JuliaBUGS/src/source_gen.jl | 158 +++++++++++++++++++++++++++++++++--- 1 file changed, 148 insertions(+), 10 deletions(-) diff --git a/JuliaBUGS/src/source_gen.jl b/JuliaBUGS/src/source_gen.jl index e6130a32d..8a83702a4 100644 --- a/JuliaBUGS/src/source_gen.jl +++ b/JuliaBUGS/src/source_gen.jl @@ -191,30 +191,56 @@ function _sort_fissioned_stmts(stmt_dep_graph, fissioned_stmts, stmt_ids) end function _reconstruct_model_def_from_sorted_fissioned_stmts(sorted_fissioned_stmts) - args = [] - for (loops, stmt) in sorted_fissioned_stmts - if loops == () - push!(args, first(stmt)) + args = Any[] + + i = 1 + N = length(sorted_fissioned_stmts) + while i <= N + loops_i, stmti = sorted_fissioned_stmts[i] + # collect consecutive statements with identical loop nests + group_stmts = Any[] + j = i + while j <= N + loops_j, stmtj = sorted_fissioned_stmts[j] + if loops_j == loops_i + append!(group_stmts, stmtj) + j += 1 + else + break + end + end + + if loops_i == () + # top-level sequential statements + append!(args, group_stmts) else - push!(args, __gen_loop_expr(loops, first(stmt))) + push!(args, __gen_loop_expr(loops_i, group_stmts)) end + i = j end + return Expr(:block, args...) end -function __gen_loop_expr(loop_vars, stmt) +# Overload to generate nested loops around a block of statements +function __gen_loop_expr(loop_vars, stmts::Vector) loop_var, l, h = loop_vars[1] if length(loop_vars) == 1 return MacroTools.@q for $(loop_var) in ($(l)):($(h)) - $(stmt) + $(Expr(:block, stmts...)) end else return MacroTools.@q for $(loop_var) in ($(l)):($(h)) - $(__gen_loop_expr(loop_vars[2:end], stmt)) + $(__gen_loop_expr(loop_vars[2:end], stmts)) end end end +# Backward-compatible helper to handle single statement +function __gen_loop_expr(loop_vars, stmt) + return __gen_loop_expr(loop_vars, Any[stmt]) +end + # add if statement in the lowered model def function _lower_model_def_to_represent_observe_stmts( reconstructed_model_def, @@ -327,7 +353,17 @@ function _generate_lowered_model_def( var_to_stmt_id = _build_var_to_stmt_id(model_def, g, evaluation_env, stmt_to_stmt_id) stmt_id_to_var = _build_stmt_id_to_var(var_to_stmt_id) coarse_graph = _build_coarse_dep_graph(g, stmt_to_stmt_id, var_to_stmt_id) - if Graphs.is_cyclic(coarse_graph) + # If there are cycles at the coarse statement level, try to resolve them + # by analyzing fine-grained dependence vectors. If all cycles are + # loop-carried with lexicographically non-negative distances within + # the same loop nest, they are sequentially valid and can be ignored + # for statement reordering. + ordering_graph, ok = _build_ordering_graph_via_dependence_vectors( + g, coarse_graph, var_to_stmt_id + ) + if !ok || Graphs.is_cyclic(ordering_graph) + # Either we detected lexicographically negative dependences or + # remaining cycles cannot be resolved by dependence vectors. return nothing, nothing end # show_coarse_graph(stmt_id_to_stmt, coarse_graph) @@ -337,8 +373,10 @@ function _generate_lowered_model_def( fissioned_stmts = _fully_fission_loop( model_def_removed_transformed_data, stmt_to_stmt_id, evaluation_env ) + # Use the filtered ordering graph (with loop-carried non-negative + # dependences removed) to sort fissioned statements. sorted_fissioned_stmts = _sort_fissioned_stmts( - coarse_graph, fissioned_stmts, stmt_to_stmt_id + ordering_graph, fissioned_stmts, stmt_to_stmt_id ) reconstructed_model_def = _reconstruct_model_def_from_sorted_fissioned_stmts( sorted_fissioned_stmts @@ -594,6 +632,106 @@ function _find_corresponding_fine_grained_edges( return fine_grained_edges end +# Determine the lexicographic relation between two iteration vectors (loop vars). +# Returns: +# - :zero -> loop-independent dependence (same iteration) +# - :positive -> loop-carried with lexicographically non-negative (and not all zero) +# - :negative -> lexicographically negative (invalid for sequential order) +# - :unknown -> cannot compare (different loop nests or empty) +function _lex_dependence_relation(src_lv::NamedTuple, dst_lv::NamedTuple) + # Require identical loop nests (same keys in the same order) + src_keys = Tuple(keys(src_lv)) + dst_keys = Tuple(keys(dst_lv)) + if src_keys != dst_keys + return :unknown + end + if length(src_keys) == 0 + return :unknown + end + + # Compute difference vector dst - src in lexicographic order + first_nonzero = 0 + for k in src_keys + d = Int(getfield(dst_lv, k)) - Int(getfield(src_lv, k)) + if d != 0 + first_nonzero = d + break + end + end + if first_nonzero == 0 + return :zero + elseif first_nonzero > 0 + return :positive + else + return :negative + end +end + +# Classify a fine-grained edge by its dependence vector category +function _classify_fine_edge( + g::JuliaBUGS.BUGSGraph, src_vn::VarName, dst_vn::VarName +) + src_lv = g[src_vn].loop_vars + dst_lv = g[dst_vn].loop_vars + rel = _lex_dependence_relation(src_lv, dst_lv) + return rel +end + +# Build an ordering graph for statements by removing edges that are purely +# loop-carried with lexicographically non-negative dependence vectors. If any +# edge has a lexicographically negative dependence, the graph is invalid. +# IMPORTANT: Only removes positive edges for self-dependencies to avoid unsafe reorderings. +function _build_ordering_graph_via_dependence_vectors( + g::JuliaBUGS.BUGSGraph, + coarse_graph::Graphs.SimpleDiGraph, + var_to_stmt_id::Dict{VarName,Int}, +) + ordering_graph = Graphs.SimpleDiGraph(Graphs.nv(coarse_graph)) + + # Iterate all coarse edges and decide whether to keep them for ordering + for e in Graphs.edges(coarse_graph) + src_stmt_id = Graphs.src(e) + dst_stmt_id = Graphs.dst(e) + fine_edges = _find_corresponding_fine_grained_edges( + g, var_to_stmt_id, src_stmt_id, dst_stmt_id + ) + + # If we can't find the fine edges, be conservative: keep the edge. + if isempty(fine_edges) + Graphs.add_edge!(ordering_graph, src_stmt_id, dst_stmt_id) + continue + end + + # Check all fine edges to classify the coarse edge + all_positive = true + for (src_vn, dst_vn) in fine_edges + rel = _classify_fine_edge(g, src_vn, dst_vn) + if rel === :negative + # Invalid sequential order due to negative dependence + return ordering_graph, false + end + if rel !== :positive + all_positive = false + end + end + + # Decision logic based on whether it's a self-edge or cross-statement edge + if src_stmt_id == dst_stmt_id + # Self-edge: only drop if ALL fine edges are positive (loop-carried) + # This safely breaks recursion cycles like x[t] ~ f(x[t-1]) + if !all_positive + Graphs.add_edge!(ordering_graph, src_stmt_id, dst_stmt_id) + end + else + # Cross-statement edge: ALWAYS keep to preserve ordering constraints + # This prevents unsafe reorderings where consumers run before producers + Graphs.add_edge!(ordering_graph, src_stmt_id, dst_stmt_id) + end + end + + return ordering_graph, true +end + struct CollectSortedNodes{ET} <: CompilerPass sorted_nodes::Vector{<:VarName} env::ET From 23888072ec7be2665b8f23514a6c1b11630d09b0 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 29 Aug 2025 16:44:07 +0100 Subject: [PATCH 2/2] code and test --- JuliaBUGS/docs/src/source_gen.md | 141 ++++++++++++++++ JuliaBUGS/src/compiler_pass.jl | 3 + JuliaBUGS/src/source_gen.jl | 272 ++++++++++++++++++++++++++++--- JuliaBUGS/test/source_gen.jl | 110 +++++++++++++ 4 files changed, 499 insertions(+), 27 deletions(-) diff --git a/JuliaBUGS/docs/src/source_gen.md b/JuliaBUGS/docs/src/source_gen.md index 6a5ded5be..34d3195eb 100644 --- a/JuliaBUGS/docs/src/source_gen.md +++ b/JuliaBUGS/docs/src/source_gen.md @@ -294,6 +294,147 @@ If the statements that form a cycle are all from the same loop (potentially at d Otherwise, the program need to be rewritten. +## Implementation Overview (Current) + +At a high level, the current implementation follows a conservative, correctness‑first pipeline. It favors simple, explainable transformations and stops with diagnostics when safety cannot be guaranteed. + +- Build coarse statement graph: Construct a statement‑level dependence graph from the compiled BUGS graph. Nodes are top‑level statements; edges indicate that any variable produced by one statement is used by another. + +- Remove transformed data: Copy the model AST and remove statements whose variables are all compile‑time computable (in‑degree or out‑degree zero at the coarse level). This keeps the runtime program focused on values that must be evaluated. + +- Fully fission loops: Split each loop so that every resulting statement runs in its own loop nest. This simplifies subsequent ordering by ensuring we can reason per‑statement and per‑loop nest. The loop nests are retained as metadata with each fissioned statement. + +- Dependence vectors at fine granularity: For every coarse edge, find corresponding fine‑grained variable edges and classify their dependence relation using lexicographic comparison of loop indices: + - Zero: loop‑independent (same iteration) + - Positive: loop‑carried, lexicographically non‑negative + - Negative: lexicographically negative (unsafe for sequential order) + - Unknown: cannot be compared (e.g., different loop nests or missing loop information) + +- Filter into an ordering graph: Build an ordering graph for statements by: + - Dropping self‑edges when all corresponding fine edges are Positive (safe loop‑carried self‑recurrence such as x[t] depends on x[t‑1]). + - Retaining all cross‑statement edges by default to preserve producer→consumer ordering. + - Aborting if any Negative dependence is observed (unsafe), or recording Unknown dependences for diagnostics. + +- Resolve remaining cycles conservatively: If cycles remain in the ordering graph, attempt limited loop fusion within strongly‑connected components (SCCs) that meet all of the following: + - All statements share the exact same loop nest (same variables and identical bounds). + - No Negative or Unknown fine‑grained dependences among the SCC members. + - The subgraph induced by Zero dependences is acyclic, providing a valid intra‑iteration order. + If this succeeds, expand clusters in topological order to obtain a global statement order. Otherwise, abort with diagnostics. + +- Reconstruct structured loops: After sorting the fissioned statements, group consecutive statements with identical loop nests and reconstruct a single nested `for` around a block of statements rather than emitting many tiny loops. This preserves structure while avoiding over‑fissioning in the final program. + +- Lower observations and flatten blocks: Insert observation guards/casts during lowering, and flatten intermediate `:block` nodes introduced by reconstruction so that analysis and codegen see a normalized statement sequence. + +Diagnostics are collected throughout and surfaced to help users rewrite programs when the transformation cannot be proven safe (e.g., negative or unknown dependences). + +### What works well now + +- State‑space patterns with self‑recurrence inside a single loop nest (e.g., x[t] depends on x[t‑1]) +- Cross‑coupled SSMs where multiple state arrays reference each other at lag 1, provided they share the same time loop +- Grid SSMs with independent per‑row recurrences, plus observation loops reading current state + +### What we intentionally reject (for now) + +- Inter‑loop cycles that require general loop fusion across different loop nests (e.g., even/odd split loops over the same domain but structured as separate loops) +- Data‑dependent indexing that produces Unknown dependences across loop nests +- Any pattern that induces Negative dependences under lexicographic ordering + +These cases either need manual refactoring into a single loop with a clear per‑iteration ordering, or future research‑grade transforms beyond our current scope. + +### Reference entry points (for developers) + +The following locations contain the mechanics described above: + +- Coarse graph, fission, and reconstruction: `JuliaBUGS/src/source_gen.jl:347` +- Grouping statements into shared loop nests: `JuliaBUGS/src/source_gen.jl:193` +- Loop construction around statement blocks: `JuliaBUGS/src/source_gen.jl:226` +- Fine‑grained dependence classification: `JuliaBUGS/src/source_gen.jl:667` +- Ordering graph via dependence vectors: `JuliaBUGS/src/source_gen.jl:710` +- Limited SCC loop fusion (identical loop nests): `JuliaBUGS/src/source_gen.jl:788` +- Sorting by explicit statement order: `JuliaBUGS/src/source_gen.jl:922` +- Block flattening in analysis/codegen: `JuliaBUGS/src/compiler_pass.jl:38`, `JuliaBUGS/src/source_gen.jl:532` + +There is a small SSM‑focused test harness and benchmarks accompanying this work. See `test_ssm.jl` for representative models that should succeed or fail with diagnostics, and `bench_ssm*.jl` for performance comparisons between graph traversal and generated sequential code. + +## State‑Space Models (SSM) Support + +This section summarizes how the current pipeline recognizes and transforms common SSM patterns into correct sequential code. + +### Recognized Patterns + +- Single time loop with self‑recurrence: + ```julia + x[1] ~ Normal(0, 1) + for t in 2:T + x[t] ~ Normal(x[t-1], sigma_x) + end + ``` + +- Lagged observations (read previous state): + ```julia + y[1] ~ Normal(x[1], sigma_y) + for t in 2:T + y[t] ~ Normal(x[t-1], sigma_y) + end + ``` + +- Cross‑coupled states within the same time loop (mutual lag‑1): + ```julia + x[1] ~ Normal(0, 1); y[1] ~ Normal(0, 1) + for t in 2:T + x[t] ~ Normal(y[t-1], sigma_x) + y[t] ~ Normal(x[t-1], sigma_y) + end + ``` + +- Grid SSMs with independent rows/series and a shared time dimension: + ```julia + for i in 1:I + x[i,1] ~ Normal(0, 1) + for t in 2:T + x[i,t] ~ Normal(x[i,t-1], sigma) + end + end + for i in 1:I, t in 1:T + y[i,t] ~ Normal(x[i,t], sigma_y) + end + ``` + +### Transformation Outline for SSMs + +1) Build the coarse statement graph and fully fission the input into per‑statement loop nests. + +2) Classify fine‑grained dependences between statements by comparing loop indices lexicographically: + - Positive self‑dependences (e.g., x[t] → x[t+1]) are considered safe within the same loop nest and are dropped for ordering. + - Cross‑statement edges are kept to preserve producer→consumer order (e.g., x[t] → y[t] or x[t-1] → y[t]). + - Any Negative dependence (e.g., x[t+1] used by x[t]) aborts with diagnostics. + +3) If an SCC remains cyclic but all members share the exact same loop nest, attempt conservative loop fusion: + - Verify no Negative/Unknown dependences inside the SCC. + - Use Zero‑dependence edges (loop‑independent) to order statements within each iteration; if none exist, any fixed order is acceptable because constraints are cross‑iteration only. + +4) Reconstruct: group consecutive statements that share a loop nest and emit a single nested `for` around a block of statements. + +For typical SSMs, this yields either: +- Separate time loops in producer→consumer order (e.g., first state update loop, then observation loop); or +- A single fused time loop when multiple state updates mutually depend on the previous time step (cross‑coupled case). + +### Examples (from tests) + +- Basic SSM and lagged observations: accepted and reconstructed into sequential loops. +- Cross‑coupled SSM: accepted; body contains both state updates per time step, unordered within iteration because constraints are cross‑iteration only. +- Grid SSM: accepted for independent rows; observations read current state. +- Negative dependence (reading future): rejected with diagnostics. +- Inter‑loop cycle requiring even/odd fusion across separate loops: rejected (manual refactor recommended into one time loop). + +### Authoring Tips for SSMs + +- Keep all state updates for a given time index inside a single time loop with identical bounds. +- Provide clear initial conditions (e.g., `x[1]`, `y[1]`). +- Avoid referencing “future” states (e.g., `x[t+1]` inside the body); these create Negative dependences. +- Prefer lag‑1 or other non‑negative lexicographic lags where the loop bounds make dependencies valid. +- Avoid splitting a single logical time loop into multiple separate loops that mutually depend on each other (e.g., even/odd passes). If needed, write one fused time loop explicitly. + We don't attempt to apply further transformations to the program, because it is a hard problem. We will use the following example to show why program transformations can be a difficult task. We will not attempt to implement the transformation demonstrated here. Consider this model, diff --git a/JuliaBUGS/src/compiler_pass.jl b/JuliaBUGS/src/compiler_pass.jl index 9333d5281..b780b21ea 100644 --- a/JuliaBUGS/src/compiler_pass.jl +++ b/JuliaBUGS/src/compiler_pass.jl @@ -35,6 +35,9 @@ function analyze_block( ) end end + elseif Meta.isexpr(statement, :block) + # Flatten nested blocks introduced by program reconstruction + analyze_block(pass, statement, loop_vars; warn_loop_bounds=warn_loop_bounds) else error("Unsupported expression in top level: $statement") end diff --git a/JuliaBUGS/src/source_gen.jl b/JuliaBUGS/src/source_gen.jl index 8a83702a4..60c290222 100644 --- a/JuliaBUGS/src/source_gen.jl +++ b/JuliaBUGS/src/source_gen.jl @@ -330,7 +330,7 @@ function __check_for_reserved_names(model_def::Expr) bad_variable_names = filter( variable_name -> startswith(string(variable_name), "__") && - endswith(string(variable_name), "__"), + endswith(string(variable_name), "__"), variable_names, ) if !isempty(bad_variable_names) @@ -345,7 +345,10 @@ function __check_for_reserved_names(model_def::Expr) end function _generate_lowered_model_def( - model_def::Expr, g::JuliaBUGS.BUGSGraph, evaluation_env::NamedTuple + model_def::Expr, + g::JuliaBUGS.BUGSGraph, + evaluation_env::NamedTuple; + diagnostics::Vector{String}=String[], ) __check_for_reserved_names(model_def) stmt_to_stmt_id = _build_stmt_to_stmt_id(model_def) @@ -353,31 +356,50 @@ function _generate_lowered_model_def( var_to_stmt_id = _build_var_to_stmt_id(model_def, g, evaluation_env, stmt_to_stmt_id) stmt_id_to_var = _build_stmt_id_to_var(var_to_stmt_id) coarse_graph = _build_coarse_dep_graph(g, stmt_to_stmt_id, var_to_stmt_id) - # If there are cycles at the coarse statement level, try to resolve them - # by analyzing fine-grained dependence vectors. If all cycles are - # loop-carried with lexicographically non-negative distances within - # the same loop nest, they are sequentially valid and can be ignored - # for statement reordering. - ordering_graph, ok = _build_ordering_graph_via_dependence_vectors( - g, coarse_graph, var_to_stmt_id - ) - if !ok || Graphs.is_cyclic(ordering_graph) - # Either we detected lexicographically negative dependences or - # remaining cycles cannot be resolved by dependence vectors. - return nothing, nothing - end - # show_coarse_graph(stmt_id_to_stmt, coarse_graph) + # Remove transformed data before fissioning model_def_removed_transformed_data = _copy_and_remove_stmt_with_degree_0( model_def, stmt_to_stmt_id, coarse_graph ) + # Fully fission now so we can reason about each statement's loop nest fissioned_stmts = _fully_fission_loop( model_def_removed_transformed_data, stmt_to_stmt_id, evaluation_env ) - # Use the filtered ordering graph (with loop-carried non-negative - # dependences removed) to sort fissioned statements. - sorted_fissioned_stmts = _sort_fissioned_stmts( - ordering_graph, fissioned_stmts, stmt_to_stmt_id + # If there are cycles at the coarse statement level, try to resolve them + # by analyzing fine-grained dependence vectors. If all cycles are + # loop-carried with lexicographically non-negative distances within + # the same loop nest, they are sequentially valid and we can either + # drop those edges (self or cross-statement) or fuse statements into + # a single loop with per-iteration ordering. + ordering_graph, ok = _build_ordering_graph_via_dependence_vectors( + g, coarse_graph, var_to_stmt_id; diagnostics=diagnostics ) + sorted_fissioned_stmts = nothing + if !ok || Graphs.is_cyclic(ordering_graph) + # Try to resolve remaining cycles by loop fusion within identical loop nests. + stmt_order = _attempt_resolve_cycles_via_loop_fusion( + g, + ordering_graph, + var_to_stmt_id, + fissioned_stmts, + stmt_to_stmt_id; + diagnostics=diagnostics, + ) + if stmt_order === nothing + if !isempty(diagnostics) + @warn "Source generation aborted due to unsafe/corner-case dependencies\n - $(join(diagnostics, "\n - "))" + end + return nothing, nothing + end + sorted_fissioned_stmts = _sort_fissioned_stmts_by_stmt_order( + stmt_order, fissioned_stmts, stmt_to_stmt_id + ) + else + # Use the filtered ordering graph (with loop-carried non-negative + # dependences removed) to sort fissioned statements. + sorted_fissioned_stmts = _sort_fissioned_stmts( + ordering_graph, fissioned_stmts, stmt_to_stmt_id + ) + end reconstructed_model_def = _reconstruct_model_def_from_sorted_fissioned_stmts( sorted_fissioned_stmts ) @@ -526,6 +548,10 @@ function __gen_logp_density_function_body_exprs(stmts::Vector, evaluation_env, e elseif Meta.isexpr(stmt, :if) new_if = _handle_if_expr(stmt, evaluation_env) push!(exprs, new_if) + elseif Meta.isexpr(stmt, :block) + # Flatten nested blocks (e.g., grouped statements inside loop bodies) + new_inner = __gen_logp_density_function_body_exprs(stmt.args, evaluation_env) + append!(exprs, new_inner) else error("Unsupported statement: $stmt") end @@ -668,9 +694,7 @@ function _lex_dependence_relation(src_lv::NamedTuple, dst_lv::NamedTuple) end # Classify a fine-grained edge by its dependence vector category -function _classify_fine_edge( - g::JuliaBUGS.BUGSGraph, src_vn::VarName, dst_vn::VarName -) +function _classify_fine_edge(g::JuliaBUGS.BUGSGraph, src_vn::VarName, dst_vn::VarName) src_lv = g[src_vn].loop_vars dst_lv = g[dst_vn].loop_vars rel = _lex_dependence_relation(src_lv, dst_lv) @@ -684,7 +708,8 @@ end function _build_ordering_graph_via_dependence_vectors( g::JuliaBUGS.BUGSGraph, coarse_graph::Graphs.SimpleDiGraph, - var_to_stmt_id::Dict{VarName,Int}, + var_to_stmt_id::Dict{VarName,Int}; + diagnostics::Vector{String}=String[], ) ordering_graph = Graphs.SimpleDiGraph(Graphs.nv(coarse_graph)) @@ -707,8 +732,16 @@ function _build_ordering_graph_via_dependence_vectors( for (src_vn, dst_vn) in fine_edges rel = _classify_fine_edge(g, src_vn, dst_vn) if rel === :negative - # Invalid sequential order due to negative dependence + push!( + diagnostics, + "Negative dependence prevents ordering: $(src_stmt_id) -> $(dst_stmt_id) via $(src_vn) -> $(dst_vn)", + ) return ordering_graph, false + elseif rel === :unknown + push!( + diagnostics, + "Unknown dependence (different loop nests or missing info): $(src_stmt_id) -> $(dst_stmt_id) via $(src_vn) -> $(dst_vn)", + ) end if rel !== :positive all_positive = false @@ -723,8 +756,8 @@ function _build_ordering_graph_via_dependence_vectors( Graphs.add_edge!(ordering_graph, src_stmt_id, dst_stmt_id) end else - # Cross-statement edge: ALWAYS keep to preserve ordering constraints - # This prevents unsafe reorderings where consumers run before producers + # Cross-statement edge: keep by default; it may later be relaxed if + # the component can be safely fused by _attempt_resolve_cycles_via_loop_fusion. Graphs.add_edge!(ordering_graph, src_stmt_id, dst_stmt_id) end end @@ -732,6 +765,191 @@ function _build_ordering_graph_via_dependence_vectors( return ordering_graph, true end +# Build a mapping from statement id to its fissioned loop nest (tuple of (var, lb, ub)). +function _build_stmt_to_loops_map(fissioned_stmts, stmt_ids) + stmt_to_loops = Dict{Int,Any}() + for (loops, stmt) in fissioned_stmts + sid = stmt_ids[first(stmt)] + stmt_to_loops[sid] = loops + end + return stmt_to_loops +end + +_loop_var_names(loops) = map(lvh -> lvh[1], collect(loops)) + +# Attempt to resolve cycles by fusing statements that: +# - are in the same SCC +# - share identical loop variable names and identical bounds (same loop nest) +# - have no lexicographically negative fine-grained dependences among them +# Ordering inside the fused loop is determined by zero-dependence edges. +# Returns a vector of statement ids in a globally valid order, or nothing if not possible. +function _attempt_resolve_cycles_via_loop_fusion( + g::JuliaBUGS.BUGSGraph, + ordering_graph::Graphs.SimpleDiGraph, + var_to_stmt_id::Dict{VarName,Int}, + fissioned_stmts, + stmt_ids::IdDict{Expr,Int}; + diagnostics::Vector{String}=String[], +) + stmt_to_loops = _build_stmt_to_loops_map(fissioned_stmts, stmt_ids) + + # Identify SCCs + sccs = Graphs.strongly_connected_components(ordering_graph) + + # Track which SCCs we will fuse and their internal orders + fuseable = Dict{Int,Vector{Int}}() # scc_index => ordered stmt ids inside SCC + + for (scc_idx, nodes) in enumerate(sccs) + if length(nodes) <= 1 + continue + end + + # Require all statements in SCC to have identical loop nests (names and bounds) + loops_first = get(stmt_to_loops, nodes[1], nothing) + if loops_first === nothing + push!(diagnostics, "Cannot fuse SCC $(scc_idx): missing loop nest metadata") + return nothing + end + names_first = _loop_var_names(loops_first) + same_loops = true + for n in nodes[2:end] + loops_n = get(stmt_to_loops, n, nothing) + if loops_n === nothing + push!( + diagnostics, + "Cannot fuse SCC $(scc_idx): missing loop nest metadata for statement $(n)", + ) + return nothing + end + if _loop_var_names(loops_n) != names_first || loops_n != loops_first + same_loops = false + break + end + end + if !same_loops + push!( + diagnostics, + "Cannot fuse SCC $(scc_idx): statements have different loop nests", + ) + return nothing + end + + # Build a subgraph with edges only for zero-dependence (loop-independent) relations + zero_graph = Graphs.SimpleDiGraph(length(nodes)) + idx_of = Dict(n => i for (i, n) in enumerate(nodes)) + + # Check all fine edges among nodes for negativity/unknown; collect zero edges + for u in nodes, v in nodes + if u == v + continue + end + # find all fine-grained edges mapping u->v + fine_edges = _find_corresponding_fine_grained_edges(g, var_to_stmt_id, u, v) + if isempty(fine_edges) + continue + end + # classify + has_zero = false + for (src_vn, dst_vn) in fine_edges + rel = _classify_fine_edge(g, src_vn, dst_vn) + if rel === :negative + push!( + diagnostics, + "Cannot fuse SCC $(scc_idx): negative dependence inside SCC ($(u) -> $(v))", + ) + return nothing + elseif rel === :unknown + # Edges across different loop nests or missing loop info + # make this SCC unsafe to fuse; abort. + push!( + diagnostics, + "Cannot fuse SCC $(scc_idx): unknown dependence inside SCC ($(u) -> $(v))", + ) + return nothing + elseif rel === :zero + has_zero = true + end + end + if has_zero + Graphs.add_edge!(zero_graph, idx_of[u], idx_of[v]) + end + end + + # zero_graph must be acyclic to yield an intra-iteration order + if Graphs.is_cyclic(zero_graph) + push!( + diagnostics, + "Cannot fuse SCC $(scc_idx): intra-iteration order (zero-dep edges) is cyclic", + ) + return nothing + end + local_order = [nodes[i] for i in Graphs.topological_sort(zero_graph)] + # If zero_graph has no edges, keep original node order as a fallback + if isempty(local_order) + local_order = copy(nodes) + end + fuseable[scc_idx] = local_order + end + + # If there are cycles but none were fuseable, abort + any_fused = any(length(v) > 1 for v in values(fuseable)) + if !any_fused + push!(diagnostics, "No fuseable SCCs found; cycles remain") + return nothing + end + + # Build a condensed cluster graph: each SCC becomes a cluster; for fuseable SCCs + # we will drop internal edges and expand in the computed local order later. + cluster_graph = Graphs.SimpleDiGraph(length(sccs)) + # Map stmt -> cluster index + stmt_to_cluster = Dict{Int,Int}() + for (ci, ns) in enumerate(sccs) + for n in ns + stmt_to_cluster[n] = ci + end + end + # Add inter-cluster edges + for e in Graphs.edges(ordering_graph) + cu = stmt_to_cluster[Graphs.src(e)] + cv = stmt_to_cluster[Graphs.dst(e)] + if cu != cv + Graphs.add_edge!(cluster_graph, cu, cv) + end + end + + # Topologically sort clusters + cluster_order = Graphs.topological_sort(cluster_graph) + # Expand clusters into a flat statement order + stmt_order = Int[] + for cid in cluster_order + nodes = sccs[cid] + if haskey(fuseable, cid) + append!(stmt_order, fuseable[cid]) + else + # size-1 SCC or non-fuseable SCC (should not exist here if cycles remain) + append!(stmt_order, nodes) + end + end + return stmt_order +end + +# Sort fissioned statements according to an explicit statement id order +function _sort_fissioned_stmts_by_stmt_order( + stmt_order::Vector{Int}, fissioned_stmts, stmt_ids +) + order_pos = Dict{Int,Int}(sid => i for (i, sid) in enumerate(stmt_order)) + # Filter only statements that appear in order (some transformed-data removed ones may be absent) + items = [] + for (loops, stmt) in fissioned_stmts + sid = stmt_ids[first(stmt)] + if haskey(order_pos, sid) + push!(items, (order_pos[sid], loops, stmt)) + end + end + sort!(items; by=x -> x[1]) + return [(loops, stmt) for (_, loops, stmt) in items] +end + struct CollectSortedNodes{ET} <: CompilerPass sorted_nodes::Vector{<:VarName} env::ET diff --git a/JuliaBUGS/test/source_gen.jl b/JuliaBUGS/test/source_gen.jl index 322ee5c52..b0cb06094 100644 --- a/JuliaBUGS/test/source_gen.jl +++ b/JuliaBUGS/test/source_gen.jl @@ -68,3 +68,113 @@ end model = compile(model_def, data) end + +@testset "state-space models (SSM) transformation" begin + # Helper: run semantic analysis + graph + source-generation and return success flag + diags + function _gen_ok(model_def, data) + eval_env = JuliaBUGS.semantic_analysis(model_def, data) + g = JuliaBUGS.create_graph(model_def, eval_env) + diags = String[] + lowered, reconstructed = JuliaBUGS._generate_lowered_model_def( + model_def, g, eval_env; diagnostics=diags + ) + return lowered !== nothing, diags + end + + # 1) Basic SSM with self-recursion and observations + model_def1 = @bugs begin + x[1] ~ Normal(0, 1) + for t in 2:T + x[t] ~ Normal(x[t - 1], sigma_x) + end + for t in 1:T + y[t] ~ Normal(x[t], sigma_y) + end + end + ok1, _ = _gen_ok(model_def1, (T=10, sigma_x=0.5, sigma_y=0.3)) + @test ok1 + + # 2) Lagged observations depend on previous state + model_def2 = @bugs begin + x[1] ~ Normal(0, 1) + for t in 2:T + x[t] ~ Normal(x[t - 1], sigma_x) + end + y[1] ~ Normal(x[1], sigma_y) + for t in 2:T + y[t] ~ Normal(x[t - 1], sigma_y) + end + end + ok2, _ = _gen_ok(model_def2, (T=10, sigma_x=0.5, sigma_y=0.3)) + @test ok2 + + # 3) Cross-coupled SSM (mutual lag-1) in a single time loop + model_def3 = @bugs begin + x[1] ~ Normal(0, 1) + y[1] ~ Normal(0, 1) + for t in 2:T + x[t] ~ Normal(y[t - 1], sigma_x) + y[t] ~ Normal(x[t - 1], sigma_y) + end + end + ok3, _ = _gen_ok(model_def3, (T=10, sigma_x=0.5, sigma_y=0.3)) + @test ok3 + + # 4) Invalid negative dependence (read future state) + model_def4 = @bugs begin + for t in 1:(T - 1) + x[t] ~ Normal(x[t + 1], sigma) + end + x[T] ~ Normal(0, 1) + end + ok4, _ = _gen_ok(model_def4, (T=10, sigma=0.7)) + @test !ok4 + + # 5) Grid SSM: independent per-row recurrences, observations at current time + model_def5 = @bugs begin + for i in 1:I + x[i, 1] ~ Normal(0, 1) + for t in 2:T + x[i, t] ~ Normal(x[i, t - 1], sigma) + end + end + for i in 1:I + for t in 1:T + y[i, t] ~ Normal(x[i, t], sigma_y) + end + end + end + ok5, _ = _gen_ok(model_def5, (I=3, T=10, sigma=0.7, sigma_y=0.3)) + @test ok5 + + # 6) Inter-loop cycle (even/odd) requiring general fusion across separate loops -> reject + model_def6 = @bugs begin + sumX[1] = x[1] + for i in 2:N + sumX[i] = sumX[i - 1] + x[i] + end + for k in 1:div(N, 2) # even indices + x[2 * k] ~ Normal(sumX[2 * k - 1], tau) + end + for k in 1:(div(N, 2) - 1) # odd indices + x[2 * k + 1] ~ Gamma(sumX[2 * k], tau) + end + end + ok6, _ = _gen_ok( + model_def6, (N=10, tau=1.2, x=Union{Float64,Missing}[1.0; fill(missing, 9)]) + ) + @test !ok6 + + # 7) Data-dependent indexing induces unknown/cyclic deps -> reject + model_def7 = @bugs begin + z[1] = 0.0 + z[2] = x[1] + 0.0 + y[1] = 0.0 + y[2] = x[3] + 0.0 + for i in 1:3 + x[i] = y[a[i]] + z[b[i]] + end + end + ok7, _ = _gen_ok(model_def7, (a=[2, 2, 1], b=[2, 1, 2])) + @test !ok7 +end