Skip to content

Commit 2f81a7a

Browse files
authored
Merge pull request #163 from JuliaArrays/cjf/generated-function-fixes
Respect generated function invariants
2 parents a5556f1 + 8fe78ec commit 2f81a7a

File tree

3 files changed

+22
-12
lines changed

3 files changed

+22
-12
lines changed

src/core.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,13 +274,10 @@ Given an AxisArray and an Axis, return the integer dimension of
274274
the Axis within the array.
275275
"""
276276
axisdim(A::AxisArray, ax::Axis) = axisdim(A, typeof(ax))
277-
@generated function axisdim(A::AxisArray, ax::Type{Ax}) where Ax<:Axis
278-
dim = axisdim(A, Ax)
279-
:($dim)
280-
end
277+
axisdim(A::AxisArray, ax::Type{Ax}) where Ax<:Axis = axisdim(typeof(A), Ax)
281278
# The actual computation is done in the type domain, which is a little tricky
282279
# due to type invariance.
283-
function axisdim(::Type{AxisArray{T,N,D,Ax}}, ::Type{<:Axis{name,S} where S}) where {T,N,D,Ax,name}
280+
@generated function axisdim(::Type{AxisArray{T,N,D,Ax}}, ::Type{<:Axis{name,S} where S}) where {T,N,D,Ax,name}
284281
isa(name, Int) && return name <= N ? name : error("axis $name greater than array dimensionality $N")
285282
names = axisnames(Ax)
286283
idx = findfirst(isequal(name), names)

src/indexing.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,15 +303,28 @@ function axisindexes(::Type{Categorical}, ax::AbstractVector, idx::AbstractVecto
303303
res
304304
end
305305

306+
# Creates *instances* of axis traits for a set of axes.
307+
# TODO: Transition axistrait() to return trait instances in line with common
308+
# practice in Base and other packages.
309+
#
310+
# This function is a utility tool to ensure that `axistrait` is only called
311+
# from outside the generated function below. (If not, we can get world age
312+
# errors.)
313+
_axistraits(ax1, rest...) = (axistrait(ax1)(), _axistraits(rest...)...)
314+
_axistraits() = ()
315+
306316
# This catch-all method attempts to convert any axis-specific non-standard
307317
# indexing types to their integer or integer range equivalents using axisindexes
308318
# It is separate from the `Base.getindex` function to allow reuse between
309319
# set- and get- index.
310-
@generated function to_index(A::AxisArray{T,N,D,Ax}, I...) where {T,N,D,Ax}
320+
to_index(A::AxisArray, I...) = _to_index(A, _axistraits(I...), I...)
321+
322+
@generated function _to_index(A::AxisArray{T,N,D,Ax}, axtraits, I...) where {T,N,D,Ax}
311323
ex = Expr(:tuple)
312324
n = 0
325+
axtrait_types = axtraits.parameters
313326
for i=1:length(I)
314-
if axistrait(I[i]) <: Categorical && i <= length(Ax.parameters)
327+
if axtrait_types[i] <: Categorical && i <= length(Ax.parameters)
315328
if I[i] <: Axis
316329
push!(ex.args, :(axisindexes(A.axes[$i], I[$i].val)))
317330
else

test/indexing.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,11 @@ B = AxisArray(reshape(1:15, 5,3), .1:.1:0.5, [:a, :b, :c])
7070
@test @view(B[ClosedInterval(0.2, 0.6), :]) == @view(B[ClosedInterval(0.2, 0.6)]) == B[2:end,:]
7171

7272
# Test Categorical indexing
73-
@test B[:, :a] == @view(B[:, :a]) == B[:,1]
74-
@test B[:, :c] == @view(B[:, :c]) == B[:,3]
75-
@test B[:, [:a]] == @view(B[:, [:a]]) == B[:,[1]]
76-
@test B[:, [:c]] == @view(B[:, [:c]]) == B[:,[3]]
77-
@test B[:, [:a,:c]] == @view(B[:, [:a,:c]]) == B[:,[1,3]]
73+
@test @inferred(B[:, :a]) == @view(B[:, :a]) == B[:,1]
74+
@test @inferred(B[:, :c]) == @view(B[:, :c]) == B[:,3]
75+
@test @inferred(B[:, [:a]]) == @view(B[:, [:a]]) == B[:,[1]]
76+
@test @inferred(B[:, [:c]]) == @view(B[:, [:c]]) == B[:,[3]]
77+
@test @inferred(B[:, [:a,:c]]) == @view(B[:, [:a,:c]]) == B[:,[1,3]]
7878

7979
@test B[Axis{:row}(ClosedInterval(0.15, 0.3))] == @view(B[Axis{:row}(ClosedInterval(0.15, 0.3))]) == B[2:3,:]
8080

0 commit comments

Comments
 (0)