Skip to content

Commit 3fa7bb0

Browse files
committed
bugfix for controller_nonMPI, started parallel FFT
1 parent bb56e95 commit 3fa7bb0

File tree

5 files changed

+399
-33
lines changed

5 files changed

+399
-33
lines changed

pySDC/implementations/controller_classes/controller_nonMPI.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -359,9 +359,11 @@ def pfasst(self, MS):
359359

360360
self.logger.debug(stage)
361361

362+
MS_active = [S for S in MS if S.status.stage is not 'DONE']
363+
362364
if stage == 'SPREAD':
363365
# (potentially) serial spreading phase
364-
for S in MS:
366+
for S in MS_active:
365367

366368
# first stage: spread values
367369
self.hooks.pre_step(step=S, level_number=0)
@@ -380,15 +382,15 @@ def pfasst(self, MS):
380382
elif stage == 'PREDICT':
381383
# call predictor (serial)
382384

383-
for S in MS:
385+
for S in MS_active:
384386
self.hooks.pre_predict(step=S, level_number=0)
385387

386388
MS = self.predictor(MS)
387389

388-
for S in MS:
390+
for S in MS_active:
389391
self.hooks.post_predict(step=S, level_number=0)
390392

391-
for S in MS:
393+
for S in MS_active:
392394
# update stage
393395
S.status.stage = 'IT_CHECK'
394396

@@ -398,7 +400,7 @@ def pfasst(self, MS):
398400

399401
# check whether to stop iterating (parallel)
400402

401-
for S in MS:
403+
for S in MS_active:
402404

403405
# send updated values forward
404406
self.hooks.pre_comm(step=S, level_number=0)
@@ -420,7 +422,7 @@ def pfasst(self, MS):
420422
if S.status.iter > 0:
421423
self.hooks.post_iteration(step=S, level_number=0)
422424

423-
for S in MS:
425+
for S in MS_active:
424426
if not S.status.first:
425427
self.hooks.pre_comm(step=S, level_number=0)
426428
S.status.prev_done = S.prev.status.done # "communicate"
@@ -453,15 +455,15 @@ def pfasst(self, MS):
453455
elif stage == 'IT_FINE':
454456
# do fine sweep for all steps (virtually parallel)
455457

456-
for S in MS:
458+
for S in MS_active:
457459
S.levels[0].status.sweep = 0
458460

459461
for k in range(self.nsweeps[0]):
460462

461-
for S in MS:
463+
for S in MS_active:
462464
S.levels[0].status.sweep += 1
463465

464-
for S in MS:
466+
for S in MS_active:
465467
# send updated values forward
466468
self.hooks.pre_comm(step=S, level_number=0)
467469
if self.params.fine_comm and not S.status.last:
@@ -476,14 +478,14 @@ def pfasst(self, MS):
476478
self.recv(S.levels[0], S.prev.levels[0], tag=(0, S.status.iter, S.prev.status.slot))
477479
self.hooks.post_comm(step=S, level_number=0, add_to_stats=(k == self.nsweeps[0] - 1))
478480

479-
for S in MS:
481+
for S in MS_active:
480482
# standard sweep workflow: update nodes, compute residual, log progress
481483
self.hooks.pre_sweep(step=S, level_number=0)
482484
S.levels[0].sweep.update_nodes()
483485
S.levels[0].sweep.compute_residual()
484486
self.hooks.post_sweep(step=S, level_number=0)
485487

486-
for S in MS:
488+
for S in MS_active:
487489
# update stage
488490
S.status.stage = 'IT_CHECK'
489491

@@ -492,7 +494,7 @@ def pfasst(self, MS):
492494
elif stage == 'IT_UP':
493495
# go up the hierarchy from finest to coarsest level (parallel)
494496

495-
for S in MS:
497+
for S in MS_active:
496498

497499
S.transfer(source=S.levels[0], target=S.levels[1])
498500

@@ -502,7 +504,7 @@ def pfasst(self, MS):
502504

503505
for k in range(self.nsweeps[l]):
504506

505-
for S in MS:
507+
for S in MS_active:
506508

507509
# send updated values forward
508510
self.hooks.pre_comm(step=S, level_number=l)
@@ -518,18 +520,18 @@ def pfasst(self, MS):
518520
self.recv(S.levels[l], S.prev.levels[l], tag=(l, S.status.iter, S.prev.status.slot))
519521
self.hooks.post_comm(step=S, level_number=l)
520522

