Skip to content
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DerivableInterfaces"
uuid = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.4.5"
version = "0.5.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"

[compat]
DerivableInterfaces = "0.4"
DerivableInterfaces = "0.5"
Documenter = "1"
Literate = "2"
2 changes: 1 addition & 1 deletion examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"

[compat]
ArrayLayouts = "1"
DerivableInterfaces = "0.4"
DerivableInterfaces = "0.5"
20 changes: 12 additions & 8 deletions src/abstractarrayinterface.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
# TODO: Add `ndims` type parameter.
abstract type AbstractArrayInterface <: AbstractInterface end
abstract type AbstractArrayInterface{N} <: AbstractInterface end

function interface(::Type{<:Broadcast.AbstractArrayStyle{N}}) where {N}
return DefaultArrayInterface{N}()
end
function interface(::Type{<:Broadcast.AbstractArrayStyle})
return DefaultArrayInterface()
end

function interface(::Type{<:Broadcast.Broadcasted{Nothing}})
return DefaultArrayInterface()
function interface(BC::Type{<:Broadcast.Broadcasted{Nothing}})
return DefaultArrayInterface{ndims(BC)}()
end

function interface(::Type{<:Broadcast.Broadcasted{<:Style}}) where {Style}
return interface(Style)
end

# TODO: Define as `Array{T}`.
arraytype(::AbstractArrayInterface, T::Type) = error("Not implemented.")
# TODO: Define as `similar(Array{T}, ax)`.
function Base.similar(interface::AbstractArrayInterface, T::Type, ax::Tuple)
return error("Not implemented.")
end

using ArrayLayouts: ArrayLayouts

Expand Down Expand Up @@ -85,7 +90,7 @@ end
@interface interface::AbstractArrayInterface function Base.similar(
a::AbstractArray, T::Type, size::Tuple{Vararg{Int}}
)
return similar(arraytype(interface, T), size)
return similar(interface, T, size)
end

@interface ::AbstractArrayInterface function Base.copy(a::AbstractArray)
Expand All @@ -105,8 +110,7 @@ end
@interface interface::AbstractArrayInterface function Base.similar(
bc::Broadcast.Broadcasted, T::Type, axes::Tuple
)
# `arraytype(::AbstractInterface)` determines the default array type associated with the interface.
return similar(arraytype(interface, T), axes)
return similar(interface, T, axes)
end

using MapBroadcast: Mapped
Expand Down
27 changes: 19 additions & 8 deletions src/concatenate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ export concatenate
@compat public Concatenated, cat, cat!, concatenated

using Base: promote_eltypeof
using ..DerivableInterfaces:
DerivableInterfaces, AbstractInterface, interface, zero!, arraytype
using ..DerivableInterfaces: DerivableInterfaces, AbstractArrayInterface, interface, zero!

unval(x) = x
unval(::Val{x}) where {x} = x
Expand All @@ -53,13 +52,17 @@ struct Concatenated{Interface,Dims,Args<:Tuple}
end
end

function Concatenated(interface::Union{Nothing,AbstractInterface}, dims::Val, args::Tuple)
function Concatenated(
interface::Union{AbstractArrayInterface,Nothing}, dims::Val, args::Tuple
)
return _Concatenated(interface, dims, args)
end
function Concatenated(dims::Val, args::Tuple)
return Concatenated(interface(args...), dims, args)
return Concatenated(cat_interface(dims, args...), dims, args)
end
function Concatenated{Interface}(dims::Val, args::Tuple) where {Interface}
function Concatenated{Interface}(
dims::Val, args::Tuple
) where {Interface<:Union{AbstractArrayInterface,Nothing}}
return Concatenated(Interface(), dims, args)
end

