@@ -592,6 +592,39 @@ def map_sum(self, expr, *args):
592592# }}}
593593
594594
595+ # {{{ helmholtz rewrite
596+ class HelmholtzRewriter (CSECachingIdentityMapper , CallExternalRecMapper ):
597+ def __init__ (self , k , ik ):
598+ self .k = k
599+ self .ik = ik
600+
601+ def map_variable (self , expr , * args ):
602+ if expr .name == self .ik .name :
603+ return 1j * self .k
604+ else :
605+ return expr
606+
607+ def map_call (self , expr , * args ):
608+ if isinstance (expr .function , prim .Variable ) \
609+ and expr .function .name == "exp" :
610+ params = expr .parameters
611+ assert len (params ) == 1
612+ param = self .rec (params [0 ])
613+ if isinstance (param , prim .Product ) and 1j in param .children :
614+ children = list (param .children )
615+ del children [children .index (1j )]
616+ params = (prim .Product (tuple (children )),)
617+ return prim .Call (prim .Variable ("cos" ), params ) + \
618+ 1j * prim .Call (prim .Variable ("sin" ), params )
619+
620+ return super ().map_call (expr , * args )
621+
622+ map_common_subexpression_uncached = IdentityMapper .map_common_subexpression
623+
624+
625+ # }}}
626+
627+
595628class MathConstantRewriter (CSECachingIdentityMapper , CallExternalRecMapper ):
596629 def map_variable (self , expr , * args ):
597630 if expr .name == "pi" :
0 commit comments