Skip to content

Commit 8749131

Browse files
committed
Refactor and test inspace
1 parent d5e9866 commit 8749131

File tree

3 files changed

+61
-66
lines changed

3 files changed

+61
-66
lines changed

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ export AbstractVarInfo,
4444
#VarName
4545
VarName,
4646
inspace,
47+
subsumes,
4748
# Compiler
4849
ModelGen,
4950
@model,

src/varname.jl

Lines changed: 45 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -73,64 +73,59 @@ Base.Symbol(vn::VarName) = Symbol(string(vn)) # simplified symbol
7373

7474

7575
"""
76-
inspace(vn::Union{VarName, Symbol, Expr}, space::Tuple)
76+
inspace(vn::Union{VarName, Symbol}, space::Tuple)
7777
7878
Check whether `vn`'s variable symbol is in `space`.
7979
"""
80-
inspace(::VarName, ::Tuple{}) = true
81-
inspace(vn::VarName, space::Tuple)::Bool = getsym(vn) in space || _in(string(vn), space)
80+
inspace(vn, space::Tuple{}) = true # empty space is treated as universal set
8281
inspace(vn, space::Tuple) = vn in space
82+
inspace(vn::VarName, space::Tuple{}) = true
83+
inspace(vn::VarName, space::Tuple) = any(_in(vn, s) for s in space)
8384

84-
_in(::String, ::Tuple{}) = false
85-
_in(vn_str::String, space::Tuple)::Bool = _in(vn_str, Base.tail(space))
86-
function _in(vn_str::String, space::Tuple{Expr,Vararg})::Bool
87-
# Collect expressions from space
88-
expr = first(space)
89-
# Filter `(` and `)` out and get a string representation of `exprs`
90-
expr_str = replace(string(expr), r"\(|\)" => "")
91-
# Check if `vn_str` is in `expr_strs`
92-
valid = occursin(expr_str, vn_str)
93-
return valid || _in(vn_str, Base.tail(space))
85+
_in(vn::VarName, s::Symbol) = getsym(vn) == s
86+
_in(vn::VarName, s::VarName) = subsumes(s, vn)
87+
88+
89+
"""
90+
subsumes(u::VarName, v::VarName)
91+
92+
Check whether the variable name `v` describes a sub-range of the variable `u`. Supported
93+
indexing:
94+
95+
- Scalar: `x` subsumes `x[1, 2]`, `x[1, 2]` subsumes `x[1, 2][3]`, etc.
96+
- Array of scalar: `x[[1, 2], 3]` subsumes `x[1, 3]`, `x[1:3]` subsumes `x[2][1]`, etc.
97+
(basically everything that fulfills `issubset`).
98+
- Slices: `x[2, :]` subsumes `x[2, 10][1]`, etc.
99+
100+
Currently _not_ supported are:
101+
102+
- Boolean indexing, literal `CartesianIndex` (these could be added, though)
103+
- Linear indexing of multidimensional arrays: `x[4]` does not subsume `x[2, 2]` for `x` a matrix
104+
- Trailing ones: `x[2, 1]` does not subsume `x[2]` for `x` a vector
105+
"""
106+
function subsumes(u::VarName, v::VarName)
107+
return getsym(u) == getsym(v) && subsumes(u.indexing, v.indexing)
108+
end
109+
110+
subsumes(::Tuple{}, ::Tuple{}) = true # x subsumes x
111+
subsumes(::Tuple{}, ::Tuple) = true # x subsumes x[1]
112+
subsumes(::Tuple, ::Tuple{}) = false # x[1] does not subsume x
113+
function subsumes(t::Tuple, u::Tuple) # does x[i]... subsume x[j]...?
114+
return _issubindex(first(t), first(u)) && subsumes(Base.tail(t), Base.tail(u))
115+
end
116+
117+
const AnyIndex = Union{Int, AbstractVector{Int}, Colon}
118+
_issubindex_(::Tuple{Vararg{AnyIndex}}, ::Tuple{Vararg{AnyIndex}}) = false
119+
function _issubindex(t::NTuple{N, AnyIndex}, u::NTuple{N, AnyIndex}) where {N}
120+
return all(_issubrange(j, i) for (i, j) in zip(t, u))
94121
end
95122

