Skip to content

Commit 166a65c

Browse files
committed
define composite
1 parent 6c81169 commit 166a65c

File tree

2 files changed

+132
-0
lines changed

2 files changed

+132
-0
lines changed

src/differential_arithmetic.jl

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,75 @@ for T in (:Any,)
8787
@eval Base.:*(a::AbstractThunk, b::$T) = unthunk(a) * b
8888
@eval Base.:*(a::$T, b::AbstractThunk) = a * unthunk(b)
8989
end
90+
91+
################## Composite ##############################################################
92+
93+
# We intentionally do not define, `Base.*(::Composite, ::Composite)` as that is not meaningful
94+
# In general one doesn't have to represent multiplications of 2 differentials
95+
# Only of a differential and a scaling factor (generally `Real`)
96+
Base.*(s::Any, comp::Composite) = map(x->s*x, comp)
97+
Base.*(comp::Composite, s::Any) = s*comp
98+
99+
function Base.:+(a::Composite{Primal, NamedTuple{an}}, b::Composite{Primal, NamedTuple{bn}}) where Primal
100+
# Base on the `merge(:;NamedTuple, ::NamedTuple)` code from Base.
101+
# https://github.com/JuliaLang/julia/blob/592748adb25301a45bd6edef3ac0a93eed069852/base/namedtuple.jl#L220-L231
102+
if @generated
103+
names = Base.merge_names(an, bn)
104+
types = Base.merge_types(names, a, b)
105+
106+
vals = map(names) do field
107+
a_field = :(getproperty(:a, $(QuoteNode(field))))
108+
b_field = :(getproperty(:b, $(QuoteNode(field))))
109+
val_expr = if Base.sym_in(field, an)
110+
if Base.sym_in(field, bn)
111+
# in both
112+
:($a_field + $b_field)
113+
else
114+
# only in `an`
115+
a_field
116+
end
117+
else # must be in `b` only
118+
b_field
119+
end
120+
end
121+
return :(NamedTuple{$names, $types}(($(vals...),)))
122+
else
123+
names = Base.merge_names(an, bn)
124+
types = Base.merge_types(names, typeof(a), typeof(b))
125+
vals = map(names) do field
126+
val_expr = if Base.sym_in(field, an)
127+
a_field = getproperty(a, field)
128+
if Base.sym_in(field, bn)
129+
# in both
130+
b_field = getproperty(a, field)
131+
:($a_field + $b_field)
132+
else
133+
# only in `an`
134+
a_field
135+
end
136+
else # must be in `b` only
137+
b_field = getproperty(a, field)
138+
b_field
139+
end
140+
end
141+
NamedTuple{names,types}(map(n->getfield(sym_in(n, bn) ? b : a, n), names))
142+
end
143+
end
144+
end
145+
146+
# this should not need to be generated, # TODO test that
147+
function Base.:+(a::Composite{Primal, <:Tuple}, b::Composite{Primal, <:Tuple}) where Primal
148+
# TODO: should we even allow it on different lengths?
149+
short, long = length(a) < length(b) ? (a.backing, b.backing) : (b.backing, a.backing)
150+
backing = ntuple(length(long)) do ii
151+
long_val = getfield(long, ii)
152+
if ii <= length(short)
153+
short_val = getfield(short, ii)
154+
return short_val + long_val
155+
else
156+
return long_val
157+
end
158+
end
159+
160+
return Composite{Primal, typeof(backing)}(backing)
161+
end

src/differentials.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,66 @@ function itself, when that function is not a closure.
284284
"""
285285
const NO_FIELDS = DoesNotExist()
286286

287+
288+
"""
289+
Composite{Primal, T} <: AbstractDifferential
290+
291+
This type represents the differential for a `struct`/`NamedTuple`, or `Tuple`.
292+
`Primal` is the the corresponding primal type that this is a differential for.
293+
294+
`Composite{Primal}` should have fields (technically properties), that match to a subset of the
295+
fields of the primal type; and each should be a differential type matching to the primal
296+
type of that field.
297+
Fields of the Primal that are not present in the Composite are treated as `Zero`.
298+
299+
`T` is an implementation detail representing the backing datastructure.
300+
For Tuple it will be a Tuple, and for everything else it will be a `NamedTuple`.
301+
It should not be passed in by user.
302+
"""
303+
struct Composite{Primal, T} <: AbstractDifferential
304+
backing::T
305+
end
306+
307+
308+
function Composite{Primal}(;kwargs...) where Primal
309+
backing = (; kwargs...)
310+
return Composite{Primal, typeof(backing)}(backing)
311+
end
312+
313+
function Composite{Primal}(args...) where Primal
314+
return Composite{Primal, typeof(args)}(args)
315+
end
316+
317+
function Base.show(io::IO, comp::Composite{Primal})
318+
print(io, "Composite{")
319+
show(io, Primal)
320+
print(io, "}")
321+
# allow Tuple or NamedTuple `show` to do the rendering of brackets etc
322+
show(io, comp.backing)
323+
end
324+
325+
#TODO think about this, for if we are missing fields
326+
#Base.convert(::Type{Primal}, comp::Composite{Primal})
327+
Base.convert(::Type{<:NamedTuple}, comp::Composite{<:Any, <:NamedTuple}) = comp.backing
328+
Base.convert(::Type{<:Tuple}, comp::Composite{<:Any, <:Tuple}) = comp.backing
329+
330+
Base.getindex(comp::Composite, idx) = getindex(comp.backing)
331+
Base.getproperty(comp::Composite, idx) = getproperty(comp.backing, idx)
332+
Base.propertynames(comp::Composite) = propertynames(comp.backing)
333+
Base.iterate(comp::Compositem, args...) = iterate(comp.backing, args...)
334+
Base.length(comp::Composite) = length(comp.backing)
335+
336+
map(f, comp::Composite{Primal, <:Tuple}) where Primal = Composite{Primal}(map(f, comp.backing))
337+
function map(f, comp::Composite{Primal, <:NamedTuple{L}}) where{Primal, L}
338+
vals = map(f, Tuple(comp.backing))
339+
named_vals = NamedTuple{L, typeof(vals)}(vals)
340+
return Composite{Primal}(named_vals)
341+
end
342+
343+
Base.conj(comp::Composite{Primal}) = map(conj, comp)
344+
345+
#==============================================================================#
346+
287347
"""
288348
refine_differential(𝒟::Type, der)
289349

0 commit comments

Comments
 (0)