1- from pyadjoint import ReducedFunctional
1+ from pyadjoint . reduced_functional import AbstractReducedFunctional , ReducedFunctional
22from pyadjoint .enlisting import Enlist
33from pyop2 .mpi import MPI
44
5- import firedrake
5+ from firedrake .function import Function
6+ from firedrake .cofunction import Cofunction
67
78
8- class EnsembleReducedFunctional (ReducedFunctional ):
9+ class EnsembleReducedFunctional (AbstractReducedFunctional ):
910 """Enable solving simultaneously reduced functionals in parallel.
1011
1112 Consider a functional :math:`J` and its gradient :math:`\\ dfrac{dJ}{dm}`,
@@ -34,7 +35,7 @@ class EnsembleReducedFunctional(ReducedFunctional):
3435
3536 Parameters
3637 ----------
37- J : pyadjoint.OverloadedType
38+ functional : pyadjoint.OverloadedType
3839 An instance of an OverloadedType, usually :class:`pyadjoint.AdjFloat`.
3940 This should be the functional that we want to reduce.
4041 control : pyadjoint.Control or list of pyadjoint.Control
@@ -86,28 +87,40 @@ class EnsembleReducedFunctional(ReducedFunctional):
8687 works, please refer to the `Firedrake manual
8788 <https://www.firedrakeproject.org/parallelism.html#ensemble-parallelism>`_.
8889 """
89- def __init__ (self , J , control , ensemble , scatter_control = True ,
90- gather_functional = None , derivative_components = None ,
91- scale = 1.0 , tape = None , eval_cb_pre = lambda * args : None ,
90+ def __init__ (self , functional , control , ensemble , scatter_control = True ,
91+ gather_functional = None ,
92+ derivative_components = None ,
93+ scale = 1.0 , tape = None ,
94+ eval_cb_pre = lambda * args : None ,
9295 eval_cb_post = lambda * args : None ,
9396 derivative_cb_pre = lambda controls : controls ,
9497 derivative_cb_post = lambda checkpoint , derivative_components , controls : derivative_components ,
95- hessian_cb_pre = lambda * args : None , hessian_cb_post = lambda * args : None ):
96- super (EnsembleReducedFunctional , self ).__init__ (
97- J , control , derivative_components = derivative_components ,
98- scale = scale , tape = tape , eval_cb_pre = eval_cb_pre ,
99- eval_cb_post = eval_cb_post , derivative_cb_pre = derivative_cb_pre ,
98+ hessian_cb_pre = lambda * args : None ,
99+ hessian_cb_post = lambda * args : None ):
100+ self .local_reduced_functional = ReducedFunctional (
101+ functional , control ,
102+ derivative_components = derivative_components ,
103+ scale = scale , tape = tape ,
104+ eval_cb_pre = eval_cb_pre ,
105+ eval_cb_post = eval_cb_post ,
106+ derivative_cb_pre = derivative_cb_pre ,
100107 derivative_cb_post = derivative_cb_post ,
101- hessian_cb_pre = hessian_cb_pre , hessian_cb_post = hessian_cb_post )
108+ hessian_cb_pre = hessian_cb_pre ,
109+ hessian_cb_post = hessian_cb_post
110+ )
102111
103112 self .ensemble = ensemble
104113 self .scatter_control = scatter_control
105114 self .gather_functional = gather_functional
106115
116+ @property
117+ def controls (self ):
118+ return self .local_reduced_functional .controls
119+
107120 def _allgather_J (self , J ):
108121 if isinstance (J , float ):
109122 vals = self .ensemble .ensemble_comm .allgather (J )
110- elif isinstance (J , firedrake . Function ):
123+ elif isinstance (J , Function ):
111124 # allgather not implemented in ensemble.py
112125 vals = []
113126 for i in range (self .ensemble .ensemble_comm .size ):
@@ -134,30 +147,31 @@ def __call__(self, values):
134147 The computed value. Typically of instance of :class:`pyadjoint.AdjFloat`.
135148
136149 """
137- local_functional = super ( EnsembleReducedFunctional , self ). __call__ (values )
150+ local_functional = self . local_reduced_functional (values )
138151 ensemble_comm = self .ensemble .ensemble_comm
139152 if self .gather_functional :
140153 controls_g = self ._allgather_J (local_functional )
141154 total_functional = self .gather_functional (controls_g )
142155 # if gather_functional is None then we do a sum
143156 elif isinstance (local_functional , float ):
144157 total_functional = ensemble_comm .allreduce (sendobj = local_functional , op = MPI .SUM )
145- elif isinstance (local_functional , firedrake . Function ):
158+ elif isinstance (local_functional , Function ):
146159 total_functional = type (local_functional )(local_functional .function_space ())
147160 total_functional = self .ensemble .allreduce (local_functional , total_functional )
148161 else :
149162 raise NotImplementedError ("This type of functional is not supported." )
150163 return total_functional
151164
152- def derivative (self , adj_input = 1.0 , options = None ):
165+ def derivative (self , adj_input = 1.0 , apply_riesz = False ):
153166 """Compute derivatives of a functional with respect to the control parameters.
154167
155168 Parameters
156169 ----------
157170 adj_input : float
158171 The adjoint input.
159- options : dict
160- Additional options for the derivative computation.
172+ apply_riesz: bool
173+ If True, apply the Riesz map of each control in order to return
174+ a primal gradient rather than a derivative in the dual space.
161175
162176 Returns
163177 -------
@@ -171,29 +185,62 @@ def derivative(self, adj_input=1.0, options=None):
171185
172186 if self .gather_functional :
173187 dJg_dmg = self .gather_functional .derivative (adj_input = adj_input ,
174- options = options )
188+ apply_riesz = False )
175189 i = self .ensemble .ensemble_comm .rank
176190 adj_input = dJg_dmg [i ]
177191
178- dJdm_local = super (EnsembleReducedFunctional , self ).derivative (adj_input = adj_input , options = options )
192+ dJdm_local = self .local_reduced_functional .derivative (adj_input = adj_input ,
193+ apply_riesz = apply_riesz )
179194
180195 if self .scatter_control :
181196 dJdm_local = Enlist (dJdm_local )
182197 dJdm_total = []
183198
184199 for dJdm in dJdm_local :
185- if not isinstance (dJdm , (firedrake .Function , float )):
186- raise NotImplementedError ("This type of gradient is not supported." )
200+ if not isinstance (dJdm , (Cofunction , Function , float )):
201+ raise NotImplementedError (
202+ f"Gradients of type { type (dJdm ).__name__ } are not supported." )
187203
188204 dJdm_total .append (
189205 self .ensemble .allreduce (dJdm , type (dJdm )(dJdm .function_space ()))
190- if isinstance (dJdm , firedrake . Function )
206+ if isinstance (dJdm , ( Cofunction , Function ) )
191207 else self .ensemble .ensemble_comm .allreduce (sendobj = dJdm , op = MPI .SUM )
192208 )
193209 return dJdm_local .delist (dJdm_total )
194210 return dJdm_local
195211
196- def hessian (self , m_dot , options = None ):
212+ def tlm (self , m_dot ):
213+ """Return the action of the tangent linear model of the functional.
214+
215+ The tangent linear model is evaluated w.r.t. the control on a vector
216+ m_dot, around the last supplied value of the control.
217+
218+ Parameters
219+ ----------
220+ m_dot : pyadjoint.OverloadedType
221+ The direction in which to compute the action of the tangent linear model.
222+
223+ Returns
224+ -------
225+ pyadjoint.OverloadedType: The action of the tangent linear model in the
226+ direction m_dot. Should be an instance of the same type as the functional.
227+ """
228+ local_tlm = self .local_reduced_functional .tlm (m_dot )
229+ ensemble_comm = self .ensemble .ensemble_comm
230+ if self .gather_functional :
231+ mdot_g = self ._allgather_J (local_tlm )
232+ total_tlm = self .gather_functional .tlm (mdot_g )
233+ # if gather_functional is None then we do a sum
234+ elif isinstance (local_tlm , float ):
235+ total_tlm = ensemble_comm .allreduce (sendobj = local_tlm , op = MPI .SUM )
236+ elif isinstance (local_tlm , Function ):
237+ total_tlm = type (local_tlm )(local_tlm .function_space ())
238+ total_tlm = self .ensemble .allreduce (local_tlm , total_tlm )
239+ else :
240+ raise NotImplementedError ("This type of functional is not supported." )
241+ return total_tlm
242+
243+ def hessian (self , m_dot , hessian_input = None , evaluate_tlm = True , apply_riesz = False ):
197244 """The Hessian is not yet implemented for ensemble reduced functional.
198245
199246 Raises:
0 commit comments