Skip to content
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DerivableInterfaces"
uuid = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.4.5"
version = "0.5.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"

[compat]
DerivableInterfaces = "0.4"
DerivableInterfaces = "0.5"
Documenter = "1"
Literate = "2"
2 changes: 1 addition & 1 deletion examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"

[compat]
ArrayLayouts = "1"
DerivableInterfaces = "0.4"
DerivableInterfaces = "0.5"
20 changes: 12 additions & 8 deletions src/abstractarrayinterface.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
# TODO: Add `ndims` type parameter.
abstract type AbstractArrayInterface <: AbstractInterface end
abstract type AbstractArrayInterface{N} <: AbstractInterface end

function interface(::Type{<:Broadcast.AbstractArrayStyle{N}}) where {N}
return DefaultArrayInterface{N}()
end
function interface(::Type{<:Broadcast.AbstractArrayStyle})
return DefaultArrayInterface()
end

function interface(::Type{<:Broadcast.Broadcasted{Nothing}})
return DefaultArrayInterface()
function interface(BC::Type{<:Broadcast.Broadcasted{Nothing}})
return DefaultArrayInterface{ndims(BC)}()
end

function interface(::Type{<:Broadcast.Broadcasted{<:Style}}) where {Style}
return interface(Style)
end

# TODO: Define as `Array{T}`.
arraytype(::AbstractArrayInterface, T::Type) = error("Not implemented.")
# TODO: Define as `similar(Array{T}, ax)`.
function Base.similar(interface::AbstractArrayInterface, T::Type, ax::Tuple)
return error("Not implemented.")
end

using ArrayLayouts: ArrayLayouts

Expand Down Expand Up @@ -85,7 +90,7 @@ end
@interface interface::AbstractArrayInterface function Base.similar(
a::AbstractArray, T::Type, size::Tuple{Vararg{Int}}
)
return similar(arraytype(interface, T), size)
return similar(interface, T, size)
end

@interface ::AbstractArrayInterface function Base.copy(a::AbstractArray)
Expand All @@ -105,8 +110,7 @@ end
@interface interface::AbstractArrayInterface function Base.similar(
bc::Broadcast.Broadcasted, T::Type, axes::Tuple
)
# `arraytype(::AbstractInterface)` determines the default array type associated with the interface.
return similar(arraytype(interface, T), axes)
return similar(interface, T, axes)
end

using MapBroadcast: Mapped
Expand Down
27 changes: 20 additions & 7 deletions src/concatenate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,16 @@ export concatenate
@compat public Concatenated, cat, cat!, concatenated

using Base: promote_eltypeof
using ..DerivableInterfaces:
DerivableInterfaces, AbstractInterface, interface, zero!, arraytype
using ..DerivableInterfaces: DerivableInterfaces, AbstractArrayInterface, interface, zero!

unval(x) = x
unval(::Val{x}) where {x} = x

set_interface_ndims(::Type{Nothing}, ::Val{N}) where {N} = nothing
function set_interface_ndims(Interface::Type{<:AbstractArrayInterface}, ::Val{N}) where {N}
return Interface(Val(N))
end

function _Concatenated end

