From 7fb622726e20c38ae6be43c256f85a8889212774 Mon Sep 17 00:00:00 2001 From: cj401-amd Date: Wed, 21 May 2025 16:42:20 +0000 Subject: [PATCH 1/4] add rocprofiler_configure as a global --- jaxlib/tools/gpu_version_script.lds | 1 + 1 file changed, 1 insertion(+) diff --git a/jaxlib/tools/gpu_version_script.lds b/jaxlib/tools/gpu_version_script.lds index 8e46b2c590b2..c40297736cda 100644 --- a/jaxlib/tools/gpu_version_script.lds +++ b/jaxlib/tools/gpu_version_script.lds @@ -4,6 +4,7 @@ VERS_1.0 { GetPjrtApi; MosaicGpuCompile; MosaicGpuUnload; + rocprofiler_configure; }; local: From 4a90d34294769e6868313fada126536b3a31e5a2 Mon Sep 17 00:00:00 2001 From: cj401-amd Date: Thu, 7 Aug 2025 08:59:05 -0500 Subject: [PATCH 2/4] uncomment test causes OOMs --- tests/profiler_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/profiler_test.py b/tests/profiler_test.py index 215e363e446d..c6c48d9e14a0 100644 --- a/tests/profiler_test.py +++ b/tests/profiler_test.py @@ -245,7 +245,7 @@ def _check_xspace_pb_exist(self, logdir): self.assertEqual(1, len(glob.glob(path)), 'Expected one path match: ' + path) - @unittest.skip("Test causes OOMs") + # @unittest.skip("Test causes OOMs") @unittest.skipIf(not (portpicker and profiler_client and tf_profiler), "Test requires tensorflow.profiler and portpicker") def testSingleWorkerSamplingMode(self, delay_ms=None): From ec929e0161b299de65812840b82293be621c769a Mon Sep 17 00:00:00 2001 From: cj401-amd Date: Thu, 7 Aug 2025 12:02:39 -0500 Subject: [PATCH 3/4] remove @unittest.skip('Test causes OOMs') as rocprofiler-sdk has no such issues anymore --- tests/profiler_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/profiler_test.py b/tests/profiler_test.py index c6c48d9e14a0..65bab7acf77a 100644 --- a/tests/profiler_test.py +++ b/tests/profiler_test.py @@ -244,8 +244,7 @@ def _check_xspace_pb_exist(self, logdir): path = os.path.join(logdir, 'plugins', 'profile', '*', '*.xplane.pb') self.assertEqual(1, len(glob.glob(path)), 'Expected one path match: ' + path) - - # @unittest.skip("Test causes OOMs") + @unittest.skipIf(not (portpicker and profiler_client and tf_profiler), "Test requires tensorflow.profiler and portpicker") def testSingleWorkerSamplingMode(self, delay_ms=None): From 0c17ba1dfb26b53b866a54f7178d542e77408e5c Mon Sep 17 00:00:00 2001 From: cj401-amd Date: Thu, 7 Aug 2025 15:21:20 -0500 Subject: [PATCH 4/4] give more appropriate warning message when using PGLE with rocprofiler-sdk --- jax/_src/profiler.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py index 912c90182977..f1016d262056 100644 --- a/jax/_src/profiler.py +++ b/jax/_src/profiler.py @@ -444,8 +444,7 @@ def trace(cls, runner: PGLEProfiler | None): if runner.fdo_profiles[-1] == b'': warnings.warn( "PGLE collected an empty trace, may be due to contention with " - "another tool that subscribes to CUPTI, such as Nsight Systems - check " - "for CUPTI_ERROR_MULTIPLE_SUBSCRIBERS_NOT_SUPPORTED from XLA. " + "another tool that subscribes to rocprofiler-sdk - check " "Consider populating a persistent compilation cache with PGLE enabled, " "and then profiling a second run that has the " "JAX_COMPILATION_CACHE_EXPECT_PGLE option enabled.",