diff --git a/Project.toml b/Project.toml index 9127905..cd5c2ce 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TypeParameterAccessors" uuid = "7e5a90cf-f82e-492e-a09b-e3e26432c138" authors = ["ITensor developers and contributors"] -version = "0.3.1" +version = "0.3.2" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/base/abstractarray.jl b/src/base/abstractarray.jl index 0316e3c..a5a2f91 100644 --- a/src/base/abstractarray.jl +++ b/src/base/abstractarray.jl @@ -1,3 +1,10 @@ +struct Self end +position(a, ::Self) = Position(0) +position(::Type, ::Self) = Position(0) +function set_type_parameters(type::Type, ::Self, param) + return error("Can't set the parent type of an unwrapped array type.") +end + position(::Type{AbstractArray}, ::typeof(eltype)) = Position(1) position(::Type{AbstractArray}, ::typeof(ndims)) = Position(2) default_type_parameters(::Type{AbstractArray}) = (Float64, 1) @@ -9,14 +16,6 @@ default_type_parameters(::Type{<:Array}) = (Float64, 1) position(::Type{<:BitArray}, ::typeof(ndims)) = Position(1) default_type_parameters(::Type{<:BitArray}) = (1,) -struct Self end -position(a, ::Self) = Position(0) -position(::Type{T}, ::Self) where {T} = Position(0) - -function set_type_parameters(type::Type, ::Self, param) - return error("Can't set the parent type of an unwrapped array type.") -end - function set_eltype(array::AbstractArray, param) return convert(set_eltype(typeof(array), param), array) end diff --git a/src/type_parameters.jl b/src/type_parameters.jl index 31e6ca0..b0bc681 100644 --- a/src/type_parameters.jl +++ b/src/type_parameters.jl @@ -27,10 +27,62 @@ function position end position(object, name) = position(typeof(object), name) position(::Type, pos::Int) = Position(pos) position(::Type, pos::Position) = pos + function position(type::Type, name) - base_type = unspecify_type_parameters(type) - base_type === type && error("`position` not defined for $type and $name.") - return position(base_type, name) + type′ = unspecify_type_parameters(type) + if type === type′ + # Fallback definition that determines the + # position automatically from the supertype of + # the type. + return position_from_supertype(type′, name) + end + return position(type′, name) +end + +# Automatically determine the position of a type parameter of a type given +# a supertype and the name of the parameter. +function position_from_supertype(type::Type, name) + type′ = unspecify_type_parameters(type) + supertype_pos = position(supertype(type′), name) + return position_from_supertype_position(type′, supertype_pos) +end + +# Automatically determine the position of a type parameter of a type given +# the supertype and the position of the corresponding parameter in the supertype. +@generated function position_from_supertype_position( + ::Type{T}, supertype_pos::Position +) where {T} + T′ = unspecify_type_parameters(T) + # The type parameters of the type as TypeVars. + # TODO: Ideally we would use `get_type_parameters` + # but that sometimes loses TypeVar names: + # https://github.com/ITensor/TypeParameterAccessors.jl/issues/30 + type_params = Base.unwrap_unionall(T′).parameters + # The type parameters of the immediate supertype as TypeVars. + # This has TypeVars with names that correspond to the names of + # the TypeVars of the type parameters of `T`, for example: + # ```julia + # julia> struct MyArray{B,A} <: AbstractArray{A,B} end + # + # julia> Base.unwrap_unionall(MyArray).parameters + # svec(B, A) + # + # julia> Base.unwrap_unionall(supertype(MyArray)).parameters + # svec(A, B) + # ``` + supertype_params = Base.unwrap_unionall(supertype(T)).parameters + supertype_param = supertype_params[Int(supertype_pos)] + pos = findfirst(param -> (param.name == supertype_param.name), type_params) + if isnothing(pos) + return error("Position not found.") + end + return :(@inline; $(Position(pos))) +end + +# Automatically determine the position of a type parameter of a type given +# a supertype and the name of the parameter. +function position_from_supertype(type::Type, supertype_target::Type, name) + return position_from_supertype(type, supertype_target, position(supertype_target, name)) end function positions(::Type{T}, pos::Tuple) where {T} diff --git a/test/test_basics.jl b/test/test_basics.jl index 29e5b7f..a668232 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -2,7 +2,12 @@ using JLArrays: JLArray, JLMatrix, JLVector using Test: @test, @test_throws, @test_broken, @testset using TestExtras: @constinferred using TypeParameterAccessors: - set_type_parameters, specify_type_parameters, type_parameters, unspecify_type_parameters + TypeParameterAccessors, + Position, + set_type_parameters, + specify_type_parameters, + type_parameters, + unspecify_type_parameters const anyarrayts = ( (arrayt=Array, matrixt=Matrix, vectort=Vector), @@ -78,3 +83,30 @@ const anyarrayts = ( (3, Float32) end end + +@testset "Automatic fallback for position" begin + struct MyArray{B,A} <: AbstractArray{A,B} end + @test @constinferred(TypeParameterAccessors.position(MyArray, eltype)) == Position(2) + @test @constinferred(TypeParameterAccessors.position(MyArray{3,Float32}, eltype)) == + Position(2) + @test @constinferred(TypeParameterAccessors.position(MyArray, ndims)) == Position(1) + @test @constinferred(TypeParameterAccessors.position(MyArray{3,Float32}, ndims)) == + Position(1) + + struct MyVector{X,Y,A<:Real} <: AbstractArray{A,1} end + @test @constinferred(TypeParameterAccessors.position(MyVector, eltype)) == Position(3) + @test @constinferred( + TypeParameterAccessors.position(MyVector{Int,(1, 2),Float32}, eltype) + ) == Position(3) + @test_throws ErrorException TypeParameterAccessors.position(MyVector, ndims) + @test_throws ErrorException TypeParameterAccessors.position( + MyVector{Int,(1, 2),Float32}, ndims + ) + + struct MyBoolArray{X,Y,Z,B} <: AbstractArray{Bool,B} end + @test_throws ErrorException TypeParameterAccessors.position(MyBoolArray, eltype) + @test_throws ErrorException TypeParameterAccessors.position(MyBoolArray{1,2,3,4}, eltype) + @test @constinferred(TypeParameterAccessors.position(MyBoolArray, ndims)) == Position(4) + @test @constinferred(TypeParameterAccessors.position(MyBoolArray{1,2,3,4}, ndims)) == + Position(4) +end