|
56 | 56 | from axlearn.common.trainer import SpmdTrainer
|
57 | 57 | from axlearn.common.utils import (
|
58 | 58 | PHYSICAL_TO_LOGICAL_DISPATCH_KEY,
|
| 59 | + DataPartitionType, |
59 | 60 | HybridMeshShape,
|
60 | 61 | MeshShape,
|
61 | 62 | NestedTensor,
|
|
74 | 75 | copy_recursively,
|
75 | 76 | count_model_params,
|
76 | 77 | create_device_mesh,
|
| 78 | + data_partition_type_to_spec, |
77 | 79 | dispatch_input_batch,
|
78 | 80 | expand_vdicts,
|
79 | 81 | find_cycles,
|
@@ -1970,6 +1972,53 @@ def test_one_per_process(self):
|
1970 | 1972 | # Check that contents are as expected.
|
1971 | 1973 | self.assertNestedEqual(global_array, replicate_to_local_data(batch))
|
1972 | 1974 |
|
| 1975 | + @pytest.mark.for_8_devices |
| 1976 | + def test_one_per_process_two_arrays(self): |
| 1977 | + """Test a case where every process produces a slice. |
| 1978 | +
|
| 1979 | + This is recommended to run on 2 process, e.g. v5e-16. |
| 1980 | + """ |
| 1981 | + # NOTE: the following can be used for local testing |
| 1982 | + # XLA_FLAGS=--xla_force_host_platform_device_count=8 |
| 1983 | + |
| 1984 | + device_count = jax.device_count() |
| 1985 | + process_count = jax.process_count() |
| 1986 | + print(f"{device_count=}, {process_count=}") |
| 1987 | + assert device_count > 1 |
| 1988 | + assert process_count <= 2 |
| 1989 | + |
| 1990 | + # Build an array that has dim=0 smaller than num devices, but still >= num processes. |
| 1991 | + global_shape = (device_count // 2, 2) |
| 1992 | + assert global_shape[0] % process_count == 0 |
| 1993 | + process_shape = global_shape[0] // process_count |
| 1994 | + |
| 1995 | + feed_index = jax.process_index() |
| 1996 | + global_a = jax.random.uniform(jax.random.PRNGKey(123), shape=global_shape) |
| 1997 | + global_b = jax.random.uniform(jax.random.PRNGKey(124), shape=global_shape) |
| 1998 | + expected_batch = {"a": global_a, "b": {"nested_value": global_b}} |
| 1999 | + |
| 2000 | + with jax.sharding.Mesh(np.array(jax.devices()).reshape(device_count // 2, 2), ("x", "y")): |
| 2001 | + # Shard dim=0 only along data. |
| 2002 | + logical_sharding = {"a": PartitionSpec("x"), "b": PartitionSpec("y")} |
| 2003 | + |
| 2004 | + # Each process has a slice. |
| 2005 | + local_batch = { |
| 2006 | + "a": global_a[feed_index * process_shape : (feed_index + 1) * process_shape], |
| 2007 | + "b": { |
| 2008 | + "nested_value": global_b[ |
| 2009 | + feed_index * process_shape : (feed_index + 1) * process_shape |
| 2010 | + ] |
| 2011 | + }, |
| 2012 | + } |
| 2013 | + batch = host_to_global_device_array(local_batch, partition=logical_sharding) |
| 2014 | + |
| 2015 | + # Check that sharding is as expected. |
| 2016 | + self.assertEqual(logical_sharding["a"], batch["a"].sharding.spec) |
| 2017 | + self.assertEqual(logical_sharding["b"], batch["b"]["nested_value"].sharding.spec) |
| 2018 | + |
| 2019 | + # Check that contents are as expected. |
| 2020 | + self.assertNestedEqual(expected_batch, replicate_to_local_data(batch)) |
| 2021 | + |
1973 | 2022 | # Test process_count // 1, process_count // 2, process_count // 4.
|
1974 | 2023 | # On v5e-16, this exercises 4, 2, and 1 reading hosts out of 4.
|
1975 | 2024 | @parameterized.parameters(1, 2, 4)
|
@@ -2183,5 +2232,36 @@ class ConfigChild(ConfigParent):
|
2183 | 2232 | self.assertSameElements(("child_field1", "child_field2"), own_fields(ConfigChild()))
|
2184 | 2233 |
|
2185 | 2234 |
|
| 2235 | +class DataPartitionTypeToSpecTest(TestCase): |
| 2236 | + @mock.patch("axlearn.common.utils.input_partition_spec") |
| 2237 | + def test_full_partition(self, mock_input_partition_spec): |
| 2238 | + # Mocks input_partition_spec to return a predictable value |
| 2239 | + mock_input_partition_spec.return_value = PartitionSpec("full_spec") |
| 2240 | + result = data_partition_type_to_spec(DataPartitionType.FULL) |
| 2241 | + self.assertEqual(result, PartitionSpec("full_spec")) |
| 2242 | + mock_input_partition_spec.assert_called_once() |
| 2243 | + |
| 2244 | + def test_replicated_partition(self): |
| 2245 | + result = data_partition_type_to_spec(DataPartitionType.REPLICATED) |
| 2246 | + self.assertEqual(result, PartitionSpec(None)) |
| 2247 | + |
| 2248 | + def test_partition_spec_input(self): |
| 2249 | + custom_spec = PartitionSpec((("data", 0), ("model", 1))) |
| 2250 | + result = data_partition_type_to_spec(custom_spec) |
| 2251 | + self.assertEqual(result, custom_spec) |
| 2252 | + |
| 2253 | + def test_dict_input(self): |
| 2254 | + dict_spec = {"a": PartitionSpec("b"), "c": {"d": PartitionSpec("d")}} |
| 2255 | + result = data_partition_type_to_spec(dict_spec) |
| 2256 | + self.assertEqual(result, dict_spec) |
| 2257 | + |
| 2258 | + def test_unsupported_partition_type(self): |
| 2259 | + with self.assertRaisesRegex(NotImplementedError, "Unsupported partition: unsupported_type"): |
| 2260 | + data_partition_type_to_spec("unsupported_type") |
| 2261 | + |
| 2262 | + with self.assertRaisesRegex(NotImplementedError, "Unsupported partition: 123"): |
| 2263 | + data_partition_type_to_spec(123) |
| 2264 | + |
| 2265 | + |
2186 | 2266 | if __name__ == "__main__":
|
2187 | 2267 | absltest.main()
|
0 commit comments