Skip to content

Commit fcd37dc

Browse files
authored
Merge branch 'master' into jc/NamedArrayPartition
2 parents f673e37 + def2cbd commit fcd37dc

21 files changed

+437
-189
lines changed

.github/workflows/CI.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ jobs:
2020
- Downstream
2121
version:
2222
- '1'
23-
- '1.6'
2423
steps:
2524
- uses: actions/checkout@v4
2625
- uses: julia-actions/setup-julia@v1

.github/workflows/Downstream.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
strategy:
1515
fail-fast: false
1616
matrix:
17-
julia-version: [1,1.6]
17+
julia-version: [1]
1818
os: [ubuntu-latest]
1919
package:
2020
- {user: SciML, repo: SciMLBase.jl, group: Core}

.github/workflows/SpellCheck.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
name: Spell Check
2+
3+
on: [pull_request]
4+
5+
jobs:
6+
typos-check:
7+
name: Spell Check with Typos
8+
runs-on: ubuntu-latest
9+
steps:
10+
- name: Checkout Actions Repository
11+
uses: actions/checkout@v4
12+
- name: Check spelling
13+
uses: crate-ci/[email protected]

.typos.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
[default.extend-words]

Project.toml

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RecursiveArrayTools"
22
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "2.39.0"
4+
version = "3.3.3"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -12,40 +12,62 @@ IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d"
1212
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1313
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
1414
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
15+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1516
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1617
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1718
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
1819
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1920

2021
[weakdeps]
22+
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
2123
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
2224
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
2325
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2426
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2527

2628
[extensions]
29+
RecursiveArrayToolsFastBroadcastExt = "FastBroadcast"
2730
RecursiveArrayToolsMeasurementsExt = "Measurements"
2831
RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements"
2932
RecursiveArrayToolsTrackerExt = "Tracker"
3033
RecursiveArrayToolsZygoteExt = "Zygote"
3134

3235
[compat]
33-
Adapt = "3"
36+
Adapt = "3, 4"
37+
Aqua = "0.8"
3438
ArrayInterface = "7"
3539
DocStringExtensions = "0.8, 0.9"
40+
FastBroadcast = "0.2.8"
41+
ForwardDiff = "0.10"
3642
GPUArraysCore = "0.1"
3743
IteratorInterfaceExtensions = "1"
44+
LabelledArrays = "1"
45+
LinearAlgebra = "1"
46+
Measurements = "2.3"
47+
MonteCarloMeasurements = "1.1"
48+
NLsolve = "4"
49+
OrdinaryDiffEq = "6"
50+
Pkg = "1"
51+
Random = "1"
3852
RecipesBase = "0.7, 0.8, 1.0"
3953
Requires = "1.0"
54+
SafeTestsets = "0.1"
55+
SparseArrays = "1"
56+
StaticArrays = "1.6"
4057
StaticArraysCore = "1.1"
4158
Statistics = "1"
42-
SymbolicIndexingInterface = "0.3"
59+
StructArrays = "0.6"
60+
SymbolicIndexingInterface = "0.3.1"
4361
Tables = "1"
62+
Test = "1"
63+
Tracker = "0.2"
64+
Unitful = "1"
4465
Zygote = "0.6.56"
45-
julia = "1.6"
66+
julia = "1.9"
4667

4768
[extras]
4869
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
70+
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
4971
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
5072
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
5173
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
@@ -63,4 +85,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
6385
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6486

6587
[targets]
66-
test = ["SafeTestsets", "Aqua", "ForwardDiff", "LabelledArrays", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote", "Measurements"]
88+
test = ["SafeTestsets", "Aqua", "FastBroadcast", "ForwardDiff", "LabelledArrays", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote", "Measurements"]

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
44

55
[compat]
66
Documenter = "1"
7-
RecursiveArrayTools = "2.32"
7+
RecursiveArrayTools = "3"
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
module RecursiveArrayToolsFastBroadcastExt
2+
3+
using RecursiveArrayTools
4+
using FastBroadcast
5+
using StaticArraysCore
6+
7+
const AbstractVectorOfSArray = AbstractVectorOfArray{T,N,<:AbstractVector{<:StaticArraysCore.SArray}} where {T,N}
8+
9+
@inline function FastBroadcast.fast_materialize!(::FastBroadcast.Static.False, ::DB, dst::AbstractVectorOfSArray, bc::Broadcast.Broadcasted{S}) where {S,DB}
10+
if FastBroadcast.use_fast_broadcast(S)
11+
for i in 1:length(dst.u)
12+
unpacked = RecursiveArrayTools.unpack_voa(bc, i)
13+
dst.u[i] = StaticArraysCore.similar_type(dst.u[i])(unpacked[j] for j in eachindex(unpacked))
14+
end
15+
else
16+
Broadcast.materialize!(dst, bc)
17+
end
18+
return dst
19+
end
20+
21+
end # module

ext/RecursiveArrayToolsZygoteExt.jl

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ else
1111
end
1212

1313
# Define a new species of projection operator for this type:
14-
ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}()
14+
# ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}()
1515

1616
function ChainRulesCore.rrule(T::Type{<:RecursiveArrayTools.GPUArraysCore.AbstractGPUArray},
1717
xs::AbstractVectorOfArray)
@@ -95,17 +95,17 @@ end
9595
VectorOfArray(u),
9696
y -> begin
9797
y isa Ref && (y = VectorOfArray(y[].u))
98-
(VectorOfArray([y[ntuple(x -> Colon(), ndims(y.u) - 1)..., i]
99-
for i in 1:size(y.u)[end]]),)
98+
(VectorOfArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i]
99+
for i in 1:size(y)[end]]),)
100100
end
101101
end
102102

