Skip to content
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DerivableInterfaces"
uuid = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.4.5"
version = "0.5.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"

[compat]
DerivableInterfaces = "0.4"
DerivableInterfaces = "0.5"
Documenter = "1"
Literate = "2"
2 changes: 1 addition & 1 deletion examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"

[compat]
ArrayLayouts = "1"
DerivableInterfaces = "0.4"
DerivableInterfaces = "0.5"
11 changes: 7 additions & 4 deletions src/abstractarrayinterface.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
# TODO: Add `ndims` type parameter.
abstract type AbstractArrayInterface <: AbstractInterface end
abstract type AbstractArrayInterface{N} <: AbstractInterface end

function interface(::Type{<:Broadcast.AbstractArrayStyle{N}}) where {N}
return DefaultArrayInterface{N}()
end
function interface(::Type{<:Broadcast.AbstractArrayStyle})
return DefaultArrayInterface()
end

function interface(::Type{<:Broadcast.Broadcasted{Nothing}})
return DefaultArrayInterface()
function interface(BC::Type{<:Broadcast.Broadcasted{Nothing}})
return DefaultArrayInterface{ndims(BC)}()
end

function interface(::Type{<:Broadcast.Broadcasted{<:Style}}) where {Style}
return interface(Style)
end

# TODO: Define as `Array{T}`.
# TODO: Define as `Array{T,N}`.
arraytype(::AbstractArrayInterface, T::Type) = error("Not implemented.")

using ArrayLayouts: ArrayLayouts
Expand Down
23 changes: 17 additions & 6 deletions src/concatenate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,16 @@ export concatenate

using Base: promote_eltypeof
using ..DerivableInterfaces:
DerivableInterfaces, AbstractInterface, interface, zero!, arraytype
DerivableInterfaces, AbstractArrayInterface, interface, zero!, arraytype

unval(x) = x
unval(::Val{x}) where {x} = x

set_interface_ndims(::Type{Nothing}, ::Val{N}) where {N} = nothing
function set_interface_ndims(Interface::Type{<:AbstractArrayInterface}, ::Val{N}) where {N}
return Interface(Val(N))
end

function _Concatenated end

"""
Expand All @@ -42,25 +47,31 @@ function _Concatenated end
Lazy representation of the concatenation of various `Args` along `Dims`, in order to provide
hooks to customize the implementation.
"""
struct Concatenated{Interface,Dims,Args<:Tuple}
struct Concatenated{Interface<:Union{AbstractArrayInterface,Nothing},Dims,Args<:Tuple}
interface::Interface
dims::Val{Dims}
args::Args
global @inline function _Concatenated(
interface::Interface, dims::Val{Dims}, args::Args
) where {Interface,Dims,Args<:Tuple}
) where {Interface<:Union{AbstractArrayInterface,Nothing},Dims,Args<:Tuple}
return new{Interface,Dims,Args}(interface, dims, args)
end
end

