Skip to content

Commit 1588911

Browse files
committed
Sharpen types, including precise islpy
Bump islpy dependency
1 parent c59bda6 commit 1588911

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+50533
-108332
lines changed

.basedpyright/baseline.json

Lines changed: 48023 additions & 107171 deletions
Large diffs are not rendered by default.

doc/ref_internals.rst

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,39 @@ Schedule
6161
.. automodule:: loopy.schedule.tools
6262
.. automodule:: loopy.schedule.tree
6363

64+
References
65+
----------
66+
67+
Mostly things that Sphinx (our documentation tool) should resolve but won't.
68+
69+
.. class:: constantdict
70+
71+
See :class:`constantdict.constantdict`.
72+
73+
.. class:: DTypeLike
74+
75+
See :data:`numpy.typing.DTypeLike`.
76+
77+
.. currentmodule:: p
78+
79+
.. class:: Call
80+
81+
See :class:`pymbolic.primitives.Call`.
82+
83+
.. class:: CallWithKwargs
84+
85+
See :class:`pymbolic.primitives.CallWithKwargs`.
86+
87+
.. currentmodule:: isl
88+
89+
.. class:: Space
90+
91+
See :class:`islpy.Space`.
92+
93+
.. class:: Aff
94+
95+
See :class:`islpy.Aff`.
96+
97+
.. class:: PwAff
6498

99+
See :class:`islpy.PwAff`.

doc/ref_kernel.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,12 @@ Instructions
270270
Assignment objects
271271
^^^^^^^^^^^^^^^^^^
272272

273+
.. currentmodule:: loopy.kernel.instruction
274+
275+
.. class:: Assignable
276+
277+
.. currentmodule:: loopy
278+
273279
.. autoclass:: Assignment
274280

275281
.. _assignment-syntax:

