Skip to content

Efficient way to copy SyntaxTree for concurrent lowering of subtrees #83

@aviatesk

Description

@aviatesk

I'm working on a PR to parallelize JETLS (aviatesk/JETLS.jl#241), where I would like to concurrently lower multiple sub trees from a single top-level file, and for that, I'm looking into how to lower SyntaxTrees in parallel and safely.
Since JL doesn't really have any global states, I initially naively thought JL.lower could be called in a multithreaded way. However, I didn't consider the shared state between parent and child syntax trees. This means if you naively call JL.lower in parallel on sub-statement level trees from a SyntaxTree parsed with rule = :all, you'll run into various dangerous errors like concurrency errors or stack overflows.

For example, consider a script like the one below, which first creates a toplevel syntax tree st0_top and then lowers its children in parallel. I want to do something similar in JETLS:

jlconcurrent.jl

using JuliaLowering: JuliaLowering as JL
using JuliaSyntax: JuliaSyntax as JS

function parsedstream(s::AbstractString; rule::Symbol=:all)
    stream = JS.ParseStream(s)
    JS.parse!(stream; rule)
    return stream
end

jlparse(s::AbstractString; rule::Symbol=:all, kwargs...) = jlparse(parsedstream(s; rule); kwargs...)
jlparse(parsed_stream::JS.ParseStream; filename::AbstractString=@__FILE__, first_line::Int=1) =
    JS.build_tree(JL.SyntaxTree, parsed_stream; filename, first_line)

txt = """
server = Server();
uri = filename2uri(abspath("src/JETLS.jl"));
txt = read("src/JETLS.jl", String);
JETLS.cache_file_info!(server.state, uri, 1, txt);
"""

st0_top = jlparse(txt);

waitall([
    Threads.@spawn JL.lower(Main, st0_top[1])
    Threads.@spawn JL.lower(Main, st0_top[2])
    Threads.@spawn JL.lower(Main, st0_top[3])
    Threads.@spawn JL.lower(Main, st0_top[4])
])

If you run this with multiple threads, various errors occur depending on the timing:

julia/packages/JETLS on  avi/multithreading [⇣$!] via ஃ 1.12.0
➜ julia --startup-file=no --threads=auto jlconcurrent.jl
ERROR: LoadError: TaskFailedException

    nested task error: ConcurrencyViolationError("Vector has invalid state. Don't modify internal fields incorrectly, or resize without correct locks")
    Stacktrace:
      [1] (::Base.var"#_growend!##0#_growend!##1"{Vector{UnitRange{Int64}}, Int64, Int64, Int64, Int64, Int64, Memory{UnitRange{Int64}}, MemoryRef{UnitRange{Int64}}})()
        @ Base ./array.jl:1133
      [2] _growend!
        @ ./array.jl:1131 [inlined]
      [3] _push!
        @ ./array.jl:1289 [inlined]
      [4] push!
        @ ./array.jl:1286 [inlined]
      [5] newnode!
        @ ~/.julia/packages/JuliaLowering/A0OwQ/src/syntax_graph.jl:98 [inlined]
      [6] _makenode(graph::JuliaLowering.SyntaxGraph{Dict{Symbol, Any}}, srcref::JuliaLowering.SyntaxTree{JuliaLowering.SyntaxGraph{Dict{Symbol, Any}}}, proto::JuliaSyntax.Kind, children::Vector{Int64}; attrs::@Kwargs{})
        @ JuliaLowering ~/.julia/packages/JuliaLowering/A0OwQ/src/ast.jl:123
      [7] _makenode
        @ ~/.julia/packages/JuliaLowering/A0OwQ/src/ast.jl:122 [inlined]
      [8] #_makenode#28
        @ ~/.julia/packages/JuliaLowering/A0OwQ/src/ast.jl:131 [inlined]
      [9] _makenode
        @ ~/.julia/packages/JuliaLowering/A0OwQ/src/ast.jl:130 [inlined]
     [10] macro expansion
        @ ~/.julia/packages/JuliaLowering/A0OwQ/src/ast.jl:268 [inlined]
     [11] expand_forms_2(ctx::JuliaLowering.DesugaringContext{JuliaLowering.SyntaxGraph{Dict{Symbol, Any}}}, ex::JuliaLowering.SyntaxTree{JuliaLowering.SyntaxGraph{Dict{Symbol, Any}}}, docs::Nothing)
        @ JuliaLowering ~/.julia/packages/JuliaLowering/A0OwQ/src/desugaring.jl:4516
     [12] expand_forms_2
        @ ~/.julia/packages/JuliaLowering/A0OwQ/src/desugaring.jl:4348 [inlined]
     [13] expand_forms_2
        @ ~/.julia/packages/JuliaLowering/A0OwQ/src/desugaring.jl:4602 [inlined]
     [14] lower(mod::Module, ex0::JuliaLowering.SyntaxTree{JuliaLowering.SyntaxGraph{Dict{Symbol, Any}}}; expr_compat_mode::Bool, world::UInt64)
        @ JuliaLowering ~/.julia/packages/JuliaLowering/A0OwQ/src/eval.jl:3
     [15] lower(mod::Module, ex0::JuliaLowering.SyntaxTree{JuliaLowering.SyntaxGraph{Dict{Symbol, Any}}})
        @ JuliaLowering ~/.julia/packages/JuliaLowering/A0OwQ/src/eval.jl:1
     [16] (::var"#8#9")()
        @ Main ~/julia/packages/JETLS/jlconcurrent.jl:24

