Skip to content

Commit 92e18e6

Browse files
[AutoPGLE] Fix pgle test after removing pjit cache.
PiperOrigin-RevId: 700359385
1 parent dc11d40 commit 92e18e6

File tree

1 file changed

+102
-87
lines changed

1 file changed

+102
-87
lines changed

tests/pgle_test.py

Lines changed: 102 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -112,47 +112,73 @@ def f(x):
112112
fdo_profile = pgle_profiler.consume_fdo_profile()
113113
self.assertEqual(fdo_profile.count(b'custom'), its)
114114

115+
def get_fdo_profiles(self, dump_dir):
116+
jit_f_fdo_profiles = [
117+
x
118+
for x in os.listdir(dump_dir)
119+
if 'jit_f' in x and x.endswith('.fdo_profile')
120+
]
121+
return jit_f_fdo_profiles
122+
115123
def testAutoPgle(self):
116124
mesh = jtu.create_mesh((2,), ('x',))
117125

118-
@partial(
119-
jax.jit,
120-
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
121-
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
122-
compiler_options={
123-
'xla_gpu_enable_latency_hiding_scheduler': 'True',
124-
# TODO(patrios): Remove this flag once b/376647494 is fixed.
125-
'xla_gpu_graph_min_graph_size': '100000',
126-
},
127-
)
128-
def f(x):
129-
return x * 2
130-
131-
shape = (16, 16)
132-
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)
133-
expected = x * 2
126+
with tempfile.TemporaryDirectory() as dump_dir:
127+
@partial(
128+
jax.jit,
129+
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
130+
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
131+
compiler_options={
132+
'xla_gpu_enable_latency_hiding_scheduler': 'True',
133+
# TODO(patrios): Remove this flag once b/376647494 is fixed.
134+
'xla_gpu_graph_min_graph_size': '100000',
135+
'xla_dump_to': dump_dir,
136+
'xla_gpu_experimental_dump_fdo_profiles': 'True'
137+
},
138+
)
139+
def f(x):
140+
return x * 2
134141

135-
with config.pgle_profiling_runs(2), config.enable_pgle(True):
136-
# Run 1: Module should be compiled without FDO. Two modules are expected
137-
# One is the funtion f, the other one is multi slice module
138-
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
139-
self.assertArraysEqual(f(x), expected)
140-
self.assertEqual(cache_miss_count[0], 2)
142+
shape = (16, 16)
143+
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)
144+
expected = x * 2
141145

142-
# Run 2: Second PGLE run should not recompile the module
143-
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
144-
self.assertArraysEqual(f(x), expected)
145-
self.assertLess(cache_miss_count[0], 2)
146+
with config.pgle_profiling_runs(2), config.enable_pgle(True):
147+
# Run 1: Module should be compiled without FDO. Two modules are expected
148+
# One is the funtion f, the other one is multi slice module
149+
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
150+
self.assertArraysEqual(f(x), expected)
151+
self.assertEqual(cache_miss_count[0], 2)
146152

147-
# Run 3: The module should be recompiled with FDO profiles
148-
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
149-
self.assertArraysEqual(f(x), expected)
150-
self.assertEqual(cache_miss_count[0], 2)
153+
# Run 2: Second PGLE run. Profile should be empty.
154+
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
155+
self.assertArraysEqual(f(x), expected)
156+
self.assertEqual(cache_miss_count[0], 2)
157+
fdo_profiles_before_pgle = self.get_fdo_profiles(dump_dir)
158+
# One for before and one for after optimization.
159+
self.assertLen(fdo_profiles_before_pgle, 2)
160+
# The FDO profile file should be empty.
161+
self.assertEqual(
162+
os.path.getsize(os.path.join(dump_dir, fdo_profiles_before_pgle[0])), 0)
163+
164+
# Run 3: The module should be recompiled with FDO profiles
165+
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
166+
self.assertArraysEqual(f(x), expected)
167+
self.assertEqual(cache_miss_count[0], 2)
168+
fdo_profiles_after_pgle = self.get_fdo_profiles(dump_dir)
169+
# One for before and one for after optimization.
170+
self.assertLen(fdo_profiles_after_pgle, 4)
171+
172+
for fdo_profile in fdo_profiles_after_pgle:
173+
if fdo_profile not in fdo_profiles_before_pgle:
174+
self.assertGreater(
175+
os.path.getsize(os.path.join(dump_dir, fdo_profile)), 0
176+
)
151177

152-
# Run 4: Fast-path should be used after PGLE is done
153-
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
154-
self.assertArraysEqual(f(x), expected)
155-
self.assertLess(cache_miss_count[0], 2)
178+
# Run 4: Fast-path should be used after PGLE is done
179+
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
180+
self.assertArraysEqual(f(x), expected)
181+
self.assertLess(cache_miss_count[0], 2)
156182

157183
def testAutoPgleWithAot(self):
158184
@jax.jit
@@ -225,38 +251,27 @@ def f(x):
225251
# Run 2: Compilation should not be called
226252
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
227253
f(x)
228-
self.assertLess(cache_miss_count[0], 2)
254+
self.assertGreater(cache_miss_count[0], 0)
229255

