@@ -197,50 +197,71 @@ class SubfunctionBlock(Block):
197197 def __init__ (self , func , idx , ad_block_tag = None ):
198198 super ().__init__ (ad_block_tag = ad_block_tag )
199199 self .add_dependency (func )
200- self .idx = idx
200+ self .sub_idx = idx
201201
202202 def evaluate_adj_component (self , inputs , adj_inputs , block_variable , idx ,
203203 prepared = None ):
204204 eval_adj = firedrake .Cofunction (block_variable .output .function_space ().dual ())
205205 if type (adj_inputs [0 ]) is firedrake .Cofunction :
206- eval_adj .sub (self .idx ).assign (adj_inputs [0 ])
206+ eval_adj .sub (self .sub_idx ).assign (adj_inputs [0 ])
207207 else :
208- eval_adj .sub (self .idx ).assign (adj_inputs [0 ].function )
208+ eval_adj .sub (self .sub_idx ).assign (adj_inputs [0 ].function )
209209 return eval_adj
210210
211211 def evaluate_tlm_component (self , inputs , tlm_inputs , block_variable , idx ,
212212 prepared = None ):
213- return firedrake .Function .sub (tlm_inputs [0 ], self .idx )
213+ return firedrake .Function .sub (tlm_inputs [0 ], self .sub_idx )
214214
215215 def evaluate_hessian_component (self , inputs , hessian_inputs , adj_inputs ,
216216 block_variable , idx ,
217217 relevant_dependencies , prepared = None ):
218218 eval_hessian = firedrake .Cofunction (block_variable .output .function_space ().dual ())
219- eval_hessian .sub (self .idx ).assign (hessian_inputs [0 ])
219+ eval_hessian .sub (self .sub_idx ).assign (hessian_inputs [0 ])
220220 return eval_hessian
221221
222222 def recompute_component (self , inputs , block_variable , idx , prepared ):
223223 return maybe_disk_checkpoint (
224- firedrake .Function .sub (inputs [0 ], self .idx )
224+ firedrake .Function .sub (inputs [0 ], self .sub_idx )
225225 )
226226
227227 def __str__ (self ):
228- return f"{ self .get_dependencies ()[0 ]} [{ self .idx } ]"
228+ return f"{ self .get_dependencies ()[0 ]} [{ self .sub_idx } ]"
229229
230230
231231class FunctionMergeBlock (Block ):
232232 def __init__ (self , func , idx , ad_block_tag = None ):
233233 super ().__init__ (ad_block_tag = ad_block_tag )
234234 self .add_dependency (func )
235- self .idx = idx
235+ self .sub_idx = idx
236236 for output in func ._ad_outputs :
237237 self .add_dependency (output )
238238
239239 def evaluate_adj_component (self , inputs , adj_inputs , block_variable , idx ,
240240 prepared = None ):
241+ # The merge block appears whenever a subfunction is the output of a block.
242+ # This means that the subfunction has been modified, so we need to make
243+ # sure that this modification is accounted for when evaluating the adjoint.
244+ #
245+ # When recomputing the merge block, the indexed subfunction in the full
246+ # Function is completely overwritten, meaning that the pre-existing value
247+ # of the subfunction in the full function is ignored.
248+ # The equivalent adjoint operation is to:
249+ # 1. send the subfunction component of the adjoint value back up
250+ # the branch of the tape corresponding to the subfunction
251+ # dependency (idx=0).
252+ # 2. zero out the subfunction component of the adjoint value sent
253+ # back up the full Function branch of the tape (idx=1).
254+ # This means that when the adjoint values of each branch are combined
255+ # after the SubfunctionBlock only the adjoint value from the subfunction
256+ # branch is used.
257+ #
258+ # See https://github.com/firedrakeproject/firedrake/pull/4177 for more
259+ # detail and for diagrams of the tape produced when accessing subfunctions.
260+
241261 if idx == 0 :
242- return adj_inputs [0 ].subfunctions [self .idx ]
262+ return adj_inputs [0 ].subfunctions [self .sub_idx ]. copy ( deepcopy = True )
243263 else :
264+ adj_inputs [0 ].subfunctions [self .sub_idx ].zero ()
244265 return adj_inputs [0 ]
245266
246267 def evaluate_tlm (self , markings = False ):
@@ -253,7 +274,7 @@ def evaluate_tlm(self, markings=False):
253274 fs = output .output .function_space ()
254275 f = type (output .output )(fs )
255276 output .add_tlm_output (
256- type (output .output ).assign (f .sub (self .idx ), tlm_input )
277+ type (output .output ).assign (f .sub (self .sub_idx ), tlm_input )
257278 )
258279
259280 def evaluate_hessian_component (self , inputs , hessian_inputs , adj_inputs ,
@@ -265,12 +286,12 @@ def recompute_component(self, inputs, block_variable, idx, prepared):
265286 sub_func = inputs [0 ]
266287 parent_in = inputs [1 ]
267288 parent_out = type (parent_in )(parent_in )
268- parent_out .sub (self .idx ).assign (sub_func )
289+ parent_out .sub (self .sub_idx ).assign (sub_func )
269290 return maybe_disk_checkpoint (parent_out )
270291
271292 def __str__ (self ):
272293 deps = self .get_dependencies ()
273- return f"{ deps [1 ]} [{ self .idx } ].assign({ deps [0 ]} )"
294+ return f"{ deps [1 ]} [{ self .sub_idx } ].assign({ deps [0 ]} )"
274295
275296
276297class CofunctionAssignBlock (Block ):
0 commit comments