"""
Expand All @@ -53,14 +57,20 @@ struct Concatenated{Interface,Dims,Args<:Tuple}
end
end

function Concatenated(interface::Union{Nothing,AbstractInterface}, dims::Val, args::Tuple)
function Concatenated(interface::Nothing, dims::Val, args::Tuple)
return _Concatenated(interface, dims, args)
end
function Concatenated(interface::AbstractArrayInterface, dims::Val, args::Tuple)
N = cat_ndims(dims, args...)
return _Concatenated(typeof(interface)(Val(N)), dims, args)
end
function Concatenated(dims::Val, args::Tuple)
return Concatenated(interface(args...), dims, args)
N = cat_ndims(dims, args...)
return _Concatenated(typeof(interface(args...))(Val(N)), dims, args)
end
function Concatenated{Interface}(dims::Val, args::Tuple) where {Interface}
return Concatenated(Interface(), dims, args)
N = cat_ndims(dims, args...)
return _Concatenated(set_interface_ndims(Interface, Val(N)), dims, args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this new call to similar, do you need to set the interface ndims already at this point? I guess this would be automatically determined from the ax later down the line, so I'm just wondering if there is any difference.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true, when you call similar, the ndims in the interface will get overridden by the axes input to similar. I kept it in case there is some other use case where storing the ndims in the interface is useful, but I admit I don't have a particular use case in mind for Concatenated.

I suppose what this does is catch cases where a user specifies an interface but it actually has the wrong ndims for the concatenation expression, I guess that could be checked.

But also thinking about this particular constructor, it should change the Interface type since that mean this constructor doesn't satisfy T(args...) isa T, so I'll change this one.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the latest commits I propose a compromise. If the interface object is passed explicitly, it is taken "as is" and isn't modified, even if the ndims aren't specified or are incorrect, but when it is constructed from the arguments the correct ndims are computed explicitly.

That matches the behavior of Broadcasted:

julia> using Base.Broadcast: DefaultArrayStyle, Broadcasted

julia> bc = Broadcasted(DefaultArrayStyle{1}(), +, (randn(2, 2), randn(2, 2)))
Broadcasted(+, ([0.095430782012298 0.4338409876936397; -0.2285907590132556 2.0739880112106475], [-0.6626604337472756 0.5369654075753642; 0.2681815713554777 -0.7930248596990674]))

julia> bc.style
DefaultArrayStyle{1}()

julia> bc = Broadcasted(+, (randn(2, 2), randn(2, 2)))
Broadcasted(+, ([-0.08113666651442918 -0.11158415095613558; -0.5031937898445847 -1.1018574952687241], [-1.595659893181313 0.12746978522353117; 0.026558187695457296 -0.22363363427492086]))

julia> bc.style
DefaultArrayStyle{2}()

That seems reasonable, since the interface is meant to be something where you can specify it to decide on the dispatch, I think we shouldn't be too opinionated about modifying what the user input since it may have been deliberate.

end

dims(::Concatenated{<:Any,D}) where {D} = D
Expand All @@ -81,8 +91,11 @@ end
# ------------------------------------
Base.similar(concat::Concatenated) = similar(concat, eltype(concat))
Base.similar(concat::Concatenated, ::Type{T}) where {T} = similar(concat, T, axes(concat))
function Base.similar(concat::Concatenated, ::Type{T}, ax) where {T}
return similar(arraytype(interface(concat), T), ax)
function Base.similar(concat::Concatenated, ax::Tuple)
return similar(interface(concat), eltype(concat), ax)
end
function Base.similar(concat::Concatenated, ::Type{T}, ax::Tuple) where {T}
return similar(interface(concat), T, ax)
end

function cat_axis(
Expand Down
26 changes: 22 additions & 4 deletions src/defaultarrayinterface.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,29 @@
# TODO: Add `ndims` type parameter.
struct DefaultArrayInterface <: AbstractArrayInterface end
struct DefaultArrayInterface{N} <: AbstractArrayInterface{N} end

DefaultArrayInterface() = DefaultArrayInterface{Any}()
DefaultArrayInterface(::Val{N}) where {N} = DefaultArrayInterface{N}()
DefaultArrayInterface{M}(::Val{N}) where {M,N} = DefaultArrayInterface{N}()

using TypeParameterAccessors: parenttype
function interface(a::Type{<:AbstractArray})
parenttype(a) === a && return DefaultArrayInterface()
return interface(parenttype(a))
end
function interface(a::Type{<:AbstractArray{<:Any,N}}) where {N}
parenttype(a) === a && return DefaultArrayInterface{N}()
return interface(parenttype(a))
end

function combine_interface_rule(
interface1::DefaultArrayInterface{N}, interface2::DefaultArrayInterface{N}
) where {N}
return DefaultArrayInterface{N}()
end
function combine_interface_rule(
interface1::DefaultArrayInterface, interface2::DefaultArrayInterface
)
return DefaultArrayInterface{Any}()
end

@interface ::DefaultArrayInterface function Base.getindex(
a::AbstractArray{<:Any,N}, I::Vararg{Int,N}
Expand All @@ -31,6 +49,6 @@ end
return Base.mapreduce(f, op, as...; kwargs...)
end

function arraytype(::DefaultArrayInterface, T::Type)
return Array{T}
function Base.similar(::DefaultArrayInterface, T::Type, ax::Tuple)
return similar(Array{T}, ax)
end
6 changes: 4 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"

[compat]
Aqua = "0.8"
ArrayLayouts = "1"
DerivableInterfaces = "0.4"
DerivableInterfaces = "0.5"
LinearAlgebra = "1"
SafeTestsets = "0.1"
Suppressor = "0.2"
LinearAlgebra = "1"
Test = "1"
TestExtras = "0.3"
16 changes: 12 additions & 4 deletions test/SparseArrayDOKs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ using DerivableInterfaces:
using LinearAlgebra: LinearAlgebra

# Define an interface.
struct SparseArrayInterface <: AbstractArrayInterface end
struct SparseArrayInterface{N} <: AbstractArrayInterface{N} end
SparseArrayInterface(::Val{N}) where {N} = SparseArrayInterface{N}()
SparseArrayInterface{M}(::Val{N}) where {M,N} = SparseArrayInterface{N}()

# Define interface functions.
@interface ::SparseArrayInterface function Base.getindex(
Expand All @@ -66,11 +68,15 @@ end
struct SparseArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end
SparseArrayStyle{M}(::Val{N}) where {M,N} = SparseArrayStyle{N}()

DerivableInterfaces.interface(::Type{<:SparseArrayStyle}) = SparseArrayInterface()
function DerivableInterfaces.interface(::Type{<:SparseArrayStyle{N}}) where {N}
return SparseArrayInterface{N}()
end

@derive SparseArrayStyle AbstractArrayStyleOps

DerivableInterfaces.arraytype(::SparseArrayInterface, T::Type) = SparseArrayDOK{T}
function Base.similar(::SparseArrayInterface, T::Type, ax::Tuple)
return similar(SparseArrayDOK{T}, ax)
end

# Interface functions.
@interface ::SparseArrayInterface function Broadcast.BroadcastStyle(type::Type)
Expand Down Expand Up @@ -260,7 +266,9 @@ function DerivableInterfaces.zero!(a::SparseArrayDOK)
end

# Specify the interface the type adheres to.
DerivableInterfaces.interface(::Type{<:SparseArrayDOK}) = SparseArrayInterface()
function DerivableInterfaces.interface(arrayt::Type{<:SparseArrayDOK})
SparseArrayInterface{ndims(arrayt)}()
end

# Define aliases like `SparseMatrixDOK`, `AnySparseArrayDOK`, etc.
@array_aliases SparseArrayDOK
Expand Down
35 changes: 29 additions & 6 deletions test/test_defaultarrayinterface.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Test: @inferred, @testset, @test
using DerivableInterfaces: @interface, DefaultArrayInterface, interface
using Test: @testset, @test
using TestExtras: @constinferred

# function wrappers to test type-stability
_getindex(A, i...) = @interface DefaultArrayInterface() A[i...]
Expand All @@ -11,28 +12,50 @@ end

@testset "indexing" begin
for (A, i) in ((zeros(2), 2), (zeros(2, 2), (2, 1)), (zeros(1, 2, 3), (1, 2, 3)))
a = @inferred _getindex(A, i...)
a = @constinferred _getindex(A, i...)
@test a == A[i...]
v = 1.1
A′ = @inferred _setindex!(A, v, i...)
A′ = @constinferred _setindex!(A, v, i...)
@test A′ == (A[i...] = v)
end
end

@testset "map!" begin
A = zeros(3)
a = @inferred _map!(Returns(2), copy(A), A)
a = @constinferred _map!(Returns(2), copy(A), A)
@test a == map!(Returns(2), copy(A), A)
end

@testset "mapreduce" begin
A = zeros(3)
a = @inferred _mapreduce(Returns(2), +, A)
a = @constinferred _mapreduce(Returns(2), +, A)
@test a == mapreduce(Returns(2), +, A)
end

@testset "DefaultArrayInterface" begin
@test interface(Array) === DefaultArrayInterface{Any}()
@test interface(Array{Float32}) === DefaultArrayInterface{Any}()
@test interface(Matrix) === DefaultArrayInterface{2}()
@test interface(Matrix{Float32}) === DefaultArrayInterface{2}()
@test DefaultArrayInterface() === DefaultArrayInterface{Any}()
@test DefaultArrayInterface(Val(2)) === DefaultArrayInterface{2}()
@test DefaultArrayInterface{Any}(Val(2)) === DefaultArrayInterface{2}()
@test DefaultArrayInterface{3}(Val(2)) === DefaultArrayInterface{2}()
end

@testset "similar(::DefaultArrayInterface, ...)" begin
a = @constinferred similar(DefaultArrayInterface(), Float32, (2, 2))
@test typeof(a) === Matrix{Float32}
@test size(a) == (2, 2)

a = @constinferred similar(DefaultArrayInterface{1}(), Float32, (2, 2))
@test typeof(a) === Matrix{Float32}
@test size(a) == (2, 2)
end

@testset "Broadcast.DefaultArrayStyle" begin
@test interface(Broadcast.DefaultArrayStyle) == DefaultArrayInterface()
@test interface(Broadcast.DefaultArrayStyle{2}) == DefaultArrayInterface{2}()
@test interface(Broadcast.Broadcasted(nothing, +, (randn(2), randn(2)))) ==
DefaultArrayInterface()
DefaultArrayInterface{1}()
end
Loading