Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TypeParameterAccessors"
uuid = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.3.2"
version = "0.3.3"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
5 changes: 4 additions & 1 deletion ext/TypeParameterAccessorsAMDGPUExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
module TypeParameterAccessorsAMDGPUExt

using AMDGPU: ROCArray
using AMDGPU: AMDGPU, ROCArray
using TypeParameterAccessors: TypeParameterAccessors, Position

TypeParameterAccessors.position(::Type{<:ROCArray}, ::typeof(eltype)) = Position(1)
TypeParameterAccessors.position(::Type{<:ROCArray}, ::typeof(ndims)) = Position(2)
function TypeParameterAccessors.default_type_parameters(::Type{<:ROCArray})
return (Float64, 1, AMDGPU.Mem.HIPBuffer)

Check warning on line 9 in ext/TypeParameterAccessorsAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TypeParameterAccessorsAMDGPUExt.jl#L8-L9

Added lines #L8 - L9 were not covered by tests
end

end
5 changes: 4 additions & 1 deletion ext/TypeParameterAccessorsCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
module TypeParameterAccessorsCUDAExt

using CUDA: CuArray
using CUDA: CUDA, CuArray
using TypeParameterAccessors: TypeParameterAccessors, Position

TypeParameterAccessors.position(::Type{<:CuArray}, ::typeof(eltype)) = Position(1)
TypeParameterAccessors.position(::Type{<:CuArray}, ::typeof(ndims)) = Position(2)
function TypeParameterAccessors.default_type_parameters(::Type{<:CuArray})
return (Float64, 1, CUDA.default_memory)

Check warning on line 9 in ext/TypeParameterAccessorsCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TypeParameterAccessorsCUDAExt.jl#L8-L9

Added lines #L8 - L9 were not covered by tests
end

end
1 change: 1 addition & 0 deletions ext/TypeParameterAccessorsJLArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ using TypeParameterAccessors: TypeParameterAccessors, Position

TypeParameterAccessors.position(::Type{<:JLArray}, ::typeof(eltype)) = Position(1)
TypeParameterAccessors.position(::Type{<:JLArray}, ::typeof(ndims)) = Position(2)
TypeParameterAccessors.default_type_parameters(::Type{<:JLArray}) = (Float64, 1)

end
5 changes: 4 additions & 1 deletion ext/TypeParameterAccessorsMetalExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
module TypeParameterAccessorsMetalExt

using Metal: MtlArray
using Metal: Metal, MtlArray
using TypeParameterAccessors: TypeParameterAccessors, Position

TypeParameterAccessors.position(::Type{<:MtlArray}, ::typeof(eltype)) = Position(1)
TypeParameterAccessors.position(::Type{<:MtlArray}, ::typeof(ndims)) = Position(2)
function TypeParameterAccessors.default_type_parameters(::Type{<:MtlArray})
return (Float64, 1, Metal.DefaultStorageMode)

Check warning on line 9 in ext/TypeParameterAccessorsMetalExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TypeParameterAccessorsMetalExt.jl#L8-L9

Added lines #L8 - L9 were not covered by tests
end

end
3 changes: 3 additions & 0 deletions ext/TypeParameterAccessorsStridedViewsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,8 @@
function TypeParameterAccessors.position(::Type{<:StridedView}, ::typeof(parenttype))
return Position(3)
end
function TypeParameterAccessors.default_type_parameters(::Type{<:StridedView})
return (Float64, 1, Vector{Float64}, typeof(identity))

Check warning on line 12 in ext/TypeParameterAccessorsStridedViewsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TypeParameterAccessorsStridedViewsExt.jl#L11-L12

Added lines #L11 - L12 were not covered by tests
end

end
5 changes: 4 additions & 1 deletion ext/TypeParameterAccessorsoneAPIExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
module TypeParameterAccessorsoneAPIExt

using oneAPI: oneArray
using oneAPI: oneAPI, oneArray
using TypeParameterAccessors: TypeParameterAccessors, Position

