Skip to content

Commit 73fa0f4

Browse files
justinjfuGoogle-ML-Automation
authored andcommitted
[Pallas] Deprecate dictionary compiler_params in favor of dataclass.
PiperOrigin-RevId: 699057658
1 parent 355589f commit 73fa0f4

File tree

3 files changed

+8
-9
lines changed

3 files changed

+8
-9
lines changed

jax/experimental/pallas/ops/gpu/layer_norm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def layer_norm_forward(
9494
]
9595
method = pl.pallas_call(
9696
kernel,
97-
compiler_params=dict(triton=dict(num_warps=num_warps)),
97+
compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps),
9898
grid=(),
9999
out_shape=out_shape,
100100
debug=False,
@@ -215,7 +215,7 @@ def layer_norm_backward(
215215
out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
216216
method = pl.pallas_call(
217217
kernel,
218-
compiler_params=dict(triton=dict(num_warps=num_warps)),
218+
compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps),
219219
grid=(),
220220
out_shape=out_shape_dx,
221221
debug=False,

jax/experimental/pallas/ops/gpu/rms_norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def rms_norm_backward(
196196
out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
197197
method = pl.pallas_call(
198198
kernel,
199-
compiler_params=dict(triton=dict(num_warps=num_warps)),
199+
compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps),
200200
grid=(),
201201
out_shape=out_shape_dx,
202202
debug=False,

tests/pallas/tpu_pallas_pipeline_test.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -486,12 +486,11 @@ def _wait_on_prev_dma():
486486
+ [pltpu.SemaphoreType.DMA] * 4
487487
+ inner_allocs
488488
),
489-
compiler_params=dict(
490-
mosaic=dict(collective_id=0,
491-
# must set scoped vmem flag *larger* than below! e.g.:
492-
# flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072
493-
vmem_limit_bytes=int(134217728 * 0.9) # 0.9 * 128MiB
494-
)
489+
compiler_params=pltpu.TPUCompilerParams(
490+
collective_id=0,
491+
# must set scoped vmem flag *larger* than below! e.g.:
492+
# flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072
493+
vmem_limit_bytes=int(134217728 * 0.9) # 0.9 * 128MiB
495494
),
496495
)
497496

0 commit comments

Comments
 (0)