File tree Expand file tree Collapse file tree 3 files changed +8
-9
lines changed
jax/experimental/pallas/ops/gpu Expand file tree Collapse file tree 3 files changed +8
-9
lines changed Original file line number Diff line number Diff line change @@ -94,7 +94,7 @@ def layer_norm_forward(
9494 ]
9595 method = pl .pallas_call (
9696 kernel ,
97- compiler_params = dict ( triton = dict ( num_warps = num_warps ) ),
97+ compiler_params = plgpu . TritonCompilerParams ( num_warps = num_warps ),
9898 grid = (),
9999 out_shape = out_shape ,
100100 debug = False ,
@@ -215,7 +215,7 @@ def layer_norm_backward(
215215 out_shape_dx = jax .ShapeDtypeStruct (shape = (n ,), dtype = x .dtype )
216216 method = pl .pallas_call (
217217 kernel ,
218- compiler_params = dict ( triton = dict ( num_warps = num_warps ) ),
218+ compiler_params = plgpu . TritonCompilerParams ( num_warps = num_warps ),
219219 grid = (),
220220 out_shape = out_shape_dx ,
221221 debug = False ,
Original file line number Diff line number Diff line change @@ -196,7 +196,7 @@ def rms_norm_backward(
196196 out_shape_dx = jax .ShapeDtypeStruct (shape = (n ,), dtype = x .dtype )
197197 method = pl .pallas_call (
198198 kernel ,
199- compiler_params = dict ( triton = dict ( num_warps = num_warps ) ),
199+ compiler_params = plgpu . TritonCompilerParams ( num_warps = num_warps ),
200200 grid = (),
201201 out_shape = out_shape_dx ,
202202 debug = False ,
Original file line number Diff line number Diff line change @@ -486,12 +486,11 @@ def _wait_on_prev_dma():
486486 + [pltpu .SemaphoreType .DMA ] * 4
487487 + inner_allocs
488488 ),
489- compiler_params = dict (
490- mosaic = dict (collective_id = 0 ,
491- # must set scoped vmem flag *larger* than below! e.g.:
492- # flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072
493- vmem_limit_bytes = int (134217728 * 0.9 ) # 0.9 * 128MiB
494- )
489+ compiler_params = pltpu .TPUCompilerParams (
490+ collective_id = 0 ,
491+ # must set scoped vmem flag *larger* than below! e.g.:
492+ # flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072
493+ vmem_limit_bytes = int (134217728 * 0.9 ) # 0.9 * 128MiB
495494 ),
496495 )
497496
You can’t perform that action at this time.
0 commit comments