Skip to content

Commit 9433d34

Browse files
authored
Merge pull request #312 from DedalusProject/eval_copy
Refined task processing
2 parents 7925795 + 1475515 commit 9433d34

File tree

6 files changed

+38
-23
lines changed

6 files changed

+38
-23
lines changed

dedalus/core/distributor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from collections import OrderedDict
1010
from math import prod
1111
import numbers
12+
from weakref import WeakSet
1213

1314
from .coords import CoordinateSystem, DirectProduct
1415
from ..tools.array import reshape_vector
@@ -112,6 +113,8 @@ def __init__(self, coordsystems, comm=None, mesh=None, dtype=None):
112113
self.comm_coords = np.array(self.comm_cart.coords, dtype=int)
113114
# Build layout objects
114115
self._build_layouts()
116+
# Keep set of weak field references
117+
self.fields = WeakSet()
115118

116119
@CachedAttribute
117120
def cs_by_axis(self):

dedalus/core/evaluator.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from .future import FutureField, FutureLockedField
1515
from .field import Field, LockedField
16+
from .operators import Copy
1617
from ..tools.cache import CachedAttribute
1718
from ..tools.general import OrderedSet
1819
from ..tools.general import oscillate
@@ -127,22 +128,19 @@ def evaluate_handlers(self, handlers, id=None, **kw):
127128
# Attempt evaluation
128129
tasks = self.attempt_tasks(tasks, id=id)
129130

130-
# # Transform all outputs to coefficient layout to dealias
131-
## D3 note: need to worry about this for redundent tasks?
132-
# outputs = OrderedSet([t['out'] for h in handlers for t in h.tasks])
133-
# self.require_coeff_space(outputs)
134-
135-
# # Copy redundant outputs so processing is independent
136-
# outputs = set()
137-
# for handler in handlers:
138-
# for task in handler.tasks:
139-
# if task['out'] in outputs:
140-
# task['out'] = task['out'].copy()
141-
# else:
142-
# outputs.add(task['out'])
131+
# Transform all outputs to coefficient layout to dealias
143132
outputs = OrderedSet([t['out'] for h in handlers for t in h.tasks if not isinstance(t['out'], LockedField)])
144133
self.require_coeff_space(outputs)
145134

135+
# Copy redundant outputs so processing is independent
136+
outputs = set()
137+
for handler in handlers:
138+
for task in handler.tasks:
139+
if task['out'] in outputs:
140+
task['out'] = task['out'].copy()
141+
else:
142+
outputs.add(task['out'])
143+
146144
# Process
147145
for handler in handlers:
148146
handler.process(**kw)
@@ -285,10 +283,9 @@ def add_task(self, task, layout='g', name=None, scales=None):
285283
# Create operator
286284
if isinstance(task, str):
287285
op = FutureField.parse(task, self.vars, self.dist)
286+
elif isinstance(task, Field):
287+
op = Copy(task)
288288
else:
289-
# op = FutureField.cast(task, self.domain)
290-
# op = Cast(task)
291-
# TODO: figure out if we need to copying here
292289
op = task
293290
# Check scales
294291
if isinstance(op, (LockedField, FutureLockedField)):

dedalus/core/field.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,8 @@ def __init__(self, dist, bases=None, name=None, tensorsig=None, dtype=None):
572572
self.layout = self.dist.get_layout_object('c')
573573
# Change scales to build buffer and data
574574
self.preset_scales((1,) * self.dist.dim)
575+
# Add weak reference to distributor
576+
dist.fields.add(self)
575577

576578
def __getitem__(self, layout):
577579
"""Return data viewed in specified layout."""
@@ -1022,3 +1024,10 @@ def lock_to_layouts(self, *layouts):
10221024
def lock_axis_to_grid(self, axis):
10231025
self.allowed_layouts = tuple(l for l in self.dist.layouts if l.grid_space[axis])
10241026

1027+
def unlock(self):
1028+
"""Return regular Field object with same data and no layout locking."""
1029+
field = Field(self.dist, bases=self.domain.bases, name=self.name, tensorsig=self.tensorsig, dtype=self.dtype)
1030+
field.preset_scales(self.scales)
1031+
field[self.layout] = self.data
1032+
return field
1033+

dedalus/core/operators.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from .domain import Domain
2121
from . import coords
22-
from .field import Operand, Field
22+
from .field import Operand, Field, LockedField
2323
from .future import Future, FutureField, FutureLockedField
2424
from ..tools.array import reshape_vector, apply_matrix, add_sparse, axindex, axslice, perm_matrix, copyto, sparse_block_diag, interleave_matrices
2525
from ..tools.cache import CachedAttribute, CachedMethod
@@ -1492,6 +1492,8 @@ class Copy(LinearOperator):
14921492

14931493
def __init__(self, operand, out=None):
14941494
super().__init__(operand, out=out)
1495+
if isinstance(operand, (LockedField, FutureLockedField)):
1496+
raise ValueError("Not yet implemented for locked fields.")
14951497
# LinearOperator requirements
14961498
self.operand = operand
14971499
# FutureField requirements

dedalus/core/problems.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
from collections import ChainMap
55

6-
from .field import Operand, Field
6+
from .field import Operand, Field, LockedField
77
from . import arithmetic
88
from . import operators
99
from . import solvers
@@ -244,6 +244,10 @@ def _build_matrix_expressions(self, eqn):
244244
# Extract matrix expressions
245245
F = eqn['LHS'] - eqn['RHS']
246246
dF = F.frechet_differential(vars, perts)
247+
# Remove any field locks
248+
dF = dF.replace(operators.Lock, lambda x: x)
249+
for field in dF.atoms(LockedField):
250+
dF = dF.replace(field, field.unlock())
247251
# Reinitialize and prep NCCs
248252
dF = dF.reinitialize(ncc=True, ncc_vars=perts)
249253
dF.prep_nccs(vars=perts)

dedalus/core/solvers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -685,11 +685,6 @@ def step(self, dt):
685685
# Assert finite timestep
686686
if not np.isfinite(dt):
687687
raise ValueError("Invalid timestep")
688-
# Enforce Hermitian symmetry for real variables
689-
if self.enforce_real_cadence:
690-
# Enforce for as many iterations as timestepper uses internally
691-
if self.iteration % self.enforce_real_cadence < self.timestepper.steps:
692-
self.enforce_hermitian_symmetry(self.state)
693688
# Record times
694689
wall_time = self.wall_time
695690
if self.iteration == self.initial_iteration:
@@ -706,6 +701,11 @@ def step(self, dt):
706701
self.run_time_start = self.wall_time
707702
# Advance using timestepper
708703
self.timestepper.step(dt, wall_time)
704+
# Enforce Hermitian symmetry for real variables
705+
if self.enforce_real_cadence:
706+
# Enforce for as many iterations as timestepper uses internally
707+
if self.iteration % self.enforce_real_cadence < self.timestepper.steps:
708+
self.enforce_hermitian_symmetry(self.state)
709709
# Update iteration
710710
self.iteration += 1
711711
self.dt = dt

0 commit comments

Comments
 (0)