230-
module_before_pgle = os.listdir(dump_dir)
231-
self.assertNotEmpty(module_before_pgle)
256+
fdo_profiles_before_pgle = self.get_fdo_profiles(dump_dir)
232257
# Run 3: Module should be compiled with FDO and stored to persistent cache
233258
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
234259
f(x)
235260
self.assertGreater(cache_miss_count[0], 0)
236261

237262
# Check if FDO profile file of the biggest module is not empty
238-
module_after_pgle = [
263+
fdo_profiles_after_pgle = [
239264
x
240-
for x in os.listdir(dump_dir)
241-
if x not in module_before_pgle
265+
for x in self.get_fdo_profiles(dump_dir)
266+
if x not in fdo_profiles_before_pgle
242267
]
243-
self.assertNotEmpty(module_after_pgle)
244-
biggest_module_after_pgle = max(
245-
module_after_pgle,
246-
key=lambda x: os.path.getsize(
247-
os.path.join(dump_dir, x)
248-
),
249-
)
250-
base_module_name = '.'.join(biggest_module_after_pgle.split('.')[0:1])
268+
self.assertNotEmpty(fdo_profiles_after_pgle)
251269

252270
# Check if FDO profile file in dump directory is not empty
253-
for module in module_after_pgle:
254-
if module.startswith(base_module_name) and module.endswith(
255-
'.fdo_profile'
256-
):
257-
self.assertGreater(
258-
os.path.getsize(os.path.join(dump_dir, module)), 0
259-
)
271+
for fdo_profile in fdo_profiles_after_pgle:
272+
self.assertGreater(
273+
os.path.getsize(os.path.join(dump_dir, fdo_profile)), 0
274+
)
260275

261276
for pgle_profiler in pjit._pgle_profiler_dict.values():
262277
self.assertTrue(pgle_profiler.is_enabled())
@@ -293,42 +308,42 @@ def check_if_cache_hit(event):
293308

294309
self.assertGreater(cache_hit, 0)
295310

296-
def testPassingFDOProfile(self):
297-
mesh = jtu.create_mesh((2,), ('x',))
311+
def testPassingFDOProfile(self):
312+
mesh = jtu.create_mesh((2,), ('x',))
298313

299-
@partial(
300-
jax.jit,
301-
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
302-
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
303-
compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'},
304-
)
305-
def f(x, y):
306-
return x @ y
314+
@partial(
315+
jax.jit,
316+
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
317+
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
318+
compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'},
319+
)
320+
def f(x, y):
321+
return x @ y
307322

308-
shape = (16, 16)
309-
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)
310-
y = x + 1
323+
shape = (16, 16)
324+
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)
325+
y = x + 1
311326

312-
with config.pgle_profiling_runs(0):
313-
f_lowered = f.lower(x, y)
314-
compiled = f_lowered.compile()
327+
with config.pgle_profiling_runs(0):
328+
f_lowered = f.lower(x, y)
329+
compiled = f_lowered.compile()
315330

316-
with tempfile.TemporaryDirectory() as cache_dir:
317-
jax.profiler.start_trace(cache_dir)
318-
compiled(x, y)
319-
jax.profiler.stop_trace()
320-
directories = glob.glob(os.path.join(cache_dir, 'plugins/profile/**/'))
321-
directories = [d for d in directories if os.path.isdir(d)]
322-
rundir = directories[-1]
323-
logging.info('rundir: %s', rundir)
324-
fdo_profile = exp_profiler.get_profiled_instructions_proto(rundir)
325-
326-
if jtu.test_device_matches(['gpu']) and jtu.is_device_cuda():
327-
self.assertIn(b'custom', fdo_profile)
328-
329-
logging.info('fdo_profile: %s', fdo_profile)
330-
# Test pass fdo_profile as compiler_options API works.
331-
f_lowered.compile(compiler_options={'fdo_profile': fdo_profile})
331+
with tempfile.TemporaryDirectory() as cache_dir:
332+
jax.profiler.start_trace(cache_dir)
333+
compiled(x, y)
334+
jax.profiler.stop_trace()
335+
directories = glob.glob(os.path.join(cache_dir, 'plugins/profile/**/'))
336+
directories = [d for d in directories if os.path.isdir(d)]
337+
rundir = directories[-1]
338+
logging.info('rundir: %s', rundir)
339+
fdo_profile = exp_profiler.get_profiled_instructions_proto(rundir)
340+
341+
if jtu.test_device_matches(['gpu']) and jtu.is_device_cuda():
342+
self.assertIn(b'custom', fdo_profile)
343+
344+
logging.info('fdo_profile: %s', fdo_profile)
345+
# Test pass fdo_profile as compiler_options API works.
346+
f_lowered.compile(compiler_options={'fdo_profile': fdo_profile})
332347

333348

334349
if __name__ == '__main__':

0 commit comments

Comments
 (0)