Skip to content

Commit 2402cc8

Browse files
authored
Determine eltype and ndims type parameter positions automatically (#31)
1 parent df3040a commit 2402cc8

File tree

4 files changed

+96
-13
lines changed

4 files changed

+96
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TypeParameterAccessors"
22
uuid = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.1"
4+
version = "0.3.2"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/base/abstractarray.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
struct Self end
2+
position(a, ::Self) = Position(0)
3+
position(::Type, ::Self) = Position(0)
4+
function set_type_parameters(type::Type, ::Self, param)
5+
return error("Can't set the parent type of an unwrapped array type.")
6+
end
7+
18
position(::Type{AbstractArray}, ::typeof(eltype)) = Position(1)
29
position(::Type{AbstractArray}, ::typeof(ndims)) = Position(2)
310
default_type_parameters(::Type{AbstractArray}) = (Float64, 1)
@@ -9,14 +16,6 @@ default_type_parameters(::Type{<:Array}) = (Float64, 1)
916
position(::Type{<:BitArray}, ::typeof(ndims)) = Position(1)
1017
default_type_parameters(::Type{<:BitArray}) = (1,)
1118

12-
struct Self end
13-
position(a, ::Self) = Position(0)
14-
position(::Type{T}, ::Self) where {T} = Position(0)
15-
16-
function set_type_parameters(type::Type, ::Self, param)
17-
return error("Can't set the parent type of an unwrapped array type.")
18-
end
19-
2019
function set_eltype(array::AbstractArray, param)
2120
return convert(set_eltype(typeof(array), param), array)
2221
end

src/type_parameters.jl

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,62 @@ function position end
2727
position(object, name) = position(typeof(object), name)
2828
position(::Type, pos::Int) = Position(pos)
2929
position(::Type, pos::Position) = pos
30+
3031
function position(type::Type, name)
31-
base_type = unspecify_type_parameters(type)
32-
base_type === type && error("`position` not defined for $type and $name.")
33-
return position(base_type, name)
32+
type′ = unspecify_type_parameters(type)
33+
if type === type′
34+
# Fallback definition that determines the
35+
# position automatically from the supertype of
36+
# the type.
37+
return position_from_supertype(type′, name)
38+
end
39+
return position(type′, name)
40+
end
41+
42+
# Automatically determine the position of a type parameter of a type given
43+
# a supertype and the name of the parameter.
44+
function position_from_supertype(type::Type, name)
45+
type′ = unspecify_type_parameters(type)
46+
supertype_pos = position(supertype(type′), name)
47+
return position_from_supertype_position(type′, supertype_pos)
48+
end
49+
50+
# Automatically determine the position of a type parameter of a type given
51+
# the supertype and the position of the corresponding parameter in the supertype.
52+
@generated function position_from_supertype_position(
53+
::Type{T}, supertype_pos::Position
54+
) where {T}
55+
T′ = unspecify_type_parameters(T)
56+
# The type parameters of the type as TypeVars.
57+
# TODO: Ideally we would use `get_type_parameters`
58+
# but that sometimes loses TypeVar names:
59+
# https://github.com/ITensor/TypeParameterAccessors.jl/issues/30
60+
type_params = Base.unwrap_unionall(T′).parameters
61+
# The type parameters of the immediate supertype as TypeVars.
62+
# This has TypeVars with names that correspond to the names of
63+
# the TypeVars of the type parameters of `T`, for example:
64+
# ```julia
65+
# julia> struct MyArray{B,A} <: AbstractArray{A,B} end
66+
#
67+
# julia> Base.unwrap_unionall(MyArray).parameters
68+
# svec(B, A)
69+
#
70+
# julia> Base.unwrap_unionall(supertype(MyArray)).parameters
71+
# svec(A, B)
72+
# ```
73+
supertype_params = Base.unwrap_unionall(supertype(T)).parameters
74+
supertype_param = supertype_params[Int(supertype_pos)]
75+
pos = findfirst(param -> (param.name == supertype_param.name), type_params)
76+
if isnothing(pos)
77+
return error("Position not found.")
78+
end
79+
return :(@inline; $(Position(pos)))
80+
end
81+
82+
# Automatically determine the position of a type parameter of a type given
83+
# a supertype and the name of the parameter.
84+
function position_from_supertype(type::Type, supertype_target::Type, name)
85+
return position_from_supertype(type, supertype_target, position(supertype_target, name))
3486
end
3587

3688
function positions(::Type{T}, pos::Tuple) where {T}

test/test_basics.jl

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@ using JLArrays: JLArray, JLMatrix, JLVector
22
using Test: @test, @test_throws, @test_broken, @testset
33
using TestExtras: @constinferred
44
using TypeParameterAccessors:
5-
set_type_parameters, specify_type_parameters, type_parameters, unspecify_type_parameters
5+
TypeParameterAccessors,
6+
Position,
7+
set_type_parameters,
8+
specify_type_parameters,
9+
type_parameters,
10+
unspecify_type_parameters
611

712
const anyarrayts = (
813
(arrayt=Array, matrixt=Matrix, vectort=Vector),
@@ -78,3 +83,30 @@ const anyarrayts = (
7883
(3, Float32)
7984
end
8085
end
86+
87+
@testset "Automatic fallback for position" begin
88+
struct MyArray{B,A} <: AbstractArray{A,B} end
89+
@test @constinferred(TypeParameterAccessors.position(MyArray, eltype)) == Position(2)
90+
@test @constinferred(TypeParameterAccessors.position(MyArray{3,Float32}, eltype)) ==
91+
Position(2)
92+
@test @constinferred(TypeParameterAccessors.position(MyArray, ndims)) == Position(1)
93+
@test @constinferred(TypeParameterAccessors.position(MyArray{3,Float32}, ndims)) ==
94+
Position(1)
95+
96+
struct MyVector{X,Y,A<:Real} <: AbstractArray{A,1} end
97+
@test @constinferred(TypeParameterAccessors.position(MyVector, eltype)) == Position(3)
98+
@test @constinferred(
99+
TypeParameterAccessors.position(MyVector{Int,(1, 2),Float32}, eltype)
100+
) == Position(3)
101+
@test_throws ErrorException TypeParameterAccessors.position(MyVector, ndims)
102+
@test_throws ErrorException TypeParameterAccessors.position(
103+
MyVector{Int,(1, 2),Float32}, ndims
104+
)
105+
106+
struct MyBoolArray{X,Y,Z,B} <: AbstractArray{Bool,B} end
107+
@test_throws ErrorException TypeParameterAccessors.position(MyBoolArray, eltype)
108+
@test_throws ErrorException TypeParameterAccessors.position(MyBoolArray{1,2,3,4}, eltype)
109+
@test @constinferred(TypeParameterAccessors.position(MyBoolArray, ndims)) == Position(4)
110+
@test @constinferred(TypeParameterAccessors.position(MyBoolArray{1,2,3,4}, ndims)) ==
111+
Position(4)
112+
end

0 commit comments

Comments
 (0)