Expand All @@ -81,8 +84,11 @@ end
# ------------------------------------
Base.similar(concat::Concatenated) = similar(concat, eltype(concat))
Base.similar(concat::Concatenated, ::Type{T}) where {T} = similar(concat, T, axes(concat))
function Base.similar(concat::Concatenated, ::Type{T}, ax) where {T}
return similar(arraytype(interface(concat), T), ax)
function Base.similar(concat::Concatenated, ax::Tuple)
return similar(interface(concat), eltype(concat), ax)
end
function Base.similar(concat::Concatenated, ::Type{T}, ax::Tuple) where {T}
return similar(interface(concat), T, ax)
end

function cat_axis(
Expand All @@ -108,10 +114,15 @@ function cat_axes(dims::Val, as::AbstractArray...)
return cat_axes(unval(dims), as...)
end

function cat_interface(dims, as::AbstractArray...)
N = cat_ndims(dims, as...)
return typeof(interface(as...))(Val(N))
end

Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...)
Base.axes(concat::Concatenated) = cat_axes(dims(concat), concat.args...)
Base.size(concat::Concatenated) = length.(axes(concat))
Base.ndims(concat::Concatenated) = length(axes(concat))
Base.ndims(concat::Concatenated) = cat_ndims(dims(concat), concat.args...)

# Main logic
# ----------
Expand Down
26 changes: 22 additions & 4 deletions src/defaultarrayinterface.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,29 @@
# TODO: Add `ndims` type parameter.
struct DefaultArrayInterface <: AbstractArrayInterface end
struct DefaultArrayInterface{N} <: AbstractArrayInterface{N} end

DefaultArrayInterface() = DefaultArrayInterface{Any}()
DefaultArrayInterface(::Val{N}) where {N} = DefaultArrayInterface{N}()
DefaultArrayInterface{M}(::Val{N}) where {M,N} = DefaultArrayInterface{N}()

using TypeParameterAccessors: parenttype
function interface(a::Type{<:AbstractArray})
parenttype(a) === a && return DefaultArrayInterface()
return interface(parenttype(a))
end
function interface(a::Type{<:AbstractArray{<:Any,N}}) where {N}
parenttype(a) === a && return DefaultArrayInterface{N}()
return interface(parenttype(a))
end

function combine_interface_rule(
interface1::DefaultArrayInterface{N}, interface2::DefaultArrayInterface{N}
) where {N}
return DefaultArrayInterface{N}()
end
function combine_interface_rule(
interface1::DefaultArrayInterface, interface2::DefaultArrayInterface
)
return DefaultArrayInterface{Any}()
end

