Skip to content

Commit 445d75c

Browse files
megrez-yliuGoogle-ML-Automation
authored andcommitted
[Pallas TPU] Add pack_elementwise and unpack_elementwise primitives.
PiperOrigin-RevId: 828697560
1 parent 2138cd0 commit 445d75c

File tree

4 files changed

+198
-1
lines changed

4 files changed

+198
-1
lines changed

jax/_src/pallas/mosaic/lowering.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3580,6 +3580,45 @@ def _stochastic_round_lowering_rule(
35803580
return tpu.stochastic_convert(out_type, x, random_bits)
35813581

35823582

3583+
def _check_elementwise_packing_dtypes(unpacked_dtype, packed_dtype):
3584+
if unpacked_dtype == jnp.float32 and packed_dtype == jnp.bfloat16:
3585+
return
3586+
if unpacked_dtype == jnp.int32 and packed_dtype in [
3587+
jnp.int16, jnp.int8, jnp.int4
3588+
]:
3589+
return
3590+
raise ValueError(
3591+
f"Unsupported elementwise packing: {unpacked_dtype} -> {packed_dtype}. "
3592+
"Only f32 <-> bf16 and i32 <-> i16/i8/i4 are supported."
3593+
)
3594+
3595+
3596+
@register_lowering_rule(tpu_primitives.pack_elementwise_p)
3597+
def _pack_elementwise_lowering_rule(
3598+
ctx: LoweringRuleContext, *xs, packed_dtype
3599+
):
3600+
in_aval = ctx.avals_in[0]
3601+
_check_elementwise_packing_dtypes(in_aval.dtype, packed_dtype)
3602+
packed_ir_type = _dtype_to_ir_type(packed_dtype)
3603+
out_type = ir.VectorType.get(
3604+
in_aval.shape, _dtype_to_ir_type(jnp.uint32)
3605+
)
3606+
return tpu.pack_elementwise(out_type, xs, target_type=packed_ir_type)
3607+
3608+
3609+
@register_lowering_rule(tpu_primitives.unpack_elementwise_p)
3610+
def _unpack_elementwise_lowering_rule(
3611+
ctx: LoweringRuleContext, x, index, packed_dtype, unpacked_dtype
3612+
):
3613+
in_aval = ctx.avals_in[0]
3614+
_check_elementwise_packing_dtypes(unpacked_dtype, packed_dtype)
3615+
out_type = ir.VectorType.get(
3616+
in_aval.shape, _dtype_to_ir_type(unpacked_dtype)
3617+
)
3618+
return tpu.unpack_elementwise(
3619+
out_type, x, source_type=_dtype_to_ir_type(packed_dtype), index=index)
3620+
3621+
35833622
@register_lowering_rule(tpu_primitives.bitcast_p)
35843623
def _bitcast_lowering_rule(ctx: LoweringRuleContext, x, *, ty):
35853624
del ty

jax/_src/pallas/mosaic/primitives.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,60 @@ def _stochastic_round_abstract_eval(x, random_bits, *, target_dtype):
969969
)
970970
return jax_core.ShapedArray(x.shape, target_dtype)
971971