123+
const ConcreteIndex = Union{Int, AbstractVector{Int}} # this include all kinds of ranges
124+
"""Determine whether indices `i` are contained in `j`, treating `:` as universal set."""
125+
_issubrange(i::ConcreteIndex, j::ConcreteIndex) = issubset(i, j)
126+
_issubrange(i::Union{ConcreteIndex, Colon}, j::Colon) = true
127+
_issubrange(i::Colon, j::ConcreteIndex) = true
96128

97-
# inspace(::Union{VarName, Symbol, Expr}, ::Tuple{}) = true
98-
# inspace(vn::Union{VarName, Symbol, Expr}, space::Tuple) = any(_ismatch(vn, s) for s in space)
99-
100-
# _ismatch(vn, s) = (_name(vn) == _name(s)) && _isprefix(_indexing(s), _indexing(vn))
101-
102-
# _isprefix(::Tuple{}, ::Tuple{}) = true
103-
# _isprefix(::Tuple{}, ::Tuple) = true
104-
# _isprefix(::Tuple, ::Tuple{}) = false
105-
# _isprefix(t::Tuple, u::Tuple) = _subsumes(first(t), first(u)) && _isprefix(Base.tail(t), Base.tail(u))
106-
107-
# const ConcreteIndex = Union{Int, AbstractVector{Int}} # this include all kinds of ranges
108-
# """Determine whether `i` is a valid if `j` is."""
109-
# _subsumes(i::ConcreteIndex, j::ConcreteIndex) = issubset(i, j)
110-
# _subsumes(i::Union{ConcreteIndex, Colon}, j::Colon) = true
111-
# _subsumes(i::Colon, j::ConcreteIndex) = false
112-
113-
# _name(vn::Symbol) = vn
114-
# _name(vn::VarName) = getsym(vn)
115-
# function _name(vn::Expr)
116-
# if Meta.isexpr(vn, :ref)
117-
# _name(vn.args[1])
118-
# else
119-
# throw("VarName: Mis-formed variable name $(vn)!")
120-
# end
121-
# end
122-
123-
# _indexing(vn::Symbol) = ()
124-
# _indexing(vn::VarName) = getindexing(vn)
125-
# function _indexing(vn::Expr)
126-
# if Meta.isexpr(vn, :ref)
127-
# init = _indexing(vn.args[1])
128-
# last = Tuple(vn.args[2:end])
129-
# return (init..., last)
130-
# else
131-
# throw("VarName: Mis-formed variable name $(vn)!")
132-
# end
133-
# end
134129

135130

136131
"""

test/varinfo.jl

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -95,23 +95,22 @@ include(dir*"/test/test_utils/AllUtils.jl")
9595
@test isempty(vi)
9696
push!(vi, vn, r, dist, gid)
9797

98-
function test_in()
99-
space = (:x, :y, :(z[1]))
100-
vn1 = @varname x
101-
vn2 = @varname y
102-
vn3 = @varname x[1]
103-
vn4 = @varname z[1][1]
104-
vn5 = @varname z[2]
105-
vn6 = @varname z
106-
107-
@test inspace(vn1, space)
108-
@test inspace(vn2, space)
109-
@test inspace(vn3, space)
110-
@test inspace(vn4, space)
111-
@test ~inspace(vn5, space)
112-
@test ~inspace(vn6, space)
98+
function test_inspace()
99+
space = (:x, :y, @varname(z[1]), @varname(M[1:10, :]))
100+
101+
@test inspace(@varname(x), space)
102+
@test inspace(@varname(y), space)
103+
@test inspace(@varname(x[1]), space)
104+
@test inspace(@varname(z[1][1]), space)
105+
@test inspace(@varname(z[1][:]), space)
106+
@test inspace(@varname(z[1][2:3:10]), space)
107+
@test inspace(@varname(M[[2,3], 1]), space)
108+
@test inspace(@varname(M[:, 1:4]), space)
109+
@test inspace(@varname(M[1, [2, 4, 6]]), space)
110+
@test !inspace(@varname(z[2]), space)
111+
@test !inspace(@varname(z), space)
113112
end
114-
test_in()
113+
test_inspace()
115114
end
116115
vi = VarInfo()
117116
test_base!(vi)

0 commit comments

Comments
 (0)