15
15
"""Module for pallas-core functionality."""
16
16
from __future__ import annotations
17
17
18
+ from collections .abc import Iterator , Sequence
18
19
import copy
19
- from collections .abc import Sequence
20
20
import contextlib
21
21
import dataclasses
22
22
import functools
23
23
from typing import Any , Callable , Union
24
- from collections .abc import Iterator
25
24
26
25
from jax ._src import api_util
27
26
from jax ._src import core as jax_core
33
32
from jax ._src .state import discharge as state_discharge
34
33
import jax .numpy as jnp
35
34
36
- # TODO(sharadmv): enable type checking
37
- # mypy: ignore-errors
38
-
39
35
partial = functools .partial
40
36
Grid = tuple [Union [int , jax_core .Array , None ], ...] # None indicates that the bound is dynamic.
41
37
DynamicGrid = tuple [Union [int , jax_core .Array ], ...]
@@ -156,6 +152,10 @@ def compute_index(self, *args):
156
152
return out
157
153
158
154
155
+ # A PyTree of BlockSpec | NoBlockSpec.
156
+ BlockSpecTree = Any
157
+
158
+
159
159
@dataclasses .dataclass (frozen = True )
160
160
class BlockMapping :
161
161
block_shape : tuple [Mapped | int , ...]
@@ -201,7 +201,7 @@ def num_dynamic_grid_bounds(self):
201
201
def static_grid (self ) -> StaticGrid :
202
202
if self .num_dynamic_grid_bounds :
203
203
raise ValueError ("Expected a grid with fully static bounds" )
204
- return self .grid # typing : ignore
204
+ return self .grid # type : ignore
205
205
206
206
207
207
def _preprocess_grid (grid : Grid | int | None ) -> Grid :
@@ -213,9 +213,9 @@ def _preprocess_grid(grid: Grid | int | None) -> Grid:
213
213
214
214
215
215
def _convert_block_spec_to_block_mapping (
216
- in_avals : list [jax_core .ShapedArray ], block_spec : BlockSpec | None ,
216
+ in_avals : Sequence [jax_core .ShapedArray ], block_spec : BlockSpec ,
217
217
aval : jax_core .ShapedArray , in_tree : Any ,
218
- ) -> BlockSpec | None :
218
+ ) -> BlockMapping | None :
219
219
if block_spec is no_block_spec :
220
220
return None
221
221
if block_spec .index_map is None :
@@ -283,12 +283,8 @@ class GridSpec:
283
283
def __init__ (
284
284
self ,
285
285
grid : Grid | None = None ,
286
- in_specs : BlockSpec
287
- | Sequence [BlockSpec | NoBlockSpec ]
288
- | NoBlockSpec = no_block_spec ,
289
- out_specs : BlockSpec
290
- | Sequence [BlockSpec | NoBlockSpec ]
291
- | NoBlockSpec = no_block_spec ,
286
+ in_specs : BlockSpecTree = no_block_spec ,
287
+ out_specs : BlockSpecTree = no_block_spec ,
292
288
):
293
289
# Be more lenient for in/out_specs
294
290
if isinstance (in_specs , list ):
0 commit comments