972+
def _get_elementwise_packing_factor(unpacked_dtype, packed_dtype):
973+
unpacked_bitwidth = dtypes.bit_width(unpacked_dtype)
974+
packed_bitwidth = dtypes.bit_width(packed_dtype)
975+
if unpacked_bitwidth % packed_bitwidth != 0:
976+
raise ValueError(
977+
"Unpacked bitwidth must be a multiple of packed bitwidth, got "
978+
f"{unpacked_bitwidth} and {packed_bitwidth}"
979+
)
980+
return unpacked_bitwidth // packed_bitwidth
981+
982+
pack_elementwise_p = jax_core.Primitive("pack_elementwise")
983+
984+
985+
def pack_elementwise(xs, *, packed_dtype):
986+
return pack_elementwise_p.bind(*xs, packed_dtype=packed_dtype)
987+
988+
989+
@pack_elementwise_p.def_abstract_eval
990+
def _pack_elementwise_abstract_eval(*xs, packed_dtype):
991+
if not xs:
992+
raise ValueError("At least one source is required")
993+
first = xs[0]
994+
if not all(x.shape == first.shape for x in xs):
995+
raise ValueError("All sources must have the same shape")
996+
if not all(x.dtype == first.dtype for x in xs):
997+
raise ValueError("All sources must have the same dtype")
998+
packing_factor = _get_elementwise_packing_factor(first.dtype, packed_dtype)
999+
if len(xs) != packing_factor:
1000+
raise ValueError(
1001+
"The number of sources must match the packing factor "
1002+
f"({packing_factor}), got {len(xs)}"
1003+
)
1004+
return jax_core.ShapedArray(first.shape, jnp.uint32)
1005+
1006+
1007+
unpack_elementwise_p = jax_core.Primitive("unpack_elementwise")
1008+
1009+
1010+
def unpack_elementwise(x, *, index, packed_dtype, unpacked_dtype):
1011+
return unpack_elementwise_p.bind(
1012+
x, index=index, packed_dtype=packed_dtype, unpacked_dtype=unpacked_dtype
1013+
)
1014+
1015+
1016+
@unpack_elementwise_p.def_abstract_eval
1017+
def _unpack_elementwise_abstract_eval(x, *, index, packed_dtype, unpacked_dtype):
1018+
if x.dtype != jnp.uint32:
1019+
raise ValueError(f"Source must be uint32, got {x.dtype}")
1020+
packing_factor = _get_elementwise_packing_factor(unpacked_dtype, packed_dtype)
1021+
if index < 0 or index >= packing_factor:
1022+
raise ValueError(
1023+
f"Index {index} is out of bounds for packing factor {packing_factor}")
1024+
return jax_core.ShapedArray(x.shape, unpacked_dtype)
1025+
9721026

9731027
def with_memory_space_constraint(
9741028
x: jax.Array, memory_space: Any

jax/experimental/pallas/tpu.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,15 @@
4545
from jax._src.pallas.mosaic.primitives import load as load
4646
from jax._src.pallas.mosaic.primitives import make_async_copy as make_async_copy
4747
from jax._src.pallas.mosaic.primitives import make_async_remote_copy as make_async_remote_copy
48+
from jax._src.pallas.mosaic.primitives import pack_elementwise as pack_elementwise
4849
from jax._src.pallas.mosaic.primitives import prng_random_bits as prng_random_bits
4950
from jax._src.pallas.mosaic.primitives import prng_seed as prng_seed
5051
from jax._src.pallas.mosaic.primitives import repeat as repeat
5152
from jax._src.pallas.mosaic.primitives import roll as roll
5253
from jax._src.pallas.mosaic.primitives import stochastic_round as stochastic_round
5354
from jax._src.pallas.mosaic.primitives import store as store
5455
from jax._src.pallas.mosaic.primitives import touch as touch
56+
from jax._src.pallas.mosaic.primitives import unpack_elementwise as unpack_elementwise
5557
from jax._src.pallas.mosaic.primitives import with_memory_space_constraint as with_memory_space_constraint
5658
from jax._src.pallas.mosaic.random import sample_block as sample_block
5759
from jax._src.pallas.mosaic.random import stateful_bernoulli as stateful_bernoulli

tests/pallas/tpu_ops_test.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,6 @@ def kernel(x_ref, b_ref, o_ref):
716716

717717
result = pl.pallas_call(
718718
kernel,
719-
in_specs=[pl.BlockSpec(), pl.BlockSpec()],
720719
out_shape=jax.ShapeDtypeStruct(x.shape, target_dtype),
721720
)(x, bits)
722721

@@ -730,5 +729,108 @@ def kernel(x_ref, b_ref, o_ref):
730729
)
731730
self.assertTrue(jnp.all(is_correct))
732731

