Skip to content

Commit 9bd1dee

Browse files
committed
fix deadlock
1 parent 1a23c2e commit 9bd1dee

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

src/Types.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ function ConcretePJRTArray(
215215
end
216216

217217
Base.wait(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = foreach(wait, x.data)
218+
Base.isready(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = all(isready, x.data)
218219
XLA.client(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = XLA.client(x.data)
219220
function XLA.device(x::Union{ConcretePJRTArray,ConcretePJRTNumber})
220221
x.sharding isa Sharding.NoShardInfo && return XLA.device(only(x.data))
@@ -412,6 +413,7 @@ function ConcreteIFRTArray(
412413
end
413414

414415
Base.wait(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = wait(x.data)
416+
Base.isready(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = isready(x.data)
415417
XLA.client(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = XLA.client(x.data)
416418
function XLA.device(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber})
417419
return XLA.device(x.data)

src/probprog/Modeling.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -199,11 +199,13 @@ function simulate(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where
199199

200200
compiled_fn = @compile optimize = :probprog simulate_internal(rng, f, args...)
201201

202-
old_gc_state = GC.enable(false)
203-
try
202+
seed_buffer = only(rng.seed.data).buffer
203+
GC.@preserve seed_buffer begin
204204
trace, _, _ = compiled_fn(rng, f, args...)
205-
finally
206-
GC.enable(old_gc_state)
205+
206+
while !isready(trace)
207+
yield()
208+
end
207209
end
208210

209211
trace = unsafe_pointer_to_objref(Ptr{Any}(Array(trace)[1]))
@@ -278,11 +280,13 @@ function generate(
278280

279281
compiled_fn = @compile optimize = :probprog wrapper_fn(rng, constraint_ptr, args...)
280282

281-
old_gc_state = GC.enable(false)
282-
try
283+
seed_buffer = only(rng.seed.data).buffer
284+
GC.@preserve seed_buffer constraint begin
283285
trace, _, _ = compiled_fn(rng, constraint_ptr, args...)
284-
finally
285-
GC.enable(old_gc_state)
286+
287+
while !isready(trace)
288+
yield()
289+
end
286290
end
287291

288292
trace = unsafe_pointer_to_objref(Ptr{Any}(Array(trace)[1]))

0 commit comments

Comments
 (0)