Skip to content

Commit 8b160ad

Browse files
committed
Sharpen types, including precise islpy
Bump islpy dependency
1 parent b448071 commit 8b160ad

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+50810
-108199
lines changed

.basedpyright/baseline.json

Lines changed: 48427 additions & 107117 deletions
Large diffs are not rendered by default.

loopy/auto_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
"""
2525

2626
from dataclasses import dataclass
27-
from typing import TYPE_CHECKING
27+
from typing import TYPE_CHECKING, cast
2828
from warnings import warn
2929

3030
import numpy as np
@@ -37,6 +37,8 @@
3737
if TYPE_CHECKING:
3838
import pyopencl.array as cla
3939

40+
from loopy.types import NumpyType
41+
4042

4143
AUTO_TEST_SKIP_RUN = False
4244

@@ -142,7 +144,7 @@ def make_ref_args(kernel, queue, parameters):
142144
"testing" % arg.name)
143145

144146
shape = evaluate_shape(arg.shape, parameters)
145-
dtype = arg.dtype
147+
dtype = cast("NumpyType", arg.dtype).dtype
146148

147149
is_output = arg.is_output
148150

loopy/check.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@
2525

2626
import logging
2727
from collections import defaultdict
28+
from dataclasses import dataclass
2829
from functools import reduce
29-
from typing import TYPE_CHECKING
30+
from typing import TYPE_CHECKING, cast
3031

3132
import numpy as np
33+
from typing_extensions import override
3234

3335
import islpy as isl
3436
from islpy import dim_type
@@ -76,6 +78,7 @@
7678
if TYPE_CHECKING:
7779
from collections.abc import Mapping, Sequence
7880

81+
import pymbolic.primitives as p
7982
from pymbolic.typing import Expression
8083

8184
from loopy.kernel import LoopKernel
@@ -707,10 +710,12 @@ def subst_func(x):
707710
# }}}
708711

709712

710-
class _AccessCheckMapper(WalkMapper):
711-
def __init__(self, kernel, callables_table):
712-
self.kernel = kernel
713-
self.callables_table = callables_table
713+
@dataclass
714+
class _AccessCheckMapper(WalkMapper[[isl.Set, str]]):
715+
kernel: LoopKernel
716+
callables_table: CallablesTable
717+
718+
def __post_init__(self) -> None:
714719
super().__init__()
715720

716721
@memoize_method
@@ -719,14 +724,20 @@ def _make_slab(self, space, iname, start, stop):
719724
return make_slab(space, iname, start, stop)
720725

721726
@memoize_method
722-
def _get_access_range(self, domain, subscript):
723-
from loopy.symbolic import UnableToDetermineAccessRangeError, get_access_map
727+
def _get_access_range(
728+
self,
729+
domain: isl.Set,
730+
subscript: tuple[Expression, ...]
731+
):
732+
from loopy.diagnostic import UnableToDetermineAccessRangeError
733+
from loopy.symbolic import get_access_map
724734
try:
725735
return get_access_map(domain, subscript).range()
726736
except UnableToDetermineAccessRangeError:
727737
return None
728738

729-
def map_subscript(self, expr, domain, insn_id):
739+
@override
740+
def map_subscript(self, expr: p.Subscript, domain: isl.Set, insn_id: str):
730741
WalkMapper.map_subscript(self, expr, domain, insn_id)
731742

732743
from pymbolic.primitives import Variable
@@ -742,6 +753,7 @@ def map_subscript(self, expr, domain, insn_id):
742753
shape = tv.shape
743754

744755
if shape is not None:
756+
assert isinstance(shape, tuple)
745757
subscript = expr.index
746758

747759
if not isinstance(subscript, tuple):
@@ -787,7 +799,8 @@ def map_subscript(self, expr, domain, insn_id):
787799
" establish '%s' is a subset of '%s')."
788800
% (expr, insn_id, access_range, shape_domain))
789801

790-
def map_if(self, expr, domain, insn_id):
802+
@override
803+
def map_if(self, expr: p. If, domain: isl.Set, insn_id: str):
791804
from loopy.symbolic import condition_to_set
792805
then_set = condition_to_set(domain.space, expr.condition)
793806
if then_set is None:
@@ -800,7 +813,8 @@ def map_if(self, expr, domain, insn_id):
800813
self.rec(expr.then, domain & then_set, insn_id)
801814
self.rec(expr.else_, domain & else_set, insn_id)
802815

803-
def map_call(self, expr, domain, insn_id):
816+
@override
817+
def map_call(self, expr: p.Call, domain: isl.Set, insn_id: str):
804818
# perform access checks on the call arguments
805819
super().map_call(expr, domain, insn_id)
806820

@@ -817,7 +831,9 @@ def map_call(self, expr, domain, insn_id):
817831
and isinstance(self.callables_table[expr.function.name],
818832
CallableKernel)):
819833

820-
subkernel = self.callables_table[expr.function.name].subkernel
834+
subkernel = cast(
835+
"CallableKernel",
836+
self.callables_table[expr.function.name]).subkernel
821837

822838
# The plan here is to add the constraints coming from the values
823839
# args passed at a call-site as assumptions to the callee. To avoid
@@ -835,8 +851,8 @@ def map_call(self, expr, domain, insn_id):
835851

836852
kw_space = isl.Space.create_from_names(
837853
subkernel.isl_context, set=[],
838-
params=(get_dependencies(tuple(kwargs.values()))
839-
| set(kwargs.keys())))
854+
params=[*get_dependencies(tuple(kwargs.values())),
855+
*kwargs.keys()])
840856

841857
extra_assumptions = isl.BasicSet.universe(kw_space).params()
842858

@@ -894,8 +910,8 @@ def _check_bounds_inner(kernel: LoopKernel, callables_table: CallablesTable) ->
894910
domain, assumptions = isl.align_two(domain, kernel.assumptions)
895911
domain_with_assumptions = domain & assumptions
896912

897-
def run_acm(expr):
898-
acm(expr, domain_with_assumptions, insn.id) # noqa: B023
913+
def run_acm(expr: Expression):
914+
acm(expr, domain_with_assumptions, not_none(insn.id)) # noqa: B023
899915
return expr
900916

901917
insn.with_transformed_expressions(run_acm)
@@ -1659,7 +1675,10 @@ def _get_sub_array_ref_swept_range(
16591675
from loopy.symbolic import get_access_map
16601676
domain = kernel.get_inames_domain(frozenset({iname_var.name
16611677
for iname_var in sar.swept_inames}))
1662-
return get_access_map(domain, sar.swept_inames, kernel.assumptions).range()
1678+
return get_access_map(
1679+
domain.to_set(),
1680+
sar.swept_inames,
1681+
kernel.assumptions.to_set()).range()
16631682

16641683

16651684
def _are_sub_array_refs_equivalent(
@@ -1875,7 +1894,7 @@ def pre_codegen_checks(t_unit: TranslationUnit) -> None:
18751894

18761895
def check_implemented_domains(
18771896
kernel: LoopKernel,
1878-
implemented_domains: Mapping[str, isl.Set],
1897+
implemented_domains: Mapping[str, Sequence[isl.Set]],
18791898
code: str | None = None,
18801899
) -> bool:
18811900
from islpy import align_two, dim_type
@@ -1942,7 +1961,7 @@ def check_implemented_domains(
19421961
d_minus_i = desired_domain - insn_impl_domain
19431962

19441963
parameter_inames = {
1945-
insn_domain.get_dim_name(dim_type.param, i)
1964+
not_none(insn_domain.get_dim_name(dim_type.param, i))
19461965
for i in range(insn_impl_domain.dim(dim_type.param))}
19471966

19481967
lines = []

loopy/codegen/__init__.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434

3535
import constantdict
3636

37+
from loopy.types import NumpyType
38+
from loopy.typing import not_none
39+
3740

3841
logger = logging.getLogger(__name__)
3942

@@ -53,14 +56,15 @@
5356

5457

5558
if TYPE_CHECKING:
59+
from pymbolic import Expression
60+
5661
from loopy.codegen.result import CodeGenerationResult, GeneratedProgram
5762
from loopy.codegen.tools import CodegenOperationCacheManager
5863
from loopy.kernel import LoopKernel
5964
from loopy.library.reduction import ReductionOpFunction
6065
from loopy.target import TargetBase
6166
from loopy.translation_unit import CallablesTable, TranslationUnit
6267
from loopy.types import LoopyType
63-
from loopy.typing import Expression
6468

6569

6670
__doc__ = """
@@ -206,7 +210,7 @@ def intersect(self, other):
206210
new_impl, new_other = isl.align_two(self.implemented_domain, other)
207211
return self.copy(implemented_domain=new_impl & new_other)
208212