TypeParameterAccessors.position(::Type{<:oneArray}, ::typeof(eltype)) = Position(1)
TypeParameterAccessors.position(::Type{<:oneArray}, ::typeof(ndims)) = Position(2)
function TypeParameterAccessors.default_type_parameters(::Type{<:oneAPI})
return (Float64, 1, oneAPI.oneL0.DeviceBuffer)

Check warning on line 9 in ext/TypeParameterAccessorsoneAPIExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TypeParameterAccessorsoneAPIExt.jl#L8-L9

Added lines #L8 - L9 were not covered by tests
end

end
52 changes: 43 additions & 9 deletions src/type_parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,21 +70,18 @@ end
# julia> Base.unwrap_unionall(supertype(MyArray)).parameters
# svec(A, B)
# ```
supertype_params = Base.unwrap_unionall(supertype(T)).parameters
supertype_params = Base.unwrap_unionall(supertype(T)).parameters
supertype_param = supertype_params[Int(supertype_pos)]
if !(supertype_param isa TypeVar)
error("Position not found.")
end
pos = findfirst(param -> (param.name == supertype_param.name), type_params)
if isnothing(pos)
return error("Position not found.")
end
return :(@inline; $(Position(pos)))
end

# Automatically determine the position of a type parameter of a type given
# a supertype and the name of the parameter.
function position_from_supertype(type::Type, supertype_target::Type, name)
return position_from_supertype(type, supertype_target, position(supertype_target, name))
end

function positions(::Type{T}, pos::Tuple) where {T}
return ntuple(length(pos)) do i
return position(T, pos[i])
Expand Down Expand Up @@ -256,11 +253,48 @@ function default_type_parameters(::Type{T}, pos) where {T}
return default_type_parameters(T, position(T, pos))
end
function default_type_parameters(::Type{T}, ::Position{pos}) where {T,pos}
return default_type_parameters(T)[pos]
param = default_type_parameters(T)[pos]
if param isa UndefinedDefaultTypeParameter
return error("No default parameter defined at this position.")
end
return param
end
default_type_parameters(::Type{T}, pos::Tuple) where {T} = default_type_parameters.(T, pos)
default_type_parameters(t) = default_type_parameters(typeof(t))
default_type_parameters(t, pos) = default_type_parameters(typeof(t), pos)
default_type_parameters(t) = default_type_parameters(typeof(t))
function default_type_parameters(type::Type)
type′ = unspecify_type_parameters(type)
if type === type′
return default_type_parameters_from_supertype(type′)
end
return default_type_parameters(type′)
end

struct UndefinedDefaultTypeParameter end

# Determine the default type parameters of a type from the default type
# parameters of the supertype of the type. Uses similar logic as
# `position_from_supertype_position` for matching TypeVar names
# between the type and the supertype. Type parameters that exist
# in the type but not the supertype will have a default type parameter
# `UndefinedDefaultTypeParameter()`. Accessing those type parameters
# by name/position will throw an error.
@generated function default_type_parameters_from_supertype(::Type{T}) where {T}
T′ = unspecify_type_parameters(T)
supertype_default_type_params = default_type_parameters(supertype(T′))
type_params = Base.unwrap_unionall(T′).parameters
supertype_params = Base.unwrap_unionall(supertype(T′)).parameters
defaults = Any[UndefinedDefaultTypeParameter() for _ in 1:nparameters(T′)]
for (supertype_param, supertype_default_type_param) in
zip(supertype_params, supertype_default_type_params)
if !(supertype_param isa TypeVar)
continue
end
param_position = findfirst(param -> (param.name == supertype_param.name), type_params)
defaults[param_position] = supertype_default_type_param
end
return :(@inline; $(Tuple(defaults)))
end

