diff --git a/pathwaysutils/profiling.py b/pathwaysutils/profiling.py index 1fc705f..60b6303 100644 --- a/pathwaysutils/profiling.py +++ b/pathwaysutils/profiling.py @@ -14,12 +14,12 @@ """Profiling utilites.""" import dataclasses +import json import logging import os -import pathlib -import tempfile import threading import time +from typing import Any import urllib.parse import fastapi @@ -45,6 +45,7 @@ def reset(self): self.executable = None +_first_profile_start = True _profile_state = _ProfileState() _original_start_trace = jax.profiler.start_trace _original_stop_trace = jax.profiler.stop_trace @@ -56,11 +57,44 @@ def toy_computation(): x.block_until_ready() -def start_trace(gcs_bucket: str): - """Starts a profiler trace.""" +def _create_profile_request( + log_dir: os.PathLike[str] | str, +) -> dict[str, Any]: + """Creates a profile request dictionary from the given options.""" + profile_request = {} + profile_request["traceLocation"] = str(log_dir) + + return profile_request + + +def _start_trace_from_profile_request( + profile_request: dict[str, Any], + *, + create_perfetto_link: bool = False, + create_perfetto_trace: bool = False, +) -> None: + """Starts a profiler trace from a profile request dictionary. + + Args: + profile_request: A dictionary containing the profile request options. + create_perfetto_link: A boolean which, if true, creates and prints link to + the Perfetto trace viewer UI (https://ui.perfetto.dev). The program will + block until the link is opened and Perfetto loads the trace. This feature + is experimental for Pathways on Cloud and may not be fully supported. + create_perfetto_trace: A boolean which, if true, additionally dumps a + ``perfetto_trace.json.gz`` file that is compatible for upload with the + Perfetto trace viewer UI (https://ui.perfetto.dev). The file will also be + generated if ``create_perfetto_link`` is true. This could be useful if you + want to generate a Perfetto-compatible trace without blocking the process. + This feature is experimental for Pathways on Cloud and may not be fully + supported. + """ + log_dir = profile_request["traceLocation"] + with _profile_state.lock: - if start_trace._first_profile_start: # pylint: disable=protected-access, attribute-error - start_trace._first_profile_start = False # pylint: disable=protected-access + global _first_profile_start + if _first_profile_start: + _first_profile_start = False toy_computation() if _profile_state.executable is not None: @@ -68,18 +102,70 @@ def start_trace(gcs_bucket: str): "start_trace called while a trace is already being taken!" ) _profile_state.executable = plugin_executable.PluginExecutable( - f"{{profileRequest: {{traceLocation: '{gcs_bucket}'}}}}" + json.dumps({"profileRequest": profile_request}) ) try: - _profile_state.executable.call()[1].result() - except: + _, result_future = _profile_state.executable.call() + result_future.result() + except Exception as e: # pylint: disable=broad-except + _logger.exception("Failed to start trace") _profile_state.reset() raise - _original_start_trace(gcs_bucket) + _original_start_trace( + log_dir=log_dir, + create_perfetto_link=create_perfetto_link, + create_perfetto_trace=create_perfetto_trace, + ) + + +def start_trace( + log_dir: os.PathLike[str] | str, + *, + create_perfetto_link: bool = False, + create_perfetto_trace: bool = False, +) -> None: + """Starts a profiler trace. + The trace will capture CPU and TPU activity, including Python + functions and JAX on-device operations. Use :func:`stop_trace` to end the + trace and save the results to ``log_dir``. -start_trace._first_profile_start = True # pylint: disable=protected-access + The resulting trace can be viewed with TensorBoard. Note that TensorBoard + doesn't need to be running when collecting the trace. + + Only one trace may be collected at a time. A RuntimeError will be raised if + :func:`start_trace` is called while another trace is running. + + Args: + log_dir: The GCS directory to save the profiler trace to (usually the + TensorBoard log directory), e.g., "gs://my_bucket/profiles". + create_perfetto_link: A boolean which, if true, creates and prints link to + the Perfetto trace viewer UI (https://ui.perfetto.dev). The program will + block until the link is opened and Perfetto loads the trace. This feature + is experimental for Pathways on Cloud and may not be fully supported. + create_perfetto_trace: A boolean which, if true, additionally dumps a + ``perfetto_trace.json.gz`` file that is compatible for upload with the + Perfetto trace viewer UI (https://ui.perfetto.dev). The file will also be + generated if ``create_perfetto_link`` is true. This could be useful if you + want to generate a Perfetto-compatible trace without blocking the process. + This feature is experimental for Pathways on Cloud and may not be fully + supported. + """ + if not str(log_dir).startswith("gs://"): + raise ValueError(f"log_dir must be a GCS bucket path, got {log_dir}") + + if create_perfetto_link or create_perfetto_trace: + _logger.warning( + "create_perfetto_link and create_perfetto_trace are experimental " + "features for Pathways on Cloud and may not be fully supported." + ) + + _start_trace_from_profile_request( + _create_profile_request(log_dir), + create_perfetto_link=create_perfetto_link, + create_perfetto_trace=create_perfetto_trace, + ) def stop_trace(): @@ -88,11 +174,10 @@ def stop_trace(): if _profile_state.executable is None: raise ValueError("stop_trace called before a trace is being taken!") try: - _profile_state.executable.call()[1].result() - except: + _, result_future = _profile_state.executable.call() + result_future.result() + finally: _profile_state.reset() - raise - _profile_state.reset() _original_stop_trace() @@ -151,7 +236,7 @@ def collect_profile( port: int, duration_ms: int, host: str, - log_dir: str, + log_dir: os.PathLike[str] | str, ) -> bool: """Collects a JAX profile and saves it to the specified directory. @@ -167,19 +252,19 @@ def collect_profile( Raises: ValueError: If the log_dir is not a GCS path. """ - if not log_dir.startswith("gs://"): - raise ValueError("log_dir must be a GCS path.") + if not str(log_dir).startswith("gs://"): + raise ValueError(f"log_dir must be a GCS bucket path, got {log_dir}") - json = { + request_json = { "duration_ms": duration_ms, "repository_path": log_dir, } address = urllib.parse.urljoin(f"http://{host}:{port}", "profiling") try: - response = requests.post(address, json=json) + response = requests.post(address, json=request_json) response.raise_for_status() - except requests.exceptions.RequestException as e: - _logger.error("Failed to collect profiling data: %s", e) + except requests.exceptions.RequestException: + _logger.exception("Failed to collect profiling data") return False return True diff --git a/pathwaysutils/test/profiling_test.py b/pathwaysutils/test/profiling_test.py index e7f3bba..d5a60db 100644 --- a/pathwaysutils/test/profiling_test.py +++ b/pathwaysutils/test/profiling_test.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import logging from unittest import mock +import jax from pathwaysutils import profiling import requests @@ -30,6 +32,27 @@ def setUp(self): self.mock_post = self.enter_context( mock.patch.object(requests, "post", autospec=True) ) + profiling._profile_state.reset() + profiling._first_profile_start = True + profiling._profiler_thread = None + self.mock_plugin_executable_cls = self.enter_context( + mock.patch.object( + profiling.plugin_executable, "PluginExecutable", autospec=True + ) + ) + self.mock_plugin_executable_cls.return_value.call.return_value = ( + mock.MagicMock(), + mock.MagicMock(), + ) + self.mock_toy_computation = self.enter_context( + mock.patch.object(profiling, "toy_computation", autospec=True) + ) + self.mock_original_start_trace = self.enter_context( + mock.patch.object(profiling, "_original_start_trace", autospec=True) + ) + self.mock_original_stop_trace = self.enter_context( + mock.patch.object(profiling, "_original_stop_trace", autospec=True) + ) @parameterized.parameters(8000, 1234) def test_collect_profile_port(self, port): @@ -130,9 +153,8 @@ def test_collect_profile_request_error(self, exception): ) self.assertLen(logs.output, 1) - self.assertIn( - f"Failed to collect profiling data: {exception}", logs.output[0] - ) + self.assertIn("Failed to collect profiling data", logs.output[0]) + self.assertIn(str(exception), logs.output[0]) self.assertFalse(result) self.mock_post.assert_called_once() @@ -152,6 +174,256 @@ def test_collect_profile_success(self): self.mock_post.assert_called_once() mock_response.raise_for_status.assert_called_once() + @parameterized.parameters( + "/logs/test_log_dir", + "relative_path/my_log_dir", + "cns://test_bucket/test_dir", + "not_a_gcs_path", + ) + def test_start_trace_log_dir_error(self, log_dir): + with self.assertRaisesRegex( + ValueError, "log_dir must be a GCS bucket path" + ): + profiling.start_trace(log_dir) + + def test_lock_released_on_success(self): + """Tests that the lock is released after successful start_trace and stop_trace.""" + profiling.start_trace("gs://test_bucket/test_dir") + self.assertFalse(profiling._profile_state.lock.locked()) + profiling.stop_trace() + self.assertFalse(profiling._profile_state.lock.locked()) + + def test_lock_released_on_start_failure(self): + """Tests that the lock is released if start_trace fails.""" + mock_result = ( + self.mock_plugin_executable_cls.return_value.call.return_value[1] + ) + mock_result.result.side_effect = RuntimeError("start failed") + with self.assertRaisesRegex(RuntimeError, "start failed"): + profiling.start_trace("gs://test_bucket/test_dir2") + self.assertFalse(profiling._profile_state.lock.locked()) + + def test_lock_released_on_stop_failure(self): + """Tests that the lock is released if stop_trace fails.""" + profiling.start_trace("gs://test_bucket/test_dir3") + self.assertFalse(profiling._profile_state.lock.locked()) + mock_result = ( + self.mock_plugin_executable_cls.return_value.call.return_value[1] + ) + mock_result.result.side_effect = RuntimeError("stop failed") + with self.assertRaisesRegex(RuntimeError, "stop failed"): + profiling.stop_trace() + self.assertFalse(profiling._profile_state.lock.locked()) + + def test_start_trace_success(self): + profiling.start_trace("gs://test_bucket/test_dir") + + self.mock_toy_computation.assert_called_once() + self.mock_plugin_executable_cls.assert_called_once_with( + json.dumps( + {"profileRequest": {"traceLocation": "gs://test_bucket/test_dir"}} + ) + ) + self.mock_plugin_executable_cls.return_value.call.assert_called_once() + self.mock_original_start_trace.assert_called_once_with( + log_dir="gs://test_bucket/test_dir", + create_perfetto_link=False, + create_perfetto_trace=False, + ) + self.assertIsNotNone(profiling._profile_state.executable) + + def test_start_trace_no_toy_computation_second_time(self): + profiling.start_trace("gs://test_bucket/test_dir") + profiling.stop_trace() + + self.mock_toy_computation.assert_called_once() + self.mock_original_start_trace.assert_called_once() + + # Reset mock and call again + self.mock_toy_computation.reset_mock() + self.mock_original_start_trace.reset_mock() + profiling.start_trace("gs://test_bucket/test_dir2") + + self.mock_toy_computation.assert_not_called() + self.mock_original_start_trace.assert_called_once() + + def test_start_trace_while_running_error(self): + profiling.start_trace("gs://test_bucket/test_dir") + with self.assertRaisesRegex(ValueError, "trace is already being taken"): + profiling.start_trace("gs://test_bucket/test_dir2") + + def test_stop_trace_success(self): + profiling.start_trace("gs://test_bucket/test_dir") + # call() is called once in start_trace, and once in stop_trace. + with self.subTest("call_in_start_trace"): + self.mock_plugin_executable_cls.return_value.call.assert_called_once() + + profiling.stop_trace() + + with self.subTest("call_count_after_stop_trace"): + self.assertEqual( + self.mock_plugin_executable_cls.return_value.call.call_count, 2 + ) + with self.subTest("original_stop_trace_called"): + self.mock_original_stop_trace.assert_called_once() + with self.subTest("executable_is_none"): + self.assertIsNone(profiling._profile_state.executable) + + def test_stop_trace_before_start_error(self): + with self.assertRaisesRegex( + ValueError, "stop_trace called before a trace is being taken!" + ): + profiling.stop_trace() + + def test_start_server_starts_thread(self): + mock_thread = self.enter_context( + mock.patch.object(profiling.threading, "Thread", autospec=True) + ) + profiling.start_server(9000) + mock_thread.assert_called_once_with(target=mock.ANY, args=(9000,)) + mock_thread.return_value.start.assert_called_once() + self.assertIsNotNone(profiling._profiler_thread) + + def test_start_server_twice_raises_error(self): + self.enter_context( + mock.patch.object(profiling.threading, "Thread", autospec=True) + ) + profiling.start_server(9000) + with self.assertRaisesRegex( + ValueError, "Only one profiler server can be active" + ): + profiling.start_server(9001) + + def test_stop_server_no_server_raises_error(self): + with self.assertRaisesRegex(ValueError, "No active profiler server"): + profiling.stop_server() + + def test_stop_server_does_nothing_if_server_exists(self): + self.enter_context( + mock.patch.object(profiling.threading, "Thread", autospec=True) + ) + profiling.start_server(9000) + profiling.stop_server() # Should not raise + + def test_monkey_patch_jax(self): + original_jax_start_trace = jax.profiler.start_trace + original_jax_stop_trace = jax.profiler.stop_trace + original_jax_start_server = jax.profiler.start_server + original_jax_stop_server = jax.profiler.stop_server + + profiling.monkey_patch_jax() + + self.assertNotEqual(jax.profiler.start_trace, original_jax_start_trace) + self.assertNotEqual(jax.profiler.stop_trace, original_jax_stop_trace) + self.assertNotEqual(jax.profiler.start_server, original_jax_start_server) + self.assertNotEqual(jax.profiler.stop_trace, original_jax_stop_trace) + + with mock.patch.object( + profiling, "start_trace", autospec=True + ) as mock_pw_start_trace: + jax.profiler.start_trace("gs://bucket/dir") + mock_pw_start_trace.assert_called_once_with("gs://bucket/dir") + + with mock.patch.object( + profiling, "stop_trace", autospec=True + ) as mock_pw_stop_trace: + jax.profiler.stop_trace() + mock_pw_stop_trace.assert_called_once() + + with mock.patch.object( + profiling, "start_server", autospec=True + ) as mock_pw_start_server: + jax.profiler.start_server(1234) + mock_pw_start_server.assert_called_once_with(1234) + + with mock.patch.object( + profiling, "stop_server", autospec=True + ) as mock_pw_stop_server: + jax.profiler.stop_server() + mock_pw_stop_server.assert_called_once() + + # Restore original jax functions + jax.profiler.start_trace = original_jax_start_trace + jax.profiler.stop_trace = original_jax_stop_trace + jax.profiler.start_server = original_jax_start_server + jax.profiler.stop_server = original_jax_stop_server + + def test_create_profile_request_no_options(self): + request = profiling._create_profile_request("gs://bucket/dir") + self.assertEqual(request, {"traceLocation": "gs://bucket/dir"}) + + @parameterized.parameters( + ({"traceLocation": "gs://test_bucket/test_dir"},), + ({ + "traceLocation": "gs://test_bucket/test_dir", + "blockUntilStart": True, + "maxDurationSecs": 10.0, + "devices": {"deviceIds": [1, 2]}, + "includeResourceManagers": True, + "maxNumHosts": 5, + "xprofTraceOptions": { + "blockUntilStart": True, + "traceDirectory": "gs://test_bucket/test_dir", + }, + },), + ({ + "traceLocation": "gs://bucket/dir", + "xprofTraceOptions": { + "hostTraceLevel": 0, + "traceOptions": { + "traceMode": "TRACE_COMPUTE", + "numSparseCoresToTrace": 1, + "numSparseCoreTilesToTrace": 2, + "numChipsToProfilePerTask": 3, + "powerTraceLevel": 4, + "enableFwThrottleEvent": True, + "enableFwPowerLevelEvent": True, + "enableFwThermalEvent": True, + }, + "traceDirectory": "gs://bucket/dir", + }, + },), + ) + def test_start_trace_from_request(self, profile_request): + profiling._start_trace_from_profile_request(profile_request) + + self.mock_toy_computation.assert_called_once() + self.mock_plugin_executable_cls.assert_called_once_with( + json.dumps({"profileRequest": profile_request}) + ) + self.mock_plugin_executable_cls.return_value.call.assert_called_once() + self.mock_original_start_trace.assert_called_once_with( + log_dir=profile_request["traceLocation"], + create_perfetto_link=False, + create_perfetto_trace=False, + ) + self.assertIsNotNone(profiling._profile_state.executable) + + @parameterized.product( + create_perfetto_link=[True, False], + create_perfetto_trace=[True, False], + ) + def test_start_trace_with_perfetto( + self, create_perfetto_link, create_perfetto_trace + ): + profile_request = {"traceLocation": "gs://test_bucket/test_dir"} + profiling._start_trace_from_profile_request( + profile_request, + create_perfetto_link=create_perfetto_link, + create_perfetto_trace=create_perfetto_trace, + ) + + self.mock_toy_computation.assert_called_once() + self.mock_plugin_executable_cls.assert_called_once_with( + json.dumps({"profileRequest": profile_request}) + ) + self.mock_plugin_executable_cls.return_value.call.assert_called_once() + self.mock_original_start_trace.assert_called_once_with( + log_dir="gs://test_bucket/test_dir", + create_perfetto_link=create_perfetto_link, + create_perfetto_trace=create_perfetto_trace, + ) + self.assertIsNotNone(profiling._profile_state.executable) if __name__ == "__main__": absltest.main()