Skip to content

Commit ac1da7e

Browse files
committed
Improvements of the HORK_EXP
1 parent 4637ac2 commit ac1da7e

File tree

3 files changed

+230
-101
lines changed

3 files changed

+230
-101
lines changed

devito/types/multistage.py

Lines changed: 89 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from devito.symbolics import uxreplace
44
from numpy import number
55
from devito.types.array import Array
6+
from devito.types.dense import Function
7+
from devito.types.constant import Constant
68
from types import MappingProxyType
79

810
method_registry = {}
@@ -51,7 +53,7 @@ class MultiStage(Eq):
5153
of update expressions for each stage in the integration process.
5254
"""
5355

54-
def __new__(cls, lhs, rhs, **kwargs):
56+
def __new__(cls, lhs, rhs, source=None, degree=6, **kwargs):
5557
if not isinstance(lhs, list):
5658
lhs=[lhs]
5759
rhs=[rhs]
@@ -61,6 +63,8 @@ def __new__(cls, lhs, rhs, **kwargs):
6163
obj._eq = [Eq(lhs[i], rhs[i]) for i in range(len(lhs))]
6264
obj._lhs = lhs
6365
obj._rhs = rhs
66+
obj._deg = degree
67+
obj._src = source
6468

6569
return obj
6670

@@ -79,6 +83,16 @@ def rhs(self):
7983
"""Return list of right-hand sides."""
8084
return self._rhs
8185

86+
@property
87+
def deg(self):
88+
"""Return list of right-hand sides."""
89+
return self._deg
90+
91+
@property
92+
def src(self):
93+
"""Return list of right-hand sides."""
94+
return self._src
95+
8296
def _evaluate(self, **kwargs):
8397
raise NotImplementedError(
8498
f"_evaluate() must be implemented in the subclass {self.__class__.__name__}")
@@ -115,7 +129,9 @@ class RK(MultiStage):
115129
Number of stages in the RK method, inferred from `b`.
116130
"""
117131

118-
def __init__(self, a: list[list[float | number]], b: list[float | number], c: list[float | number], lhs, rhs, **kwargs) -> None:
132+
CoeffsBC = list[float | number]
133+
CoeffsA = list[CoeffsBC]
134+
def __init__(self, a: CoeffsA, b: CoeffsBC, c: CoeffsBC, lhs, rhs, **kwargs) -> None:
119135
self.a, self.b, self.c = a, b, c
120136