"""
set_default_type_parameters(type::Type, [positions::Tuple])
Expand Down
109 changes: 73 additions & 36 deletions test/test_defaults.jl
Original file line number Diff line number Diff line change
@@ -1,55 +1,63 @@
using Test: @test_throws, @testset
using JLArrays: JLArray
using Test: @test, @test_throws, @testset
using TestExtras: @constinferred
using TypeParameterAccessors:
TypeParameterAccessors,
Position,
default_type_parameters,
set_default_type_parameters,
specify_default_type_parameters
using TestExtras: @constinferred
@testset "TypeParameterAccessors defaults" begin
@testset "Erroneously requires wrapping to infer" begin end

const arrayts = (Array, JLArray)
@testset "TypeParameterAccessors defaults $arrayt" for arrayt in arrayts
vectort = arrayt{<:Any,1}
@testset "Get defaults" begin
@test @constinferred(default_type_parameters($Array, 1)) == Float64
@test @constinferred(default_type_parameters($Array, 2)) == 1
@test @constinferred(default_type_parameters($Array)) == (Float64, 1)
@test @constinferred(default_type_parameters($Array, $((2, 1)))) == (1, Float64)
@test @constinferred(broadcast($default_type_parameters, $Array, (ndims, eltype))) ==
@test @constinferred(default_type_parameters($arrayt, 1)) == Float64
@test @constinferred(default_type_parameters($arrayt, 2)) == 1
@test @constinferred(default_type_parameters($arrayt)) == (Float64, 1)
@test @constinferred(default_type_parameters($arrayt, $((2, 1)))) == (1, Float64)
@test @constinferred(broadcast($default_type_parameters, $arrayt, (ndims, eltype))) ==
(1, Float64)
@test @constinferred(broadcast($default_type_parameters, $Array, $((2, 1)))) ==
@test @constinferred(broadcast($default_type_parameters, $arrayt, $((2, 1)))) ==
(1, Float64)
@test @constinferred(broadcast($default_type_parameters, $Array, (ndims, eltype))) ==
@test @constinferred(broadcast($default_type_parameters, $arrayt, (ndims, eltype))) ==
(1, Float64)
end

@testset "Set defaults" begin
@test @constinferred(set_default_type_parameters($(Array{Float32}), 1)) ==
Array{Float64}
@test @constinferred(set_default_type_parameters($(Array{Float32}), eltype)) ==
Array{Float64}
@test @constinferred(set_default_type_parameters($(Array{Float32}))) == Vector{Float64}
@test @constinferred(set_default_type_parameters($(Array{Float32}), $((1, 2)))) ==
Vector{Float64}
@test @constinferred(set_default_type_parameters($(Array{Float32}), (eltype, ndims))) ==
Vector{Float64}
@test @constinferred(set_default_type_parameters($Array)) == Vector{Float64}
@test @constinferred(set_default_type_parameters($Array, 1)) == Array{Float64}
@test @constinferred(set_default_type_parameters($Array, $((1, 2)))) == Vector{Float64}
@test @constinferred(set_default_type_parameters($(arrayt{Float32}), 1)) ==
arrayt{Float64}
@test @constinferred(set_default_type_parameters($(arrayt{Float32}), eltype)) ==
arrayt{Float64}
@test @constinferred(set_default_type_parameters($(arrayt{Float32}))) ==
vectort{Float64}
@test @constinferred(set_default_type_parameters($(arrayt{Float32}), $((1, 2)))) ==
vectort{Float64}
@test @constinferred(
set_default_type_parameters($(arrayt{Float32}), (eltype, ndims))
) == vectort{Float64}
@test @constinferred(set_default_type_parameters($arrayt)) == vectort{Float64}
@test @constinferred(set_default_type_parameters($arrayt, 1)) == arrayt{Float64}
@test @constinferred(set_default_type_parameters($arrayt, $((1, 2)))) ==
vectort{Float64}
end

@testset "Specify defaults" begin
@test @constinferred(specify_default_type_parameters($Array, 1)) == Array{Float64}
@test @constinferred(specify_default_type_parameters($Array, eltype)) == Array{Float64}
@test @constinferred(specify_default_type_parameters($Array, 2)) == Vector
@test @constinferred(specify_default_type_parameters($Array, ndims)) == Vector
@test @constinferred(specify_default_type_parameters($Array)) == Vector{Float64}
@test @constinferred(specify_default_type_parameters($Array, 1)) == Array{Float64}
@test @constinferred(specify_default_type_parameters($Array, eltype)) == Array{Float64}
@test @constinferred(specify_default_type_parameters($Array, 2)) == Vector
@test @constinferred(specify_default_type_parameters($Array, ndims)) == Vector
@test @constinferred(specify_default_type_parameters($Array, $((1, 2)))) ==
Vector{Float64}
@test @constinferred(specify_default_type_parameters($Array, (eltype, ndims))) ==
Vector{Float64}
@test @constinferred(specify_default_type_parameters($arrayt, 1)) == arrayt{Float64}
@test @constinferred(specify_default_type_parameters($arrayt, eltype)) ==
arrayt{Float64}
@test @constinferred(specify_default_type_parameters($arrayt, 2)) == vectort
@test @constinferred(specify_default_type_parameters($arrayt, ndims)) == vectort
@test @constinferred(specify_default_type_parameters($arrayt)) == vectort{Float64}
@test @constinferred(specify_default_type_parameters($arrayt, 1)) == arrayt{Float64}
@test @constinferred(specify_default_type_parameters($arrayt, eltype)) ==
arrayt{Float64}
@test @constinferred(specify_default_type_parameters($arrayt, 2)) == vectort
@test @constinferred(specify_default_type_parameters($arrayt, ndims)) == vectort
@test @constinferred(specify_default_type_parameters($arrayt, $((1, 2)))) ==
vectort{Float64}
@test @constinferred(specify_default_type_parameters($arrayt, (eltype, ndims))) ==
vectort{Float64}
end

@testset "On objects" begin
Expand All @@ -60,4 +68,33 @@ using TestExtras: @constinferred
@test @constinferred(default_type_parameters(a, ndims)) == 1
@test @constinferred(default_type_parameters(a)) == (Float64, 1)
end

@testset "Automatic fallback for defaults" begin
struct MyArray{B,A} <: AbstractArray{A,B} end
@test @constinferred(default_type_parameters(MyArray)) === (1, Float64)
@test @constinferred(default_type_parameters(MyArray{2,Float32})) === (1, Float64)
@test @constinferred(default_type_parameters($MyArray, 1)) === 1
@test @constinferred(default_type_parameters($MyArray, 2)) === Float64
@test @constinferred(default_type_parameters(MyArray, eltype)) === Float64
@test @constinferred(default_type_parameters(MyArray, ndims)) === 1

und = TypeParameterAccessors.UndefinedDefaultTypeParameter()

struct MyVector{X,Y,A<:Real} <: AbstractArray{A,1} end
@test @constinferred(default_type_parameters(MyVector)) === (und, und, Float64)
@test_throws ErrorException default_type_parameters(MyVector, 1)
@test_throws ErrorException default_type_parameters(MyVector, 2)
@test @constinferred(default_type_parameters($MyVector, 3)) === Float64
@test @constinferred(default_type_parameters(MyVector, eltype)) === Float64
@test_throws ErrorException default_type_parameters(MyVector, ndims)

struct MyBoolArray{X,Y,Z,B} <: AbstractArray{Bool,B} end
@test @constinferred(default_type_parameters(MyBoolArray)) === (und, und, und, 1)
@test_throws ErrorException default_type_parameters(MyBoolArray, 1)
@test_throws ErrorException default_type_parameters(MyBoolArray, 2)
@test_throws ErrorException default_type_parameters(MyBoolArray, 3)
@test @constinferred(default_type_parameters($MyBoolArray, 4)) === 1
@test_throws ErrorException default_type_parameters(MyBoolArray, eltype)
@test @constinferred(default_type_parameters(MyBoolArray, ndims)) === 1
end
end
Loading