function Concatenated(interface::Union{Nothing,AbstractInterface}, dims::Val, args::Tuple)
function Concatenated(interface::Nothing, dims::Val, args::Tuple)
return _Concatenated(interface, dims, args)
end
function Concatenated(interface::AbstractArrayInterface, dims::Val, args::Tuple)
N = cat_ndims(dims, args...)
return _Concatenated(typeof(interface)(Val(N)), dims, args)
end
function Concatenated(dims::Val, args::Tuple)
return Concatenated(interface(args...), dims, args)
N = cat_ndims(dims, args...)
return _Concatenated(typeof(interface(args...))(Val(N)), dims, args)
end
function Concatenated{Interface}(dims::Val, args::Tuple) where {Interface}
return Concatenated(Interface(), dims, args)
N = cat_ndims(dims, args...)
return _Concatenated(set_interface_ndims(Interface, Val(N)), dims, args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this new call to similar, do you need to set the interface ndims already at this point? I guess this would be automatically determined from the ax later down the line, so I'm just wondering if there is any difference.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true, when you call similar, the ndims in the interface will get overridden by the axes input to similar. I kept it in case there is some other use case where storing the ndims in the interface is useful, but I admit I don't have a particular use case in mind for Concatenated.

I suppose what this does is catch cases where a user specifies an interface but it actually has the wrong ndims for the concatenation expression, I guess that could be checked.

But also thinking about this particular constructor, it should change the Interface type since that mean this constructor doesn't satisfy T(args...) isa T, so I'll change this one.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the latest commits I propose a compromise. If the interface object is passed explicitly, it is taken "as is" and isn't modified, even if the ndims aren't specified or are incorrect, but when it is constructed from the arguments the correct ndims are computed explicitly.

That matches the behavior of Broadcasted:

julia> using Base.Broadcast: DefaultArrayStyle, Broadcasted

julia> bc = Broadcasted(DefaultArrayStyle{1}(), +, (randn(2, 2), randn(2, 2)))
Broadcasted(+, ([0.095430782012298 0.4338409876936397; -0.2285907590132556 2.0739880112106475], [-0.6626604337472756 0.5369654075753642; 0.2681815713554777 -0.7930248596990674]))

julia> bc.style
DefaultArrayStyle{1}()

julia> bc = Broadcasted(+, (randn(2, 2), randn(2, 2)))
Broadcasted(+, ([-0.08113666651442918 -0.11158415095613558; -0.5031937898445847 -1.1018574952687241], [-1.595659893181313 0.12746978522353117; 0.026558187695457296 -0.22363363427492086]))

julia> bc.style
DefaultArrayStyle{2}()

That seems reasonable, since the interface is meant to be something where you can specify it to decide on the dispatch, I think we shouldn't be too opinionated about modifying what the user input since it may have been deliberate.

end

dims(::Concatenated{<:Any,D}) where {D} = D
Expand Down
25 changes: 21 additions & 4 deletions src/defaultarrayinterface.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
# TODO: Add `ndims` type parameter.
struct DefaultArrayInterface <: AbstractArrayInterface end
struct DefaultArrayInterface{N} <: AbstractArrayInterface{N} end

DefaultArrayInterface() = DefaultArrayInterface{Any}()
DefaultArrayInterface(::Val{N}) where {N} = DefaultArrayInterface{N}()
DefaultArrayInterface{M}(::Val{N}) where {M,N} = DefaultArrayInterface{N}()

using TypeParameterAccessors: parenttype
function interface(a::Type{<:AbstractArray})
parenttype(a) === a && return DefaultArrayInterface()
parenttype(a) === a && return DefaultArrayInterface{ndims(a)}()
return interface(parenttype(a))
end

function combine_interface_rule(
interface1::DefaultArrayInterface{N}, interface2::DefaultArrayInterface{N}
) where {N}
return DefaultArrayInterface{N}()
end
function combine_interface_rule(
interface1::DefaultArrayInterface, interface2::DefaultArrayInterface
)
return DefaultArrayInterface{Any}()
end

@interface ::DefaultArrayInterface function Base.getindex(
a::AbstractArray{<:Any,N}, I::Vararg{Int,N}
) where {N}
Expand All @@ -31,6 +45,9 @@ end
return Base.mapreduce(f, op, as...; kwargs...)
end

function arraytype(::DefaultArrayInterface, T::Type)
function arraytype(::DefaultArrayInterface{N}, T::Type) where {N}
return Array{T,N}
end
function arraytype(::DefaultArrayInterface{Any}, T::Type)
return Array{T}
end
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[compat]
Aqua = "0.8"
ArrayLayouts = "1"
DerivableInterfaces = "0.4"
DerivableInterfaces = "0.5"
SafeTestsets = "0.1"
Suppressor = "0.2"
LinearAlgebra = "1"
Expand Down
12 changes: 9 additions & 3 deletions test/SparseArrayDOKs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ using DerivableInterfaces:
using LinearAlgebra: LinearAlgebra

# Define an interface.
struct SparseArrayInterface <: AbstractArrayInterface end
struct SparseArrayInterface{N} <: AbstractArrayInterface{N} end
SparseArrayInterface(::Val{N}) where {N} = SparseArrayInterface{N}()
SparseArrayInterface{M}(::Val{N}) where {M,N} = SparseArrayInterface{N}()

# Define interface functions.
@interface ::SparseArrayInterface function Base.getindex(
Expand All @@ -66,7 +68,9 @@ end
struct SparseArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end
SparseArrayStyle{M}(::Val{N}) where {M,N} = SparseArrayStyle{N}()

DerivableInterfaces.interface(::Type{<:SparseArrayStyle}) = SparseArrayInterface()
function DerivableInterfaces.interface(::Type{<:SparseArrayStyle{N}}) where {N}
return SparseArrayInterface{N}()
end

@derive SparseArrayStyle AbstractArrayStyleOps

Expand Down Expand Up @@ -260,7 +264,9 @@ function DerivableInterfaces.zero!(a::SparseArrayDOK)
end

# Specify the interface the type adheres to.
DerivableInterfaces.interface(::Type{<:SparseArrayDOK}) = SparseArrayInterface()
function DerivableInterfaces.interface(arrayt::Type{<:SparseArrayDOK})
SparseArrayInterface{ndims(arrayt)}()
end

# Define aliases like `SparseMatrixDOK`, `AnySparseArrayDOK`, etc.
@array_aliases SparseArrayDOK
Expand Down
17 changes: 15 additions & 2 deletions test/test_defaultarrayinterface.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Test: @inferred, @testset, @test
using DerivableInterfaces: @interface, DefaultArrayInterface, interface
using DerivableInterfaces: @interface, DefaultArrayInterface, arraytype, interface

# function wrappers to test type-stability
_getindex(A, i...) = @interface DefaultArrayInterface() A[i...]
Expand Down Expand Up @@ -31,8 +31,21 @@ end
@test a == mapreduce(Returns(2), +, A)
end

@testset "DefaultArrayInterface" begin
@test DefaultArrayInterface() === DefaultArrayInterface{Any}()
@test DefaultArrayInterface(Val(2)) === DefaultArrayInterface{2}()
@test DefaultArrayInterface{Any}(Val(2)) === DefaultArrayInterface{2}()
@test DefaultArrayInterface{3}(Val(2)) === DefaultArrayInterface{2}()
end

@testset "arraytype" begin
@test arraytype(DefaultArrayInterface{2}(), Float32) == Matrix{Float32}
@test arraytype(DefaultArrayInterface(), Float32) == Array{Float32}
end

@testset "Broadcast.DefaultArrayStyle" begin
@test interface(Broadcast.DefaultArrayStyle) == DefaultArrayInterface()
@test interface(Broadcast.DefaultArrayStyle{2}) == DefaultArrayInterface{2}()
@test interface(Broadcast.Broadcasted(nothing, +, (randn(2), randn(2)))) ==
DefaultArrayInterface()
DefaultArrayInterface{1}()
end
Loading