@@ -484,45 +484,49 @@ def f(x: float):
484484 def test_modify_ir (self , pass_name , target , replacement ):
485485 """Turn a square function in IRs into a cubic one."""
486486
487- @qjit (keep_intermediate = True )
488487 def f (x ):
489488 """Square function."""
490489 return x ** 2
491490
491+ f .__name__ = f .__name__ + pass_name
492+
493+ jit_f = qjit (f , keep_intermediate = True )
492494 data = 2.0
493- old_result = f (data )
494- old_ir = get_compilation_stage (f , pass_name )
495- old_workspace = str (f .workspace )
495+ old_result = jit_f (data )
496+ old_ir = get_compilation_stage (jit_f , pass_name )
497+ old_workspace = str (jit_f .workspace )
496498
497499 new_ir = old_ir .replace (target , replacement )
498- replace_ir (f , pass_name , new_ir )
499- new_result = f (data )
500+ replace_ir (jit_f , pass_name , new_ir )
501+ new_result = jit_f (data )
500502
501503 shutil .rmtree (old_workspace , ignore_errors = True )
502- shutil .rmtree (str (f .workspace ), ignore_errors = True )
504+ shutil .rmtree (str (jit_f .workspace ), ignore_errors = True )
503505 assert old_result * data == new_result
504506
505507 @pytest .mark .parametrize ("pass_name" , ["HLOLoweringPass" , "O2Opt" , "Enzyme" ])
506508 def test_modify_ir_file_generation (self , pass_name ):
507509 """Test if recompilation rerun the same pass."""
508510
509- @qjit
510511 def f (x : float ):
511512 """Square function."""
512513 return x ** 2
513514
514- grad_f = qjit (value_and_grad (f ), keep_intermediate = True )
515- grad_f (3.0 )
516- ir = get_compilation_stage (grad_f , pass_name )
517- old_workspace = str (grad_f .workspace )
515+ f .__name__ = f .__name__ + pass_name
516+
517+ jit_f = qjit (f )
518+ jit_grad_f = qjit (value_and_grad (jit_f ), keep_intermediate = True )
519+ jit_grad_f (3.0 )
520+ ir = get_compilation_stage (jit_grad_f , pass_name )
521+ old_workspace = str (jit_grad_f .workspace )
518522
519- replace_ir (grad_f , pass_name , ir )
520- grad_f (3.0 )
521- file_list = os .listdir (str (grad_f .workspace ))
523+ replace_ir (jit_grad_f , pass_name , ir )
524+ jit_grad_f (3.0 )
525+ file_list = os .listdir (str (jit_grad_f .workspace ))
522526 res = [i for i in file_list if pass_name in i ]
523527
524528 shutil .rmtree (old_workspace , ignore_errors = True )
525- shutil .rmtree (str (grad_f .workspace ), ignore_errors = True )
529+ shutil .rmtree (str (jit_grad_f .workspace ), ignore_errors = True )
526530 assert len (res ) == 0
527531
528532 def test_get_compilation_stage_without_keep_intermediate (self ):
0 commit comments