Skip to content

Commit a57eb87

Browse files
committed
Fix toys
1 parent 57dc042 commit a57eb87

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

sumpy/toys.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def _p2e(actx, psource, center, rscale, order: int, p2e, expn_class, expn_kwargs
289289
centers = actx.from_numpy(
290290
np.array(center, dtype=np.float64).reshape(toy_ctx.kernel.dim, 1))
291291

292-
coeffs, = p2e(
292+
coeffs = p2e(
293293
actx,
294294
source_boxes=source_boxes,
295295
box_source_starts=box_source_starts,
@@ -379,7 +379,7 @@ def _e2e(actx: PyOpenCLArrayContext,
379379
**toy_ctx.extra_kernel_kwargs,
380380
}
381381

382-
to_coeffs, = e2e(**args)
382+
to_coeffs = e2e(**args)
383383
return expn_class(
384384
toy_ctx, to_center, to_rscale, to_order,
385385
actx.to_numpy(to_coeffs[1]),
@@ -414,12 +414,12 @@ def _m2l(actx: PyOpenCLArrayContext,
414414

415415
if toy_ctx.use_fft:
416416

417-
fft_app = get_opencl_fft_app(actx, (expn_size,),
417+
fft_app = get_opencl_fft_app(actx, (1, expn_size,),
418418
dtype=preprocessed_src_expansions.dtype, inverse=False)
419-
ifft_app = get_opencl_fft_app(actx, (expn_size,),
419+
ifft_app = get_opencl_fft_app(actx, (1, expn_size,),
420420
dtype=preprocessed_src_expansions.dtype, inverse=True)
421421

422-
preprocessed_src_expansions = run_opencl_fft(actx, fft_app,
422+
_evt, preprocessed_src_expansions = run_opencl_fft(actx, fft_app,
423423
preprocessed_src_expansions, inverse=False)
424424

425425
# Compute translation classes data
@@ -443,7 +443,7 @@ def _m2l(actx: PyOpenCLArrayContext,
443443
**toy_ctx.extra_kernel_kwargs)
444444

445445
if toy_ctx.use_fft:
446-
m2l_translation_classes_dependent_data = run_opencl_fft(
446+
_evt, m2l_translation_classes_dependent_data = run_opencl_fft(
447447
actx, fft_app,
448448
m2l_translation_classes_dependent_data,
449449
inverse=False)
@@ -739,7 +739,7 @@ class Sum(PotentialExpressionNode):
739739
def eval(self, actx: PyOpenCLArrayContext, targets: np.ndarray) -> np.ndarray:
740740
result = np.zeros(targets.shape[1])
741741
for psource in self.psources:
742-
result += psource.eval(actx, targets)
742+
result = result + psource.eval(actx, targets)
743743

744744
return result
745745

@@ -753,7 +753,7 @@ class Product(PotentialExpressionNode):
753753
def eval(self, actx: PyOpenCLArrayContext, targets: np.ndarray) -> np.ndarray:
754754
result = np.ones(targets.shape[1])
755755
for psource in self.psources:
756-
result *= psource.eval(actx, targets)
756+
result = result * psource.eval(actx, targets)
757757

758758
return result
759759
# }}}

0 commit comments

Comments
 (0)