@interface ::DefaultArrayInterface function Base.getindex(
a::AbstractArray{<:Any,N}, I::Vararg{Int,N}
Expand All @@ -31,6 +49,6 @@ end
return Base.mapreduce(f, op, as...; kwargs...)
end

function arraytype(::DefaultArrayInterface, T::Type)
return Array{T}
function Base.similar(::DefaultArrayInterface, T::Type, ax::Tuple)
return similar(Array{T}, ax)
end
6 changes: 4 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"

[compat]
Aqua = "0.8"
ArrayLayouts = "1"
DerivableInterfaces = "0.4"
DerivableInterfaces = "0.5"
LinearAlgebra = "1"
SafeTestsets = "0.1"
Suppressor = "0.2"
LinearAlgebra = "1"
Test = "1"
TestExtras = "0.3"
16 changes: 12 additions & 4 deletions test/SparseArrayDOKs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ using DerivableInterfaces:
using LinearAlgebra: LinearAlgebra

# Define an interface.
struct SparseArrayInterface <: AbstractArrayInterface end
struct SparseArrayInterface{N} <: AbstractArrayInterface{N} end
SparseArrayInterface(::Val{N}) where {N} = SparseArrayInterface{N}()
SparseArrayInterface{M}(::Val{N}) where {M,N} = SparseArrayInterface{N}()

# Define interface functions.
@interface ::SparseArrayInterface function Base.getindex(
Expand All @@ -66,11 +68,15 @@ end
struct SparseArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end
SparseArrayStyle{M}(::Val{N}) where {M,N} = SparseArrayStyle{N}()

DerivableInterfaces.interface(::Type{<:SparseArrayStyle}) = SparseArrayInterface()
function DerivableInterfaces.interface(::Type{<:SparseArrayStyle{N}}) where {N}
return SparseArrayInterface{N}()
end

@derive SparseArrayStyle AbstractArrayStyleOps

DerivableInterfaces.arraytype(::SparseArrayInterface, T::Type) = SparseArrayDOK{T}
function Base.similar(::SparseArrayInterface, T::Type, ax::Tuple)
return similar(SparseArrayDOK{T}, ax)
end

# Interface functions.
@interface ::SparseArrayInterface function Broadcast.BroadcastStyle(type::Type)
Expand Down Expand Up @@ -260,7 +266,9 @@ function DerivableInterfaces.zero!(a::SparseArrayDOK)
end

# Specify the interface the type adheres to.
DerivableInterfaces.interface(::Type{<:SparseArrayDOK}) = SparseArrayInterface()
function DerivableInterfaces.interface(arrayt::Type{<:SparseArrayDOK})
SparseArrayInterface{ndims(arrayt)}()
end

# Define aliases like `SparseMatrixDOK`, `AnySparseArrayDOK`, etc.
@array_aliases SparseArrayDOK
Expand Down
35 changes: 29 additions & 6 deletions test/test_defaultarrayinterface.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Test: @inferred, @testset, @test
using DerivableInterfaces: @interface, DefaultArrayInterface, interface
using Test: @testset, @test
using TestExtras: @constinferred

# function wrappers to test type-stability
_getindex(A, i...) = @interface DefaultArrayInterface() A[i...]
Expand All @@ -11,28 +12,50 @@ end

@testset "indexing" begin
for (A, i) in ((zeros(2), 2), (zeros(2, 2), (2, 1)), (zeros(1, 2, 3), (1, 2, 3)))
a = @inferred _getindex(A, i...)
a = @constinferred _getindex(A, i...)
@test a == A[i...]
v = 1.1
A′ = @inferred _setindex!(A, v, i...)
A′ = @constinferred _setindex!(A, v, i...)
@test A′ == (A[i...] = v)
end
end

@testset "map!" begin
A = zeros(3)
a = @inferred _map!(Returns(2), copy(A), A)
a = @constinferred _map!(Returns(2), copy(A), A)
@test a == map!(Returns(2), copy(A), A)
end

@testset "mapreduce" begin
A = zeros(3)
a = @inferred _mapreduce(Returns(2), +, A)
a = @constinferred _mapreduce(Returns(2), +, A)
@test a == mapreduce(Returns(2), +, A)
end

@testset "DefaultArrayInterface" begin
@test interface(Array) === DefaultArrayInterface{Any}()
@test interface(Array{Float32}) === DefaultArrayInterface{Any}()
@test interface(Matrix) === DefaultArrayInterface{2}()
@test interface(Matrix{Float32}) === DefaultArrayInterface{2}()
@test DefaultArrayInterface() === DefaultArrayInterface{Any}()
@test DefaultArrayInterface(Val(2)) === DefaultArrayInterface{2}()
@test DefaultArrayInterface{Any}(Val(2)) === DefaultArrayInterface{2}()
@test DefaultArrayInterface{3}(Val(2)) === DefaultArrayInterface{2}()
end

@testset "similar(::DefaultArrayInterface, ...)" begin
a = @constinferred similar(DefaultArrayInterface(), Float32, (2, 2))
@test typeof(a) === Matrix{Float32}
@test size(a) == (2, 2)

a = @constinferred similar(DefaultArrayInterface{1}(), Float32, (2, 2))
@test typeof(a) === Matrix{Float32}
@test size(a) == (2, 2)
end

@testset "Broadcast.DefaultArrayStyle" begin
@test interface(Broadcast.DefaultArrayStyle) == DefaultArrayInterface()
@test interface(Broadcast.DefaultArrayStyle{2}) == DefaultArrayInterface{2}()
@test interface(Broadcast.Broadcasted(nothing, +, (randn(2), randn(2)))) ==
DefaultArrayInterface()
DefaultArrayInterface{1}()
end
Loading