521-
for S in MS:
523+
for S in MS_active:
522524

523525
self.hooks.pre_sweep(step=S, level_number=l)
524526
S.levels[l].sweep.update_nodes()
525527
S.levels[l].sweep.compute_residual()
526528
self.hooks.post_sweep(step=S, level_number=l)
527529

528-
for S in MS:
530+
for S in MS_active:
529531
# transfer further up the hierarchy
530532
S.transfer(source=S.levels[l], target=S.levels[l + 1])
531533

532-
for S in MS:
534+
for S in MS_active:
533535
# update stage
534536
S.status.stage = 'IT_COARSE'
535537

@@ -538,7 +540,7 @@ def pfasst(self, MS):
538540
elif stage == 'IT_COARSE':
539541
# sweeps on coarsest level (serial/blocking)
540542

541-
for S in MS:
543+
for S in MS_active:
542544

543545
# receive from previous step (if not first)
544546
self.hooks.pre_comm(step=S, level_number=len(S.levels) - 1)
@@ -575,7 +577,7 @@ def pfasst(self, MS):
575577

576578
for l in range(self.nlevels - 1, 0, -1):
577579

578-
for S in MS:
580+
for S in MS_active:
579581
# prolong values
580582
S.transfer(source=S.levels[l], target=S.levels[l - 1])
581583

@@ -584,7 +586,7 @@ def pfasst(self, MS):
584586

585587
for k in range(self.nsweeps[l - 1]):
586588

587-
for S in MS:
589+
for S in MS_active:
588590

589591
# send updated values forward
590592
self.hooks.pre_comm(step=S, level_number=l - 1)
@@ -602,14 +604,14 @@ def pfasst(self, MS):
602604
self.hooks.post_comm(step=S, level_number=l - 1,
603605
add_to_stats=(k == self.nsweeps[l - 1] - 1))
604606

605-
for S in MS:
607+
for S in MS_active:
606608

607609
self.hooks.pre_sweep(step=S, level_number=l - 1)
608610
S.levels[l - 1].sweep.update_nodes()
609611
S.levels[l - 1].sweep.compute_residual()
610612
self.hooks.post_sweep(step=S, level_number=l - 1)
611613

612-
for S in MS:
614+
for S in MS_active:
613615
# update stage
614616
S.status.stage = 'IT_FINE'
615617

pySDC/implementations/problem_classes/AllenCahn_2D_FFT.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def eval_f(self, u, t):
9292
f = self.dtype_f(self.init)
9393
v = u.values.flatten()
9494
tmp = self.lap * self.rfft_object(u.values)
95-
f.impl.values[:] = np.real(self.irfft_object(tmp))
95+
f.impl.values[:] = self.irfft_object(tmp)
9696
if self.params.eps > 0:
9797
f.expl.values = 1.0 / self.params.eps ** 2 * v * (1.0 - v ** self.params.nu)
9898
f.expl.values = f.expl.values.reshape(self.params.nvars)
@@ -115,7 +115,7 @@ def solve_system(self, rhs, factor, u0, t):
115115
me = self.dtype_u(self.init)
116116

117117
tmp = self.rfft_object(rhs.values) / (1.0 - factor * self.lap)
118-
me.values[:] = np.real(self.irfft_object(tmp))
118+
me.values[:] = self.irfft_object(tmp)
119119

120120
return me
121121

@@ -201,10 +201,10 @@ def eval_f(self, u, t):
201201
f = self.dtype_f(self.init)
202202
v = u.values.flatten()
203203
tmp = self.lap * self.rfft_object(u.values)
204-
f.impl.values[:] = np.real(self.irfft_object(tmp))
204+
f.impl.values[:] = self.irfft_object(tmp)
205205
if self.params.eps > 0:
206206
f.expl.values = 1.0 / self.params.eps ** 2 * v * (1.0 - v ** self.params.nu) + \
207-
2.0 / self.params.eps ** 2 * v
207+
2.0 / self.params.eps ** 2 * v
208208
f.expl.values = f.expl.values.reshape(self.params.nvars)
209209
return f
210210

@@ -225,6 +225,6 @@ def solve_system(self, rhs, factor, u0, t):
225225
me = self.dtype_u(self.init)
226226

227227
tmp = self.rfft_object(rhs.values) / (1.0 - factor * self.lap)
228-
me.values[:] = np.real(self.irfft_object(tmp))
228+
me.values[:] = self.irfft_object(tmp)
229229

230230
return me

0 commit comments

Comments
 (0)