@@ -376,6 +376,16 @@ number of `sampler`.
376
376
"""
377
377
(model:: Model )(args... ) = first (evaluate!! (model, args... ))
378
378
379
+ """
380
+ use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
381
+
382
+ Return `true` if evaluation of a model using `context` and `varinfo` should
383
+ wrap `varinfo` in `ThreadSafeVarInfo`, i.e. threadsafe evaluation, and `false` otherwise.
384
+ """
385
+ function use_threadsafe_eval (context:: AbstractContext , varinfo:: AbstractVarInfo )
386
+ return Threads. nthreads () > 1
387
+ end
388
+
379
389
"""
380
390
evaluate!!(model::Model[, rng, varinfo, sampler, context])
381
391
@@ -388,10 +398,10 @@ The method resets the log joint probability of `varinfo` and increases the evalu
388
398
number of `sampler`.
389
399
"""
390
400
function evaluate!! (model:: Model , varinfo:: AbstractVarInfo , context:: AbstractContext )
391
- if Threads . nthreads () == 1
392
- return evaluate_threadunsafe !! (model, varinfo, context)
401
+ return if use_threadsafe_eval (context, varinfo)
402
+ evaluate_threadsafe !! (model, varinfo, context)
393
403
else
394
- return evaluate_threadsafe !! (model, varinfo, context)
404
+ evaluate_threadunsafe !! (model, varinfo, context)
395
405
end
396
406
end
397
407
0 commit comments