@@ -30,6 +30,13 @@ To be able to work with multi-threading, it should also implement:
30
30
- `split(acc::T)`
31
31
- `combine(acc::T, acc2::T)`
32
32
33
+ If two accumulators of the same type should be merged in some non-trivial way, other than
34
+ always keeping the second one over the first, `merge(acc1::T, acc2::T)` should be defined.
35
+
36
+ If limiting the accumulator to a subset of `VarName`s is a meaningful operation and should
37
+ do something other than copy the original accumulator, then
38
+ `subset(acc::T, vns::AbstractVector{<:VarnName})` should be defined.`
39
+
33
40
See the documentation for each of these functions for more details.
34
41
"""
35
42
abstract type AbstractAccumulator end
@@ -113,6 +120,24 @@ used by various AD backends, should implement a method for this function.
113
120
"""
114
121
convert_eltype (:: Type , acc:: AbstractAccumulator ) = acc
115
122
123
+ """
124
+ subset(acc::AbstractAccumulator, vns::AbstractVector{<:VarName})
125
+
126
+ Return a new accumulator that only contains the information for the `VarName`s in `vns`.
127
+
128
+ By default returns a copy of `acc`. Subtypes should override this behaviour as needed.
129
+ """
130
+ subset (acc:: AbstractAccumulator , :: AbstractVector{<:VarName} ) = copy (acc)
131
+
132
+ """
133
+ merge(acc1::AbstractAccumulator, acc2::AbstractAccumulator)
134
+
135
+ Merge two accumulators of the same type. Returns a new accumulator of the same type.
136
+
137
+ By default returns a copy of `acc2`. Subtypes should override this behaviour as needed.
138
+ """
139
+ Base. merge (acc1:: AbstractAccumulator , acc2:: AbstractAccumulator ) = copy (acc2)
140
+
116
141
"""
117
142
AccumulatorTuple{N,T<:NamedTuple}
118
143
@@ -158,6 +183,50 @@ function Base.convert(::Type{AccumulatorTuple{N,T}}, accs::AccumulatorTuple{N})
158
183
return AccumulatorTuple (convert (T, accs. nt))
159
184
end
160
185
186
+ """
187
+ subset(at::AccumulatorTuple, vns::AbstractVector{<:VarName})
188
+
189
+ Replace each accumulator `acc` in `at` with `subset(acc, vns)`.
190
+ """
191
+ function subset (at:: AccumulatorTuple , vns:: AbstractVector{<:VarName} )
192
+ return AccumulatorTuple (map (Base. Fix2 (subset, vns), at. nt))
193
+ end
194
+
195
+ """
196
+ _joint_keys(nt1::NamedTuple, nt2::NamedTuple)
197
+
198
+ A helper function that returns three tuples of keys given two `NamedTuple`s:
199
+ The keys only in `nt1`, only in `nt2`, and in both, and in that order.
200
+
201
+ Implemented as a generated function to enable constant propagation of the result in `merge`.
202
+ """
203
+ @generated function _joint_keys (
204
+ nt1:: NamedTuple{names1} , nt2:: NamedTuple{names2}
205
+ ) where {names1,names2}
206
+ only_in_nt1 = tuple (setdiff (names1, names2)... )
207
+ only_in_nt2 = tuple (setdiff (names2, names1)... )
208
+ in_both = tuple (intersect (names1, names2)... )
209
+ return :($ only_in_nt1, $ only_in_nt2, $ in_both)
210
+ end
211
+
212
+ """
213
+ merge(at1::AccumulatorTuple, at2::AccumulatorTuple)
214
+
215
+ Merge two `AccumulatorTuple`s.
216
+
217
+ For any `accumulator_name` that exists in both `at1` and `at2`, we call `merge` on the two
218
+ accumulators themselves. Other accumulators are copied.
219
+ """
220
+ function Base. merge (at1:: AccumulatorTuple , at2:: AccumulatorTuple )
221
+ keys_in_at1, keys_in_at2, keys_in_both = _joint_keys (at1. nt, at2. nt)
222
+ accs_in_at1 = (getfield (at1. nt, key) for key in keys_in_at1)
223
+ accs_in_at2 = (getfield (at2. nt, key) for key in keys_in_at2)
224
+ accs_in_both = (
225
+ merge (getfield (at1. nt, key), getfield (at2. nt, key)) for key in keys_in_both
226
+ )
227
+ return AccumulatorTuple (accs_in_at1... , accs_in_both... , accs_in_at2... )
228
+ end
229
+
161
230
"""
162
231
setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator)
163
232
0 commit comments