Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/ref_internals.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ the codegen pipeline user-provided types are converted to

.. automodule:: loopy.types

Type inference
^^^^^^^^^^^^^^

.. automodule:: loopy.type_inference

Codegen
-------

Expand Down
2 changes: 2 additions & 0 deletions loopy/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
AddressSpace,
ArrayArg,
ArrayDimImplementationTag,
AxisTag,
InameImplementationTag,
TemporaryVariable,
auto,
Expand Down Expand Up @@ -1426,6 +1427,7 @@ def _check_for_unused_hw_axes_in_kernel_chunk(
iname, AutoLocalInameTagBase, max_num=1)

if ltags:
tag: AxisTag
tag, = ltags
local_axes_used.add(tag.axis)
elif gtags:
Expand Down
13 changes: 8 additions & 5 deletions loopy/kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from sys import intern
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Callable,
ClassVar,
Expand All @@ -59,7 +60,7 @@
memoize_method,
natsorted,
)
from pytools.tag import Tag, Taggable
from pytools.tag import Tag, Taggable, TagT

import loopy.codegen
import loopy.kernel.data # to help out Sphinx
Expand Down Expand Up @@ -539,14 +540,16 @@ def _get_inames_domain_backend(self, inames):
def iname_tags(self, iname):
return self.inames[iname].tags

def iname_tags_of_type(self, iname, tag_type_or_types,
max_num=None, min_num=None):
def iname_tags_of_type(
self, iname: str,
tag_type_or_types: type[TagT] | tuple[type[TagT], ...],
max_num: int | None = None,
min_num: int | None = None
) -> AbstractSet[TagT]:
"""Return a subset of *tags* that matches type *tag_type*. Raises exception
if the number of tags found were greater than *max_num* or less than
*min_num*.

:arg tags: An iterable of tags.
:arg tag_type_or_types: a subclass of :class:`loopy.kernel.data.InameTag`.
:arg max_num: the maximum number of tags expected to be found.
:arg min_num: the minimum number of tags expected to be found.
"""
Expand Down
9 changes: 5 additions & 4 deletions loopy/kernel/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,10 +1217,11 @@ def _apply_offset(sub: Expression, ary: ArrayBase) -> Expression:


def get_access_info(kernel: LoopKernel,
ary: ArrayArg | TemporaryVariable,
index: Expression | tuple[Expression, ...],
eval_expr: Callable[[Expression], int],
vectorization_info: VectorizationInfo) -> AccessInfo:
ary: ArrayArg | TemporaryVariable,
index: Expression | tuple[Expression, ...],
eval_expr: Callable[[Expression], int],
vectorization_info: VectorizationInfo | None
) -> AccessInfo:
"""
:arg ary: an object of type :class:`ArrayBase`
:arg index: a tuple of indices representing a subscript into ary
Expand Down
20 changes: 16 additions & 4 deletions loopy/kernel/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
import numpy as np

from pytools import ImmutableRecord
from pytools.tag import Tag, Taggable, UniqueTag as UniqueTagBase
from pytools.tag import Tag, Taggable, TagT, UniqueTag as UniqueTagBase

from loopy.diagnostic import LoopyError
from loopy.kernel.array import ArrayBase, ArrayDimImplementationTag
Expand All @@ -64,7 +64,7 @@


if TYPE_CHECKING:
from collections.abc import Mapping
from collections.abc import Iterable, Mapping

from pymbolic import ArithmeticExpression, Variable

Expand Down Expand Up @@ -98,6 +98,10 @@
.. class:: ToLoopyTypeConvertible

See :class:`loopy.ToLoopyTypeConvertible`.

.. class:: TagT

A type variable with a lower bound of :class:`pytools.tag.Tag`.
"""

# This docstring is included in ref_internals. Do not include parts of the public
Expand Down Expand Up @@ -143,7 +147,12 @@ def _names_from_dim_tags(

# {{{ iname tags

def filter_iname_tags_by_type(tags, tag_type, max_num=None, min_num=None):
def filter_iname_tags_by_type(
tags: Iterable[Tag],
tag_type: type[TagT] | tuple[type[TagT], ...],
max_num: int | None = None,
min_num: int | None = None,
) -> set[TagT]:
"""Return a subset of *tags* that matches type *tag_type*. Raises exception
if the number of tags found were greater than *max_num* or less than
*min_num*.
Expand All @@ -154,7 +163,9 @@ def filter_iname_tags_by_type(tags, tag_type, max_num=None, min_num=None):
:arg min_num: the minimum number of tags expected to be found.
"""

result = {tag for tag in tags if isinstance(tag, tag_type)}
result: set[TagT] = cast(
"set[TagT]",
{tag for tag in tags if isinstance(tag, tag_type)})

def strify_tag_type():
if isinstance(tag_type, tuple):
Expand All @@ -170,6 +181,7 @@ def strify_tag_type():
if len(result) < min_num:
raise LoopyError("must have more than {} tags "
"of type(s): {}".format(max_num, strify_tag_type()))

return result


Expand Down
9 changes: 7 additions & 2 deletions loopy/target/c/codegen/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,13 @@
from loopy.diagnostic import LoopyError
from loopy.expression import dtype_to_type_context
from loopy.target.c import CExpression
from loopy.type_inference import TypeReader
from loopy.type_inference import TypeInferenceMapper, TypeReader
from loopy.types import LoopyType
from loopy.typing import Expression, is_integer


if TYPE_CHECKING:
from loopy.codegen import CodeGenerationState
from loopy.symbolic import TypeCast


Expand All @@ -79,7 +80,11 @@ class ExpressionToCExpressionMapper(IdentityMapper):
expected type for untyped expressions such as python scalars. The
type of the expressions takes precedence over *type_context*.
"""
def __init__(self, codegen_state, fortran_abi=False, type_inf_mapper=None):
def __init__(self,
codegen_state: CodeGenerationState,
fortran_abi: bool = False,
type_inf_mapper: TypeInferenceMapper | None = None
) -> None:
self.kernel = codegen_state.kernel
self.codegen_state = codegen_state

Expand Down
Loading
Loading