Skip to content

Commit 1c06114

Browse files
author
jax authors
committed
Merge pull request #21729 from superbobry:pallas
PiperOrigin-RevId: 642225089
2 parents f147dd2 + 70f6ab3 commit 1c06114

File tree

5 files changed

+24
-34
lines changed

5 files changed

+24
-34
lines changed

jax/_src/pallas/core.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,12 @@
1515
"""Module for pallas-core functionality."""
1616
from __future__ import annotations
1717

18+
from collections.abc import Iterator, Sequence
1819
import copy
19-
from collections.abc import Sequence
2020
import contextlib
2121
import dataclasses
2222
import functools
2323
from typing import Any, Callable, Union
24-
from collections.abc import Iterator
2524

2625
from jax._src import api_util
2726
from jax._src import core as jax_core
@@ -33,9 +32,6 @@
3332
from jax._src.state import discharge as state_discharge
3433
import jax.numpy as jnp
3534

36-
# TODO(sharadmv): enable type checking
37-
# mypy: ignore-errors
38-
3935
partial = functools.partial
4036
Grid = tuple[Union[int, jax_core.Array, None], ...] # None indicates that the bound is dynamic.
4137
DynamicGrid = tuple[Union[int, jax_core.Array], ...]
@@ -156,6 +152,10 @@ def compute_index(self, *args):
156152
return out
157153

158154

155+
# A PyTree of BlockSpec | NoBlockSpec.
156+
BlockSpecTree = Any
157+
158+
159159
@dataclasses.dataclass(frozen=True)
160160
class BlockMapping:
161161
block_shape: tuple[Mapped | int, ...]
@@ -201,7 +201,7 @@ def num_dynamic_grid_bounds(self):
201201
def static_grid(self) -> StaticGrid:
202202
if self.num_dynamic_grid_bounds:
203203
raise ValueError("Expected a grid with fully static bounds")
204-
return self.grid # typing: ignore
204+
return self.grid # type: ignore
205205

206206

207207
def _preprocess_grid(grid: Grid | int | None) -> Grid:
@@ -213,9 +213,9 @@ def _preprocess_grid(grid: Grid | int | None) -> Grid:
213213

214214

215215
def _convert_block_spec_to_block_mapping(
216-
in_avals: list[jax_core.ShapedArray], block_spec: BlockSpec | None,
216+
in_avals: Sequence[jax_core.ShapedArray], block_spec: BlockSpec,
217217
aval: jax_core.ShapedArray, in_tree: Any,
218-
) -> BlockSpec | None:
218+
) -> BlockMapping | None:
219219
if block_spec is no_block_spec:
220220
return None
221221
if block_spec.index_map is None:
@@ -283,12 +283,8 @@ class GridSpec:
283283
def __init__(
284284
self,
285285
grid: Grid | None = None,
286-
in_specs: BlockSpec
287-
| Sequence[BlockSpec | NoBlockSpec]
288-
| NoBlockSpec = no_block_spec,
289-
out_specs: BlockSpec
290-
| Sequence[BlockSpec | NoBlockSpec]
291-
| NoBlockSpec = no_block_spec,
286+
in_specs: BlockSpecTree = no_block_spec,
287+
out_specs: BlockSpecTree = no_block_spec,
292288
):
293289
# Be more lenient for in/out_specs
294290
if isinstance(in_specs, list):

jax/_src/pallas/mosaic/core.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import dataclasses
2020
import enum
2121
import functools
22-
from typing import Any, Union
22+
from typing import Any
2323

2424
from jax._src import core as jax_core
2525
from jax._src import dtypes
@@ -28,15 +28,13 @@
2828
import jax.numpy as jnp
2929
from jax._src.pallas import core as pallas_core
3030

31-
# TODO(sharadmv): enable type checking
32-
# mypy: ignore-errors
33-
3431
map, unsafe_map = util.safe_map, map
3532
zip, unsafe_zip = util.safe_zip, zip
3633

3734
partial = functools.partial
3835
Grid = pallas_core.Grid
3936
BlockSpec = pallas_core.BlockSpec
37+
BlockSpecTree = pallas_core.BlockSpecTree
4038
GridMapping = pallas_core.GridMapping
4139
NoBlockSpec = pallas_core.NoBlockSpec
4240
AbstractMemoryRef = pallas_core.AbstractMemoryRef
@@ -97,6 +95,7 @@ class SemaphoreType(enum.Enum):
9795
BARRIER = "barrier"
9896

9997
def __call__(self, shape: tuple[int, ...]):
98+
dtype: Any
10099
if self == SemaphoreType.DMA:
101100
dtype = DmaSemaphoreTy()
102101
elif self == SemaphoreType.BARRIER:
@@ -143,9 +142,6 @@ def _make_aval(obj: object) -> jax_core.AbstractValue:
143142
"Only VMEM and SemaphoreType are supported.")
144143

