Skip to content

Commit ed13e83

Browse files
committed
Misc bpr typing fixes, update baseline
1 parent 70c79ad commit ed13e83

File tree

10 files changed

+5720
-24108
lines changed

10 files changed

+5720
-24108
lines changed

.basedpyright/baseline.json

Lines changed: 5640 additions & 24070 deletions
Large diffs are not rendered by default.

doc/codegen.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,12 @@ Generic target support
1212
----------------------
1313

1414
.. automodule:: pytato.target
15+
16+
References
17+
----------
18+
19+
.. currentmodule:: cl_array
20+
21+
.. class:: Allocator
22+
23+
See :class:`pyopencl.array.Allocator`.

pyproject.toml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,30 @@ ignore = [
178178
]
179179

180180

181+
[[tool.basedpyright.executionEnvironments]]
182+
root = "test"
183+
reportUnknownArgumentType = "none"
184+
reportUnknownVariableType = "none"
185+
reportMissingParameterType = "none"
186+
reportAttributeAccessIssue = "none"
187+
reportMissingImports = "none"
188+
reportArgumentType = "hint"
189+
reportUnknownMemberType = "hint"
190+
reportUnknownParameterType = "hint"
191+
reportAny = "none"
192+
193+
[[tool.basedpyright.executionEnvironments]]
194+
root = "examples"
195+
reportUnknownArgumentType = "none"
196+
reportUnknownVariableType = "none"
197+
reportMissingParameterType = "none"
198+
reportAttributeAccessIssue = "none"
199+
reportMissingImports = "none"
200+
reportArgumentType = "hint"
201+
reportUnknownMemberType = "hint"
202+
reportUnknownParameterType = "hint"
203+
reportAny = "none"
204+
181205
[tool.typos.default]
182206
extend-ignore-re = [
183207
"(?Rm)^.*(#|//)\\s*spellchecker:\\s*disable-line$"

pytato/loopy.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060

6161

6262
if TYPE_CHECKING:
63-
from collections.abc import Iterable, Iterator, Mapping, Sequence
63+
from collections.abc import Iterator, Mapping, Sequence
6464

6565
from pymbolic.typing import ArithmeticExpression, Expression, Integer
6666

@@ -338,8 +338,8 @@ def _get_val_in_bset(bset: isl.BasicSet, idim: int) -> ScalarExpression:
338338
return aff_to_expr(aff)
339339

340340

341-
def solve_constraints(variables: Iterable[str],
342-
parameters: Iterable[str],
341+
def solve_constraints(variables: Sequence[str],
342+
parameters: Sequence[str],
343343
constraints: Sequence[tuple[ArithmeticExpression,
344344
ArithmeticExpression]],
345345

@@ -520,10 +520,10 @@ def extend_bindings_with_shape_inference(knl: lp.LoopKernel,
520520

521521
# }}}
522522

523-
solutions = solve_constraints(variables={_lp_var_to_global_namespace(var)
524-
for var in lp_size_params},
525-
parameters={_pt_var_to_global_namespace(var.name)
526-
for var in pt_size_params},
523+
solutions = solve_constraints(variables=list({_lp_var_to_global_namespace(var)
524+
for var in lp_size_params}),
525+
parameters=list({_pt_var_to_global_namespace(var.name)
526+
for var in pt_size_params}),
527527
constraints=constraints)
528528

529529
as_pt_size_param = EvaluationMapper({_pt_var_to_global_namespace(arg.name): arg

pytato/pad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def _normalize_pad_width(
130130

131131

132132
def pad(array: Array,
133-
pad_width: Integer | Sequence[Integer],
133+
pad_width: int | Sequence[int],
134134
mode: str = "constant",
135135
**kwargs: Any) -> Array:
136136
r"""

pytato/stringifier.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ def map_foreign(self, expr: Any, depth: int) -> str:
109109
elif isinstance(expr, frozenset | set):
110110
return "{" + ", ".join(self.rec(el, depth) for el in expr) + "}"
111111
elif isinstance(expr, np.dtype):
112-
return f"'{expr.name}'"
112+
dtype_name = cast("str", expr.name)
113+
return f"'{dtype_name}'"
113114
else:
114115
return repr(expr)
115116

pytato/target/loopy/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
from collections.abc import Callable, Mapping
6868

6969
import pyopencl
70+
import pyopencl.array as cl_array
7071

7172

7273
class ImplSubstitution(ImplementationStrategy):
@@ -176,7 +177,7 @@ def with_transformed_program(self, f: Callable[[loopy.TranslationUnit],
176177
def _get_processed_bound_arguments(
177178
self,
178179
queue: pyopencl.CommandQueue,
179-
allocator: Callable[[int], pyopencl.MemoryObject] | None,
180+
allocator: cl_array.Allocator | None,
180181
) -> Mapping[str, Any]:
181182
import pyopencl.array as cla
182183

@@ -216,7 +217,8 @@ def all_bound_args_on_host(self) -> bool:
216217
for arg in self.bound_arguments.values())
217218

218219
def __call__(self, queue: pyopencl.CommandQueue, # type: ignore[no-untyped-def,no-any-unimported]
219-
allocator=None, wait_for=None, out_host: bool | None = None,
220+
allocator: cl_array.Allocator | None = None,
221+
wait_for: pyopencl.WaitList = None, out_host: bool | None = None,
220222
**kwargs: Any) -> Any:
221223
"""Convenience function for launching a :mod:`pyopencl` computation."""
222224

pytato/target/python/numpy_like.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878

7979

8080
if TYPE_CHECKING:
81+
from ast import _ConstantValue # pyright: ignore[reportPrivateUsage]
8182
from collections.abc import Callable, Iterable
8283

8384
from pytato.target.python import BoundPythonProgram, NumpyLikePythonTarget
@@ -127,6 +128,10 @@ def first_true(iterable: Iterable[T], default: T,
127128
return next(filter(pred, iterable), default)
128129

129130

131+
def _constant(value: object) -> ast.Constant:
132+
return ast.Constant(cast("_ConstantValue", value))
133+
134+
130135
def _is_slice_trivial(slice_: NormalizedSlice,
131136
dim: ShapeComponent) -> bool:
132137
"""
@@ -217,26 +222,26 @@ def _rec_ary_or_constant(e: ArrayOrScalar) -> ast.expr:
217222
# generates code like: `np.float64("nan")`.
218223
return ast.Call(
219224
func=ast.Attribute(value=ast.Name(self.numpy),
220-
attr=e_np.dtype.name),
221-
args=[ast.Constant(value="nan")],
225+
attr=cast("str", e_np.dtype.name)),
226+
args=[_constant(value="nan")],
222227
keywords=[])
223228
else:
224-
return ast.Constant(e)
229+
return _constant(e)
225230

226231
if isinstance(hlo, FullOp):
227232
if hlo.fill_value == 1:
228233
if expr.dtype == np.dtype(float):
229234
rhs = ast.Call(
230235
ast.Attribute(ast.Name(self.numpy_backend),
231236
"ones"),
232-
args=[ast.Tuple(elts=[ast.Constant(d)
237+
args=[ast.Tuple(elts=[_constant(d)
233238
for d in expr.shape])],
234239
keywords=[])
235240
else:
236241
rhs = ast.Call(
237242
ast.Attribute(ast.Name(self.numpy_backend),
238243
"ones"),
239-
args=[ast.Tuple(elts=[ast.Constant(d)
244+
args=[ast.Tuple(elts=[_constant(d)
240245
for d in expr.shape])],
241246
keywords=[ast.keyword(
242247
arg="dtype",
@@ -248,14 +253,14 @@ def _rec_ary_or_constant(e: ArrayOrScalar) -> ast.expr:
248253
rhs = ast.Call(
249254
ast.Attribute(ast.Name(self.numpy_backend),
250255
"zeros"),
251-
args=[ast.Tuple(elts=[ast.Constant(d)
256+
args=[ast.Tuple(elts=[_constant(d)
252257
for d in expr.shape])],
253258
keywords=[])
254259
else:
255260
rhs = ast.Call(
256261
ast.Attribute(ast.Name(self.numpy_backend),
257262
"zeros"),
258-
args=[ast.Tuple(elts=[ast.Constant(d)
263+
args=[ast.Tuple(elts=[_constant(d)
259264
for d in expr.shape])],
260265
keywords=[ast.keyword(
261266
arg="dtype",
@@ -266,7 +271,7 @@ def _rec_ary_or_constant(e: ArrayOrScalar) -> ast.expr:
266271
rhs = ast.Call(
267272
ast.Attribute(ast.Name(self.numpy_backend),
268273
"full"),
269-
args=[ast.Tuple(elts=[ast.Constant(d)
274+
args=[ast.Tuple(elts=[_constant(d)
270275
for d in expr.shape]),
271276
_rec_ary_or_constant(hlo.fill_value),
272277
],
@@ -324,7 +329,7 @@ def _rec_ary_or_constant(e: ArrayOrScalar) -> ast.expr:
324329
rhs = ast.Call(ast.Attribute(ast.Name(self.numpy_backend),
325330
"broadcast_to"),
326331
args=[ast.Name(self.rec(hlo.x)),
327-
ast.Tuple(elts=[ast.Constant(d)
332+
ast.Tuple(elts=[_constant(d)
328333
for d in expr.shape])],
329334
keywords=[])
330335
elif isinstance(hlo, ReduceOp):
@@ -339,9 +344,9 @@ def _rec_ary_or_constant(e: ArrayOrScalar) -> ast.expr:
339344
else:
340345
if len(hlo.axes) == 1:
341346
axis, = hlo.axes.keys()
342-
axis_ast: ast.expr = ast.Constant(axis)
347+
axis_ast: ast.expr = _constant(axis)
343348
else:
344-
axis_ast = ast.Tuple(elts=[ast.Constant(e)
349+
axis_ast = ast.Tuple(elts=[_constant(e)
345350
for e in sorted(hlo.axes.keys())])
346351
rhs = ast.Call(ast.Attribute(ast.Name(self.numpy_backend),
347352
np_fn_name),
@@ -366,7 +371,7 @@ def map_stack(self, expr: Stack) -> str:
366371
args=[ast.List([ast.Name(id_)
367372
for id_ in rec_ids])],
368373
keywords=[ast.keyword(arg="axis",
369-
value=ast.Constant(expr.axis))])
374+
value=_constant(expr.axis))])
370375

371376
return self._record_line_and_return_lhs(lhs, rhs)
372377

@@ -379,7 +384,7 @@ def map_concatenate(self, expr: Concatenate) -> str:
379384
args=[ast.List([ast.Name(id_)
380385
for id_ in rec_ids])],
381386
keywords=[ast.keyword(arg="axis",
382-
value=ast.Constant(expr.axis))])
387+
value=_constant(expr.axis))])
383388

384389
return self._record_line_and_return_lhs(lhs, rhs)
385390

@@ -389,9 +394,9 @@ def map_roll(self, expr: Roll) -> str:
389394
args=[ast.Name(self.rec(expr.array)),
390395
],
391396
keywords=[ast.keyword(arg="shift",
392-
value=ast.Constant(expr.shift)),
397+
value=_constant(expr.shift)),
393398
ast.keyword(arg="axis",
394-
value=ast.Constant(expr.axis))])
399+
value=_constant(expr.axis))])
395400

396401
return self._record_line_and_return_lhs(lhs, rhs)
397402

@@ -404,7 +409,7 @@ def map_axis_permutation(self, expr: AxisPermutation) -> str:
404409
args=[ast.Name(self.rec(expr.array))],
405410
keywords=[ast.keyword(
406411
arg="axes",
407-
value=ast.List(elts=[ast.Constant(a)
412+
value=ast.List(elts=[_constant(a)
408413
for a in expr.axis_permutation]))
409414
])
410415

@@ -427,7 +432,7 @@ def _map_index_base(self, expr: IndexBase) -> str:
427432

428433
def _rec_idx(idx: IndexExpr, dim: ShapeComponent) -> ast.expr:
429434
if isinstance(idx, int):
430-
return ast.Constant(idx)
435+
return _constant(idx)
431436
elif isinstance(idx, NormalizedSlice):
432437
step = idx.step if idx.step != 1 else None
433438
if idx.step > 0:
@@ -458,13 +463,13 @@ class SliceKwargs(TypedDict):
458463
kwargs: SliceKwargs = {}
459464
if step is not None:
460465
assert isinstance(step, int)
461-
kwargs["step"] = ast.Constant(step)
466+
kwargs["step"] = _constant(step)
462467
if start is not None:
463468
assert isinstance(start, int)
464-
kwargs["lower"] = ast.Constant(start)
469+
kwargs["lower"] = _constant(start)
465470
if stop is not None:
466471
assert isinstance(stop, int)
467-
kwargs["upper"] = ast.Constant(stop)
472+
kwargs["upper"] = _constant(stop)
468473

469474
return ast.Slice(**kwargs)
470475
else:
@@ -500,7 +505,7 @@ def map_einsum(self, expr: Einsum) -> str:
500505
lhs = self.vng("_pt_tmp")
501506
args = [ast.Name(self.rec(arg)) for arg in expr.args]
502507
rhs = ast.Call(ast.Attribute(ast.Name(self.numpy_backend), "einsum"),
503-
args=[ast.Constant(get_einsum_specification(expr)),
508+
args=[_constant(get_einsum_specification(expr)),
504509
*args],
505510
keywords=[],
506511
)
@@ -513,7 +518,7 @@ def map_reshape(self, expr: Reshape) -> str:
513518
raise NotImplementedError("Non-integral reshapes.")
514519
rhs = ast.Call(ast.Attribute(ast.Name(self.numpy_backend), "reshape"),
515520
args=[ast.Name(self.rec(expr.array)),
516-
ast.Tuple(elts=[ast.Constant(d)
521+
ast.Tuple(elts=[_constant(d)
517522
for d in expr.shape])],
518523
keywords=[],
519524
)
@@ -530,7 +535,7 @@ def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> str:
530535
keys: list[expr_t | None] = []
531536
values: list[expr_t] = []
532537
for name, subexpr in sorted(expr._data.items()):
533-
keys.append(ast.Constant(name))
538+
keys.append(_constant(name))
534539
values.append(ast.Name(self.rec(subexpr)))
535540

536541
rhs = ast.Dict(keys=keys, values=values)

pytato/transform/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2366,7 +2366,7 @@ def _get_data_dedup_cache_key(self, ary: DataInterface) -> CacheKeyT:
23662366

23672367
if isinstance(ary, CLArray):
23682368
base_data = ary.base_data
2369-
if isinstance(ary.base_data, MemoryObjectHolder):
2369+
if isinstance(base_data, MemoryObjectHolder):
23702370
ptr = base_data.int_ptr
23712371
elif SVMPointer is not None and isinstance(base_data, SVMPointer):
23722372
ptr = base_data.svm_ptr

pytato/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -423,8 +423,7 @@ def _is_non_negative(expr: ShapeComponent) -> Bool:
423423
for expr in InputGatherer()(expr)})
424424
aff = ShapeToISLExpressionMapper(space)(expr)
425425

426-
# type-ignore because islpy is not typed yet
427-
return (aff.ge_set(aff * 0) >= _get_size_params_assumptions_bset(space)) # type: ignore[no-any-return]
426+
return aff.ge_set(aff * 0) >= _get_size_params_assumptions_bset(space).to_set()
428427

429428

430429
def _is_non_positive(expr: ShapeComponent) -> Bool:
@@ -485,7 +484,9 @@ def _normalize_slice(slice_: slice,
485484

486485

487486
def _normalized_slice_len(slice_: NormalizedSlice) -> ShapeComponent:
488-
start, stop, step = slice_.start, slice_.stop, slice_.step
487+
start, stop, step = cast(
488+
"tuple[int, int, int]",
489+
(slice_.start, slice_.stop, slice_.step))
489490

490491
if step > 0:
491492
if _is_non_negative(stop - start):

0 commit comments

Comments
 (0)