Skip to content

Commit b6ab98e

Browse files
altimofeevchanglan
authored andcommitted
Allow Nested[PartitionSpec] in host_to_global_device_array
* Support nested PartitionSpec * Add unittest * Add unittest * Recursion * Revert original GitOrigin-RevId: 438d288
1 parent 78833d4 commit b6ab98e

File tree

2 files changed

+86
-4
lines changed

2 files changed

+86
-4
lines changed

axlearn/common/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -804,23 +804,25 @@ class DataPartitionType(Enum):
804804

805805

806806
def data_partition_type_to_spec(
807-
partition: Union[DataPartitionType, PartitionSpec],
808-
) -> PartitionSpec:
807+
partition: Union[DataPartitionType, Nested[PartitionSpec]],
808+
) -> Nested[PartitionSpec]:
809809
"""Returns a PartitionSpec for the given partition type."""
810810
if partition == DataPartitionType.FULL:
811811
return input_partition_spec()
812812
elif partition == DataPartitionType.REPLICATED:
813813
return PartitionSpec(None)
814814
elif isinstance(partition, PartitionSpec):
815815
return partition
816+
elif isinstance(partition, dict):
817+
return {k: data_partition_type_to_spec(v) for k, v in partition.items()}
816818
else:
817819
raise NotImplementedError(f"Unsupported partition: {partition}")
818820

819821

820822
def host_to_global_array(
821823
host_arrays: Nested[Union[np.ndarray, Tensor]],
822824
*,
823-
partition: Union[PartitionSpec, DataPartitionType] = DataPartitionType.FULL,
825+
partition: Union[Nested[PartitionSpec], DataPartitionType] = DataPartitionType.FULL,
824826
) -> Nested[Tensor]:
825827
"""Converts the given host device arrays to global device arrays.
826828
@@ -858,7 +860,7 @@ def make_array(x: np.ndarray, partition_spec: PartitionSpec):
858860
global_shape = (x.shape[0] * process_count, *x.shape[1:])
859861
elif partition == DataPartitionType.REPLICATED:
860862
global_shape = (x.shape[0], *x.shape[1:])
861-
elif isinstance(partition, PartitionSpec):
863+
elif isinstance(partition, (PartitionSpec, dict)):
862864
global_shape = None # Allow jax to infer.
863865
else:
864866
raise NotImplementedError(f"Unsupported partition: {partition}")

axlearn/common/utils_test.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from axlearn.common.trainer import SpmdTrainer
5757
from axlearn.common.utils import (
5858
PHYSICAL_TO_LOGICAL_DISPATCH_KEY,
59+
DataPartitionType,
5960
HybridMeshShape,
6061
MeshShape,
6162
NestedTensor,
@@ -74,6 +75,7 @@
7475
copy_recursively,
7576
count_model_params,
7677
create_device_mesh,
78+
data_partition_type_to_spec,
7779
dispatch_input_batch,
7880
expand_vdicts,
7981
find_cycles,
@@ -1970,6 +1972,53 @@ def test_one_per_process(self):
19701972
# Check that contents are as expected.
19711973
self.assertNestedEqual(global_array, replicate_to_local_data(batch))
19721974

1975+
@pytest.mark.for_8_devices
1976+
def test_one_per_process_two_arrays(self):
1977+
"""Test a case where every process produces a slice.
1978+
1979+
This is recommended to run on 2 process, e.g. v5e-16.
1980+
"""
1981+
# NOTE: the following can be used for local testing
1982+
# XLA_FLAGS=--xla_force_host_platform_device_count=8
1983+
1984+
device_count = jax.device_count()
1985+
process_count = jax.process_count()
1986+
print(f"{device_count=}, {process_count=}")
1987+
assert device_count > 1
1988+
assert process_count <= 2
1989+
1990+
# Build an array that has dim=0 smaller than num devices, but still >= num processes.
1991+
global_shape = (device_count // 2, 2)
1992+
assert global_shape[0] % process_count == 0
1993+
process_shape = global_shape[0] // process_count
1994+
1995+
feed_index = jax.process_index()
1996+
global_a = jax.random.uniform(jax.random.PRNGKey(123), shape=global_shape)
1997+
global_b = jax.random.uniform(jax.random.PRNGKey(124), shape=global_shape)
1998+
expected_batch = {"a": global_a, "b": {"nested_value": global_b}}
1999+
2000+
with jax.sharding.Mesh(np.array(jax.devices()).reshape(device_count // 2, 2), ("x", "y")):
2001+
# Shard dim=0 only along data.
2002+
logical_sharding = {"a": PartitionSpec("x"), "b": PartitionSpec("y")}
2003+
2004+
# Each process has a slice.
2005+
local_batch = {
2006+
"a": global_a[feed_index * process_shape : (feed_index + 1) * process_shape],
2007+
"b": {
2008+
"nested_value": global_b[
2009+
feed_index * process_shape : (feed_index + 1) * process_shape
2010+
]
2011+
},
2012+
}
2013+
batch = host_to_global_device_array(local_batch, partition=logical_sharding)
2014+
2015+
# Check that sharding is as expected.
2016+
self.assertEqual(logical_sharding["a"], batch["a"].sharding.spec)
2017+
self.assertEqual(logical_sharding["b"], batch["b"]["nested_value"].sharding.spec)
2018+
2019+
# Check that contents are as expected.
2020+
self.assertNestedEqual(expected_batch, replicate_to_local_data(batch))
2021+
19732022
# Test process_count // 1, process_count // 2, process_count // 4.
19742023
# On v5e-16, this exercises 4, 2, and 1 reading hosts out of 4.
19752024
@parameterized.parameters(1, 2, 4)
@@ -2183,5 +2232,36 @@ class ConfigChild(ConfigParent):
21832232
self.assertSameElements(("child_field1", "child_field2"), own_fields(ConfigChild()))
21842233

21852234

2235+
class DataPartitionTypeToSpecTest(TestCase):
2236+
@mock.patch("axlearn.common.utils.input_partition_spec")
2237+
def test_full_partition(self, mock_input_partition_spec):
2238+
# Mocks input_partition_spec to return a predictable value
2239+
mock_input_partition_spec.return_value = PartitionSpec("full_spec")
2240+
result = data_partition_type_to_spec(DataPartitionType.FULL)
2241+
self.assertEqual(result, PartitionSpec("full_spec"))
2242+
mock_input_partition_spec.assert_called_once()
2243+
2244+
def test_replicated_partition(self):
2245+
result = data_partition_type_to_spec(DataPartitionType.REPLICATED)
2246+
self.assertEqual(result, PartitionSpec(None))
2247+
2248+
def test_partition_spec_input(self):
2249+
custom_spec = PartitionSpec((("data", 0), ("model", 1)))
2250+
result = data_partition_type_to_spec(custom_spec)
2251+
self.assertEqual(result, custom_spec)
2252+
2253+
def test_dict_input(self):
2254+
dict_spec = {"a": PartitionSpec("b"), "c": {"d": PartitionSpec("d")}}
2255+
result = data_partition_type_to_spec(dict_spec)
2256+
self.assertEqual(result, dict_spec)
2257+
2258+
def test_unsupported_partition_type(self):
2259+
with self.assertRaisesRegex(NotImplementedError, "Unsupported partition: unsupported_type"):
2260+
data_partition_type_to_spec("unsupported_type")
2261+
2262+
with self.assertRaisesRegex(NotImplementedError, "Unsupported partition: 123"):
2263+
data_partition_type_to_spec(123)
2264+
2265+
21862266
if __name__ == "__main__":
21872267
absltest.main()

0 commit comments

Comments
 (0)