Skip to content

Commit bf7c1b9

Browse files
committed
Update TVM-FFI to v0.1.7-rc0
1 parent ec0fed0 commit bf7c1b9

File tree

158 files changed

+5885
-101
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

158 files changed

+5885
-101
lines changed

3rdparty/tvm-ffi

Submodule tvm-ffi updated 147 files

python/tvm/arith/_ffi_api.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,71 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""FFI APIs for tvm.arith"""
18-
import tvm_ffi
18+
# tvm-ffi-stubgen(begin): import-section
19+
# fmt: off
20+
# isort: off
21+
from __future__ import annotations
22+
from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC
23+
from typing import TYPE_CHECKING
24+
if TYPE_CHECKING:
25+
from arith import ConstIntBound, IntConstraints, IntConstraintsTransform, IntGroupBounds, IntervalSet, IterMark, IterSplitExpr, IterSumExpr, ModularSet
26+
from collections.abc import Mapping, Sequence
27+
from ir import IntImm, IntSet, PrimExpr, Range
28+
from tir import Buffer, PrimFunc, Stmt, Var
29+
from tvm_ffi import Object
30+
from typing import Any
31+
# isort: on
32+
# fmt: on
33+
# tvm-ffi-stubgen(end)
1934

2035

21-
tvm_ffi.init_ffi_api("arith", __name__)
36+
37+
# tvm-ffi-stubgen(begin): global/arith
38+
# fmt: off
39+
_FFI_INIT_FUNC("arith", __name__)
40+
if TYPE_CHECKING:
41+
def ConstIntBound(_0: int, _1: int, /) -> ConstIntBound: ...
42+
def CreateAnalyzer(*args: Any) -> Any: ...
43+
def DeduceBound(_0: PrimExpr, _1: PrimExpr, _2: Mapping[Var, IntSet], _3: Mapping[Var, IntSet], /) -> IntSet: ...
44+
def DetectClipBound(_0: PrimExpr, _1: Sequence[Var], /) -> Sequence[PrimExpr]: ...
45+
def DetectCommonSubExpr(_0: PrimExpr, _1: int, /) -> Mapping[PrimExpr, IntImm]: ...
46+
def DetectIterMap(_0: Sequence[PrimExpr], _1: Mapping[Var, Range], _2: PrimExpr, _3: int, _4: bool, /) -> Object: ...
47+
def DetectLinearEquation(_0: PrimExpr, _1: Sequence[Var], /) -> Sequence[PrimExpr]: ...
48+
def DomainTouched(_0: Stmt, _1: Buffer, _2: bool, _3: bool, /) -> Sequence[Range]: ...
49+
def DomainTouchedAccessMap(_0: PrimFunc, /) -> Mapping[Buffer, Sequence[Object]]: ...
50+
def EstimateRegionLowerBound(_0: Sequence[Range], _1: Mapping[Var, Range], _2: PrimExpr, /) -> Sequence[IntSet] | None: ...
51+
def EstimateRegionStrictBound(_0: Sequence[Range], _1: Mapping[Var, Range], _2: PrimExpr, /) -> Sequence[IntSet] | None: ...
52+
def EstimateRegionUpperBound(_0: Sequence[Range], _1: Mapping[Var, Range], _2: PrimExpr, /) -> Sequence[IntSet] | None: ...
53+
def IntConstraints(_0: Sequence[Var], _1: Mapping[Var, Range], _2: Sequence[PrimExpr], /) -> IntConstraints: ...
54+
def IntConstraintsTransform(_0: IntConstraints, _1: IntConstraints, _2: Mapping[Var, PrimExpr], _3: Mapping[Var, PrimExpr], /) -> IntConstraintsTransform: ...
55+
def IntGroupBounds(_0: PrimExpr, _1: Sequence[PrimExpr], _2: Sequence[PrimExpr], _3: Sequence[PrimExpr], /) -> IntGroupBounds: ...
56+
def IntGroupBounds_FindBestRange(*args: Any) -> Any: ...
57+
def IntGroupBounds_from_range(_0: Range, /) -> IntGroupBounds: ...
58+
def IntSetIsEverything(_0: IntSet, /) -> bool: ...
59+
def IntSetIsNothing(_0: IntSet, /) -> bool: ...
60+
def IntervalSet(_0: PrimExpr, _1: PrimExpr, /) -> IntervalSet: ...
61+
def IntervalSetGetMax(_0: IntSet, /) -> PrimExpr: ...
62+
def IntervalSetGetMin(_0: IntSet, /) -> PrimExpr: ...
63+
def InverseAffineIterMap(_0: Sequence[IterSumExpr], _1: Sequence[PrimExpr], /) -> Mapping[Var, PrimExpr]: ...
64+
def IterMapSimplify(_0: Sequence[PrimExpr], _1: Mapping[Var, Range], _2: PrimExpr, _3: int, _4: bool, /) -> Sequence[PrimExpr]: ...
65+
def IterMark(_0: PrimExpr, _1: PrimExpr, /) -> IterMark: ...
66+
def IterSplitExpr(_0: IterMark, _1: PrimExpr, _2: PrimExpr, _3: PrimExpr, /) -> IterSplitExpr: ...
67+
def IterSumExpr(_0: Sequence[IterSplitExpr], _1: PrimExpr, /) -> IterSumExpr: ...
68+
def ModularSet(_0: int, _1: int, /) -> ModularSet: ...
69+
def NarrowPredicateExpression(_0: PrimExpr, _1: Mapping[Var, Range], /) -> PrimExpr: ...
70+
def NegInf() -> PrimExpr: ...
71+
def NormalizeIterMapToExpr(_0: PrimExpr, /) -> PrimExpr: ...
72+
def NormalizeToIterSum(_0: PrimExpr, _1: Mapping[Var, Range], /) -> IterSumExpr: ...
73+
def PosInf() -> PrimExpr: ...
74+
def PresburgerSet(_0: PrimExpr, /) -> IntSet: ...
75+
def SolveInequalitiesAsCondition(*args: Any) -> Any: ...
76+
def SolveInequalitiesDeskewRange(*args: Any) -> Any: ...
77+
def SolveInequalitiesToRange(*args: Any) -> Any: ...
78+
def SolveLinearEquations(*args: Any) -> Any: ...
79+
def SubspaceDivide(_0: Sequence[PrimExpr], _1: Mapping[Var, Range], _2: Sequence[Var], _3: PrimExpr, _4: int, _5: bool, /) -> Sequence[Sequence[IterMark]]: ...
80+
def UnionLowerBound(_0: Sequence[IntSet], /) -> IntSet: ...
81+
def intset_interval(_0: PrimExpr, _1: PrimExpr, /) -> IntSet: ...
82+
def intset_single_point(_0: PrimExpr, /) -> IntSet: ...
83+
def intset_vector(_0: PrimExpr, /) -> IntSet: ...
84+
# fmt: on
85+
# tvm-ffi-stubgen(end)

python/tvm/arith/analyzer.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
# under the License.
1717
# pylint: disable=invalid-name
1818
"""Arithmetic data structure and utility"""
19+
# tvm-ffi-stubgen(begin): import-section
20+
# tvm-ffi-stubgen(end)
1921
import enum
2022
from typing import Union
2123

@@ -51,6 +53,13 @@ class Extension(enum.Flag):
5153
class ModularSet(Object):
5254
"""Represent range of (coeff * x + base) for x in Z"""
5355

56+
# tvm-ffi-stubgen(begin): object/arith.ModularSet
57+
# fmt: off
58+
coeff: int
59+
base: int
60+
# fmt: on
61+
# tvm-ffi-stubgen(end)
62+
5463
def __init__(self, coeff, base):
5564
self.__init_handle_by_constructor__(_ffi_api.ModularSet, coeff, base)
5665

@@ -68,6 +77,13 @@ class ConstIntBound(Object):
6877
The maximum value of the bound.
6978
"""
7079

