@@ -395,18 +395,18 @@ def test_linking_patch(listdir_mock, platform):
395395 ]
396396
397397
398- def test_cache_race_condition ():
399- with tempfile .TemporaryDirectory () as dir_name :
398+ @config .change_flags (on_opt_error = "raise" , on_shape_error = "raise" )
399+ def _f_build_cache_race_condition (factor ):
400+ # Some of the caching issues arise during constant folding within the
401+ # optimization passes, so we need these config changes to prevent the
402+ # exceptions from being caught
403+ a = pt .vector ()
404+ f = pytensor .function ([a ], factor * a )
405+ return f (np .array ([1 ], dtype = config .floatX ))
400406
401- @config .change_flags (on_opt_error = "raise" , on_shape_error = "raise" )
402- def f_build (factor ):
403- # Some of the caching issues arise during constant folding within the
404- # optimization passes, so we need these config changes to prevent the
405- # exceptions from being caught
406- a = pt .vector ()
407- f = pytensor .function ([a ], factor * a )
408- return f (np .array ([1 ], dtype = config .floatX ))
409407
408+ def test_cache_race_condition ():
409+ with tempfile .TemporaryDirectory () as dir_name :
410410 ctx = multiprocessing .get_context ()
411411 compiledir_prop = pytensor .config ._config_var_dict ["compiledir" ]
412412
@@ -425,7 +425,7 @@ def f_build(factor):
425425 # A random, constant input to prevent caching between runs
426426 factor = rng .random ()
427427 procs = [
428- ctx .Process (target = f_build , args = (factor ,))
428+ ctx .Process (target = _f_build_cache_race_condition , args = (factor ,))
429429 for i in range (num_procs )
430430 ]
431431 for proc in procs :
0 commit comments