209-
def fix(self, iname, aff):
213+
def fix(self, iname: str, aff: isl.Aff) -> CodeGenerationState:
210214
new_impl_domain = self.implemented_domain
211215

212216
impl_space = self.implemented_domain.get_space()
@@ -341,8 +345,12 @@ class PreambleInfo:
341345

342346
# {{{ main code generation entrypoint
343347

344-
def generate_code_for_a_single_kernel(kernel, callables_table, target,
345-
is_entrypoint):
348+
def generate_code_for_a_single_kernel(
349+
kernel: LoopKernel,
350+
callables_table: CallablesTable,
351+
target: TargetBase,
352+
is_entrypoint: bool,
353+
) -> CodeGenerationResult:
346354
"""
347355
:returns: a :class:`CodeGenerationResult`
348356
@@ -359,7 +367,8 @@ def generate_code_for_a_single_kernel(kernel, callables_table, target,
359367
# {{{ examine arg list
360368

361369
allow_complex = False
362-
for var in kernel.args + list(kernel.temporary_variables.values()):
370+
for var in [*kernel.args, *kernel.temporary_variables.values()]:
371+
assert isinstance(var.dtype, NumpyType)
363372
if var.dtype.involves_complex():
364373
allow_complex = True
365374

@@ -376,7 +385,7 @@ def generate_code_for_a_single_kernel(kernel, callables_table, target,
376385
codegen_state = CodeGenerationState(
377386
kernel=kernel,
378387
target=target,
379-
implemented_domain=initial_implemented_domain,
388+
implemented_domain=isl.Set.from_basic_set(initial_implemented_domain),
380389
implemented_predicates=frozenset(),
381390
seen_dtypes=seen_dtypes,
382391
seen_functions=seen_functions,
@@ -389,7 +398,7 @@ def generate_code_for_a_single_kernel(kernel, callables_table, target,
389398
target.host_program_name_prefix
390399
+ kernel.name
391400
+ kernel.target.host_program_name_suffix),
392-
schedule_index_end=len(kernel.linearization),
401+
schedule_index_end=len(not_none(kernel.linearization)),
393402
callables_table=callables_table,
394403
is_entrypoint=is_entrypoint,
395404
codegen_cache_manager=CodegenOperationCacheManager.from_kernel(kernel),
@@ -418,7 +427,8 @@ def generate_code_for_a_single_kernel(kernel, callables_table, target,
418427
if kernel.all_inames():
419428
seen_dtypes.add(kernel.index_dtype)
420429

421-
preambles = kernel.preambles + codegen_result.device_preambles
430+
preambles = [
431+
*kernel.preambles, *codegen_result.device_preambles]
422432

423433
preamble_info = PreambleInfo(
424434
kernel=kernel,
@@ -429,10 +439,10 @@ def generate_code_for_a_single_kernel(kernel, callables_table, target,
429439
codegen_state=codegen_state
430440
)
431441

432-
preamble_generators = (list(kernel.preamble_generators)
433-
+ list(target.get_device_ast_builder().preamble_generators()))
434-
for prea_gen in preamble_generators:
435-
preambles = preambles + tuple(prea_gen(preamble_info))
442+
for prea_gen in [
443+
*kernel.preamble_generators,
444+
*target.get_device_ast_builder().preamble_generators()]:
445+
preambles.extend(prea_gen(preamble_info))
436446

437447
codegen_result = codegen_result.copy(device_preambles=preambles)
438448

loopy/codegen/control.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@
2525
"""
2626

2727
from functools import partial
28+
from typing import TYPE_CHECKING
2829

2930
import islpy as isl
3031

31-
from loopy.codegen.result import merge_codegen_results, wrap_in_if
32+
from loopy.codegen.result import CodeGenerationResult, merge_codegen_results, wrap_in_if
3233
from loopy.diagnostic import LoopyError
3334
from loopy.schedule import (
3435
Barrier,
@@ -41,8 +42,17 @@
4142
)
4243

4344

44-
def generate_code_for_sched_index(codegen_state, sched_index):
45+
if TYPE_CHECKING:
46+
from loopy.codegen import CodeGenerationState
47+
48+
49+
def generate_code_for_sched_index(
50+
codegen_state: CodeGenerationState,
51+
sched_index: int
52+
) -> CodeGenerationResult:
4553
kernel = codegen_state.kernel
54+
assert kernel.linearization is not None
55+
4656
sched_item = kernel.linearization[sched_index]
4757

4858
if isinstance(sched_item, CallKernel):

0 commit comments

Comments
 (0)