loopy/auto_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
"""
2525

2626
from dataclasses import dataclass
27-
from typing import TYPE_CHECKING
27+
from typing import TYPE_CHECKING, cast
2828
from warnings import warn
2929

3030
import numpy as np
@@ -37,6 +37,8 @@
3737
if TYPE_CHECKING:
3838
import pyopencl.array as cla
3939

40+
from loopy.types import NumpyType
41+
4042

4143
AUTO_TEST_SKIP_RUN = False
4244

@@ -142,7 +144,7 @@ def make_ref_args(kernel, queue, parameters):
142144
"testing" % arg.name)
143145

144146
shape = evaluate_shape(arg.shape, parameters)
145-
dtype = arg.dtype
147+
dtype = cast("NumpyType", arg.dtype).dtype
146148

147149
is_output = arg.is_output
148150

loopy/check.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@
2525

2626
import logging
2727
from collections import defaultdict
28+
from dataclasses import dataclass
2829
from functools import reduce
29-
from typing import TYPE_CHECKING
30+
from typing import TYPE_CHECKING, cast
3031

3132
import numpy as np
33+
from typing_extensions import override
3234

3335
import islpy as isl
3436
from islpy import dim_type
@@ -76,6 +78,7 @@
7678
if TYPE_CHECKING:
7779
from collections.abc import Mapping, Sequence
7880

81+
import pymbolic.primitives as p
7982
from pymbolic.typing import Expression
8083

8184
from loopy.kernel import LoopKernel
@@ -707,10 +710,12 @@ def subst_func(x):
707710
# }}}
708711

709712

710-
class _AccessCheckMapper(WalkMapper):
711-
def __init__(self, kernel, callables_table):
712-
self.kernel = kernel
713-
self.callables_table = callables_table
713+
@dataclass
714+
class _AccessCheckMapper(WalkMapper[[isl.Set, str]]):
715+
kernel: LoopKernel
716+
callables_table: CallablesTable
717+
718+
def __post_init__(self) -> None:
714719
super().__init__()
715720

716721
@memoize_method
@@ -719,14 +724,20 @@ def _make_slab(self, space, iname, start, stop):
719724
return make_slab(space, iname, start, stop)
720725

721726
@memoize_method
722-
def _get_access_range(self, domain, subscript):
723-
from loopy.symbolic import UnableToDetermineAccessRangeError, get_access_map
727+
def _get_access_range(
728+
self,
729+
domain: isl.Set,
730+
subscript: tuple[Expression, ...]
731+
):
732+
from loopy.diagnostic import UnableToDetermineAccessRangeError
733+
from loopy.symbolic import get_access_map
724734
try:
725735
return get_access_map(domain, subscript).range()
726736
except UnableToDetermineAccessRangeError:
727737
return None
728738

729-
def map_subscript(self, expr, domain, insn_id):
739+
@override
740+
def map_subscript(self, expr: p.Subscript, domain: isl.Set, insn_id: str):
730741
WalkMapper.map_subscript(self, expr, domain, insn_id)
731742

732743
from pymbolic.primitives import Variable
@@ -742,6 +753,7 @@ def map_subscript(self, expr, domain, insn_id):
742753
shape = tv.shape
743754

744755
if shape is not None:
756+
assert isinstance(shape, tuple)
745757
subscript = expr.index
746758

747759
if not isinstance(subscript, tuple):
@@ -787,7 +799,8 @@ def map_subscript(self, expr, domain, insn_id):
787799
" establish '%s' is a subset of '%s')."
788800
% (expr, insn_id, access_range, shape_domain))
789801

790-
def map_if(self, expr, domain, insn_id):
802+
@override
803+
def map_if(self, expr: p. If, domain: isl.Set, insn_id: str):
791804
from loopy.symbolic import condition_to_set
792805
then_set = condition_to_set(domain.space, expr.condition)
793806
if then_set is None:
@@ -800,7 +813,8 @@ def map_if(self, expr, domain, insn_id):
800813
self.rec(expr.then, domain & then_set, insn_id)
801814
self.rec(expr.else_, domain & else_set, insn_id)
802815

803-
def map_call(self, expr, domain, insn_id):
816+
@override
817+
def map_call(self, expr: p.Call, domain: isl.Set, insn_id: str):
804818
# perform access checks on the call arguments
805819
super().map_call(expr, domain, insn_id)
806820

@@ -817,7 +831,9 @@ def map_call(self, expr, domain, insn_id):
817831
and isinstance(self.callables_table[expr.function.name],
818832
CallableKernel)):
819833

820-
subkernel = self.callables_table[expr.function.name].subkernel
834+
subkernel = cast(
835+
"CallableKernel",
836+
self.callables_table[expr.function.name]).subkernel
821837

822838
# The plan here is to add the constraints coming from the values
823839
# args passed at a call-site as assumptions to the callee. To avoid
@@ -835,8 +851,8 @@ def map_call(self, expr, domain, insn_id):
835851

836852
kw_space = isl.Space.create_from_names(
837853
subkernel.isl_context, set=[],
838-
params=(get_dependencies(tuple(kwargs.values()))
839-
| set(kwargs.keys())))
854+
params=[*get_dependencies(tuple(kwargs.values())),
855+
*kwargs.keys()])
840856

841857
extra_assumptions = isl.BasicSet.universe(kw_space).params()
842858

@@ -894,8 +910,8 @@ def _check_bounds_inner(kernel: LoopKernel, callables_table: CallablesTable) ->
894910
domain, assumptions = isl.align_two(domain, kernel.assumptions)
895911
domain_with_assumptions = domain & assumptions
896912

897-
def run_acm(expr):
898-
acm(expr, domain_with_assumptions, insn.id) # noqa: B023
913+
def run_acm(expr: Expression):
914+
acm(expr, domain_with_assumptions, not_none(insn.id)) # noqa: B023
899915
return expr
900916

901917
insn.with_transformed_expressions(run_acm)
@@ -1659,7 +1675,10 @@ def _get_sub_array_ref_swept_range(
16591675
from loopy.symbolic import get_access_map
16601676
domain = kernel.get_inames_domain(frozenset({iname_var.name
16611677
for iname_var in sar.swept_inames}))
1662-
return get_access_map(domain, sar.swept_inames, kernel.assumptions).range()
1678+
return get_access_map(
1679+
domain.to_set(),
1680+
sar.swept_inames,
1681+
kernel.assumptions.to_set()).range()
16631682

16641683

16651684
def _are_sub_array_refs_equivalent(
@@ -1875,7 +1894,7 @@ def pre_codegen_checks(t_unit: TranslationUnit) -> None:
18751894

18761895
def check_implemented_domains(
18771896
kernel: LoopKernel,
1878-
implemented_domains: Mapping[str, isl.Set],
1897+
implemented_domains: Mapping[str, Sequence[isl.Set]],
18791898
code: str | None = None,
18801899
) -> bool:
18811900
from islpy import align_two, dim_type
@@ -1942,7 +1961,7 @@ def check_implemented_domains(
19421961
d_minus_i = desired_domain - insn_impl_domain
19431962

19441963
parameter_inames = {
1945-
insn_domain.get_dim_name(dim_type.param, i)
1964+
not_none(insn_domain.get_dim_name(dim_type.param, i))
19461965
for i in range(insn_impl_domain.dim(dim_type.param))}
19471966

19481967
lines = []

loopy/codegen/__init__.py

Lines changed: 23 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,15 @@
2828
from typing import (
2929
TYPE_CHECKING,
3030
Any,
31-
Mapping,
32-
Sequence,
3331
)
3432

3533
import constantdict
3634

35+
from loopy.typing import not_none
36+
3737

3838
logger = logging.getLogger(__name__)
3939

40-
from functools import reduce
4140

4241
import islpy # to help out Sphinx
4342
import islpy as isl
@@ -47,20 +46,22 @@
4746

4847
from loopy.diagnostic import LoopyError, warn
4948
from loopy.kernel.function_interface import CallableKernel
50-
from loopy.symbolic import CombineMapper
5149
from loopy.tools import LoopyKeyBuilder, caches
5250
from loopy.version import DATA_MODEL_VERSION
5351

5452

5553
if TYPE_CHECKING:
54+
from collections.abc import Mapping, Sequence
55+
56+
from pymbolic import Expression
57+
5658
from loopy.codegen.result import CodeGenerationResult, GeneratedProgram
5759
from loopy.codegen.tools import CodegenOperationCacheManager
5860
from loopy.kernel import LoopKernel
5961
from loopy.library.reduction import ReductionOpFunction
6062
from loopy.target import TargetBase
6163
from loopy.translation_unit import CallablesTable, TranslationUnit
6264
from loopy.types import LoopyType
63-
from loopy.typing import Expression
6465

6566

6667
__doc__ = """
@@ -206,7 +207,7 @@ def intersect(self, other):
206207
new_impl, new_other = isl.align_two(self.implemented_domain, other)
207208
return self.copy(implemented_domain=new_impl & new_other)
208209

