Skip to content

Commit 9a2ce84

Browse files
Merge pull request #46 from SciML/restructure
add restructure
2 parents f194a97 + 29de896 commit 9a2ce84

File tree

4 files changed

+44
-5
lines changed

4 files changed

+44
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "2.7.0"
3+
version = "2.8.0"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

README.md

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ Is a form of `vec` which is safe for all values in vector spaces, i.e. if
7575
is already a vector, like an AbstractVector or Number, it will return said
7676
AbstractVector or Number.
7777

78-
## zeromatrix(u::AbstractVector)
78+
## zeromatrix(u)
7979

8080
Creates the zero'd matrix version of `u`. Note that this is unique because
8181
`similar(u,length(u),length(u))` returns a mutable type, so is not type-matching,
@@ -85,15 +85,24 @@ i.e. you'll get a CPU array from a GPU array. The generic fallback is
8585
with weird (recursive) broadcast overloads. For higher order tensors, this
8686
returns the matrix linear operator type which acts on the `vec` of the array.
8787

88-
## List of things to add
88+
## restructure(x,y)
89+
90+
Restructures the object `y` into a shape of `x`, keeping its values intact. For
91+
simple objects like an `Array`, this simply amounts to a reshape. However, for
92+
more complex objects such as an `ArrayPartition`, not all of the structural
93+
information is adequately contained in the type for standard tools to work. In
94+
these cases, `restructure` gives a way to convert for example an `Array` into
95+
a matching `ArrayPartition`.
96+
97+
# List of things to add
8998

9099
- https://github.com/JuliaLang/julia/issues/22216
91100
- https://github.com/JuliaLang/julia/issues/22218
92101
- https://github.com/JuliaLang/julia/issues/22622
93102
- https://github.com/JuliaLang/julia/issues/25760
94103
- https://github.com/JuliaLang/julia/issues/25107
95104

96-
## Array Types to Handle
105+
# Array Types to Handle
97106

98107
The following common array types are being understood and tested as part of this
99108
development.

src/ArrayInterface.jl

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,25 @@ function zeromatrix(u)
447447
x .* x' .* false
448448
end
449449

450+
"""
451+
restructure(x,y)
452+
453+
Restructures the object `y` into a shape of `x`, keeping its values intact. For
454+
simple objects like an `Array`, this simply amounts to a reshape. However, for
455+
more complex objects such as an `ArrayPartition`, not all of the structural
456+
information is adequately contained in the type for standard tools to work. In
457+
these cases, `restructure` gives a way to convert for example an `Array` into
458+
a matching `ArrayPartition`.
459+
"""
460+
function restructure(x,y)
461+
out = similar(x,eltype(y))
462+
out .= y
463+
end
464+
465+
function restructure(x::Array,y)
466+
reshape(convert(Array,y),size(x)...)
467+
end
468+
450469
function __init__()
451470

452471
@require SuiteSparse="4607b0f0-06f3-5cda-b6b1-a6196a1729e9" begin
@@ -461,13 +480,18 @@ function __init__()
461480
ismutable(::Type{<:StaticArrays.StaticArray}) = false
462481
can_setindex(::Type{<:StaticArrays.StaticArray}) = false
463482
ismutable(::Type{<:StaticArrays.MArray}) = true
483+
464484
function lu_instance(_A::StaticArrays.StaticMatrix{N,N}) where {N}
465485
A = StaticArrays.SArray(_A)
466486
L = LowerTriangular(A)
467487
U = UpperTriangular(A)
468488
p = StaticArrays.SVector{N,Int}(1:N)
469489
return StaticArrays.LU(L, U, p)
470490
end
491+
492+
function restructure(x::StaticArrays.SArray,y)
493+
error("Currently not supported")
494+
end
471495
end
472496

473497
@require LabelledArrays="2ee39098-c373-598a-b85f-a56591580800" begin
@@ -483,7 +507,9 @@ function __init__()
483507
end
484508

485509
@require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
486-
include("cuarrays.jl")
510+
@require Adapt="79e6a3ab-5dfb-504d-930d-738a2a938a0e" begin
511+
include("cuarrays.jl")
512+
end
487513
end
488514

489515
@require BandedMatrices="aae01518-5342-5314-be14-df237901396f" begin

src/cuarrays.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,7 @@ function Base.setindex(x::CuArrays.CuArray,v,i::Int)
77
allowed_setindex!(_x,v,i)
88
_x
99
end
10+
11+
function restructure(x::CuArrays,y)
12+
reshape(adapt(typeof(x),y),size(x)...)
13+
end

0 commit comments

Comments
 (0)