Skip to content

Commit e510295

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:mosaic_gpu] Do not store the grid mapping in ModuleContext
We really only ever use the grid names. PiperOrigin-RevId: 703108864
1 parent d034680 commit e510295

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
from __future__ import annotations
1818

1919
import collections
20-
from collections.abc import MutableMapping, MutableSequence, Sequence
20+
from collections.abc import Hashable, MutableMapping, MutableSequence, Sequence
2121
import contextlib
2222
import dataclasses
2323
import functools
2424
import itertools as it
2525
import math
26-
from typing import Any, Hashable, Protocol, cast
26+
from typing import Any, Protocol, cast
2727

2828
import jax
2929
from jax import lax
@@ -192,7 +192,7 @@ def _reduce_sum_resource_estimator(x_aval: jax_core.ShapedArray, *, axes) -> int
192192
@dataclasses.dataclass
193193
class ModuleContext:
194194
name: str
195-
grid_mapping: pallas_core.GridMapping
195+
grid_names: Sequence[Hashable] | None
196196
program_ids: Sequence[ir.Value] | None
197197
approx_math: bool
198198
runtime_smem: ir.Value # ir.MemRefType
@@ -517,7 +517,7 @@ def make_program_ids(step: ir.Value):
517517
grouped_barriers[barrier].append(barrier_ref)
518518
module_ctx = ModuleContext(
519519
name_and_src_info.name,
520-
grid_mapping,
520+
grid_mapping.grid_names,
521521
None,
522522
approx_math,
523523
runtime_smem,
@@ -1290,7 +1290,7 @@ def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
12901290
@register_lowering_rule(lax.axis_index_p)
12911291
def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable):
12921292
i32 = ir.IntegerType.get_signless(32)
1293-
grid_names = ctx.module_ctx.grid_mapping.grid_names
1293+
grid_names = ctx.module_ctx.grid_names
12941294
squashed_dims = ctx.module_ctx.squashed_dims
12951295
if squashed_dims:
12961296
unsquashed_names = grid_names[-3:]

0 commit comments

Comments
 (0)