80+
# tvm-ffi-stubgen(begin): object/arith.ConstIntBound
81+
# fmt: off
82+
min_value: int
83+
max_value: int
84+
# fmt: on
85+
# tvm-ffi-stubgen(end)
86+
7187
POS_INF = (1 << 63) - 1
7288
NEG_INF = -POS_INF
7389

python/tvm/arith/int_set.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,16 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""Integer set."""
18+
# tvm-ffi-stubgen(begin): import-section
19+
# fmt: off
20+
# isort: off
21+
from __future__ import annotations
22+
from typing import TYPE_CHECKING
23+
if TYPE_CHECKING:
24+
from ir import PrimExpr
25+
# isort: on
26+
# fmt: on
27+
# tvm-ffi-stubgen(end)
1828
import tvm_ffi
1929
from tvm.runtime import Object
2030
from . import _ffi_api
@@ -24,6 +34,11 @@
2434
class IntSet(Object):
2535
"""Represent a set of integer in one dimension."""
2636

37+
# tvm-ffi-stubgen(begin): object/ir.IntSet
38+
# fmt: off
39+
# fmt: on
40+
# tvm-ffi-stubgen(end)
41+
2742
def is_nothing(self):
2843
"""Whether the set represent nothing"""
2944
return _ffi_api.IntSetIsNothing(self)
@@ -78,6 +93,13 @@ class IntervalSet(IntSet):
7893
The maximum value in the interval.
7994
"""
8095

