Skip to content

Commit 532bddd

Browse files
authored
doc strings type and add trait inheritance section (#186)
* doc strings type and add trait inheritance section * Add consistent type return annotation to doc strings (where appropriate) * Section on inheriting behavior using `parent_type`. * Put doc strings in manual after text * Add dimension section to docs
1 parent ae255a9 commit 532bddd

File tree

8 files changed

+110
-70
lines changed

8 files changed

+110
-70
lines changed

docs/src/index.md

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,32 @@ CurrentModule = ArrayInterface
44

55
# ArrayInterface
66

7-
```@index
8-
```
7+
Designs for new Base array interface primitives, used widely through scientific machine learning (SciML) and other organizations
98

10-
```@autodocs
11-
Modules = [ArrayInterface]
9+
## Inheriting Array Traits
10+
11+
Creating an array type with unique behavior in Julia is often accomplished by creating a lazy wrapper around previously defined array types.
12+
This allows the new array type to inherit functionality by redirecting methods to the parent array (e.g., `Base.size(x::Wrapper) = size(parent(x))`).
13+
Generic design limits the need to define an excessive number of methods like this.
14+
However, methods used to describe a type's traits often need to be explicitly defined for each trait method.
15+
`ArrayInterface` assists with this by providing information about the parent type using [`ArrayInterface.parent_type`](@ref).
16+
By default `ArrayInterface.parent_type(::Type{T})` returns `T` (analogous to `Base.parent(x) = x`).
17+
If any type other than `T` is returned we assume `T` wraps a parent structure, so methods know to unwrap instances of `T`.
18+
It is also assumed that if `T` has a parent type `Base.parent` is defined.
19+
20+
For those authoring new trait methods, this may change the default definition from `has_trait(::Type{T}) where {T} = false`, to:
21+
```julia
22+
function has_trait(::Type{T}) where {T}
23+
if parent_type(T) <:T
24+
return false
25+
else
26+
return has_trait(parent_type(T))
27+
end
28+
end
1229
```
1330

31+
Most traits in `ArrayInterface` are a variant on this pattern.
32+
1433
## Static Traits
1534

1635
The size along one or more dimensions of an array may be known at compile time.
@@ -50,3 +69,36 @@ Generic support for `ArrayInterface.known_size` relies on calling `known_length`
5069
Therefore, the recommended approach for supporting static sizing in newly defined array types is defining a new `axes_types` method.
5170

5271
Static information related to subtypes of `AbstractRange` include `known_length`, `known_first`, `known_step`, and `known_last`.
72+
73+
## Dimensions
74+
75+
Methods such as `size(x, dim)` need to map `dim` to the dimensions of `x`.
76+
Typically, `dim` is an `Int` with an invariant mapping to the dimensions of `x`.
77+
Some methods accept `:` or a tuple of dimensions as an argument.
78+
`ArrayInterface` also considers `StaticInt` a viable dimension argument.
79+
80+
[`ArrayInterface.to_dims`](@ref) helps ensure that `dim` is converted to a viable dimension mapping in a manner that helps with type stability.
81+
For example, all `Integers` passed to `to_dims` are converted to `Int` (unless `dim` is a `StaticInt`).
82+
This is also useful for arrays that uniquely label dimensions, in which case `to_dims` serves as a safe point of hooking into existing methods with dimension arguments.
83+
`ArrayInterface` also defines native `Symbol` to `Int` and `StaticSymbol` to `StaticInt` mapping for arrays defining [`ArrayInterface.dimnames`](@ref).
84+
85+
Methods accepting dimension specific arguments should use some variation of the following pattern.
86+
87+
```julia
88+
f(x, dim) = f(x, ArrayInterface.to_dims(x, dim))
89+
f(x, dim::Int) = ...
90+
f(x, dim::StaticInt) = ...
91+
```
92+
93+
If `x`'s first dimension is named `:dim_1` then calling `f(x, :dim_1)` would result in `f(x, 1)`.
94+
If users knew they always wanted to call `f(x, 2)` then they could define `h(x) = f(x, static(2))`, ensuring `f` passes along that information while compiling.
95+
96+
## API
97+
98+
```@index
99+
```
100+
101+
```@autodocs
102+
Modules = [ArrayInterface]
103+
```
104+

src/ArrayInterface.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ const LoTri{T,M} = Union{LowerTriangular{T,M},UnitLowerTriangular{T,M}}
5252
include("array_index.jl")
5353

5454
"""
55-
parent_type(::Type{T})
55+
parent_type(::Type{T}) -> Type
5656
5757
Returns the parent array that type `T` wraps.
5858
"""
@@ -82,7 +82,7 @@ _has_parent(::Type{T}, ::Type{T}) where {T} = False()
8282
_has_parent(::Type{T1}, ::Type{T2}) where {T1,T2} = True()
8383

8484
"""
85-
known_length(::Type{T})
85+
known_length(::Type{T}) -> Union{Int,Nothing}
8686
8787
If `length` of an instance of type `T` is known at compile time, return it.
8888
Otherwise, return `nothing`.
@@ -119,7 +119,7 @@ can_change_size(::Type{<:Base.ImmutableDict}) = false
119119
function ismutable end
120120

121121
"""
122-
ismutable(x::DataType)
122+
ismutable(x::DataType) -> Bool
123123
124124
Query whether a type is mutable or not, see
125125
https://github.com/JuliaDiffEq/RecursiveArrayTools.jl/issues/19.
@@ -168,7 +168,7 @@ function Base.setindex(x::AbstractMatrix, v, i::Int, j::Int)
168168
end
169169

170170
"""
171-
can_setindex(x::DataType)
171+
can_setindex(x::DataType) -> Bool
172172
173173
Query whether a type can use `setindex!`.
174174
"""
@@ -208,7 +208,7 @@ A scalar `setindex!` which is always allowed.
208208
allowed_setindex!(x, v, i...) = Base.setindex!(x, v, i...)
209209

210210
"""
211-
isstructured(x::DataType)
211+
isstructured(x::DataType) -> Bool
212212
213213
Query whether a type is a representation of a structured matrix.
214214
"""
@@ -224,7 +224,7 @@ isstructured(::Bidiagonal) = true
224224
isstructured(::Diagonal) = true
225225

226226
"""
227-
has_sparsestruct(x::AbstractArray)
227+
has_sparsestruct(x::AbstractArray) -> Bool
228228
229229
Determine whether `findstructralnz` accepts the parameter `x`.
230230
"""
@@ -238,7 +238,7 @@ has_sparsestruct(x::Type{<:Tridiagonal}) = true
238238
has_sparsestruct(x::Type{<:SymTridiagonal}) = true
239239

240240
"""
241-
issingular(A::AbstractMatrix)
241+
issingular(A::AbstractMatrix) -> Bool
242242
243243
Determine whether a given abstract matrix is singular.
244244
"""
@@ -354,7 +354,7 @@ Returns the number.
354354
lu_instance(a::Any) = lu(a, check = false)
355355

356356
"""
357-
safevec(v)
357+
safevec(v)
358358
359359
It is a form of `vec` which is safe for all values in vector spaces, i.e., if it
360360
is already a vector, like an AbstractVector or Number, it will return said
@@ -409,7 +409,7 @@ struct CPUIndex <: AbstractCPU end
409409
struct GPU <: AbstractDevice end
410410

411411
"""
412-
device(::Type{T})
412+
device(::Type{T}) -> AbstractDevice
413413
414414
Indicates the most efficient way to access elements from the collection in low-level code.
415415
For `GPUArrays`, will return `ArrayInterface.GPU()`.
@@ -617,7 +617,7 @@ end
617617

618618

619619
"""
620-
is_lazy_conjugate(::AbstractArray)
620+
is_lazy_conjugate(::AbstractArray) -> Bool
621621
622622
Determine if a given array will lazyily take complex conjugates, such as with `Adjoint`. This will work with
623623
nested wrappers, so long as there is no type in the chain of wrappers such that `parent_type(T) == T`

src/axes.jl

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11

22
"""
3-
axes_types(::Type{T}, dim)
3+
axes_types(::Type{T}) -> Type{Tuple{Vararg{AbstractUnitRange{Int}}}}
4+
axes_types(::Type{T}, dim) -> Type{AbstractUnitRange{Int}}
45
5-
Returns the axis type along dimension `dim`.
6+
Returns the type of each axis for the `T`, or the type of of the axis along dimension `dim`.
67
"""
78
axes_types(x, dim) = axes_types(typeof(x), dim)
89
@inline axes_types(::Type{T}, dim) where {T} = axes_types(T, to_dims(T, dim))
@@ -21,11 +22,6 @@ end
2122
end
2223
end
2324

24-
"""
25-
axes_types(::Type{T}) -> Type
26-
27-
Returns the type of the axes for `T`
28-
"""
2925
axes_types(x) = axes_types(typeof(x))
3026
axes_types(::Type{T}) where {T<:Array} = Tuple{Vararg{OneTo{Int},ndims(T)}}
3127
function axes_types(::Type{T}) where {T}
@@ -115,11 +111,6 @@ similar_type(::Type{OptionallyStaticUnitRange{One,StaticInt{N}}}, ::Type{Int}, :
115111
similar_type(::Type{OptionallyStaticUnitRange{One,StaticInt{N}}}, ::Type{Int}, ::Type{OptionallyStaticUnitRange{One,Int}}) where {N} = OptionallyStaticUnitRange{One,Int}
116112
similar_type(::Type{OptionallyStaticUnitRange{One,StaticInt{N1}}}, ::Type{Int}, ::Type{OptionallyStaticUnitRange{One,StaticInt{N2}}}) where {N1,N2} = OptionallyStaticUnitRange{One,StaticInt{N2}}
117113

118-
"""
119-
axes(A, d)
120-
121-
Return a valid range that maps to each index along dimension `d` of `A`.
122-
"""
123114
@inline axes(a, dim) = axes(a, to_dims(a, dim))
124115
@inline axes(a, dims::Tuple{Vararg{Any,K}}) where {K} = (axes(a, first(dims)), axes(a, tail(dims))...)
125116
@inline axes(a, dims::Tuple{T}) where {T} = (axes(a, first(dims)), )
@@ -157,9 +148,10 @@ end
157148
@inline _axes(A::Base.ReshapedArray, dim::Integer) = Base.axes(A, Int(dim)) # TODO implement ArrayInterface version
158149

159150
"""
160-
axes(A)
151+
axes(A) -> Tuple{Vararg{AbstractUnitRange{Int}}}
152+
axes(A, dim) -> AbstractUnitRange{Int}
161153
162-
Return a tuple of ranges where each range maps to each element along a dimension of `A`.
154+
Returns the axis associated with each dimension of `A` or dimension `dim`
163155
"""
164156
@inline function axes(a::A) where {A}
165157
if parent_type(A) <: A

src/dimensions.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ _ndims_index(::Type{I}, i::StaticInt) where {I} = ndims_index(_get_tuple(I, i))
3737
ndims_index(::Type{I}) where {N,I<:Tuple{Vararg{Any,N}}} = eachop(_ndims_index, nstatic(Val(N)), I)
3838

3939
"""
40-
from_parent_dims(::Type{T})::Tuple{Vararg{Union{Int,StaticInt}}}
40+
from_parent_dims(::Type{T}) -> Tuple{Vararg{Union{Int,StaticInt}}}
4141
4242
Returns the mapping from parent dimensions to child dimensions.
4343
"""
@@ -72,7 +72,7 @@ function from_parent_dims(::Type{R}) where {T,N,S,A,R<:ReinterpretArray{T,N,S,A}
7272
end
7373

7474
"""
75-
from_parent_dims(::Type{T}, dim)::Union{Int,StaticInt}
75+
from_parent_dims(::Type{T}, dim) -> Union{Int,StaticInt}
7676
7777
Returns the mapping from child dimensions to parent dimensions.
7878
"""
@@ -98,7 +98,7 @@ function from_parent_dims(::Type{T}, ::StaticInt{dim}) where {T,dim}
9898
end
9999

100100
"""
101-
to_parent_dims(::Type{T})::Tuple{Vararg{Union{Int,StaticInt}}}
101+
to_parent_dims(::Type{T}) -> Tuple{Vararg{Union{Int,StaticInt}}}
102102
103103
Returns the mapping from child dimensions to parent dimensions.
104104
"""
@@ -130,7 +130,7 @@ function to_parent_dims(::Type{R}) where {T,N,S,A,R<:ReinterpretArray{T,N,S,A}}
130130
end
131131

132132
"""
133-
to_parent_dims(::Type{T}, dim)::Union{Int,StaticInt}
133+
to_parent_dims(::Type{T}, dim) -> Union{Int,StaticInt}
134134
135135
Returns the mapping from child dimensions to parent dimensions.
136136
"""
@@ -156,7 +156,7 @@ function to_parent_dims(::Type{T}, ::StaticInt{dim}) where {T,dim}
156156
end
157157

158158
"""
159-
has_dimnames(::Type{T})::Bool
159+
has_dimnames(::Type{T}) -> Bool
160160
161161
Returns `true` if `x` has names for each dimension.
162162
"""
@@ -173,8 +173,8 @@ end
173173
const SUnderscore = StaticSymbol(:_)
174174

175175
"""
176-
dimnames(::Type{T})::Tuple{Vararg{StaticSymbol}}
177-
dimnames(::Type{T}, dim)::StaticSymbol
176+
dimnames(::Type{T}) -> Tuple{Vararg{StaticSymbol}}
177+
dimnames(::Type{T}, dim) -> StaticSymbol
178178
179179
Return the names of the dimensions for `x`.
180180
"""
@@ -204,7 +204,7 @@ function dimnames(::Type{T}) where {T<:SubArray}
204204
end
205205

206206
"""
207-
to_dims(::Type{T}, dim)::Union{Int,StaticInt}
207+
to_dims(::Type{T}, dim) -> Union{Int,StaticInt}
208208
209209
This returns the dimension(s) of `x` corresponding to `d`.
210210
"""

src/indexing.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

22
"""
3-
is_canonical(::Type{I})::StaticBool
3+
is_canonical(::Type{I}) -> StaticBool
44
55
Returns `True` if instances of `I` can be used for indexing without any further change
66
(e.g., `Int`, `StaticInt`, `UnitRange{Int}`)
@@ -43,7 +43,7 @@ is_linear_indexing(A, args::Tuple{Arg}) where {Arg} = ndims_index(Arg) < 2
4343
is_linear_indexing(A, args::Tuple{Arg,Vararg{Any}}) where {Arg} = false
4444

4545
"""
46-
to_indices(A, inds::Tuple)::Tuple
46+
to_indices(A, inds::Tuple) -> Tuple
4747
4848
Maps indexing arguments `inds` to the axes of `A`, ensures they are converted to a native
4949
indexing form, and that they are inbounds. Unless all indices in `inds` return `static(true)`
@@ -255,7 +255,7 @@ function unsafe_reconstruct(A::AbstractUnitRange, data; kwargs...)
255255
end
256256

257257
"""
258-
to_axes(A, inds)
258+
to_axes(A, inds) -> Tuple
259259
260260
Construct new axes given the corresponding `inds` constructed after
261261
`to_indices(A, args) -> inds`. This method iterates through each pair of axes and

src/ranges.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

22
"""
3-
known_first(::Type{T})
3+
known_first(::Type{T}) -> Union{Int,Nothing}
44
55
If `first` of an instance of type `T` is known at compile time, return it.
66
Otherwise, return `nothing`.
@@ -24,7 +24,7 @@ end
2424
known_first(::Type{Base.OneTo{T}}) where {T} = one(T)
2525

2626
"""
27-
known_last(::Type{T})
27+
known_last(::Type{T}) -> Union{Int,Nothing}
2828
2929
If `last` of an instance of type `T` is known at compile time, return it.
3030
Otherwise, return `nothing`.
@@ -48,7 +48,7 @@ function known_last(::Type{T}) where {T}
4848
end
4949

5050
"""
51-
known_step(::Type{T})
51+
known_step(::Type{T}) -> Union{Int,Nothing}
5252
5353
If `step` of an instance of type `T` is known at compile time, return it.
5454
Otherwise, return `nothing`.
@@ -572,14 +572,14 @@ end
572572
end
573573

574574
"""
575-
indices(x, dim)
575+
indices(x, dim) -> AbstractUnitRange{Int}
576576
577577
Given an array `x`, this returns the indices along dimension `dim`.
578578
"""
579579
@inline indices(x, d) = indices(axes(x, d))
580580

581581
"""
582-
indices(x) -> AbstractUnitRange
582+
indices(x) -> AbstractUnitRange{Int}
583583
584584
Returns valid indices for the entire length of `x`.
585585
"""
@@ -594,7 +594,7 @@ end
594594
@inline indices(x::AbstractUnitRange{<:Integer}) = Base.Slice(OptionallyStaticUnitRange(x))
595595

596596
"""
597-
indices(x::Tuple) -> AbstractUnitRange
597+
indices(x::Tuple) -> AbstractUnitRange{Int}
598598
599599
Returns valid indices for the entire length of each array in `x`.
600600
"""
@@ -604,7 +604,7 @@ Returns valid indices for the entire length of each array in `x`.
604604
end
605605

606606
"""
607-
indices(x::Tuple, dim) -> AbstractUnitRange
607+
indices(x::Tuple, dim) -> AbstractUnitRange{Int}
608608
609609
Returns valid indices for each array in `x` along dimension `dim`
610610
"""
@@ -614,7 +614,7 @@ Returns valid indices for each array in `x` along dimension `dim`
614614
end
615615

616616
"""
617-
indices(x::Tuple, dim::Tuple) -> AbstractUnitRange
617+
indices(x::Tuple, dim::Tuple) -> AbstractUnitRange{Int}
618618
619619
Returns valid indices given a tuple of arrays `x` and tuple of dimesions for each
620620
respective array (`dim`).
@@ -625,7 +625,7 @@ respective array (`dim`).
625625
end
626626

627627
"""
628-
indices(x, dim::Tuple) -> Tuple{Vararg{AbstractUnitRange}}
628+
indices(x, dim::Tuple) -> Tuple{Vararg{AbstractUnitRange{Int}}}
629629
630630
Returns valid indices for array `x` along each dimension specified in `dim`.
631631
"""

0 commit comments

Comments
 (0)