@@ -87,3 +87,75 @@ for T in (:Any,)
87
87
@eval Base.:* (a:: AbstractThunk , b:: $T ) = unthunk (a) * b
88
88
@eval Base.:* (a:: $T , b:: AbstractThunk ) = a * unthunk (b)
89
89
end
90
+
91
+ # ################# Composite ##############################################################
92
+
93
+ # We intentionally do not define, `Base.*(::Composite, ::Composite)` as that is not meaningful
94
+ # In general one doesn't have to represent multiplications of 2 differentials
95
+ # Only of a differential and a scaling factor (generally `Real`)
96
+ Base.* (s:: Any , comp:: Composite ) = map (x-> s* x, comp)
97
+ Base.* (comp:: Composite , s:: Any ) = s* comp
98
+
99
+ function Base.:+ (a:: Composite{Primal, NamedTuple{an}} , b:: Composite{Primal, NamedTuple{bn}} ) where Primal
100
+ # Base on the `merge(:;NamedTuple, ::NamedTuple)` code from Base.
101
+ # https://github.com/JuliaLang/julia/blob/592748adb25301a45bd6edef3ac0a93eed069852/base/namedtuple.jl#L220-L231
102
+ if @generated
103
+ names = Base. merge_names (an, bn)
104
+ types = Base. merge_types (names, a, b)
105
+
106
+ vals = map (names) do field
107
+ a_field = :(getproperty (:a , $ (QuoteNode (field))))
108
+ b_field = :(getproperty (:b , $ (QuoteNode (field))))
109
+ val_expr = if Base. sym_in (field, an)
110
+ if Base. sym_in (field, bn)
111
+ # in both
112
+ :($ a_field + $ b_field)
113
+ else
114
+ # only in `an`
115
+ a_field
116
+ end
117
+ else # must be in `b` only
118
+ b_field
119
+ end
120
+ end
121
+ return :(NamedTuple {$names, $types} (($ (vals... ),)))
122
+ else
123
+ names = Base. merge_names (an, bn)
124
+ types = Base. merge_types (names, typeof (a), typeof (b))
125
+ vals = map (names) do field
126
+ val_expr = if Base. sym_in (field, an)
127
+ a_field = getproperty (a, field)
128
+ if Base. sym_in (field, bn)
129
+ # in both
130
+ b_field = getproperty (a, field)
131
+ :($ a_field + $ b_field)
132
+ else
133
+ # only in `an`
134
+ a_field
135
+ end
136
+ else # must be in `b` only
137
+ b_field = getproperty (a, field)
138
+ b_field
139
+ end
140
+ end
141
+ NamedTuple {names,types} (map (n-> getfield (sym_in (n, bn) ? b : a, n), names))
142
+ end
143
+ end
144
+ end
145
+
146
+ # this should not need to be generated, # TODO test that
147
+ function Base.:+ (a:: Composite{Primal, <:Tuple} , b:: Composite{Primal, <:Tuple} ) where Primal
148
+ # TODO : should we even allow it on different lengths?
149
+ short, long = length (a) < length (b) ? (a. backing, b. backing) : (b. backing, a. backing)
150
+ backing = ntuple (length (long)) do ii
151
+ long_val = getfield (long, ii)
152
+ if ii <= length (short)
153
+ short_val = getfield (short, ii)
154
+ return short_val + long_val
155
+ else
156
+ return long_val
157
+ end
158
+ end
159
+
160
+ return Composite {Primal, typeof(backing)} (backing)
161
+ end
0 commit comments