145144

146-
BlockSpecTree = Union[BlockSpec, NoBlockSpec, Sequence["BlockSpecTree"]]
147-
148-
149145
@dataclasses.dataclass(init=False, unsafe_hash=True)
150146
class PrefetchScalarGridSpec(pallas_core.GridSpec):
151147
grid: Grid

jax/_src/pallas/pallas_call.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,11 @@
5555
zip, unsafe_zip = safe_zip, zip
5656

5757
Grid = pallas_core.Grid
58-
BlockSpec = pallas_core.BlockSpec
5958
GridSpec = pallas_core.GridSpec
6059
BlockMapping = pallas_core.BlockMapping
6160
GridMapping = pallas_core.GridMapping
61+
BlockSpec = pallas_core.BlockSpec
62+
BlockSpecTree = pallas_core.BlockSpecTree
6263
NoBlockSpec = pallas_core.NoBlockSpec
6364
no_block_spec = pallas_core.no_block_spec
6465

@@ -763,14 +764,13 @@ def pallas_call(
763764
grid_spec: GridSpec | None = None,
764765
debug: bool = False,
765766
grid: Grid | None = None,
766-
in_specs: Sequence[BlockSpec | NoBlockSpec] | NoBlockSpec = no_block_spec,
767-
out_specs: BlockSpec | NoBlockSpec
768-
| Sequence[BlockSpec | NoBlockSpec] = no_block_spec,
767+
in_specs: BlockSpecTree = no_block_spec,
768+
out_specs: BlockSpecTree = no_block_spec,
769769
input_output_aliases: dict[int, int] = {},
770770
interpret: bool = False,
771771
name: str | None = None,
772772
compiler_params: dict[str, Any] | None = None,
773-
):
773+
) -> Callable[..., Any]:
774774
name = _extract_function_name(f, name)
775775
if compiler_params is None:
776776
compiler_params = {}

jax/_src/pallas/primitives.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,6 @@
3939
from jax.interpreters import mlir
4040
import jax.numpy as jnp
4141

42-
43-
# TODO(sharadmv): enable type checking
44-
# mypy: ignore-errors
45-
4642
partial = functools.partial
4743
Slice = indexing.Slice
4844
NDIndexer = indexing.NDIndexer
@@ -64,6 +60,7 @@ def program_id_bind(*, axis: int):
6460

6561
def _program_id_impl(*, axis: int):
6662
grid_env = pallas_core.current_grid_env()
63+
assert grid_env
6764
return grid_env[axis].axis_index
6865
program_id_p.def_impl(_program_id_impl)
6966

@@ -87,6 +84,7 @@ def _num_programs_bind(*, axis: int):
8784
@num_programs_p.def_impl
8885
def _num_programs_impl(*, axis: int):
8986
grid_env = pallas_core.current_grid_env()
87+
assert grid_env
9088
return jnp.asarray(grid_env[axis].axis_size, dtype=jnp.int32)
9189

9290
@num_programs_p.def_abstract_eval
@@ -569,7 +567,7 @@ def debug_print(fmt: str, *args: jax.ArrayLike):
569567
""" # fmt: skip
570568
has_placeholders = False
571569
if fmt:
572-
_, field_name, *_ = next(string.Formatter().parse(fmt))
570+
_, field_name, *_ = next(iter(string.Formatter().parse(fmt)))
573571
has_placeholders = field_name is not None
574572
return debug_print_p.bind(*args, fmt=fmt, has_placeholders=has_placeholders)
575573

jax/experimental/pallas/ops/tpu/flash_attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,7 +1099,7 @@ def dkv_index_map(batch_index, head_index, kv_seq_index, _):
10991099
grid_spec=pltpu.PrefetchScalarGridSpec(
11001100
num_scalar_prefetch=0,
11011101
grid=grid,
1102-
in_specs=in_specs, # type: ignore
1102+
in_specs=in_specs,
11031103
out_specs=out_specs,
11041104
scratch_shapes=scratch_shapes,
11051105
),
@@ -1444,8 +1444,8 @@ def kv_segment_ids_index_map(
14441444
grid_spec=pltpu.PrefetchScalarGridSpec(
14451445
num_scalar_prefetch=0,
14461446
grid=grid,
1447-
in_specs=in_specs, # type: ignore
1448-
out_specs=out_specs, # type: ignore
1447+
in_specs=in_specs,
1448+
out_specs=out_specs,
14491449
scratch_shapes=scratch_shapes,
14501450
),
14511451
out_shape=out_shapes,

0 commit comments

Comments
 (0)