|
30 | 30 | import sumpy.symbolic as sym |
31 | 31 | from sumpy.kernel import Kernel |
32 | 32 | from sumpy.tools import add_mi |
| 33 | +import pymbolic.primitives as prim |
33 | 34 |
|
34 | 35 | import logging |
35 | 36 | logger = logging.getLogger(__name__) |
@@ -377,6 +378,18 @@ def _split_coeffs_into_hyperplanes( |
377 | 378 |
|
378 | 379 |
|
379 | 380 | class FullExpansionTermsWrangler(ExpansionTermsWrangler): |
| 381 | + |
| 382 | + def get_storage_index(self, mi, order=None): |
| 383 | + if not order: |
| 384 | + order = sum(mi) |
| 385 | + if self.dim == 3: |
| 386 | + return (order*(order + 1)*(order + 2))//6 + \ |
| 387 | + (order + 2)*mi[2] - (mi[2]*(mi[2] + 1))//2 + mi[1] |
| 388 | + elif self.dim == 2: |
| 389 | + return (order*(order + 1))//2 + mi[1] |
| 390 | + else: |
| 391 | + raise NotImplementedError |
| 392 | + |
380 | 393 | def get_coefficient_identifiers(self): |
381 | 394 | return super().get_full_coefficient_identifiers() |
382 | 395 |
|
@@ -584,6 +597,51 @@ def get_full_coefficient_identifiers(self): |
584 | 597 | key, _ = self._get_mi_ordering_key_and_axis_permutation() |
585 | 598 | return sorted(identifiers, key=key) |
586 | 599 |
|
| 600 | + def get_storage_index(self, mi, order=None): |
| 601 | + if not order: |
| 602 | + order = sum(mi) |
| 603 | + |
| 604 | + ordering_key, axis_permutation = \ |
| 605 | + self._get_mi_ordering_key_and_axis_permutation() |
| 606 | + deriv_id_to_coeff, = self.knl.get_pde_as_diff_op().eqs |
| 607 | + max_mi = max(deriv_id_to_coeff, key=ordering_key).mi |
| 608 | + |
| 609 | + if all(m != 0 for m in max_mi): |
| 610 | + raise NotImplementedError("non-elliptic PDEs") |
| 611 | + |
| 612 | + c = max_mi[axis_permutation[0]] |
| 613 | + |
| 614 | + mi = list(mi) |
| 615 | + mi[axis_permutation[0]], mi[0] = mi[0], mi[axis_permutation[0]] |
| 616 | + |
| 617 | + if self.dim == 3: |
| 618 | + if all(isinstance(axis, int) for axis in mi): |
| 619 | + if order < c - 1: |
| 620 | + return (order*(order + 1)*(order + 2))/6 + \ |
| 621 | + (order + 2)*mi[0] - (mi[0]*(mi[0] + 1))/2 + mi[1] |
| 622 | + else: |
| 623 | + return (c*(c-1)*(c-2))/6 + (c * order * (2 + order - c) |
| 624 | + + mi[0]*(3 - mi[0]+2*order))/2 + mi[1] |
| 625 | + else: |
| 626 | + return prim.If(prim.Comparison(order, "<", c - 1), |
| 627 | + (order*(order + 1)*(order + 2))/6 |
| 628 | + + (order + 2)*mi[0] - (mi[0]*(mi[0] + 1))/2 + mi[1], |
| 629 | + (c*(c-1)*(c-2))/6 + (c * order * (2 + order - c) |
| 630 | + + mi[0]*(3 - mi[0]+2*order))/2 + mi[1] |
| 631 | + ) |
| 632 | + elif self.dim == 2: |
| 633 | + if all(isinstance(axis, int) for axis in mi): |
| 634 | + if order < c - 1: |
| 635 | + return (order*(order + 1))//2 + mi[0] |
| 636 | + else: |
| 637 | + return (c*(c-1))//2 + c*(order - c + 1) + mi[0] |
| 638 | + else: |
| 639 | + return prim.If(prim.Comparison(order, "<", c - 1), |
| 640 | + (order*(order + 1))//2 + mi[0], |
| 641 | + (c*(c-1))//2 + c*(order - c + 1) + mi[0]) |
| 642 | + else: |
| 643 | + raise NotImplementedError |
| 644 | + |
587 | 645 | @memoize_method |
588 | 646 | def get_stored_ids_and_unscaled_projection_matrix(self): |
589 | 647 | from pytools import ProcessLogger |
|
0 commit comments