22
33import numpy as np
44
5- from pytensor import config
65from pytensor .compile .ops import ViewOp
76from pytensor .graph .basic import Variable
87from pytensor .link .numba .dispatch import basic as numba_basic
@@ -137,7 +136,6 @@ def {scalar_op_fn_name}({', '.join(input_names)}):
137136
138137 return numba_basic .numba_njit (
139138 signature ,
140- fastmath = config .numba__fastmath ,
141139 # Functions that call a function pointer can't be cached
142140 cache = False ,
143141 )(scalar_op_fn )
@@ -177,19 +175,15 @@ def numba_funcify_Add(op, node, **kwargs):
177175 signature = create_numba_signature (node , force_scalar = True )
178176 nary_add_fn = binary_to_nary_func (node .inputs , "add" , "+" )
179177
180- return numba_basic .numba_njit (signature , fastmath = config .numba__fastmath )(
181- nary_add_fn
182- )
178+ return numba_basic .numba_njit (signature )(nary_add_fn )
183179
184180
185181@numba_funcify .register (Mul )
186182def numba_funcify_Mul (op , node , ** kwargs ):
187183 signature = create_numba_signature (node , force_scalar = True )
188184 nary_add_fn = binary_to_nary_func (node .inputs , "mul" , "*" )
189185
190- return numba_basic .numba_njit (signature , fastmath = config .numba__fastmath )(
191- nary_add_fn
192- )
186+ return numba_basic .numba_njit (signature )(nary_add_fn )
193187
194188
195189@numba_funcify .register (Cast )
@@ -239,7 +233,7 @@ def numba_funcify_Composite(op, node, **kwargs):
239233
240234 _ = kwargs .pop ("storage_map" , None )
241235
242- composite_fn = numba_basic .numba_njit (signature , fastmath = config . numba__fastmath )(
236+ composite_fn = numba_basic .numba_njit (signature )(
243237 numba_funcify (op .fgraph , squeeze_output = True , ** kwargs )
244238 )
245239 return composite_fn
@@ -267,7 +261,7 @@ def numba_funcify_Reciprocal(op, node, **kwargs):
267261 return numba_basic .global_numba_func (reciprocal )
268262
269263
270- @numba_basic .numba_njit ( fastmath = config . numba__fastmath )
264+ @numba_basic .numba_njit
271265def sigmoid (x ):
272266 return 1 / (1 + np .exp (- x ))
273267
@@ -277,7 +271,7 @@ def numba_funcify_Sigmoid(op, node, **kwargs):
277271 return numba_basic .global_numba_func (sigmoid )
278272
279273
280- @numba_basic .numba_njit ( fastmath = config . numba__fastmath )
274+ @numba_basic .numba_njit
281275def gammaln (x ):
282276 return math .lgamma (x )
283277
@@ -287,7 +281,7 @@ def numba_funcify_GammaLn(op, node, **kwargs):
287281 return numba_basic .global_numba_func (gammaln )
288282
289283
290- @numba_basic .numba_njit ( fastmath = config . numba__fastmath )
284+ @numba_basic .numba_njit
291285def logp1mexp (x ):
292286 if x < np .log (0.5 ):
293287 return np .log1p (- np .exp (x ))
@@ -300,7 +294,7 @@ def numba_funcify_Log1mexp(op, node, **kwargs):
300294 return numba_basic .global_numba_func (logp1mexp )
301295
302296
303- @numba_basic .numba_njit ( fastmath = config . numba__fastmath )
297+ @numba_basic .numba_njit
304298def erf (x ):
305299 return math .erf (x )
306300
@@ -310,7 +304,7 @@ def numba_funcify_Erf(op, **kwargs):
310304 return numba_basic .global_numba_func (erf )
311305
312306
313- @numba_basic .numba_njit ( fastmath = config . numba__fastmath )
307+ @numba_basic .numba_njit
314308def erfc (x ):
315309 return math .erfc (x )
316310
0 commit comments