Skip to content

Commit 62166c5

Browse files
Tzung-Han Juangrmoyard
andauthored
Fix enzyme skipping if input ir is optimized by O2Opt (#1024)
**Context:** Run enzyme pass if checkpoint stage is set to `O2Opt`. Also rename the functions in the related tests to avoid race condition. --------- Co-authored-by: Romain Moyard <[email protected]>
1 parent 0422dca commit 62166c5

File tree

2 files changed

+23
-17
lines changed

2 files changed

+23
-17
lines changed

frontend/test/pytest/test_debug.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -484,45 +484,49 @@ def f(x: float):
484484
def test_modify_ir(self, pass_name, target, replacement):
485485
"""Turn a square function in IRs into a cubic one."""
486486

487-
@qjit(keep_intermediate=True)
488487
def f(x):
489488
"""Square function."""
490489
return x**2
491490

491+
f.__name__ = f.__name__ + pass_name
492+
493+
jit_f = qjit(f, keep_intermediate=True)
492494
data = 2.0
493-
old_result = f(data)
494-
old_ir = get_compilation_stage(f, pass_name)
495-
old_workspace = str(f.workspace)
495+
old_result = jit_f(data)
496+
old_ir = get_compilation_stage(jit_f, pass_name)
497+
old_workspace = str(jit_f.workspace)
496498

497499
new_ir = old_ir.replace(target, replacement)
498-
replace_ir(f, pass_name, new_ir)
499-
new_result = f(data)
500+
replace_ir(jit_f, pass_name, new_ir)
501+
new_result = jit_f(data)
500502

501503
shutil.rmtree(old_workspace, ignore_errors=True)
502-
shutil.rmtree(str(f.workspace), ignore_errors=True)
504+
shutil.rmtree(str(jit_f.workspace), ignore_errors=True)
503505
assert old_result * data == new_result
504506

505507
@pytest.mark.parametrize("pass_name", ["HLOLoweringPass", "O2Opt", "Enzyme"])
506508
def test_modify_ir_file_generation(self, pass_name):
507509
"""Test if recompilation rerun the same pass."""
508510

509-
@qjit
510511
def f(x: float):
511512
"""Square function."""
512513
return x**2
513514

514-
grad_f = qjit(value_and_grad(f), keep_intermediate=True)
515-
grad_f(3.0)
516-
ir = get_compilation_stage(grad_f, pass_name)
517-
old_workspace = str(grad_f.workspace)
515+
f.__name__ = f.__name__ + pass_name
516+
517+
jit_f = qjit(f)
518+
jit_grad_f = qjit(value_and_grad(jit_f), keep_intermediate=True)
519+
jit_grad_f(3.0)
520+
ir = get_compilation_stage(jit_grad_f, pass_name)
521+
old_workspace = str(jit_grad_f.workspace)
518522

519-
replace_ir(grad_f, pass_name, ir)
520-
grad_f(3.0)
521-
file_list = os.listdir(str(grad_f.workspace))
523+
replace_ir(jit_grad_f, pass_name, ir)
524+
jit_grad_f(3.0)
525+
file_list = os.listdir(str(jit_grad_f.workspace))
522526
res = [i for i in file_list if pass_name in i]
523527

524528
shutil.rmtree(old_workspace, ignore_errors=True)
525-
shutil.rmtree(str(grad_f.workspace), ignore_errors=True)
529+
shutil.rmtree(str(jit_grad_f.workspace), ignore_errors=True)
526530
assert len(res) == 0
527531

528532
def test_get_compilation_stage_without_keep_intermediate(self):

mlir/lib/Driver/CompilerDriver.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,9 @@ LogicalResult QuantumDriverMain(const CompilerOptions &options, CompilerOutput &
648648
catalyst::utils::LinesCount::ModuleOp(*op);
649649
output.isCheckpointFound = options.checkpointStage == "mlir";
650650

651-
bool enzymeRun = false;
651+
// Enzyme always happens after O2Opt. If the checkpoint is O2Opt, enzymeRun must be set to
652+
// true so that the enzyme pass can be executed.
653+
bool enzymeRun = options.checkpointStage == "O2Opt";
652654
if (op) {
653655
enzymeRun = containsGradients(*op);
654656
if (failed(runLowering(options, &ctx, *op, output))) {

0 commit comments

Comments
 (0)