7878
7979
8080if TYPE_CHECKING :
81+ from ast import _ConstantValue # pyright: ignore[reportPrivateUsage]
8182 from collections .abc import Callable , Iterable
8283
8384 from pytato .target .python import BoundPythonProgram , NumpyLikePythonTarget
@@ -127,6 +128,10 @@ def first_true(iterable: Iterable[T], default: T,
127128 return next (filter (pred , iterable ), default )
128129
129130
131+ def _constant (value : object ) -> ast .Constant :
132+ return ast .Constant (cast ("_ConstantValue" , value ))
133+
134+
130135def _is_slice_trivial (slice_ : NormalizedSlice ,
131136 dim : ShapeComponent ) -> bool :
132137 """
@@ -217,26 +222,26 @@ def _rec_ary_or_constant(e: ArrayOrScalar) -> ast.expr:
217222 # generates code like: `np.float64("nan")`.
218223 return ast .Call (
219224 func = ast .Attribute (value = ast .Name (self .numpy ),
220- attr = e_np .dtype .name ),
221- args = [ast . Constant (value = "nan" )],
225+ attr = cast ( "str" , e_np .dtype .name ) ),
226+ args = [_constant (value = "nan" )],
222227 keywords = [])
223228 else :
224- return ast . Constant (e )
229+ return _constant (e )
225230
226231 if isinstance (hlo , FullOp ):
227232 if hlo .fill_value == 1 :
228233 if expr .dtype == np .dtype (float ):
229234 rhs = ast .Call (
230235 ast .Attribute (ast .Name (self .numpy_backend ),
231236 "ones" ),
232- args = [ast .Tuple (elts = [ast . Constant (d )
237+ args = [ast .Tuple (elts = [_constant (d )
233238 for d in expr .shape ])],
234239 keywords = [])
235240 else :
236241 rhs = ast .Call (
237242 ast .Attribute (ast .Name (self .numpy_backend ),
238243 "ones" ),
239- args = [ast .Tuple (elts = [ast . Constant (d )
244+ args = [ast .Tuple (elts = [_constant (d )
240245 for d in expr .shape ])],
241246 keywords = [ast .keyword (
242247 arg = "dtype" ,
@@ -248,14 +253,14 @@ def _rec_ary_or_constant(e: ArrayOrScalar) -> ast.expr:
248253 rhs = ast .Call (
249254 ast .Attribute (ast .Name (self .numpy_backend ),
250255 "zeros" ),
251- args = [ast .Tuple (elts = [ast . Constant (d )
256+ args = [ast .Tuple (elts = [_constant (d )
252257 for d in expr .shape ])],
253258 keywords = [])
254259 else :
255260 rhs = ast .Call (
256261 ast .Attribute (ast .Name (self .numpy_backend ),
257262 "zeros" ),
258- args = [ast .Tuple (elts = [ast . Constant (d )
263+ args = [ast .Tuple (elts = [_constant (d )
259264 for d in expr .shape ])],
260265 keywords = [ast .keyword (
261266 arg = "dtype" ,
@@ -266,7 +271,7 @@ def _rec_ary_or_constant(e: ArrayOrScalar) -> ast.expr:
266271 rhs = ast .Call (
267272 ast .Attribute (ast .Name (self .numpy_backend ),
268273 "full" ),
269- args = [ast .Tuple (elts = [ast . Constant (d )
274+ args = [ast .Tuple (elts = [_constant (d )
270275 for d in expr .shape ]),
271276 _rec_ary_or_constant (hlo .fill_value ),
272277 ],
@@ -324,7 +329,7 @@ def _rec_ary_or_constant(e: ArrayOrScalar) -> ast.expr:
324329 rhs = ast .Call (ast .Attribute (ast .Name (self .numpy_backend ),
325330 "broadcast_to" ),
326331 args = [ast .Name (self .rec (hlo .x )),
327- ast .Tuple (elts = [ast . Constant (d )
332+ ast .Tuple (elts = [_constant (d )
328333 for d in expr .shape ])],
329334 keywords = [])
330335 elif isinstance (hlo , ReduceOp ):
@@ -339,9 +344,9 @@ def _rec_ary_or_constant(e: ArrayOrScalar) -> ast.expr:
339344 else :
340345 if len (hlo .axes ) == 1 :
341346 axis , = hlo .axes .keys ()
342- axis_ast : ast .expr = ast . Constant (axis )
347+ axis_ast : ast .expr = _constant (axis )
343348 else :
344- axis_ast = ast .Tuple (elts = [ast . Constant (e )
349+ axis_ast = ast .Tuple (elts = [_constant (e )
345350 for e in sorted (hlo .axes .keys ())])
346351 rhs = ast .Call (ast .Attribute (ast .Name (self .numpy_backend ),
347352 np_fn_name ),
@@ -366,7 +371,7 @@ def map_stack(self, expr: Stack) -> str:
366371 args = [ast .List ([ast .Name (id_ )
367372 for id_ in rec_ids ])],
368373 keywords = [ast .keyword (arg = "axis" ,
369- value = ast . Constant (expr .axis ))])
374+ value = _constant (expr .axis ))])
370375
371376 return self ._record_line_and_return_lhs (lhs , rhs )
372377
@@ -379,7 +384,7 @@ def map_concatenate(self, expr: Concatenate) -> str:
379384 args = [ast .List ([ast .Name (id_ )
380385 for id_ in rec_ids ])],
381386 keywords = [ast .keyword (arg = "axis" ,
382- value = ast . Constant (expr .axis ))])
387+ value = _constant (expr .axis ))])
383388
384389 return self ._record_line_and_return_lhs (lhs , rhs )
385390
@@ -389,9 +394,9 @@ def map_roll(self, expr: Roll) -> str:
389394 args = [ast .Name (self .rec (expr .array )),
390395 ],
391396 keywords = [ast .keyword (arg = "shift" ,
392- value = ast . Constant (expr .shift )),
397+ value = _constant (expr .shift )),
393398 ast .keyword (arg = "axis" ,
394- value = ast . Constant (expr .axis ))])
399+ value = _constant (expr .axis ))])
395400
396401 return self ._record_line_and_return_lhs (lhs , rhs )
397402
@@ -404,7 +409,7 @@ def map_axis_permutation(self, expr: AxisPermutation) -> str:
404409 args = [ast .Name (self .rec (expr .array ))],
405410 keywords = [ast .keyword (
406411 arg = "axes" ,
407- value = ast .List (elts = [ast . Constant (a )
412+ value = ast .List (elts = [_constant (a )
408413 for a in expr .axis_permutation ]))
409414 ])
410415
@@ -427,7 +432,7 @@ def _map_index_base(self, expr: IndexBase) -> str:
427432
428433 def _rec_idx (idx : IndexExpr , dim : ShapeComponent ) -> ast .expr :
429434 if isinstance (idx , int ):
430- return ast . Constant (idx )
435+ return _constant (idx )
431436 elif isinstance (idx , NormalizedSlice ):
432437 step = idx .step if idx .step != 1 else None
433438 if idx .step > 0 :
@@ -458,13 +463,13 @@ class SliceKwargs(TypedDict):
458463 kwargs : SliceKwargs = {}
459464 if step is not None :
460465 assert isinstance (step , int )
461- kwargs ["step" ] = ast . Constant (step )
466+ kwargs ["step" ] = _constant (step )
462467 if start is not None :
463468 assert isinstance (start , int )
464- kwargs ["lower" ] = ast . Constant (start )
469+ kwargs ["lower" ] = _constant (start )
465470 if stop is not None :
466471 assert isinstance (stop , int )
467- kwargs ["upper" ] = ast . Constant (stop )
472+ kwargs ["upper" ] = _constant (stop )
468473
469474 return ast .Slice (** kwargs )
470475 else :
@@ -500,7 +505,7 @@ def map_einsum(self, expr: Einsum) -> str:
500505 lhs = self .vng ("_pt_tmp" )
501506 args = [ast .Name (self .rec (arg )) for arg in expr .args ]
502507 rhs = ast .Call (ast .Attribute (ast .Name (self .numpy_backend ), "einsum" ),
503- args = [ast . Constant (get_einsum_specification (expr )),
508+ args = [_constant (get_einsum_specification (expr )),
504509 * args ],
505510 keywords = [],
506511 )
@@ -513,7 +518,7 @@ def map_reshape(self, expr: Reshape) -> str:
513518 raise NotImplementedError ("Non-integral reshapes." )
514519 rhs = ast .Call (ast .Attribute (ast .Name (self .numpy_backend ), "reshape" ),
515520 args = [ast .Name (self .rec (expr .array )),
516- ast .Tuple (elts = [ast . Constant (d )
521+ ast .Tuple (elts = [_constant (d )
517522 for d in expr .shape ])],
518523 keywords = [],
519524 )
@@ -530,7 +535,7 @@ def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> str:
530535 keys : list [expr_t | None ] = []
531536 values : list [expr_t ] = []
532537 for name , subexpr in sorted (expr ._data .items ()):
533- keys .append (ast . Constant (name ))
538+ keys .append (_constant (name ))
534539 values .append (ast .Name (self .rec (subexpr )))
535540
536541 rhs = ast .Dict (keys = keys , values = values )
0 commit comments