Skip to content

Commit bdbaf32

Browse files
torfjeldegithub-actions[bot]yebai
authored
Use view whenever possible (#272)
* use views whenever possible * formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * dont view literals * fixed the failing tests * added a bunch of get_sections to tests to avoid unnecessary warnings * formatting * added comment to describe maybe_unwrap_view Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Hong Ge <[email protected]>
1 parent 222091e commit bdbaf32

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

src/compiler.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ function isassumption(expr::Union{Symbol,Expr})
2727
true
2828
else
2929
# Evaluate the LHS
30-
$expr === missing
30+
$(maybe_view(expr)) === missing
3131
end
3232
end
3333
end
@@ -36,6 +36,16 @@ end
3636
# failsafe: a literal is never an assumption
3737
isassumption(expr) = :(false)
3838

39+
# If we're working with, say, a `Symbol`, then we're not going to `view`.
40+
maybe_view(x) = x
41+
maybe_view(x::Expr) = :($(DynamicPPL.maybe_unwrap_view)(@view($x)))
42+
43+
# If the result of a `view` is a zero-dim array then it's just a
44+
# single element. Likely the rest is expecting type `eltype(x)`, hence
45+
# we extract the value rather than passing the array.
46+
maybe_unwrap_view(x) = x
47+
maybe_unwrap_view(x::SubArray{<:Any,0}) = x[1]
48+
3949
"""
4050
isliteral(expr)
4151
@@ -325,7 +335,7 @@ function generate_tilde(left, right)
325335
$(DynamicPPL.tilde_observe!)(
326336
__context__,
327337
$(DynamicPPL.check_tilde_rhs)($right),
328-
$left,
338+
$(maybe_view(left)),
329339
$vn,
330340
$inds,
331341
__varinfo__,
@@ -360,7 +370,7 @@ function generate_dot_tilde(left, right)
360370
$left .= $(DynamicPPL.dot_tilde_assume!)(
361371
__context__,
362372
$(DynamicPPL.unwrap_right_left_vns)(
363-
$(DynamicPPL.check_tilde_rhs)($right), $left, $vn
373+
$(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn
364374
)...,
365375
$inds,
366376
__varinfo__,
@@ -369,7 +379,7 @@ function generate_dot_tilde(left, right)
369379
$(DynamicPPL.dot_tilde_observe!)(
370380
__context__,
371381
$(DynamicPPL.check_tilde_rhs)($right),
372-
$left,
382+
$(maybe_view(left)),
373383
$vn,
374384
$inds,
375385
__varinfo__,

0 commit comments

Comments
 (0)