|
20 | 20 | from absl.testing import parameterized |
21 | 21 | from tunix.perf.experimental import export |
22 | 22 | from tunix.perf.experimental import tracer |
| 23 | +import jax |
| 24 | +from tunix.rl import rl_cluster |
23 | 25 |
|
24 | 26 |
|
25 | 27 | class ExportTest(parameterized.TestCase): |
@@ -63,19 +65,24 @@ def test_basic_metrics_export(self): |
63 | 65 | ), |
64 | 66 | ) |
65 | 67 | @mock.patch.object( |
66 | | - export.trace_writer_lib, "PerfettoTraceWriter", autospec=True |
| 68 | + export.trace_writer_lib, |
| 69 | + "PerfettoTraceWriter", |
| 70 | + autospec=True, |
| 71 | + spec_set=True, |
67 | 72 | ) |
68 | 73 | def test_perf_metrics_export_initialization_with_trace_writer_enabled( |
69 | 74 | self, mock_writer_cls, trace_dir, expected_dir |
70 | 75 | ): |
71 | 76 | with export.PerfMetricsExport( |
72 | 77 | enable_trace_writer=True, trace_dir=trace_dir |
73 | 78 | ) as exporter: |
74 | | - mock_writer_cls.assert_called_once_with(expected_dir) |
| 79 | + mock_writer_cls.assert_called_once_with(expected_dir, role_to_devices=None) |
75 | 80 | # export_metrics shouldn't crash |
76 | 81 | exporter.export_metrics({}) |
77 | 82 |
|
78 | | - @mock.patch.object(export.trace_writer_lib, "NoopTraceWriter", autospec=True) |
| 83 | + @mock.patch.object( |
| 84 | + export.trace_writer_lib, "NoopTraceWriter", autospec=True, spec_set=True |
| 85 | + ) |
79 | 86 | def test_perf_metrics_export_initialization_with_trace_writer_disabled( |
80 | 87 | self, mock_noop_cls |
81 | 88 | ): |
@@ -109,6 +116,62 @@ def test_perf_metrics_export_shutdown_can_be_called_manually( |
109 | 116 | mock_executor_instance.shutdown.assert_called_once_with(wait=False) |
110 | 117 | self.assertIsNone(exporter._executor) |
111 | 118 |
|
| 119 | + @mock.patch.object( |
| 120 | + export.trace_writer_lib, |
| 121 | + "PerfettoTraceWriter", |
| 122 | + autospec=True, |
| 123 | + ) |
| 124 | + def test_from_cluster_config(self, mock_writer_cls): |
| 125 | + import numpy as np |
| 126 | + mock_mesh_1 = mock.create_autospec(jax.sharding.Mesh, instance=True) |
| 127 | + mock_mesh_1.devices = np.array([["tpu0", "tpu1"], ["tpu2", "tpu3"]]) |
| 128 | + |
| 129 | + mock_mesh_2 = mock.create_autospec(jax.sharding.Mesh, instance=True) |
| 130 | + mock_mesh_2.devices = np.array([["tpu4", "tpu5"], ["tpu6", "tpu7"]]) |
| 131 | + |
| 132 | + mock_cluster_config = mock.create_autospec( |
| 133 | + rl_cluster.ClusterConfig, instance=True |
| 134 | + ) |
| 135 | + mock_cluster_config.role_to_mesh = { |
| 136 | + rl_cluster.Role.ACTOR: mock_mesh_1, |
| 137 | + rl_cluster.Role.ROLLOUT: mock_mesh_2, |
| 138 | + } |
| 139 | + |
| 140 | + exporter = export.PerfMetricsExport.from_cluster_config( |
| 141 | + mock_cluster_config, |
| 142 | + enable_trace_writer=True, |
| 143 | + trace_dir="/test/dir", |
| 144 | + ) |
| 145 | + |
| 146 | + expected_role_to_devices = { |
| 147 | + "actor": ["tpu0", "tpu1", "tpu2", "tpu3"], |
| 148 | + "rollout": ["tpu4", "tpu5", "tpu6", "tpu7"], |
| 149 | + } |
| 150 | + mock_writer_cls.assert_called_once_with( |
| 151 | + "/test/dir", role_to_devices=expected_role_to_devices |
| 152 | + ) |
| 153 | + self.assertIs(exporter._writer, mock_writer_cls.return_value) |
| 154 | + |
| 155 | + @mock.patch.object( |
| 156 | + export.trace_writer_lib, |
| 157 | + "PerfettoTraceWriter", |
| 158 | + autospec=True, |
| 159 | + spec_set=True, |
| 160 | + ) |
| 161 | + def test_from_cluster_config_no_role_to_mesh(self, mock_writer_cls): |
| 162 | + mock_cluster_config = mock.create_autospec( |
| 163 | + rl_cluster.ClusterConfig, instance=True, spec_set=True |
| 164 | + ) |
| 165 | + del mock_cluster_config.role_to_mesh |
| 166 | + |
| 167 | + exporter = export.PerfMetricsExport.from_cluster_config( |
| 168 | + mock_cluster_config, |
| 169 | + enable_trace_writer=True, |
| 170 | + trace_dir="/test/dir", |
| 171 | + ) |
| 172 | + |
| 173 | + mock_writer_cls.assert_called_once_with("/test/dir", role_to_devices={}) |
| 174 | + self.assertIs(exporter._writer, mock_writer_cls.return_value) |
112 | 175 |
|
113 | 176 | if __name__ == "__main__": |
114 | 177 | absltest.main() |
0 commit comments