|
17 | 17 | from __future__ import annotations |
18 | 18 |
|
19 | 19 | import collections |
20 | | -from collections.abc import MutableMapping, MutableSequence, Sequence |
| 20 | +from collections.abc import Hashable, MutableMapping, MutableSequence, Sequence |
21 | 21 | import contextlib |
22 | 22 | import dataclasses |
23 | 23 | import functools |
24 | 24 | import itertools as it |
25 | 25 | import math |
26 | | -from typing import Any, Hashable, Protocol, cast |
| 26 | +from typing import Any, Protocol, cast |
27 | 27 |
|
28 | 28 | import jax |
29 | 29 | from jax import lax |
@@ -192,7 +192,7 @@ def _reduce_sum_resource_estimator(x_aval: jax_core.ShapedArray, *, axes) -> int |
192 | 192 | @dataclasses.dataclass |
193 | 193 | class ModuleContext: |
194 | 194 | name: str |
195 | | - grid_mapping: pallas_core.GridMapping |
| 195 | + grid_names: Sequence[Hashable] | None |
196 | 196 | program_ids: Sequence[ir.Value] | None |
197 | 197 | approx_math: bool |
198 | 198 | runtime_smem: ir.Value # ir.MemRefType |
@@ -517,7 +517,7 @@ def make_program_ids(step: ir.Value): |
517 | 517 | grouped_barriers[barrier].append(barrier_ref) |
518 | 518 | module_ctx = ModuleContext( |
519 | 519 | name_and_src_info.name, |
520 | | - grid_mapping, |
| 520 | + grid_mapping.grid_names, |
521 | 521 | None, |
522 | 522 | approx_math, |
523 | 523 | runtime_smem, |
@@ -1290,7 +1290,7 @@ def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes): |
1290 | 1290 | @register_lowering_rule(lax.axis_index_p) |
1291 | 1291 | def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): |
1292 | 1292 | i32 = ir.IntegerType.get_signless(32) |
1293 | | - grid_names = ctx.module_ctx.grid_mapping.grid_names |
| 1293 | + grid_names = ctx.module_ctx.grid_names |
1294 | 1294 | squashed_dims = ctx.module_ctx.squashed_dims |
1295 | 1295 | if squashed_dims: |
1296 | 1296 | unsquashed_names = grid_names[-3:] |
|
0 commit comments