Skip to content

Commit f1f98af

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:mosaic_gpu] Fix the tests following the changes to pl.core_map
PiperOrigin-RevId: 713283207
1 parent 51b9fe3 commit f1f98af

File tree

3 files changed

+31
-17
lines changed

3 files changed

+31
-17
lines changed

jax/_src/pallas/mosaic_gpu/core.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,6 @@ class GPUMesh:
483483
# Those are NOT CUDA threads. On Hopper they correspond to warpgroups.
484484
num_threads: int | None = None
485485
axis_names: tuple[str, ...] = ()
486-
approx_math: bool = False
487486

488487
def __post_init__(self):
489488
if len(self.axis_names) != len(self.grid) + (self.num_threads is not None):
@@ -521,21 +520,38 @@ def _gpu_mesh_discharge_rule(
521520
*args,
522521
mesh,
523522
jaxpr,
523+
compiler_params,
524+
interpret,
525+
debug,
526+
cost_estimate,
524527
):
525-
assert isinstance(mesh, GPUMesh)
528+
if not isinstance(mesh, GPUMesh):
529+
raise TypeError(f"Mesh must be a GPUMesh, got {type(mesh)}")
526530
if mesh.cluster:
527531
raise NotImplementedError
528532
if mesh.num_threads is None:
529533
raise NotImplementedError
534+
if compiler_params and not isinstance(compiler_params, GPUCompilerParams):
535+
raise TypeError(
536+
"Compiler params must be a GPUCompilerParams, got"
537+
f" {type(compiler_params)}"
538+
)
539+
if not compiler_params:
540+
compiler_params = GPUCompilerParams()
530541
return pallas_core.default_mesh_discharge_rule(
531542
in_avals,
532543
out_avals,
533544
*args,
534545
jaxpr=jaxpr,
535546
grid=tuple(mesh.shape.items()),
536547
backend="mosaic_gpu",
537-
compiler_params=GPUCompilerParams(approx_math=mesh.approx_math),
548+
compiler_params=compiler_params,
549+
debug=debug,
550+
interpret=interpret,
551+
cost_estimate=cost_estimate,
538552
)
553+
554+
539555
pallas_core._core_map_mesh_rules[GPUMesh] = _gpu_mesh_discharge_rule
540556

541557

jax/experimental/pallas/ops/gpu/attention_mgpu.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,11 @@ def run(refs):
200200
grid=(batch_size, num_q_tiles, num_q_heads),
201201
num_threads=3,
202202
axis_names=("batch", "q_seq", "heads", "wg"),
203-
approx_math=True,
204203
)
205-
@pl.core_map(mesh)
204+
205+
@pl.core_map(
206+
mesh, compiler_params=plgpu.GPUCompilerParams(approx_math=True)
207+
)
206208
def _kernel_entry():
207209
compute_wgs = 2
208210
tiling = plgpu.TilingTransform((64, 64))

tests/pallas/mosaic_gpu_test.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,14 +1481,11 @@ def copy_kernel(x_smem, o_smem, o_last_block_smem):
14811481
index_map=lambda i, j: (0, 0))
14821482
],
14831483
)
1484-
mesh = plgpu.GPUMesh(
1485-
grid=(1,),
1486-
num_threads=3,
1487-
axis_names=("_", "wg",),
1488-
approx_math=True,
1489-
)
1484+
mesh = plgpu.GPUMesh(grid=(1,), num_threads=3, axis_names=("_", "wg"))
14901485
def run(refs):
1491-
@pl.core_map(mesh)
1486+
@pl.core_map(
1487+
mesh, compiler_params=plgpu.GPUCompilerParams(approx_math=True)
1488+
)
14921489
def _kernel_entry():
14931490
pipeline(*refs)
14941491
@jax.jit
@@ -1535,13 +1532,12 @@ def tiled_add_kernel(x_smem, y_smem, o_smem):
15351532
transforms=[])],
15361533
)
15371534
mesh = plgpu.GPUMesh(
1538-
grid=(1,),
1539-
num_threads=num_compute_wgs + 1,
1540-
axis_names=("_", "wg",),
1541-
approx_math=True,
1535+
grid=(1,), num_threads=num_compute_wgs + 1, axis_names=("_", "wg")
15421536
)
15431537
def run(refs):
1544-
@pl.core_map(mesh)
1538+
@pl.core_map(
1539+
mesh, compiler_params=plgpu.GPUCompilerParams(approx_math=True)
1540+
)
15451541
def _kernel_entry():
15461542
pipeline(*refs)
15471543
@jax.jit

0 commit comments

Comments
 (0)