From e630c1d0d5acba2b184413fb181ff98fe78f4e65 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 12 Feb 2025 13:26:50 -0500 Subject: [PATCH 1/4] More `default_type_parameters` definitions --- Project.toml | 2 +- ext/TypeParameterAccessorsAMDGPUExt.jl | 5 ++++- ext/TypeParameterAccessorsCUDAExt.jl | 5 ++++- ext/TypeParameterAccessorsJLArraysExt.jl | 1 + ext/TypeParameterAccessorsMetalExt.jl | 5 ++++- ext/TypeParameterAccessorsStridedViewsExt.jl | 3 +++ ext/TypeParameterAccessorsoneAPIExt.jl | 5 ++++- src/type_parameters.jl | 9 ++++++++- 8 files changed, 29 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index cd5c2ce..2009836 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TypeParameterAccessors" uuid = "7e5a90cf-f82e-492e-a09b-e3e26432c138" authors = ["ITensor developers and contributors"] -version = "0.3.2" +version = "0.3.3" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/ext/TypeParameterAccessorsAMDGPUExt.jl b/ext/TypeParameterAccessorsAMDGPUExt.jl index 2c998c7..39af86c 100644 --- a/ext/TypeParameterAccessorsAMDGPUExt.jl +++ b/ext/TypeParameterAccessorsAMDGPUExt.jl @@ -1,9 +1,12 @@ module TypeParameterAccessorsAMDGPUExt -using AMDGPU: ROCArray +using AMDGPU: AMDGPU, ROCArray using TypeParameterAccessors: TypeParameterAccessors, Position TypeParameterAccessors.position(::Type{<:ROCArray}, ::typeof(eltype)) = Position(1) TypeParameterAccessors.position(::Type{<:ROCArray}, ::typeof(ndims)) = Position(2) +function TypeParameterAccessors.default_type_parameters(::Type{<:ROCArray}) + return (Float64, 1, AMDGPU.Mem.HIPBuffer) +end end diff --git a/ext/TypeParameterAccessorsCUDAExt.jl b/ext/TypeParameterAccessorsCUDAExt.jl index 3c8de3e..3bc2b8c 100644 --- a/ext/TypeParameterAccessorsCUDAExt.jl +++ b/ext/TypeParameterAccessorsCUDAExt.jl @@ -1,9 +1,12 @@ module TypeParameterAccessorsCUDAExt -using CUDA: CuArray +using CUDA: CUDA, CuArray using TypeParameterAccessors: TypeParameterAccessors, Position TypeParameterAccessors.position(::Type{<:CuArray}, ::typeof(eltype)) = Position(1) TypeParameterAccessors.position(::Type{<:CuArray}, ::typeof(ndims)) = Position(2) +function TypeParameterAccessors.default_type_parameters(::Type{<:CuArray}) + return (Float64, 1, CUDA.default_memory) +end end diff --git a/ext/TypeParameterAccessorsJLArraysExt.jl b/ext/TypeParameterAccessorsJLArraysExt.jl index 90264a5..f42c8e6 100644 --- a/ext/TypeParameterAccessorsJLArraysExt.jl +++ b/ext/TypeParameterAccessorsJLArraysExt.jl @@ -5,5 +5,6 @@ using TypeParameterAccessors: TypeParameterAccessors, Position TypeParameterAccessors.position(::Type{<:JLArray}, ::typeof(eltype)) = Position(1) TypeParameterAccessors.position(::Type{<:JLArray}, ::typeof(ndims)) = Position(2) +TypeParameterAccessors.default_type_parameters(::Type{<:JLArray}) = (Float64, 1) end diff --git a/ext/TypeParameterAccessorsMetalExt.jl b/ext/TypeParameterAccessorsMetalExt.jl index 5fbbe38..fae82f5 100644 --- a/ext/TypeParameterAccessorsMetalExt.jl +++ b/ext/TypeParameterAccessorsMetalExt.jl @@ -1,9 +1,12 @@ module TypeParameterAccessorsMetalExt -using Metal: MtlArray +using Metal: Metal, MtlArray using TypeParameterAccessors: TypeParameterAccessors, Position TypeParameterAccessors.position(::Type{<:MtlArray}, ::typeof(eltype)) = Position(1) TypeParameterAccessors.position(::Type{<:MtlArray}, ::typeof(ndims)) = Position(2) +function TypeParameterAccessors.default_type_parameters(::Type{<:MtlArray}) + return (Float64, 1, Metal.DefaultStorageMode) +end end diff --git a/ext/TypeParameterAccessorsStridedViewsExt.jl b/ext/TypeParameterAccessorsStridedViewsExt.jl index de8a42a..de1da3e 100644 --- a/ext/TypeParameterAccessorsStridedViewsExt.jl +++ b/ext/TypeParameterAccessorsStridedViewsExt.jl @@ -8,5 +8,8 @@ TypeParameterAccessors.position(::Type{<:StridedView}, ::typeof(ndims)) = Positi function TypeParameterAccessors.position(::Type{<:StridedView}, ::typeof(parenttype)) return Position(3) end +function TypeParameterAccessors.default_type_parameters(::Type{<:StridedView}) + return (Float64, 1, Vector{Float64}, typeof(identity)) +end end diff --git a/ext/TypeParameterAccessorsoneAPIExt.jl b/ext/TypeParameterAccessorsoneAPIExt.jl index 66b8f48..b841fa7 100644 --- a/ext/TypeParameterAccessorsoneAPIExt.jl +++ b/ext/TypeParameterAccessorsoneAPIExt.jl @@ -1,9 +1,12 @@ module TypeParameterAccessorsoneAPIExt -using oneAPI: oneArray +using oneAPI: oneAPI, oneArray using TypeParameterAccessors: TypeParameterAccessors, Position TypeParameterAccessors.position(::Type{<:oneArray}, ::typeof(eltype)) = Position(1) TypeParameterAccessors.position(::Type{<:oneArray}, ::typeof(ndims)) = Position(2) +function TypeParameterAccessors.default_type_parameters(::Type{<:oneAPI}) + return (Float64, 1, oneAPI.oneL0.DeviceBuffer) +end end diff --git a/src/type_parameters.jl b/src/type_parameters.jl index b0bc681..b8b8be0 100644 --- a/src/type_parameters.jl +++ b/src/type_parameters.jl @@ -259,8 +259,15 @@ function default_type_parameters(::Type{T}, ::Position{pos}) where {T,pos} return default_type_parameters(T)[pos] end default_type_parameters(::Type{T}, pos::Tuple) where {T} = default_type_parameters.(T, pos) -default_type_parameters(t) = default_type_parameters(typeof(t)) default_type_parameters(t, pos) = default_type_parameters(typeof(t), pos) +default_type_parameters(t) = default_type_parameters(typeof(t)) +function default_type_parameters(type::Type) + type′ = unspecify_type_parameters(type) + if type === type′ + error("Default type parameters have not been defined for `$(type′)`.") + end + return default_type_parameters(type′) +end """ set_default_type_parameters(type::Type, [positions::Tuple]) From 6c4ac56f1b9bc4e0cd800d5a48ac0de28b05e436 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 12 Feb 2025 17:58:04 -0500 Subject: [PATCH 2/4] Add tests --- test/test_defaults.jl | 80 ++++++++++++++++++++++++------------------- 1 file changed, 44 insertions(+), 36 deletions(-) diff --git a/test/test_defaults.jl b/test/test_defaults.jl index 280e9a0..14f6de0 100644 --- a/test/test_defaults.jl +++ b/test/test_defaults.jl @@ -1,55 +1,63 @@ -using Test: @test_throws, @testset +using JLArrays: JLArray +using Test: @test, @testset +using TestExtras: @constinferred using TypeParameterAccessors: TypeParameterAccessors, Position, default_type_parameters, set_default_type_parameters, specify_default_type_parameters -using TestExtras: @constinferred -@testset "TypeParameterAccessors defaults" begin - @testset "Erroneously requires wrapping to infer" begin end + +const arrayts = (Array, JLArray) +@testset "TypeParameterAccessors defaults $arrayt" for arrayt in arrayts + vectort = arrayt{<:Any,1} @testset "Get defaults" begin - @test @constinferred(default_type_parameters($Array, 1)) == Float64 - @test @constinferred(default_type_parameters($Array, 2)) == 1 - @test @constinferred(default_type_parameters($Array)) == (Float64, 1) - @test @constinferred(default_type_parameters($Array, $((2, 1)))) == (1, Float64) - @test @constinferred(broadcast($default_type_parameters, $Array, (ndims, eltype))) == + @test @constinferred(default_type_parameters($arrayt, 1)) == Float64 + @test @constinferred(default_type_parameters($arrayt, 2)) == 1 + @test @constinferred(default_type_parameters($arrayt)) == (Float64, 1) + @test @constinferred(default_type_parameters($arrayt, $((2, 1)))) == (1, Float64) + @test @constinferred(broadcast($default_type_parameters, $arrayt, (ndims, eltype))) == (1, Float64) - @test @constinferred(broadcast($default_type_parameters, $Array, $((2, 1)))) == + @test @constinferred(broadcast($default_type_parameters, $arrayt, $((2, 1)))) == (1, Float64) - @test @constinferred(broadcast($default_type_parameters, $Array, (ndims, eltype))) == + @test @constinferred(broadcast($default_type_parameters, $arrayt, (ndims, eltype))) == (1, Float64) end @testset "Set defaults" begin - @test @constinferred(set_default_type_parameters($(Array{Float32}), 1)) == - Array{Float64} - @test @constinferred(set_default_type_parameters($(Array{Float32}), eltype)) == - Array{Float64} - @test @constinferred(set_default_type_parameters($(Array{Float32}))) == Vector{Float64} - @test @constinferred(set_default_type_parameters($(Array{Float32}), $((1, 2)))) == - Vector{Float64} - @test @constinferred(set_default_type_parameters($(Array{Float32}), (eltype, ndims))) == - Vector{Float64} - @test @constinferred(set_default_type_parameters($Array)) == Vector{Float64} - @test @constinferred(set_default_type_parameters($Array, 1)) == Array{Float64} - @test @constinferred(set_default_type_parameters($Array, $((1, 2)))) == Vector{Float64} + @test @constinferred(set_default_type_parameters($(arrayt{Float32}), 1)) == + arrayt{Float64} + @test @constinferred(set_default_type_parameters($(arrayt{Float32}), eltype)) == + arrayt{Float64} + @test @constinferred(set_default_type_parameters($(arrayt{Float32}))) == + vectort{Float64} + @test @constinferred(set_default_type_parameters($(arrayt{Float32}), $((1, 2)))) == + vectort{Float64} + @test @constinferred( + set_default_type_parameters($(arrayt{Float32}), (eltype, ndims)) + ) == vectort{Float64} + @test @constinferred(set_default_type_parameters($arrayt)) == vectort{Float64} + @test @constinferred(set_default_type_parameters($arrayt, 1)) == arrayt{Float64} + @test @constinferred(set_default_type_parameters($arrayt, $((1, 2)))) == + vectort{Float64} end @testset "Specify defaults" begin - @test @constinferred(specify_default_type_parameters($Array, 1)) == Array{Float64} - @test @constinferred(specify_default_type_parameters($Array, eltype)) == Array{Float64} - @test @constinferred(specify_default_type_parameters($Array, 2)) == Vector - @test @constinferred(specify_default_type_parameters($Array, ndims)) == Vector - @test @constinferred(specify_default_type_parameters($Array)) == Vector{Float64} - @test @constinferred(specify_default_type_parameters($Array, 1)) == Array{Float64} - @test @constinferred(specify_default_type_parameters($Array, eltype)) == Array{Float64} - @test @constinferred(specify_default_type_parameters($Array, 2)) == Vector - @test @constinferred(specify_default_type_parameters($Array, ndims)) == Vector - @test @constinferred(specify_default_type_parameters($Array, $((1, 2)))) == - Vector{Float64} - @test @constinferred(specify_default_type_parameters($Array, (eltype, ndims))) == - Vector{Float64} + @test @constinferred(specify_default_type_parameters($arrayt, 1)) == arrayt{Float64} + @test @constinferred(specify_default_type_parameters($arrayt, eltype)) == + arrayt{Float64} + @test @constinferred(specify_default_type_parameters($arrayt, 2)) == vectort + @test @constinferred(specify_default_type_parameters($arrayt, ndims)) == vectort + @test @constinferred(specify_default_type_parameters($arrayt)) == vectort{Float64} + @test @constinferred(specify_default_type_parameters($arrayt, 1)) == arrayt{Float64} + @test @constinferred(specify_default_type_parameters($arrayt, eltype)) == + arrayt{Float64} + @test @constinferred(specify_default_type_parameters($arrayt, 2)) == vectort + @test @constinferred(specify_default_type_parameters($arrayt, ndims)) == vectort + @test @constinferred(specify_default_type_parameters($arrayt, $((1, 2)))) == + vectort{Float64} + @test @constinferred(specify_default_type_parameters($arrayt, (eltype, ndims))) == + vectort{Float64} end @testset "On objects" begin From 0e24a20510d319765ee06a16cd432cea2c65cae5 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 12 Feb 2025 20:28:50 -0500 Subject: [PATCH 3/4] Add fallback definition for defaults, add tests. --- src/type_parameters.jl | 32 +++++++++++++++++++++++++++++--- test/test_defaults.jl | 22 +++++++++++++++++++++- 2 files changed, 50 insertions(+), 4 deletions(-) diff --git a/src/type_parameters.jl b/src/type_parameters.jl index b8b8be0..383f15e 100644 --- a/src/type_parameters.jl +++ b/src/type_parameters.jl @@ -70,8 +70,11 @@ end # julia> Base.unwrap_unionall(supertype(MyArray)).parameters # svec(A, B) # ``` - supertype_params = Base.unwrap_unionall(supertype(T)).parameters + supertype_params = Base.unwrap_unionall(supertype(T′)).parameters supertype_param = supertype_params[Int(supertype_pos)] + if !(supertype_param isa TypeVar) + error("Position not found.") + end pos = findfirst(param -> (param.name == supertype_param.name), type_params) if isnothing(pos) return error("Position not found.") @@ -256,7 +259,11 @@ function default_type_parameters(::Type{T}, pos) where {T} return default_type_parameters(T, position(T, pos)) end function default_type_parameters(::Type{T}, ::Position{pos}) where {T,pos} - return default_type_parameters(T)[pos] + param = default_type_parameters(T)[pos] + if param isa UndefinedDefaultTypeParameter + return error("No default parameter defined at this position.") + end + return param end default_type_parameters(::Type{T}, pos::Tuple) where {T} = default_type_parameters.(T, pos) default_type_parameters(t, pos) = default_type_parameters(typeof(t), pos) @@ -264,11 +271,30 @@ default_type_parameters(t) = default_type_parameters(typeof(t)) function default_type_parameters(type::Type) type′ = unspecify_type_parameters(type) if type === type′ - error("Default type parameters have not been defined for `$(type′)`.") + return default_type_parameters_from_supertype(type′) end return default_type_parameters(type′) end +struct UndefinedDefaultTypeParameter end + +@generated function default_type_parameters_from_supertype(::Type{T}) where {T} + T′ = unspecify_type_parameters(T) + supertype_default_type_params = default_type_parameters(supertype(T′)) + type_params = Base.unwrap_unionall(T′).parameters + supertype_params = Base.unwrap_unionall(supertype(T′)).parameters + defaults = Any[UndefinedDefaultTypeParameter() for _ in 1:nparameters(T′)] + for (supertype_param, supertype_default_type_param) in + zip(supertype_params, supertype_default_type_params) + if !(supertype_param isa TypeVar) + continue + end + param_position = findfirst(param -> (param.name == supertype_param.name), type_params) + defaults[param_position] = supertype_default_type_param + end + return :(@inline; $(Tuple(defaults))) +end + """ set_default_type_parameters(type::Type, [positions::Tuple]) set_default_type_parameters(type::Type, position) diff --git a/test/test_defaults.jl b/test/test_defaults.jl index 14f6de0..5a12901 100644 --- a/test/test_defaults.jl +++ b/test/test_defaults.jl @@ -1,5 +1,5 @@ using JLArrays: JLArray -using Test: @test, @testset +using Test: @test, @test_throws, @testset using TestExtras: @constinferred using TypeParameterAccessors: TypeParameterAccessors, @@ -68,4 +68,24 @@ const arrayts = (Array, JLArray) @test @constinferred(default_type_parameters(a, ndims)) == 1 @test @constinferred(default_type_parameters(a)) == (Float64, 1) end + + @testset "Automatic fallback for defaults" begin + struct MyArray{B,A} <: AbstractArray{A,B} end + @test @constinferred(default_type_parameters(MyArray)) === (1, Float64) + @test @constinferred(default_type_parameters(MyArray{2,Float32})) === (1, Float64) + @test @constinferred(default_type_parameters(MyArray, eltype)) === Float64 + @test @constinferred(default_type_parameters(MyArray, ndims)) === 1 + + und = TypeParameterAccessors.UndefinedDefaultTypeParameter() + + struct MyVector{X,Y,A<:Real} <: AbstractArray{A,1} end + @test @constinferred(default_type_parameters(MyVector)) === (und, und, Float64) + @test @constinferred(default_type_parameters(MyVector, eltype)) === Float64 + @test_throws ErrorException default_type_parameters(MyVector, ndims) + + struct MyBoolArray{X,Y,Z,B} <: AbstractArray{Bool,B} end + @test @constinferred(default_type_parameters(MyBoolArray)) === (und, und, und, 1) + @test_throws ErrorException default_type_parameters(MyBoolArray, eltype) + @test @constinferred(default_type_parameters(MyBoolArray, ndims)) === 1 + end end From 0b20716742a23be701901ccca1b1f058c4b53dc7 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 12 Feb 2025 21:16:17 -0500 Subject: [PATCH 4/4] More tests and comments --- src/type_parameters.jl | 13 +++++++------ test/test_defaults.jl | 9 +++++++++ 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/type_parameters.jl b/src/type_parameters.jl index 383f15e..3034038 100644 --- a/src/type_parameters.jl +++ b/src/type_parameters.jl @@ -82,12 +82,6 @@ end return :(@inline; $(Position(pos))) end -# Automatically determine the position of a type parameter of a type given -# a supertype and the name of the parameter. -function position_from_supertype(type::Type, supertype_target::Type, name) - return position_from_supertype(type, supertype_target, position(supertype_target, name)) -end - function positions(::Type{T}, pos::Tuple) where {T} return ntuple(length(pos)) do i return position(T, pos[i]) @@ -278,6 +272,13 @@ end struct UndefinedDefaultTypeParameter end +# Determine the default type parameters of a type from the default type +# parameters of the supertype of the type. Uses similar logic as +# `position_from_supertype_position` for matching TypeVar names +# between the type and the supertype. Type parameters that exist +# in the type but not the supertype will have a default type parameter +# `UndefinedDefaultTypeParameter()`. Accessing those type parameters +# by name/position will throw an error. @generated function default_type_parameters_from_supertype(::Type{T}) where {T} T′ = unspecify_type_parameters(T) supertype_default_type_params = default_type_parameters(supertype(T′)) diff --git a/test/test_defaults.jl b/test/test_defaults.jl index 5a12901..1afb6f7 100644 --- a/test/test_defaults.jl +++ b/test/test_defaults.jl @@ -73,6 +73,8 @@ const arrayts = (Array, JLArray) struct MyArray{B,A} <: AbstractArray{A,B} end @test @constinferred(default_type_parameters(MyArray)) === (1, Float64) @test @constinferred(default_type_parameters(MyArray{2,Float32})) === (1, Float64) + @test @constinferred(default_type_parameters($MyArray, 1)) === 1 + @test @constinferred(default_type_parameters($MyArray, 2)) === Float64 @test @constinferred(default_type_parameters(MyArray, eltype)) === Float64 @test @constinferred(default_type_parameters(MyArray, ndims)) === 1 @@ -80,11 +82,18 @@ const arrayts = (Array, JLArray) struct MyVector{X,Y,A<:Real} <: AbstractArray{A,1} end @test @constinferred(default_type_parameters(MyVector)) === (und, und, Float64) + @test_throws ErrorException default_type_parameters(MyVector, 1) + @test_throws ErrorException default_type_parameters(MyVector, 2) + @test @constinferred(default_type_parameters($MyVector, 3)) === Float64 @test @constinferred(default_type_parameters(MyVector, eltype)) === Float64 @test_throws ErrorException default_type_parameters(MyVector, ndims) struct MyBoolArray{X,Y,Z,B} <: AbstractArray{Bool,B} end @test @constinferred(default_type_parameters(MyBoolArray)) === (und, und, und, 1) + @test_throws ErrorException default_type_parameters(MyBoolArray, 1) + @test_throws ErrorException default_type_parameters(MyBoolArray, 2) + @test_throws ErrorException default_type_parameters(MyBoolArray, 3) + @test @constinferred(default_type_parameters($MyBoolArray, 4)) === 1 @test_throws ErrorException default_type_parameters(MyBoolArray, eltype) @test @constinferred(default_type_parameters(MyBoolArray, ndims)) === 1 end