121137
@property
@@ -132,19 +148,18 @@ def _evaluate(self, **kwargs):
132148
133149
Returns
134150
-------
135-
list of Eq
151+
list of Devito Eq objects
136152
A list of SymPy Eq objects representing:
137153
- `s` stage equations of the form `k_i = rhs evaluated at intermediate state`
138154
- 1 final update equation of the form `u.forward = u + dt * sum(b_i * k_i)`
139155
"""
140156
n_eq=len(self.eq)
141157
u = [i.function for i in self.lhs]
142-
grid = [u[i].grid for i in range(n_eq)]
143-
t = grid[0].time_dim
158+
t = u[0].grid.time_dim
144159
dt = t.spacing
145160

146161
# Create temporary Functions to hold each stage
147-
k = [[Array(name=f'{kwargs.get('sregistry').make_name(prefix='k')}', dimensions=grid[j].dimensions, grid=grid[j], dtype=u[j].dtype) for i in range(self.s)]
162+
k = [[Array(name=f'{kwargs.get('sregistry').make_name(prefix='k')}', dimensions=u[j].grid.dimensions, grid=u[j].grid, dtype=u[j].dtype) for i in range(self.s)]
148163
for j in range(n_eq)]
149164

150165
stage_eqs = []
@@ -214,8 +229,8 @@ class RK32(RK):
214229
b = [0, 0, 1]
215230
c = [0, 1/2, 1/2]
216231

217-
def __init__(self, *args, **kwargs):
218-
super().__init__(a=self.a, b=self.b, c=self.c, **kwargs)
232+
def __init__(self, lhs, rhs, **kwargs):
233+
super().__init__(a=self.a, b=self.b, c=self.c, lhs=lhs, rhs=rhs, **kwargs)
219234

220235

221236
@register_method
@@ -249,12 +264,12 @@ class RK97(RK):
249264
5963949/25894400, 50000000000/599799373173, 28487/712800]
250265
c = [0, 4/63, 2/21, 1/7, 7/17, 13/24, 7/9, 91/100, 1]
251266

252-
def __init__(self, *args, **kwargs):
253-
super().__init__(a=self.a, b=self.b, c=self.c, **kwargs)
267+
def __init__(self, lhs, rhs, **kwargs):
268+
super().__init__(a=self.a, b=self.b, c=self.c, lhs=lhs, rhs=rhs, **kwargs)
254269

255270

256271
@register_method
257-
class HORK(MultiStage):
272+
class HORK_EXP(MultiStage):
258273
# In construction
259274
"""
260275
n stages Runge-Kutta (HORK) time integration method.
@@ -271,8 +286,19 @@ class HORK(MultiStage):
271286
Time positions of intermediate stages.
272287
"""
273288

289+
def source_derivatives(self, src_index, t, **kwargs):
290+
291+
# Compute the base wavelet function
292+
f_deriv = [[self.src[i][1] for i in range(len(self.src))]]
293+
294+
# Compute derivatives up to order p
295+
for _ in range(self.deg - 1):
296+
f_deriv.append([f_deriv[-1][i].diff(t) for i in range(len(src_index))])
297+
298+
f_deriv.reverse()
299+
return f_deriv
274300

275-
def ssprk_alpha(mu=1, **kwargs):
301+
def ssprk_alpha(self, mu=1):
276302
"""
277303
Computes the coefficients for the Strong Stability Preserving Runge-Kutta (SSPRK) method.
278304
@@ -287,18 +313,33 @@ def ssprk_alpha(mu=1, **kwargs):
287313
numpy.ndarray
288314
Array of SSPRK coefficients.
289315
"""
290-
degree=kwargs.get('degree')
291316

292-
alpha = [0]*degree
317+
alpha = [0]*self.deg
293318
alpha[0] = 1.0 # Initial coefficient
294319

295-
for i in range(1, degree):
296-
alpha[i] = 1 / (mu * (i + 1)) * alpha[i - 1]
297-
alpha[1:i] = 1 / (mu * list(range(1, i))) * alpha[:i - 1]
320+
for i in range(1, self.deg):
321+
alpha[i] = 1/(mu*(i+1))*alpha[i-1]
322+
alpha[1:i] = [1/(mu*j)*alpha[j-1] for j in range(1,i)]
298323
alpha[0] = 1 - sum(alpha[1:i + 1])
299324

300325
return alpha
301326

327+
328+
def source_inclusion(self, u, k, src_index, src_deriv, e_p, t, dt, mu, n_eq):
329+
330+
src_lhs = [uxreplace(self.rhs[i], {u[m]: k[m] for m in range(n_eq)}) for i in range(n_eq)]
331+
332+
p = len(src_deriv)
333+
334+
for i in range(p):
335+
if e_p[i] != 0:
336+
for j in range(len(src_index)):
337+
src_lhs[src_index[j]] += self.src[j][0]*src_deriv[i][j].subs({t: t * dt})*e_p[i]
338+
e_p = [e_p[i]+mu*dt*e_p[i + 1] for i in range(p - 1)]+[e_p[-1]]
339+
340+
return src_lhs, e_p
341+
342+
302343
def _evaluate(self, **kwargs):
303344
"""
304345
Generate the stage-wise equations for a Runge-Kutta time integration method.
@@ -315,66 +356,52 @@ def _evaluate(self, **kwargs):
315356
- 1 final update equation of the form `u.forward = u + dt * sum(b_i * k_i)`
316357
"""
317358

318-
u = self.lhs.function
319-
rhs = self.rhs
320-
grid = u.grid
321-
t = grid.time_dim
359+
n_eq=len(self.eq)
360+
u = [i.function for i in self.lhs]
361+
t = u[0].grid.time_dim
322362
dt = t.spacing
323363

324-
an_eq = range(len(U0))
364+
# Create a temporary Array for each variable to save the time stages
365+
# k = [Array(name=f'{kwargs.get('sregistry').make_name(prefix='k')}', dimensions=u[i].grid.dimensions, grid=u[i].grid, dtype=u[i].dtype) for i in range(n_eq)]
366+
k = [Function(name=f'{kwargs.get('sregistry').make_name(prefix='k')}', grid=u[i].grid, space_order=2, time_order=1, dtype=u[i].dtype) for i in range(n_eq)]
367+
k_old = [Function(name=f'{kwargs.get('sregistry').make_name(prefix='k')}', grid=u[i].grid, space_order=2, time_order=1, dtype=u[i].dtype) for i in range(n_eq)]
325368

326369
# Compute SSPRK coefficients
327-
alpha = np.array(ssprk_alpha(mu, degree), dtype=np.float64)
370+
mu = 1
371+
alpha = self.ssprk_alpha(mu=mu)
328372

329373
# Initialize symbolic differentiation for source terms
330-
t_var = sym.Symbol('t_var')
331-
src_deriv = aux_fun.derivates_f(degree, f0)
374+
src_index_map={val: i for i, val in enumerate(u)}
375+
src_index = [src_index_map[val] for val in [self.src[i][2] for i in range(len(self.src))]]
376+
src_deriv = self.source_derivatives(src_index, t, **kwargs)
332377

333378
# Expansion coefficients for stability control
334-
e_p = [0] * degree
379+
e_p = [0] * self.deg
380+
eta = 1
335381
e_p[-1] = 1 / eta
336382

337-
# Initialize approximation and auxiliary variable
338-
approx = [U0[i] * alpha[0] for i in n_eq]
339-
aux = U0
340-
341-
# Perform Runge-Kutta steps
342-
for i in range(1, degree - 1):
343-
system_op, e_p = sys_op_extended(aux, x, y, z, param_fun, system, fd_order, src_spat, src_deriv, t, dt, t_var, e_p)
344-
aux = [aux[j] + mu * dt * system_op[j] for j in n_eq]
345-
approx = [approx[j] + aux[j] * alpha[i] for j in n_eq]
346-
347-
# Final Runge-Kutta updates
348-
system_op, e_p = sys_op_extended(aux, x, y, z, param_fun, system, fd_order, src_spat, src_deriv, t, dt, t_var, e_p)
349-
aux = [aux[i] + mu * dt * system_op[i] for i in n_eq]
350-
system_op, e_p = sys_op_extended(aux, x, y, z, param_fun, system, fd_order, src_spat, src_deriv, t, dt, t_var, e_p)
351-
aux = [aux[i] + mu * dt * system_op[i] for i in n_eq]
352-
353-
# Compute final approximation
354-
approx = [approx[i] + aux[i] * alpha[degree - 1] for i in n_eq]
355-
356-
# Generate final PDE system
357-
return [dv.Eq(U0[i].forward, approx[i]) for i in n_eq]
358383

359-
# Create temporary Functions to hold each stage
360-
# k = [Array(name=f'{kwargs.get('sregistry').make_name(prefix='k')}', dimensions=grid.shape, grid=grid, dtype=u.dtype) for i in range(self.s)] # Trying Array
361-
k = [Function(name=f'{kwargs.get('sregistry').make_name(prefix='k')}', grid=grid, space_order=u.space_order, dtype=u.dtype)
362-
for i in range(self.s)]
363-
364-
stage_eqs = []
384+
stage_eqs = [Eq(k[i], u[i]) for i in range(n_eq)]
385+
[stage_eqs.append(Eq(u[i].forward, u[i]*alpha[0])) for i in range(n_eq)]
365386

366387
# Build each stage
367-
for i in range(self.s):
368-
u_temp = u + dt * sum(aij * kj for aij, kj in zip(self.a[i][:i], k[:i]))
369-
t_shift = t + self.c[i] * dt
388+
for i in range(1, self.deg-1):
389+
[stage_eqs.append(Eq(k_old[j], k[j])) for j in range(n_eq)]
390+
src_lhs, e_p = self.source_inclusion(u, k_old, src_index, src_deriv, e_p, t, dt, mu, n_eq)
391+
[stage_eqs.append(Eq(k[j], k_old[j]+mu*dt*src_lhs[j])) for j in range(n_eq)]
392+
[stage_eqs.append(Eq(u[j].forward, u[j].forward+k[j]*alpha[i])) for j in range(n_eq)]
370393

371-
# Evaluate RHS at intermediate value
372-
stage_rhs = uxreplace(rhs, {u: u_temp, t: t_shift})
373-
stage_eqs.append(Eq(k[i], stage_rhs))
394+
# Final Runge-Kutta updates
395+
[stage_eqs.append(Eq(k_old[j], k[j])) for j in range(n_eq)]
396+
src_lhs, e_p = self.source_inclusion(u, k_old, src_index, src_deriv, e_p, t, dt, mu, n_eq)
397+
[stage_eqs.append(Eq(k[j], k_old[j]+mu*dt*src_lhs[j])) for j in range(n_eq)]
374398

375-
# Final update: u.forward = u + dt * sum(b_i * k_i)
376-
u_next = u + dt * sum(bi * ki for bi, ki in zip(self.b, k))
377-
stage_eqs.append(Eq(u.forward, u_next))
399+
[stage_eqs.append(Eq(k_old[j], k[j])) for j in range(n_eq)]
400+
src_lhs, _ = self.source_inclusion(u, k_old, src_index, src_deriv, e_p, t, dt, mu, n_eq)
401+
[stage_eqs.append(Eq(k[j], k_old[j]+mu*dt*src_lhs[j])) for j in range(n_eq)]
402+
403+
# Compute final approximation
404+
[stage_eqs.append(Eq(u[j].forward, u[j].forward+k[j]*alpha[self.deg-1])) for j in range(n_eq)]
378405

379406
return stage_eqs
380407

0 commit comments

Comments
 (0)