209-
def fix(self, iname, aff):
210+
def fix(self, iname: str, aff: isl.Aff) -> CodeGenerationState:
210211
new_impl_domain = self.implemented_domain
211212

212213
impl_space = self.implemented_domain.get_space()
@@ -296,32 +297,6 @@ def ast_builder(self):
296297
caches.append(code_gen_cache)
297298

298299

299-
class InKernelCallablesCollector(CombineMapper):
300-
"""
301-
Returns an instance of :class:`frozenset` containing instances of
302-
:class:`loopy.kernel.function_interface.InKernelCallable` in the
303-
:attr:``kernel`.
304-
"""
305-
def __init__(self, kernel):
306-
self.kernel = kernel
307-
308-
def combine(self, values):
309-
import operator
310-
return reduce(operator.or_, values, frozenset())
311-
312-
def map_resolved_function(self, expr):
313-
return frozenset([self.kernel.scoped_functions[
314-
expr.name]])
315-
316-
def map_constant(self, expr):
317-
return frozenset()
318-
319-
map_variable = map_constant
320-
map_function_symbol = map_constant
321-
map_tagged_variable = map_constant
322-
map_type_cast = map_constant
323-
324-
325300
@dataclass(frozen=True)
326301
class PreambleInfo:
327302
"""
@@ -341,8 +316,12 @@ class PreambleInfo:
341316

342317
# {{{ main code generation entrypoint
343318

344-
def generate_code_for_a_single_kernel(kernel, callables_table, target,
345-
is_entrypoint):
319+
def generate_code_for_a_single_kernel(
320+
kernel: LoopKernel,
321+
callables_table: CallablesTable,
322+
target: TargetBase,
323+
is_entrypoint: bool,
324+
) -> CodeGenerationResult:
346325
"""
347326
:returns: a :class:`CodeGenerationResult`
348327
@@ -359,8 +338,8 @@ def generate_code_for_a_single_kernel(kernel, callables_table, target,
359338
# {{{ examine arg list
360339

361340
allow_complex = False
362-
for var in kernel.args + list(kernel.temporary_variables.values()):
363-
if var.dtype.involves_complex():
341+
for var in [*kernel.args, *kernel.temporary_variables.values()]:
342+
if not_none(var.dtype).involves_complex():
364343
allow_complex = True
365344

366345
# }}}
@@ -376,7 +355,7 @@ def generate_code_for_a_single_kernel(kernel, callables_table, target,
376355
codegen_state = CodeGenerationState(
377356
kernel=kernel,
378357
target=target,
379-
implemented_domain=initial_implemented_domain,
358+
implemented_domain=isl.Set.from_basic_set(initial_implemented_domain),
380359
implemented_predicates=frozenset(),
381360
seen_dtypes=seen_dtypes,
382361
seen_functions=seen_functions,
@@ -389,7 +368,7 @@ def generate_code_for_a_single_kernel(kernel, callables_table, target,
389368
target.host_program_name_prefix
390369
+ kernel.name
391370
+ kernel.target.host_program_name_suffix),
392-
schedule_index_end=len(kernel.linearization),
371+
schedule_index_end=len(not_none(kernel.linearization)),
393372
callables_table=callables_table,
394373
is_entrypoint=is_entrypoint,
395374
codegen_cache_manager=CodegenOperationCacheManager.from_kernel(kernel),
@@ -418,7 +397,8 @@ def generate_code_for_a_single_kernel(kernel, callables_table, target,
418397
if kernel.all_inames():
419398
seen_dtypes.add(kernel.index_dtype)
420399

421-
preambles = kernel.preambles + codegen_result.device_preambles
400+
preambles = [
401+
*kernel.preambles, *codegen_result.device_preambles]
422402

423403
preamble_info = PreambleInfo(
424404
kernel=kernel,
@@ -429,10 +409,10 @@ def generate_code_for_a_single_kernel(kernel, callables_table, target,
429409
codegen_state=codegen_state
430410
)
431411

432-
preamble_generators = (list(kernel.preamble_generators)
433-
+ list(target.get_device_ast_builder().preamble_generators()))
434-
for prea_gen in preamble_generators:
435-
preambles = preambles + tuple(prea_gen(preamble_info))
412+
for prea_gen in [
413+
*kernel.preamble_generators,
414+
*target.get_device_ast_builder().preamble_generators()]:
415+
preambles.extend(prea_gen(preamble_info))
436416

437417
codegen_result = codegen_result.copy(device_preambles=preambles)
438418

0 commit comments

Comments
 (0)