732+
def _pack_unpack_elementwise_test_data(
733+
self, shape, unpacked_dtype, packed_dtype):
734+
"""Generates data for test_pack_elementwise and test_unpack_elementwise."""
735+
bitwidth = dtypes.bit_width(packed_dtype)
736+
num_sources = 32 // bitwidth
737+
if unpacked_dtype == jnp.int32:
738+
stacked_sources = jax.random.randint(
739+
jax.random.key(0),
740+
(num_sources, *shape),
741+
minval=-1000,
742+
maxval=1000,
743+
dtype=unpacked_dtype,
744+
)
745+
else:
746+
stacked_sources = jax.random.uniform(
747+
jax.random.key(0), (num_sources, *shape), dtype=unpacked_dtype
748+
)
749+
stacked_results = (
750+
stacked_sources.astype(packed_dtype)
751+
.view(getattr(jnp, f"uint{bitwidth}"))
752+
.astype(jnp.uint32)
753+
)
754+
shifts = jnp.arange(num_sources, dtype=jnp.uint32) * bitwidth
755+
shifts = jnp.expand_dims(shifts, axis=tuple(range(1, stacked_results.ndim)))
756+
packed_data = jnp.bitwise_or.reduce(stacked_results << shifts, axis=0)
757+
return stacked_sources, packed_data
758+
759+
@parameterized.product(
760+
config=[
761+
(jnp.float32, jnp.bfloat16),
762+
(jnp.int32, jnp.int16),
763+
(jnp.int32, jnp.int8),
764+
(jnp.int32, jnp.int4),
765+
],
766+
shape=[(8, 128), (2, 15, 300)],
767+
)
768+
def test_pack_elementwise(self, config, shape):
769+
unpacked_dtype, packed_dtype = config
770+
if not jtu.is_device_tpu_at_least(version=5):
771+
self.skipTest("Requires TPU v5+")
772+
if not jtu.if_cloud_tpu_at_least(2025, 11, 7):
773+
self.skipTest("Test requires libtpu from 2025/11/7 or later")
774+
775+
bitwidth = dtypes.bit_width(packed_dtype)
776+
num_sources = 32 // bitwidth
777+
778+
def kernel(xs_ref, o_ref):
779+
xs = [xs_ref[i] for i in range(num_sources)]
780+
o_ref[...] = pltpu.pack_elementwise(xs, packed_dtype=packed_dtype)
781+
782+
stacked_sources, expected = self._pack_unpack_elementwise_test_data(
783+
shape, unpacked_dtype, packed_dtype
784+
)
785+
786+
result = self.pallas_call(
787+
kernel,
788+
out_shape=jax.ShapeDtypeStruct(shape, jnp.uint32),
789+
)(stacked_sources)
790+
791+
np.testing.assert_array_equal(result, expected)
792+
793+
@parameterized.product(
794+
config=[
795+
(jnp.float32, jnp.bfloat16),
796+
(jnp.int32, jnp.int16),
797+
(jnp.int32, jnp.int8),
798+
(jnp.int32, jnp.int4),
799+
],
800+
index=[0, 1, 3],
801+
shape=[(8, 128), (2, 15, 300)],
802+
)
803+
def test_unpack_elementwise(self, config, index, shape):
804+
unpacked_dtype, packed_dtype = config
805+
if not jtu.is_device_tpu_at_least(version=5):
806+
self.skipTest("Requires TPU v5+")
807+
if not jtu.if_cloud_tpu_at_least(2025, 11, 7):
808+
self.skipTest("Test requires libtpu from 2025/11/7 or later")
809+
810+
bitwidth = dtypes.bit_width(packed_dtype)
811+
packing_factor = 32 // bitwidth
812+
813+
if index >= packing_factor:
814+
self.skipTest(
815+
f"Index {index} out of bounds for packing factor {packing_factor}")
816+
817+
def kernel(x_ref, o_ref):
818+
o_ref[...] = pltpu.unpack_elementwise(
819+
x_ref[...], index=index,
820+
packed_dtype=packed_dtype, unpacked_dtype=unpacked_dtype
821+
)
822+
823+
sources, packed = self._pack_unpack_elementwise_test_data(
824+
shape, unpacked_dtype, packed_dtype
825+
)
826+
expected = sources[index].astype(packed_dtype).astype(unpacked_dtype)
827+
828+
result = self.pallas_call(
829+
kernel,
830+
out_shape=jax.ShapeDtypeStruct(shape, unpacked_dtype),
831+
)(packed)
832+
833+
np.testing.assert_array_equal(result, expected)
834+
733835
if __name__ == "__main__":
734836
absltest.main()

0 commit comments

Comments
 (0)