Skip to content

Commit 8fe78ec

Browse files
committed
Respect generated function invariants
Generated functions cannot call any function which might have a method added after the generated function is defined. This means we need to hoist such functions outside the generator, either by calling them beforehand (as we do with axistrait), or ensuring that all methods are defined before the generator (axisdim). This fixes #161 (test failures with precompiled-modules=no).
1 parent 48ec735 commit 8fe78ec

File tree

3 files changed

+22
-12
lines changed

3 files changed

+22
-12
lines changed

src/core.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -271,13 +271,10 @@ Given an AxisArray and an Axis, return the integer dimension of
271271
the Axis within the array.
272272
"""
273273
axisdim(A::AxisArray, ax::Axis) = axisdim(A, typeof(ax))
274-
@generated function axisdim(A::AxisArray, ax::Type{Ax}) where Ax<:Axis
275-
dim = axisdim(A, Ax)
276-
:($dim)
277-
end
274+
axisdim(A::AxisArray, ax::Type{Ax}) where Ax<:Axis = axisdim(typeof(A), Ax)
278275
# The actual computation is done in the type domain, which is a little tricky
279276
# due to type invariance.
280-
function axisdim(::Type{AxisArray{T,N,D,Ax}}, ::Type{<:Axis{name,S} where S}) where {T,N,D,Ax,name}
277+
@generated function axisdim(::Type{AxisArray{T,N,D,Ax}}, ::Type{<:Axis{name,S} where S}) where {T,N,D,Ax,name}
281278
isa(name, Int) && return name <= N ? name : error("axis $name greater than array dimensionality $N")
282279
names = axisnames(Ax)
283280
idx = Compat.findfirst(isequal(name), names)

src/indexing.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,15 +303,28 @@ function axisindexes(::Type{Categorical}, ax::AbstractVector, idx::AbstractVecto
303303
res
304304
end
305305

306+
# Creates *instances* of axis traits for a set of axes.
307+
# TODO: Transition axistrait() to return trait instances in line with common
308+
# practice in Base and other packages.
309+
#
310+
# This function is a utility tool to ensure that `axistrait` is only called
311+
# from outside the generated function below. (If not, we can get world age
312+
# errors.)
313+
_axistraits(ax1, rest...) = (axistrait(ax1)(), _axistraits(rest...)...)
314+
_axistraits() = ()
315+
306316
# This catch-all method attempts to convert any axis-specific non-standard
307317
# indexing types to their integer or integer range equivalents using axisindexes
308318
# It is separate from the `Base.getindex` function to allow reuse between
309319
# set- and get- index.
310-
@generated function to_index(A::AxisArray{T,N,D,Ax}, I...) where {T,N,D,Ax}
320+
to_index(A::AxisArray, I...) = _to_index(A, _axistraits(I...), I...)
321+
322+
@generated function _to_index(A::AxisArray{T,N,D,Ax}, axtraits, I...) where {T,N,D,Ax}
311323
ex = Expr(:tuple)
312324
n = 0
325+
axtrait_types = axtraits.parameters
313326
for i=1:length(I)
314-
if axistrait(I[i]) <: Categorical && i <= length(Ax.parameters)
327+
if axtrait_types[i] <: Categorical && i <= length(Ax.parameters)
315328
if I[i] <: Axis
316329
push!(ex.args, :(axisindexes(A.axes[$i], I[$i].val)))
317330
else

test/indexing.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,11 @@ B = AxisArray(reshape(1:15, 5,3), .1:.1:0.5, [:a, :b, :c])
7070
@test @view(B[ClosedInterval(0.2, 0.6), :]) == @view(B[ClosedInterval(0.2, 0.6)]) == B[2:end,:]
7171

7272
# Test Categorical indexing
73-
@test B[:, :a] == @view(B[:, :a]) == B[:,1]
74-
@test B[:, :c] == @view(B[:, :c]) == B[:,3]
75-
@test B[:, [:a]] == @view(B[:, [:a]]) == B[:,[1]]
76-
@test B[:, [:c]] == @view(B[:, [:c]]) == B[:,[3]]
77-
@test B[:, [:a,:c]] == @view(B[:, [:a,:c]]) == B[:,[1,3]]
73+
@test @inferred(B[:, :a]) == @view(B[:, :a]) == B[:,1]
74+
@test @inferred(B[:, :c]) == @view(B[:, :c]) == B[:,3]
75+
@test @inferred(B[:, [:a]]) == @view(B[:, [:a]]) == B[:,[1]]
76+
@test @inferred(B[:, [:c]]) == @view(B[:, [:c]]) == B[:,[3]]
77+
@test @inferred(B[:, [:a,:c]]) == @view(B[:, [:a,:c]]) == B[:,[1,3]]
7878

7979
@test B[Axis{:row}(ClosedInterval(0.15, 0.3))] == @view(B[Axis{:row}(ClosedInterval(0.15, 0.3))]) == B[2:3,:]
8080

0 commit comments

Comments
 (0)