Skip to content

Commit d0476b5

Browse files
committed
compiler: make printer namespace more flexible for expression depenedency
1 parent 08ea454 commit d0476b5

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

devito/ir/cgen/printer.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ def prec_literal(self, expr):
9292
def func_literal(self, expr):
9393
return self._func_literals.get(self._prec(expr), '')
9494

95+
def ns(self, expr):
96+
return self._ns
97+
9598
def func_prefix(self, expr, mfunc=False):
9699
prefix = self._func_prefix.get(self._prec(expr), '')
97100
if mfunc:
@@ -193,7 +196,7 @@ def _print_math_func(self, expr, nest=False, known=None):
193196
else:
194197
args = ', '.join([self._print(arg) for arg in expr.args])
195198

196-
return f'{self._ns}{cname}({args})'
199+
return f'{self.ns(expr)}{cname}({args})'
197200

198201
def _print_Pow(self, expr):
199202
# Completely reimplement `_print_Pow` from sympy, since it doesn't
@@ -207,11 +210,11 @@ def _print_Pow(self, expr):
207210
return self._print_Float(Float(1.0)) + '/' + \
208211
self.parenthesize(expr.base, PREC)
209212
elif equal_valued(expr.exp, 0.5):
210-
return f'{self._ns}sqrt{suffix}({base})'
213+
return f'{self.ns(expr)}sqrt{suffix}({base})'
211214
elif expr.exp == S.One/3 and self.standard != 'C89':
212-
return f'{self._ns}cbrt{suffix}({base})'
215+
return f'{self.ns(expr)}cbrt{suffix}({base})'
213216
else:
214-
return f'{self._ns}pow{suffix}({base}, {self._print(expr.exp)})'
217+
return f'{self.ns(expr)}pow{suffix}({base}, {self._print(expr.exp)})'
215218

216219
def _print_SafeInv(self, expr):
217220
"""Print a SafeInv as a C-like division with a check for zero."""
@@ -241,7 +244,7 @@ def _print_Mul(self, expr):
241244
def _print_fmath_func(self, name, expr):
242245
args = ",".join([self._print(i) for i in expr.args])
243246
func = f'{self.func_prefix(expr, mfunc=True)}{name}{self.func_literal(expr)}'
244-
return f"{self._ns}{func}({args})"
247+
return f"{self.ns(expr)}{func}({args})"
245248

246249
def _print_Min(self, expr):
247250
if len(expr.args) > 2:
@@ -391,7 +394,7 @@ def _print_SizeOf(self, expr):
391394
return f'sizeof({self._print(expr.intype)}{self._print(expr.stars)})'
392395

393396
def _print_MathFunction(self, expr):
394-
return f"{self._ns}{self._print_DefFunction(expr)}"
397+
return f"{self.ns(expr)}{self._print_DefFunction(expr)}"
395398

396399
def _print_Fallback(self, expr):
397400
return expr.__str__()

0 commit comments

Comments
 (0)