...and 2 more exceptions.

Stacktrace:
 [1] _wait_multiple(waiting_tasks::Vector{Task}, throwexc::Bool, all::Bool, failfast::Bool)
   @ Base ./task.jl:503
 [2] waitall(tasks::Vector{Task})
   @ Base ./task.jl:404
 [3] top-level scope
   @ ~/julia/packages/JETLS/jlconcurrent.jl:0
 [4] macro expansion
   @ threadingconstructs.jl:523 [inlined]
 [5] include(mod::Module, _path::String)
   @ Base ./Base.jl:305
 [6] exec_options(opts::Base.JLOptions)
   @ Base ./client.jl:321
 [7] _start()
   @ Base ./client.jl:554
in expression starting at ~/julia/packages/JETLS/jlconcurrent.jl:23

The tricky part here is that JL.lower(Main, st0_top[1]) modifies the state of st0_top itself, leading to classic race conditions if other children are being lowered concurrently.

I think st0_top[1] sharing internal state with st0_top is the right choice for a serial computing environment, so for parallel lowering, we probably need to make copies.
I've confirmed that this issue can be avoided by calling JL.lower on deepcopy(st0_top[i]) instead of st0_top[i].
However, the problem is that deepcopy(st0_top[i]) is quite expensive, taking about the same amount of time as JL.lower(st0_top[i]) (around 20ms).
What's more, deepcopy(st0_top) seems to be even more expensive than creating st0_top from scratch:

julia> s = read("src/diagnostics.jl", String);

julia> @benchmark jlparse($s)
BenchmarkTools.Trial: 2868 samples with 1 evaluation per sample.
 Range (min  max):  1.503 ms    9.615 ms  ┊ GC (min  max): 0.00%  82.20%
 Time  (median):     1.567 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   1.742 ms ± 756.634 μs  ┊ GC (mean ± σ):  9.16% ± 14.12%

  █▇▄▂
  ████▇▃▅ ▃         ▃▃       ▃ ▃▄▃▃▃▄▅▄ ▅▅▅▄▅▅▆▅▅▅▄▅▄▄▆▆▅▆▅▆▅ █
  1.5 ms       Histogram: log(frequency) by time       5.4 ms <

 Memory estimate: 2.24 MiB, allocs estimate: 36298.

julia> st0_top = jlparse(s);

julia> @benchmark deepcopy($st0_top)
BenchmarkTools.Trial: 258 samples with 1 evaluation per sample.
 Range (min  max):  18.472 ms  29.972 ms  ┊ GC (min  max): 0.00%  36.63%
 Time  (median):     19.224 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   19.436 ms ±  1.392 ms  ┊ GC (mean ± σ):  1.18% ±  4.83%

   ▂ ▇▃▇▇█
  ▇████████▇█▃▄ ▂                                    ▂  ▃ ▂ ▂ ▃
  18.5 ms         Histogram: frequency by time        26.1 ms <

 Memory estimate: 3.09 MiB, allocs estimate: 67931.

So, it might seem a bit counter-intuitive, but I understand that the most efficient way to do concurrent lowering right now is to parse the same source::AbstractString in each parallel task and then apply JL.lower to the result.

Of course, this level of latency might not be an issue in practice, and it might be fine to create a SyntaxTree from scratch for each message handling. Or, introducing a thread-safe cache for lowered syntax trees in the LS could further improve performance.
However, considering memory pressure, especially in a parallel environment, I'd like to know if there's a more efficient way to copy a SyntaxTree.
Could we add an API to create an isolated copy of a SyntaxTree? This would be more efficient than building one from scratch. We might want a copy(::SyntaxTree) implementation that uses less memory than jlparse(s).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions