@@ -112,47 +112,73 @@ def f(x):
112112 fdo_profile = pgle_profiler .consume_fdo_profile ()
113113 self .assertEqual (fdo_profile .count (b'custom' ), its )
114114
115+ def get_fdo_profiles (self , dump_dir ):
116+ jit_f_fdo_profiles = [
117+ x
118+ for x in os .listdir (dump_dir )
119+ if 'jit_f' in x and x .endswith ('.fdo_profile' )
120+ ]
121+ return jit_f_fdo_profiles
122+
115123 def testAutoPgle (self ):
116124 mesh = jtu .create_mesh ((2 ,), ('x' ,))
117125
118- @partial (
119- jax .jit ,
120- in_shardings = NamedSharding (mesh , PartitionSpec ('x' )),
121- out_shardings = NamedSharding (mesh , PartitionSpec ('x' )),
122- compiler_options = {
123- 'xla_gpu_enable_latency_hiding_scheduler' : 'True' ,
124- # TODO(patrios): Remove this flag once b/376647494 is fixed.
125- 'xla_gpu_graph_min_graph_size' : '100000' ,
126- },
127- )
128- def f (x ):
129- return x * 2
130-
131- shape = (16 , 16 )
132- x = jnp .arange (math .prod (shape )).reshape (shape ).astype (np .float32 )
133- expected = x * 2
126+ with tempfile .TemporaryDirectory () as dump_dir :
127+ @partial (
128+ jax .jit ,
129+ in_shardings = NamedSharding (mesh , PartitionSpec ('x' )),
130+ out_shardings = NamedSharding (mesh , PartitionSpec ('x' )),
131+ compiler_options = {
132+ 'xla_gpu_enable_latency_hiding_scheduler' : 'True' ,
133+ # TODO(patrios): Remove this flag once b/376647494 is fixed.
134+ 'xla_gpu_graph_min_graph_size' : '100000' ,
135+ 'xla_dump_to' : dump_dir ,
136+ 'xla_gpu_experimental_dump_fdo_profiles' : 'True'
137+ },
138+ )
139+ def f (x ):
140+ return x * 2
134141
135- with config .pgle_profiling_runs (2 ), config .enable_pgle (True ):
136- # Run 1: Module should be compiled without FDO. Two modules are expected
137- # One is the funtion f, the other one is multi slice module
138- with jtu .count_cached_compilation_cache_miss () as cache_miss_count :
139- self .assertArraysEqual (f (x ), expected )
140- self .assertEqual (cache_miss_count [0 ], 2 )
142+ shape = (16 , 16 )
143+ x = jnp .arange (math .prod (shape )).reshape (shape ).astype (np .float32 )
144+ expected = x * 2
141145
142- # Run 2: Second PGLE run should not recompile the module
143- with jtu .count_cached_compilation_cache_miss () as cache_miss_count :
144- self .assertArraysEqual (f (x ), expected )
145- self .assertLess (cache_miss_count [0 ], 2 )
146+ with config .pgle_profiling_runs (2 ), config .enable_pgle (True ):
147+ # Run 1: Module should be compiled without FDO. Two modules are expected
148+ # One is the funtion f, the other one is multi slice module
149+ with jtu .count_cached_compilation_cache_miss () as cache_miss_count :
150+ self .assertArraysEqual (f (x ), expected )
151+ self .assertEqual (cache_miss_count [0 ], 2 )
146152
147- # Run 3: The module should be recompiled with FDO profiles
148- with jtu .count_cached_compilation_cache_miss () as cache_miss_count :
149- self .assertArraysEqual (f (x ), expected )
150- self .assertEqual (cache_miss_count [0 ], 2 )
153+ # Run 2: Second PGLE run. Profile should be empty.
154+ with jtu .count_cached_compilation_cache_miss () as cache_miss_count :
155+ self .assertArraysEqual (f (x ), expected )
156+ self .assertEqual (cache_miss_count [0 ], 2 )
157+ fdo_profiles_before_pgle = self .get_fdo_profiles (dump_dir )
158+ # One for before and one for after optimization.
159+ self .assertLen (fdo_profiles_before_pgle , 2 )
160+ # The FDO profile file should be empty.
161+ self .assertEqual (
162+ os .path .getsize (os .path .join (dump_dir , fdo_profiles_before_pgle [0 ])), 0 )
163+
164+ # Run 3: The module should be recompiled with FDO profiles
165+ with jtu .count_cached_compilation_cache_miss () as cache_miss_count :
166+ self .assertArraysEqual (f (x ), expected )
167+ self .assertEqual (cache_miss_count [0 ], 2 )
168+ fdo_profiles_after_pgle = self .get_fdo_profiles (dump_dir )
169+ # One for before and one for after optimization.
170+ self .assertLen (fdo_profiles_after_pgle , 4 )
171+
172+ for fdo_profile in fdo_profiles_after_pgle :
173+ if fdo_profile not in fdo_profiles_before_pgle :
174+ self .assertGreater (
175+ os .path .getsize (os .path .join (dump_dir , fdo_profile )), 0
176+ )
151177
152- # Run 4: Fast-path should be used after PGLE is done
153- with jtu .count_cached_compilation_cache_miss () as cache_miss_count :
154- self .assertArraysEqual (f (x ), expected )
155- self .assertLess (cache_miss_count [0 ], 2 )
178+ # Run 4: Fast-path should be used after PGLE is done
179+ with jtu .count_cached_compilation_cache_miss () as cache_miss_count :
180+ self .assertArraysEqual (f (x ), expected )
181+ self .assertLess (cache_miss_count [0 ], 2 )
156182
157183 def testAutoPgleWithAot (self ):
158184 @jax .jit
@@ -225,38 +251,27 @@ def f(x):
225251 # Run 2: Compilation should not be called
226252 with jtu .count_cached_compilation_cache_miss () as cache_miss_count :
227253 f (x )
228- self .assertLess (cache_miss_count [0 ], 2 )
254+ self .assertGreater (cache_miss_count [0 ], 0 )
229255
230- module_before_pgle = os .listdir (dump_dir )
231- self .assertNotEmpty (module_before_pgle )
256+ fdo_profiles_before_pgle = self .get_fdo_profiles (dump_dir )
232257 # Run 3: Module should be compiled with FDO and stored to persistent cache
233258 with jtu .count_cached_compilation_cache_miss () as cache_miss_count :
234259 f (x )
235260 self .assertGreater (cache_miss_count [0 ], 0 )
236261
237262 # Check if FDO profile file of the biggest module is not empty
238- module_after_pgle = [
263+ fdo_profiles_after_pgle = [
239264 x
240- for x in os . listdir (dump_dir )
241- if x not in module_before_pgle
265+ for x in self . get_fdo_profiles (dump_dir )
266+ if x not in fdo_profiles_before_pgle
242267 ]
243- self .assertNotEmpty (module_after_pgle )
244- biggest_module_after_pgle = max (
245- module_after_pgle ,
246- key = lambda x : os .path .getsize (
247- os .path .join (dump_dir , x )
248- ),
249- )
250- base_module_name = '.' .join (biggest_module_after_pgle .split ('.' )[0 :1 ])
268+ self .assertNotEmpty (fdo_profiles_after_pgle )
251269
252270 # Check if FDO profile file in dump directory is not empty
253- for module in module_after_pgle :
254- if module .startswith (base_module_name ) and module .endswith (
255- '.fdo_profile'
256- ):
257- self .assertGreater (
258- os .path .getsize (os .path .join (dump_dir , module )), 0
259- )
271+ for fdo_profile in fdo_profiles_after_pgle :
272+ self .assertGreater (
273+ os .path .getsize (os .path .join (dump_dir , fdo_profile )), 0
274+ )
260275
261276 for pgle_profiler in pjit ._pgle_profiler_dict .values ():
262277 self .assertTrue (pgle_profiler .is_enabled ())
@@ -293,42 +308,42 @@ def check_if_cache_hit(event):
293308
294309 self .assertGreater (cache_hit , 0 )
295310
296- def testPassingFDOProfile (self ):
297- mesh = jtu .create_mesh ((2 ,), ('x' ,))
311+ def testPassingFDOProfile (self ):
312+ mesh = jtu .create_mesh ((2 ,), ('x' ,))
298313
299- @partial (
300- jax .jit ,
301- in_shardings = NamedSharding (mesh , PartitionSpec ('x' )),
302- out_shardings = NamedSharding (mesh , PartitionSpec ('x' )),
303- compiler_options = {'xla_gpu_enable_latency_hiding_scheduler' : 'True' },
304- )
305- def f (x , y ):
306- return x @ y
314+ @partial (
315+ jax .jit ,
316+ in_shardings = NamedSharding (mesh , PartitionSpec ('x' )),
317+ out_shardings = NamedSharding (mesh , PartitionSpec ('x' )),
318+ compiler_options = {'xla_gpu_enable_latency_hiding_scheduler' : 'True' },
319+ )
320+ def f (x , y ):
321+ return x @ y
307322
308- shape = (16 , 16 )
309- x = jnp .arange (math .prod (shape )).reshape (shape ).astype (np .float32 )
310- y = x + 1
323+ shape = (16 , 16 )
324+ x = jnp .arange (math .prod (shape )).reshape (shape ).astype (np .float32 )
325+ y = x + 1
311326
312- with config .pgle_profiling_runs (0 ):
313- f_lowered = f .lower (x , y )
314- compiled = f_lowered .compile ()
327+ with config .pgle_profiling_runs (0 ):
328+ f_lowered = f .lower (x , y )
329+ compiled = f_lowered .compile ()
315330
316- with tempfile .TemporaryDirectory () as cache_dir :
317- jax .profiler .start_trace (cache_dir )
318- compiled (x , y )
319- jax .profiler .stop_trace ()
320- directories = glob .glob (os .path .join (cache_dir , 'plugins/profile/**/' ))
321- directories = [d for d in directories if os .path .isdir (d )]
322- rundir = directories [- 1 ]
323- logging .info ('rundir: %s' , rundir )
324- fdo_profile = exp_profiler .get_profiled_instructions_proto (rundir )
325-
326- if jtu .test_device_matches (['gpu' ]) and jtu .is_device_cuda ():
327- self .assertIn (b'custom' , fdo_profile )
328-
329- logging .info ('fdo_profile: %s' , fdo_profile )
330- # Test pass fdo_profile as compiler_options API works.
331- f_lowered .compile (compiler_options = {'fdo_profile' : fdo_profile })
331+ with tempfile .TemporaryDirectory () as cache_dir :
332+ jax .profiler .start_trace (cache_dir )
333+ compiled (x , y )
334+ jax .profiler .stop_trace ()
335+ directories = glob .glob (os .path .join (cache_dir , 'plugins/profile/**/' ))
336+ directories = [d for d in directories if os .path .isdir (d )]
337+ rundir = directories [- 1 ]
338+ logging .info ('rundir: %s' , rundir )
339+ fdo_profile = exp_profiler .get_profiled_instructions_proto (rundir )
340+
341+ if jtu .test_device_matches (['gpu' ]) and jtu .is_device_cuda ():
342+ self .assertIn (b'custom' , fdo_profile )
343+
344+ logging .info ('fdo_profile: %s' , fdo_profile )
345+ # Test pass fdo_profile as compiler_options API works.
346+ f_lowered .compile (compiler_options = {'fdo_profile' : fdo_profile })
332347
333348
334349if __name__ == '__main__' :
0 commit comments