Skip to content

Commit 2287580

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[jaxlib] Bundle _{tpu,triton,mosaic_gpu}_ext.pyi with jaxlib
PiperOrigin-RevId: 874094949
1 parent db34816 commit 2287580

File tree

11 files changed

+348
-97
lines changed

11 files changed

+348
-97
lines changed

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3587,16 +3587,18 @@ def tree_unflatten(cls, aux, flat_registers):
35873587

35883588
@runtime_checkable
35893589
class TransferPlan(Protocol):
3590-
tile_index_transforms: tuple[IndexTransform, ...]
3590+
@property
3591+
def tile_index_transforms(self) -> tuple[IndexTransform, ...]:
3592+
raise NotImplementedError
35913593

3592-
def select(self, group_elems: Sequence[ir.Value]) -> ir.Value:
3594+
def select(self, group_elems: Sequence[ir.Value], /) -> ir.Value:
35933595
"""Selects the value corresponding to the group of the current thread.
35943596
35953597
The argument must be of the same length as tile_index_transforms.
35963598
"""
35973599
raise NotImplementedError
35983600

3599-
def select_if_group(self, group_idx: int, old: ir.Value, new: ir.Value) -> ir.Value:
3601+
def select_if_group(self, group_idx: int, old: ir.Value, new: ir.Value, /) -> ir.Value:
36003602
"""Returns `new` if the current thread belongs to the given group and `old` otherwise.
36013603
36023604
group_idx must be between 0 and len(tile_index_transforms) - 1.

jaxlib/mlir/_mlir_libs/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ nanobind_pywrap_extension(
194194
name = "_mosaic_gpu_ext",
195195
srcs = ["mosaic_gpu_ext.cc"],
196196
copts = COPTS,
197+
pytype_srcs = ["_mosaic_gpu_ext.pyi"],
197198
deps = [
198199
"//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi",
199200
"//jaxlib/mosaic/gpu:tiled_layout",
@@ -216,6 +217,7 @@ nanobind_pywrap_extension(
216217
name = "_tpu_ext",
217218
srcs = ["tpu_ext.cc"],
218219
copts = COPTS,
220+
pytype_srcs = ["_tpu_ext.pyi"],
219221
deps = [
220222
"//jaxlib/mosaic:tpu_dialect_capi",
221223
"@com_google_absl//absl/log:check",
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from collections.abc import Iterable, Sequence
16+
import enum
17+
from mlir import ir
18+
19+
def register_dialect(context: ir.Context, load: bool = ...) -> None: ...
20+
def register_inliner_extensions(arg: ir.Context, /) -> None: ...
21+
22+
class TileTransformAttr(ir.Attribute):
23+
@staticmethod
24+
def isinstance(other_attribute: ir.Attribute) -> bool: ...
25+
def __repr__(self) -> str: ...
26+
@staticmethod
27+
def get(
28+
tiling: Sequence[int], context: ir.Context | None = None
29+
) -> TileTransformAttr:
30+
"""Creates a TileTransformAttr with the given tiling."""
31+
32+
@property
33+
def tiling(self) -> list[int]: ...
34+
35+
class TransposeTransformAttr(ir.Attribute):
36+
@staticmethod
37+
def isinstance(other_attribute: ir.Attribute) -> bool: ...
38+
def __repr__(self) -> str: ...
39+
@staticmethod
40+
def get(
41+
permutation: Sequence[int], context: ir.Context | None = None
42+
) -> TransposeTransformAttr:
43+
"""Creates a TransposeTransformAttr with the given permutation."""
44+
45+
@property
46+
def permutation(self) -> list[int]: ...
47+
48+
class SwizzleTransformAttr(ir.Attribute):
49+
@staticmethod
50+
def isinstance(other_attribute: ir.Attribute) -> bool: ...
51+
def __repr__(self) -> str: ...
52+
@staticmethod
53+
def get(
54+
swizzle: int, context: ir.Context | None = None
55+
) -> SwizzleTransformAttr:
56+
"""Creates a SwizzleTransformAttr with the given swizzle."""
57+
58+
@property
59+
def swizzle(self) -> int: ...
60+
61+
def init_cc_mlir(arg: object, /) -> bool: ...
62+
63+
class Tiling:
64+
def __init__(self, tiles: Iterable) -> None: ...
65+
def tile_shape(self, shape: Sequence[int]) -> tuple: ...
66+
def untile_shape(self, shape: Sequence[int]) -> tuple: ...
67+
def tile_strides(self, strides: Sequence[int]) -> tuple: ...
68+
def tile_indices(self, indices: Sequence[int]) -> tuple: ...
69+
def untile_indices(self, indices: Sequence[int]) -> tuple: ...
70+
def tile_nested_shape_strides(
71+
self, shape: Sequence[Sequence[int]], strides: Sequence[Sequence[int]]
72+
) -> tuple: ...
73+
def tile_dimension(self, dim: int) -> tuple: ...
74+
def remove_dimension(self, dim: int) -> Tiling: ...
75+
def canonicalize(self) -> Tiling: ...
76+
@property
77+
def tiles(self) -> tuple: ...
78+
def __str__(self) -> str: ...
79+
def __repr__(self) -> str: ...
80+
def __eq__(self, other: object) -> bool: ...
81+
def __hash__(self) -> int: ...
82+
83+
class Replicated:
84+
def __init__(self, times: int) -> None: ...
85+
@property
86+
def times(self) -> int: ...
87+
@times.setter
88+
def times(self, arg: int, /) -> None: ...
89+
def __repr__(self) -> str: ...
90+
def __hash__(self) -> int: ...
91+
def __eq__(self, arg: object, /) -> bool: ...
92+
93+
class TiledLayout:
94+
def __init__(
95+
self,
96+
tiling: Tiling,
97+
warp_dims: Iterable,
98+
lane_dims: Iterable,
99+
vector_dim: int,
100+
_check_canonical: bool = ...,
101+
) -> None: ...
102+
@property
103+
def warp_dims(self) -> tuple: ...
104+
@property
105+
def lane_dims(self) -> tuple: ...
106+
@property
107+
def partitioned_warp_dims(self) -> tuple: ...
108+
@property
109+
def partitioned_lane_dims(self) -> tuple: ...
110+
@property
111+
def vector_length(self) -> int: ...
112+
@property
113+
def vector_dim(self) -> int: ...
114+
@property
115+
def tiling(self) -> Tiling: ...
116+
@property
117+
def tiled_tiling_shape(self) -> tuple: ...
118+
@property
119+
def tiled_tiling_rank(self) -> int: ...
120+
def warp_indices(self) -> tuple: ...
121+
def lane_indices(self) -> tuple: ...
122+
def canonicalize(self) -> TiledLayout: ...
123+
def registers_shape(self, shape: Sequence[int]) -> tuple: ...
124+
def registers_element_type(self, t: ir.Type) -> object: ...
125+
def shape_from_registers_shape(self, shape: Sequence[int]) -> tuple: ...
126+
@property
127+
def base_tile_shape(self) -> tuple: ...
128+
def remove_dimension(self, dim: int) -> TiledLayout: ...
129+
def reduce(self, axes: Iterable) -> TiledLayout: ...
130+
def thread_idxs(self, arg: Sequence[int], /) -> list: ...
131+
def __str__(self) -> str: ...
132+
def __repr__(self) -> str: ...
133+
def __hash__(self) -> int: ...
134+
def __eq__(self, other: object | None) -> bool: ...
135+
136+
class Rounding(enum.Enum):
137+
UP = 0
138+
139+
DOWN = 1
140+
141+
class TileTransform:
142+
def __init__(
143+
self, tiling: Sequence[int], rounding: Rounding | None = ...
144+
) -> None: ...
145+
def apply(self, arg: object, /) -> ir.Value: ...
146+
def transform_index(self, arg: Iterable, /) -> tuple: ...
147+
def transform_shape(self, arg: Sequence[int], /) -> tuple: ...
148+
def transform_strides(self, arg: Sequence[int], /) -> tuple: ...
149+
150+
class TrivialTransferPlan:
151+
def __init__(self) -> None: ...
152+
@property
153+
def tile_index_transforms(self) -> tuple: ...
154+
def select(self, arg: Iterable, /) -> ir.Value: ...
155+
def select_if_group(
156+
self, arg0: int, arg1: ir.Value, arg2: ir.Value, /
157+
) -> ir.Value: ...
158+
159+
class StaggeredTransferPlan:
160+
def __init__(
161+
self, stagger: int, dim: int, size: int, group_pred: object
162+
) -> None: ...
163+
@property
164+
def stagger(self) -> int: ...
165+
@property
166+
def dim(self) -> int: ...
167+
@property
168+
def size(self) -> int: ...
169+
@property
170+
def group_pred(self) -> ir.Value: ...
171+
@property
172+
def tile_index_transforms(self) -> tuple: ...
173+
def select(self, arg: Iterable, /) -> ir.Value: ...
174+
def select_if_group(
175+
self, arg0: int, arg1: ir.Value, arg2: ir.Value, /
176+
) -> ir.Value: ...
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from mlir import ir
16+
17+
def register_dialect(context: ir.Context, load: bool = ...) -> None: ...
18+
def private_has_communication(
19+
arg: ir.Operation, /
20+
) -> tuple[bool, bool]: ...
21+
def private_set_arg_attr(
22+
arg0: ir.Operation, arg1: int, arg2: str, arg3: ir.Attribute, /
23+
) -> None: ...
24+
25+
class Float8EXMYType(ir.Type):
26+
@staticmethod
27+
def isinstance(other_type: ir.Type) -> bool: ...
28+
def __repr__(self) -> str: ...
29+
@staticmethod
30+
def get(
31+
exmy_type: ir.Type | None = None, ctx: ir.Context | None = None
32+
) -> Float8EXMYType: ...
33+
@property
34+
def underlying_type(self) -> ir.Type: ...

jaxlib/mlir/_mlir_libs/_triton_ext.pyi

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 The JAX Authors.
1+
# Copyright 2025 The JAX Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -17,19 +17,20 @@ from jaxlib.mlir import ir
1717
def register_dialect(context: ir.Context, load: bool = ...) -> None: ...
1818

1919
class PointerType(ir.Type):
20-
@classmethod
21-
def get(cls, pointee_type: ir.Type, address_space: int) -> PointerType: ...
22-
2320
@staticmethod
24-
def isinstance(other: ir.Type) -> bool: ...
21+
def isinstance(other_type: ir.Type) -> bool: ...
22+
def __repr__(self) -> str: ...
23+
@staticmethod
24+
def get_static_typeid() -> ir.TypeID: ...
25+
@staticmethod
26+
def get(pointee_type: ir.Type, address_space: int) -> PointerType:
27+
"""Creates a PointerType type."""
2528

2629
@property
2730
def pointee_type(self) -> ir.Type: ...
28-
2931
@property
3032
def address_space(self) -> int: ...
3133

3234
def infer_reduce_op_encoding(
33-
op_attribute: ir.Attribute,
34-
axis: int,
35-
) -> ir.Attribute: ...
35+
arg0: ir.Attribute, arg1: int, /
36+
) -> ir.Attribute | None: ...

0 commit comments

Comments
 (0)