Skip to content

Commit 30b18b2

Browse files
committed
Add get_storage_index for the wranglers
1 parent 8ee947e commit 30b18b2

File tree

2 files changed

+100
-0
lines changed

2 files changed

+100
-0
lines changed

sumpy/expansion/__init__.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import sumpy.symbolic as sym
3131
from sumpy.kernel import Kernel
3232
from sumpy.tools import add_mi
33+
import pymbolic.primitives as prim
3334

3435
import logging
3536
logger = logging.getLogger(__name__)
@@ -377,6 +378,18 @@ def _split_coeffs_into_hyperplanes(
377378

378379

379380
class FullExpansionTermsWrangler(ExpansionTermsWrangler):
381+
382+
def get_storage_index(self, mi, order=None):
383+
if not order:
384+
order = sum(mi)
385+
if self.dim == 3:
386+
return (order*(order + 1)*(order + 2))//6 + \
387+
(order + 2)*mi[2] - (mi[2]*(mi[2] + 1))//2 + mi[1]
388+
elif self.dim == 2:
389+
return (order*(order + 1))//2 + mi[1]
390+
else:
391+
raise NotImplementedError
392+
380393
def get_coefficient_identifiers(self):
381394
return super().get_full_coefficient_identifiers()
382395

@@ -584,6 +597,51 @@ def get_full_coefficient_identifiers(self):
584597
key, _ = self._get_mi_ordering_key_and_axis_permutation()
585598
return sorted(identifiers, key=key)
586599

600+
def get_storage_index(self, mi, order=None):
601+
if not order:
602+
order = sum(mi)
603+
604+
ordering_key, axis_permutation = \
605+
self._get_mi_ordering_key_and_axis_permutation()
606+
deriv_id_to_coeff, = self.knl.get_pde_as_diff_op().eqs
607+
max_mi = max(deriv_id_to_coeff, key=ordering_key).mi
608+
609+
if all(m != 0 for m in max_mi):
610+
raise NotImplementedError("non-elliptic PDEs")
611+
612+
c = max_mi[axis_permutation[0]]
613+
614+
mi = list(mi)
615+
mi[axis_permutation[0]], mi[0] = mi[0], mi[axis_permutation[0]]
616+
617+
if self.dim == 3:
618+
if all(isinstance(axis, int) for axis in mi):
619+
if order < c - 1:
620+
return (order*(order + 1)*(order + 2))/6 + \
621+
(order + 2)*mi[0] - (mi[0]*(mi[0] + 1))/2 + mi[1]
622+
else:
623+
return (c*(c-1)*(c-2))/6 + (c * order * (2 + order - c)
624+
+ mi[0]*(3 - mi[0]+2*order))/2 + mi[1]
625+
else:
626+
return prim.If(prim.Comparison(order, "<", c - 1),
627+
(order*(order + 1)*(order + 2))/6
628+
+ (order + 2)*mi[0] - (mi[0]*(mi[0] + 1))/2 + mi[1],
629+
(c*(c-1)*(c-2))/6 + (c * order * (2 + order - c)
630+
+ mi[0]*(3 - mi[0]+2*order))/2 + mi[1]
631+
)
632+
elif self.dim == 2:
633+
if all(isinstance(axis, int) for axis in mi):
634+
if order < c - 1:
635+
return (order*(order + 1))//2 + mi[0]
636+
else:
637+
return (c*(c-1))//2 + c*(order - c + 1) + mi[0]
638+
else:
639+
return prim.If(prim.Comparison(order, "<", c - 1),
640+
(order*(order + 1))//2 + mi[0],
641+
(c*(c-1))//2 + c*(order - c + 1) + mi[0])
642+
else:
643+
raise NotImplementedError
644+
587645
@memoize_method
588646
def get_stored_ids_and_unscaled_projection_matrix(self):
589647
from pytools import ProcessLogger

test/test_misc.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@
4949
make_identity_diff_op, concat, as_scalar_pde, diff,
5050
gradient, divergence, laplacian, curl)
5151

52+
from sumpy.expansion import (FullExpansionTermsWrangler,
53+
LinearPDEBasedExpansionTermsWrangler)
54+
5255
import logging
5356
logger = logging.getLogger(__name__)
5457

@@ -516,6 +519,45 @@ def get_pde_as_diff_op(self):
516519
# }}}
517520

518521

522+
# {{{ test_get_storage_index
523+
524+
class TestKernel(ExpressionKernel):
525+
def __init__(self, dim, max_mi):
526+
super().__init__(dim=dim, expression=1, global_scaling_const=1,
527+
is_complex_valued=False)
528+
self._max_mi = max_mi
529+
530+
def get_pde_as_diff_op(self):
531+
w = make_identity_diff_op(self.dim)
532+
pde = diff(w, tuple(self._max_mi))
533+
return pde
534+
535+
536+
@pytest.mark.parametrize("order", [6])
537+
@pytest.mark.parametrize("knl", [
538+
LaplaceKernel(2),
539+
LaplaceKernel(3),
540+
TestKernel(2, (3, 0)),
541+
TestKernel(2, (0, 3)),
542+
TestKernel(3, (3, 0, 0)),
543+
TestKernel(3, (0, 3, 0)),
544+
TestKernel(3, (0, 0, 3)),
545+
BiharmonicKernel(2),
546+
BiharmonicKernel(3),
547+
])
548+
@pytest.mark.parametrize("compressed", (True, False))
549+
def test_get_storage_index(order, knl, compressed):
550+
dim = knl.dim
551+
if compressed:
552+
wrangler = LinearPDEBasedExpansionTermsWrangler(order, dim, knl=knl)
553+
else:
554+
wrangler = FullExpansionTermsWrangler(order, dim)
555+
for i, mi in enumerate(wrangler.get_coefficient_identifiers()):
556+
assert i == wrangler.get_storage_index(mi)
557+
558+
# }}}
559+
560+
519561
# You can test individual routines by typing
520562
# $ python test_misc.py 'test_pde_check_kernels(_acf,
521563
# KernelInfo(HelmholtzKernel(2), k=5), order=5)'

0 commit comments

Comments
 (0)