Skip to content

Commit a5e8275

Browse files
Merge pull request #353 from oscardssmith/patch-1
Add `undefmatrix` (similar to `zeromatrix`)
2 parents aa1fcd2 + 2b2e40c commit a5e8275

File tree

5 files changed

+42
-3
lines changed

5 files changed

+42
-3
lines changed

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ ArrayInterfaceCore.promote_eltype
3939
ArrayInterfaceCore.restructure
4040
ArrayInterfaceCore.safevec
4141
ArrayInterfaceCore.zeromatrix
42+
ArrayInterfaceCore.undefmatrix
4243
```
4344

4445
### Types

lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,21 @@ function zeromatrix(u::Array{T}) where {T}
512512
fill!(out, false)
513513
end
514514

515+
"""
516+
undefmatrix(u::AbstractVector)
517+
518+
Creates the matrix version of `u` with possibly undefined values. Note that this is unique because
519+
`similar(u,length(u),length(u))` returns a mutable type, so it is not type-matching,
520+
while `fill(zero(eltype(u)),length(u),length(u))` doesn't match the array type,
521+
i.e., you'll get a CPU array from a GPU array. The generic fallback is
522+
`u .* u'`, which works on a surprising number of types, but can be broken
523+
with weird (recursive) broadcast overloads. For higher-order tensors, this
524+
returns the matrix linear operator type which acts on the `vec` of the array.
525+
"""
526+
function undefmatrix(u)
527+
similar(u, length(u), length(u))
528+
end
529+
515530
"""
516531
restructure(x,y)
517532

lib/ArrayInterfaceCore/test/runtests.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using ArrayInterfaceCore
2-
using ArrayInterfaceCore: zeromatrix
2+
using ArrayInterfaceCore: zeromatrix, undefmatrix
33
import ArrayInterfaceCore: has_sparsestruct, findstructralnz, fast_scalar_indexing, lu_instance,
44
parent_type, zeromatrix, IndicesInfo
55
using Base: setindex
@@ -11,7 +11,22 @@ using Test
1111
using Aqua
1212
Aqua.test_all(ArrayInterfaceCore)
1313

14-
@test zeromatrix(rand(4,4,4)) == zeros(4*4*4,4*4*4)
14+
@testset "zeromatrix and unsafematrix" begin
15+
for T in (Int, Float32, Float64)
16+
for (vectype, mattype) in ((Vector{T}, Matrix{T}), (SparseVector{T}, SparseMatrixCSC{T, Int}))
17+
v = vectype(rand(T, 4))
18+
um = undefmatrix(v)
19+
@test size(um) == (length(v),length(v))
20+
@test typeof(um) == mattype
21+
@test zeromatrix(v) == zeros(T,length(v),length(v))
22+
end
23+
v = rand(T,4,4,4)
24+
um = undefmatrix(v)
25+
@test size(um) == (length(v),length(v))
26+
@test typeof(um) == Matrix{T}
27+
@test zeromatrix(v) == zeros(T,4*4*4,4*4*4)
28+
end
29+
end
1530

1631
@testset "matrix colors" begin
1732
@test ArrayInterfaceCore.fast_matrix_colors(1) == false

lib/ArrayInterfaceStaticArrays/src/ArrayInterfaceStaticArrays.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@ import ArrayInterfaceStaticArraysCore
99

1010
const CanonicalInt = Union{Int,StaticInt}
1111

12+
function ArrayInterface.undefmatrix(::MArray{S, T, N, L}) where {S, T, N, L}
13+
return MMatrix{L, L, T, L*L}(undef)
14+
end
15+
# SArray doesn't have an undef constructor and is going to be small enough that this is fine.
16+
function ArrayInterface.undefmatrix(s::SArray)
17+
v = vec(s)
18+
return v.*v'
19+
end
1220
ArrayInterface.known_first(::Type{<:StaticArrays.SOneTo}) = 1
1321
ArrayInterface.known_last(::Type{StaticArrays.SOneTo{N}}) where {N} = N
1422
ArrayInterface.known_length(::Type{StaticArrays.SOneTo{N}}) where {N} = N

src/ArrayInterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using ArrayInterfaceCore
44
import ArrayInterfaceCore: allowed_getindex, allowed_setindex!, aos_to_soa, buffer,
55
parent_type, fast_matrix_colors, findstructralnz, has_sparsestruct,
66
issingular, isstructured, matrix_colors, restructure, lu_instance,
7-
safevec, zeromatrix, ColoringAlgorithm, fast_scalar_indexing, parameterless_type,
7+
safevec, zeromatrix, undefmatrix, ColoringAlgorithm, fast_scalar_indexing, parameterless_type,
88
ndims_index, ndims_shape, is_splat_index, is_forwarding_wrapper, IndicesInfo, childdims,
99
parentdims, map_tuple_type, flatten_tuples, GetIndex, SetIndex!, defines_strides,
1010
stride_preserving_index

0 commit comments

Comments
 (0)