96+
# tvm-ffi-stubgen(begin): object/arith.IntervalSet
97+
# fmt: off
98+
min_value: PrimExpr
99+
max_value: PrimExpr
100+
# fmt: on
101+
# tvm-ffi-stubgen(end)
102+
81103
def __init__(self, min_value, max_value):
82104
self.__init_handle_by_constructor__(_ffi_api.IntervalSet, min_value, max_value)
83105

@@ -86,6 +108,11 @@ def __init__(self, min_value, max_value):
86108
class PresburgerSet(IntSet):
87109
"""Represent of Presburger Set"""
88110

111+
# tvm-ffi-stubgen(begin): object/arith.PresburgerSet
112+
# fmt: off
113+
# fmt: on
114+
# tvm-ffi-stubgen(end)
115+
89116
def __init__(self):
90117
self.__init_handle_by_constructor__(_ffi_api.PresburgerSet)
91118

python/tvm/arith/int_solver.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,18 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""integer constraints data structures and solvers"""
18+
# tvm-ffi-stubgen(begin): import-section
19+
# fmt: off
20+
# isort: off
21+
from __future__ import annotations
22+
from typing import TYPE_CHECKING
23+
if TYPE_CHECKING:
24+
from collections.abc import Mapping, Sequence
25+
from ir import PrimExpr, Range
26+
from tir import Var
27+
# isort: on
28+
# fmt: on
29+
# tvm-ffi-stubgen(end)
1830
import tvm_ffi
1931
from tvm.runtime import Object
2032
from . import _ffi_api
@@ -40,6 +52,15 @@ class IntGroupBounds(Object):
4052
the upper bounds (include)
4153
"""
4254

55+
# tvm-ffi-stubgen(begin): object/arith.IntGroupBounds
56+
# fmt: off
57+
coef: PrimExpr
58+
lower: Sequence[PrimExpr]
59+
equal: Sequence[PrimExpr]
60+
upper: Sequence[PrimExpr]
61+
# fmt: on
62+
# tvm-ffi-stubgen(end)
63+
4364
def __init__(self, coef, lower, equal, upper):
4465
self.__init_handle_by_constructor__(_ffi_api.IntGroupBounds, coef, lower, equal, upper)
4566

@@ -81,6 +102,14 @@ class IntConstraints(Object):
81102
The relations between the variables (either equations or inequalities)
82103
"""
83104

105+
# tvm-ffi-stubgen(begin): object/arith.IntConstraints
106+
# fmt: off
107+
variables: Sequence[Var]
108+
ranges: Mapping[Var, Range]
109+
relations: Sequence[PrimExpr]
110+
# fmt: on
111+
# tvm-ffi-stubgen(end)
112+
84113
def __init__(self, variables, ranges, relations):
85114
self.__init_handle_by_constructor__(_ffi_api.IntConstraints, variables, ranges, relations)
86115

@@ -113,6 +142,15 @@ class IntConstraintsTransform(Object):
113142
e.g., {m -> a, n -> -b}
114143
"""
115144

