Skip to content

Commit 495dfc5

Browse files
committed
rewrite at codegen time
1 parent 154bbe6 commit 495dfc5

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

sumpy/codegen.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,39 @@ def map_sum(self, expr, *args):
592592
# }}}
593593

594594

595+
# {{{ helmholtz rewrite
596+
class HelmholtzRewriter(CSECachingIdentityMapper, CallExternalRecMapper):
597+
def __init__(self, k, ik):
598+
self.k = k
599+
self.ik = ik
600+
601+
def map_variable(self, expr, *args):
602+
if expr.name == self.ik.name:
603+
return 1j*self.k
604+
else:
605+
return expr
606+
607+
def map_call(self, expr, *args):
608+
if isinstance(expr.function, prim.Variable) \
609+
and expr.function.name == "exp":
610+
params = expr.parameters
611+
assert len(params) == 1
612+
param = self.rec(params[0])
613+
if isinstance(param, prim.Product) and 1j in param.children:
614+
children = list(param.children)
615+
del children[children.index(1j)]
616+
params = (prim.Product(tuple(children)),)
617+
return prim.Call(prim.Variable("cos"), params) + \
618+
1j * prim.Call(prim.Variable("sin"), params)
619+
620+
return super().map_call(expr, *args)
621+
622+
map_common_subexpression_uncached = IdentityMapper.map_common_subexpression
623+
624+
625+
# }}}
626+
627+
595628
class MathConstantRewriter(CSECachingIdentityMapper, CallExternalRecMapper):
596629
def map_variable(self, expr, *args):
597630
if expr.name == "pi":

sumpy/kernel.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ def __init__(self, dim, helmholtz_k_name="k",
522522
if allow_evanescent:
523523
expr = var("exp")(var("I")*k*r)/r
524524
else:
525-
expr = (var("cos")(k*r) + var("I")*var("sin")(k*r))/r
525+
expr = var("exp")(var("Ik")*r)/r
526526
scaling = 1/(4*var("pi"))
527527
else:
528528
raise RuntimeError("unsupported dimensionality")
@@ -579,6 +579,15 @@ def get_pde_as_diff_op(self):
579579
k = sym.Symbol(self.helmholtz_k_name)
580580
return (laplacian(w) + k**2 * w)
581581

582+
def get_code_transformer(self):
583+
k = SpatialConstant(self.helmholtz_k_name)
584+
585+
if self.allow_evanescent:
586+
return lambda expr: expr
587+
else:
588+
from sumpy.codegen import HelmholtzRewriter
589+
return HelmholtzRewriter(k, var("Ik"))
590+
582591

583592
class YukawaKernel(ExpressionKernel):
584593
init_arg_names = ("dim", "yukawa_lambda_name")

0 commit comments

Comments
 (0)