103103
@adjoint function DiffEqArray(u, t)
104104
DiffEqArray(u, t),
105105
y -> begin
106106
y isa Ref && (y = VectorOfArray(y[].u))
107-
(DiffEqArray([y[ntuple(x -> Colon(), ndims(y.u) - 1)..., i]
108-
for i in 1:size(y.u)[end]],
107+
(DiffEqArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i]
108+
for i in 1:size(y)[end]],
109109
t), nothing)
110110
end
111111
end
@@ -117,4 +117,53 @@ end
117117
A.x, literal_ArrayPartition_x_adjoint
118118
end
119119

120+
@adjoint function Array(VA::AbstractVectorOfArray)
121+
Array(VA),
122+
y -> (Array(y),)
120123
end
124+
125+
126+
ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a)))
127+
128+
function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x)
129+
arr = reshape(x, p.sz)
130+
return VectorOfArray([arr[:, i] for i in 1:p.sz[end]])
131+
end
132+
133+
function Zygote.unbroadcast(x::AbstractVectorOfArray, x̄)
134+
N = ndims(x̄)
135+
if length(x) == length(x̄)
136+
Zygote._project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors
137+
else
138+
dims = ntuple(d -> size(x, d) == 1 ? d : ndims(x̄)+1, ndims(x̄))
139+
Zygote._project(x, Zygote.accum_sum(x̄; dims = dims))
140+
end
141+
end
142+
143+
@adjoint Broadcast.broadcasted(::Broadcast.AbstractArrayStyle, f::F, a::AbstractVectorOfArray, b) where {F} = _broadcast_generic(__context__, f, a, b)
144+
@adjoint Broadcast.broadcasted(::Broadcast.AbstractArrayStyle, f::F, a, b::AbstractVectorOfArray) where {F} = _broadcast_generic(__context__, f, a, b)
145+
@adjoint Broadcast.broadcasted(::Broadcast.AbstractArrayStyle, f::F, a::AbstractVectorOfArray, b::AbstractVectorOfArray) where {F} = _broadcast_generic(__context__, f, a, b)
146+
147+
@inline function _broadcast_generic(__context__, f::F, args...) where {F}
148+
T = Broadcast.combine_eltypes(f, args)
149+
# Avoid generic broadcasting in two easy cases:
150+
if T == Bool
151+
return (f.(args...), _ -> nothing)
152+
elseif T <: Union{Real, Complex} && isconcretetype(T) && Zygote._dual_purefun(F) && all(Zygote._dual_safearg, args) && !Zygote.isderiving()
153+
return Zygote.broadcast_forward(f, args...)
154+
end
155+
len = Zygote.inclen(args)
156+
y∂b = Zygote._broadcast((x...) -> Zygote._pullback(__context__, f, x...), args...)
157+
y = broadcast(first, y∂b)
158+
function ∇broadcasted(ȳ)
159+
y∂b = y∂b isa AbstractVectorOfArray ? Iterators.flatten(y∂b.u) : y∂b
160+
ȳ = ȳ isa AbstractVectorOfArray ? Iterators.flatten.u) : ȳ
161+
dxs_zip = map(((_, pb), ȳ₁) -> pb(ȳ₁), y∂b, ȳ)
162+
getters = ntuple(i -> Zygote.StaticGetter{i}(), len)
163+
dxs = map(g -> Zygote.collapse_nothings(map(g, dxs_zip)), getters)
164+
(nothing, Zygote.accum_sum(dxs[1]), map(Zygote.unbroadcast, args, Base.tail(dxs))...)
165+
end
166+
return y, ∇broadcasted
167+
end
168+
169+
end # module

src/RecursiveArrayTools.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using DocStringExtensions
88
using RecipesBase, StaticArraysCore, Statistics,
99
ArrayInterface, LinearAlgebra
1010
using SymbolicIndexingInterface
11+
using SparseArrays
1112

1213
import Adapt
1314

@@ -27,7 +28,8 @@ function Base.show(io::IO, x::Union{ArrayPartition, AbstractVectorOfArray})
2728
end
2829

2930
import GPUArraysCore
30-
Base.convert(T::Type{<:GPUArraysCore.AbstractGPUArray}, VA::AbstractVectorOfArray) = T(VA)
31+
Base.convert(T::Type{<:GPUArraysCore.AnyGPUArray}, VA::AbstractVectorOfArray) = stack(VA.u)
32+
(T::Type{<:GPUArraysCore.AnyGPUArray})(VA::AbstractVectorOfArray) = T(Array(VA))
3133

3234
import Requires
3335
@static if !isdefined(Base, :get_extension)

0 commit comments

Comments
 (0)