145+
# tvm-ffi-stubgen(begin): object/arith.IntConstraintsTransform
146+
# fmt: off
147+
src: IntConstraints
148+
dst: IntConstraints
149+
src_to_dst: Mapping[Var, PrimExpr]
150+
dst_to_src: Mapping[Var, PrimExpr]
151+
# fmt: on
152+
# tvm-ffi-stubgen(end)
153+
116154
def __init__(self, src, dst, src_to_dst, dst_to_src):
117155
self.__init_handle_by_constructor__(
118156
_ffi_api.IntConstraintsTransform, src, dst, src_to_dst, dst_to_src

python/tvm/arith/iter_affine_map.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,17 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""Iterator (quasi)affine mapping patterns."""
18+
# tvm-ffi-stubgen(begin): import-section
19+
# fmt: off
20+
# isort: off
21+
from __future__ import annotations
22+
from typing import TYPE_CHECKING
23+
if TYPE_CHECKING:
24+
from collections.abc import Sequence
25+
from ir import PrimExpr
26+
# isort: on
27+
# fmt: on
28+
# tvm-ffi-stubgen(end)
1829
from enum import IntEnum
1930
import tvm_ffi
2031
from tvm.runtime import Object
@@ -26,6 +37,11 @@
2637
class IterMapExpr(PrimExpr):
2738
"""Base class of all IterMap expressions."""
2839

40+
# tvm-ffi-stubgen(begin): object/arith.IterMapExpr
41+
# fmt: off
42+
# fmt: on
43+
# tvm-ffi-stubgen(end)
44+
2945

3046
@tvm_ffi.register_object("arith.IterMark")
3147
class IterMark(Object):
@@ -40,6 +56,13 @@ class IterMark(Object):
4056
The extent of the iterator.
4157
"""
4258

59+
# tvm-ffi-stubgen(begin): object/arith.IterMark
60+
# fmt: off
61+
source: PrimExpr
62+
extent: PrimExpr
63+
# fmt: on
64+
# tvm-ffi-stubgen(end)
65+
4366
def __init__(self, source, extent):
4467
self.__init_handle_by_constructor__(_ffi_api.IterMark, source, extent)
4568

@@ -65,6 +88,15 @@ class IterSplitExpr(IterMapExpr):
6588
Additional scale to the split.
6689
"""
6790

91+
# tvm-ffi-stubgen(begin): object/arith.IterSplitExpr
92+
# fmt: off
93+
source: IterMark
94+
lower_factor: PrimExpr
95+
extent: PrimExpr
96+
scale: PrimExpr
97+
# fmt: on
98+
# tvm-ffi-stubgen(end)
99+
68100
def __init__(self, source, lower_factor, extent, scale):
69101
self.__init_handle_by_constructor__(
70102
_ffi_api.IterSplitExpr, source, lower_factor, extent, scale
@@ -86,6 +118,13 @@ class IterSumExpr(IterMapExpr):
86118
The base offset.
87119
"""
88120

121+
# tvm-ffi-stubgen(begin): object/arith.IterSumExpr
122+
# fmt: off
123+
args: Sequence[IterSplitExpr]
124+
base: PrimExpr
125+
# fmt: on
126+
# tvm-ffi-stubgen(end)
127+
89128
def __init__(self, args, base):
90129
self.__init_handle_by_constructor__(_ffi_api.IterSumExpr, args, base)
91130

@@ -94,6 +133,14 @@ def __init__(self, args, base):
94133
class IterMapResult(Object):
95134
"""Result of iter map detection."""
96135

136+
# tvm-ffi-stubgen(begin): object/arith.IterMapResult
137+
# fmt: off
138+
indices: Sequence[IterSumExpr]
139+
errors: Sequence[str]
140+
padding_predicate: PrimExpr
141+
# fmt: on
142+
# tvm-ffi-stubgen(end)
143+
97144

98145
class IterMapLevel(IntEnum):
99146
"""Possible kinds of iter mapping check level."""

python/tvm/contrib/cutlass/_ffi_api.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,17 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""FFI API for CUTLASS BYOC."""
18-
import tvm_ffi
18+
# tvm-ffi-stubgen(begin): import-section
19+
# fmt: off
20+
# isort: off
21+
from __future__ import annotations
22+
from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC
23+
from typing import TYPE_CHECKING
24+
# isort: on
25+
# fmt: on
26+
# tvm-ffi-stubgen(end)
1927

20-
tvm_ffi.init_ffi_api("contrib.cutlass", __name__)
28+
29+
# tvm-ffi-stubgen(begin): global/contrib.cutlass
30+
_FFI_INIT_FUNC("contrib.cutlass", __name__)
31+
# tvm-ffi-stubgen(end)

python/tvm/contrib/msc/core/_ffi_api.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,17 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""tvm.contrib.msc.core._ffi_api"""
18+
# tvm-ffi-stubgen(begin): import-section
19+
# fmt: off
20+
# isort: off
21+
from __future__ import annotations
22+
from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC
23+
from typing import TYPE_CHECKING
24+
# isort: on
25+
# fmt: on
26+
# tvm-ffi-stubgen(end)
1827

19-
import tvm_ffi
2028

21-
tvm_ffi.init_ffi_api("msc.core", __name__)
29+
# tvm-ffi-stubgen(begin): global/msc.core
30+
_FFI_INIT_FUNC("msc.core", __name__)
31+
# tvm-ffi-stubgen(end)

0 commit comments

Comments
 (0)