From 379f45a0ea2042fc37cc8fccbd28c77e4af485c7 Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Tue, 18 Nov 2025 15:06:23 -0500 Subject: [PATCH 1/2] Entend gats definition and lowering in the core compiler --- frontend/catalyst/jax_primitives.py | 126 ++++++ mlir/include/Quantum/IR/QuantumOps.td | 261 +++++++++++ mlir/lib/Quantum/IR/QuantumOps.cpp | 42 ++ .../Transforms/ChainedSelfInversePatterns.cpp | 6 +- .../Quantum/Transforms/ConversionPatterns.cpp | 406 ++++++++++++++++++ 5 files changed, 840 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 5d11fbb4b2..2008bd0d50 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -96,6 +96,16 @@ ComputationalBasisOp, CountsOp, CustomOp, + PauliXOp, + PauliYOp, + PauliZOp, + HadamardOp, + SGateOp, + TGateOp, + CNOTOp, + RXOp, + RYOp, + RZOp, DeallocOp, DeallocQubitOp, DeviceInitOp, @@ -1310,6 +1320,122 @@ def _qinst_lowering( name_str = str(name_attr) name_str = name_str.replace('"', "") + if name_str == "PauliX": + assert len(float_params) == 0, "PauliX takes no float parameters" + return PauliXOp( + out_qubits=[qubit.type for qubit in qubits], + out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits], + in_qubits=qubits, + in_ctrl_qubits=ctrl_qubits, + in_ctrl_values=ctrl_values_i1, + adjoint=adjoint, + ).results + + if name_str == "PauliY": + assert len(float_params) == 0, "PauliY takes no float parameters" + return PauliYOp( + out_qubits=[qubit.type for qubit in qubits], + out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits], + in_qubits=qubits, + in_ctrl_qubits=ctrl_qubits, + in_ctrl_values=ctrl_values_i1, + adjoint=adjoint, + ).results + + if name_str == "PauliZ": + assert len(float_params) == 0, "PauliZ takes no float parameters" + return PauliZOp( + out_qubits=[qubit.type for qubit in qubits], + out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits], + in_qubits=qubits, + in_ctrl_qubits=ctrl_qubits, + in_ctrl_values=ctrl_values_i1, + adjoint=adjoint, + ).results + + if name_str == "Hadamard": + assert len(float_params) == 0, "Hadamard takes no float parameters" + return HadamardOp( + out_qubits=[qubit.type for qubit in qubits], + out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits], + in_qubits=qubits, + in_ctrl_qubits=ctrl_qubits, + in_ctrl_values=ctrl_values_i1, + adjoint=adjoint, + ).results + + if name_str == "S": + assert len(float_params) == 0, "S takes no float parameters" + return SGateOp( + out_qubits=[qubit.type for qubit in qubits], + out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits], + in_qubits=qubits, + in_ctrl_qubits=ctrl_qubits, + in_ctrl_values=ctrl_values_i1, + adjoint=adjoint, + ).results + + if name_str == "T": + assert len(float_params) == 0, "T takes no float parameters" + return TGateOp( + out_qubits=[qubit.type for qubit in qubits], + out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits], + in_qubits=qubits, + in_ctrl_qubits=ctrl_qubits, + in_ctrl_values=ctrl_values_i1, + adjoint=adjoint, + ).results + + if name_str == "CNOT": + assert len(float_params) == 0, "CNOT takes no float parameters" + return CNOTOp( + out_qubits=[qubit.type for qubit in qubits], + out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits], + in_qubits=qubits, + in_ctrl_qubits=ctrl_qubits, + in_ctrl_values=ctrl_values_i1, + adjoint=adjoint, + ).results + + if name_str == "RX": + assert len(float_params) == 1, "RX takes one float parameter" + float_param = float_params[0] + return RXOp( + out_qubits=[qubit.type for qubit in qubits], + out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits], + theta=float_param, + in_qubits=qubits, + in_ctrl_qubits=ctrl_qubits, + in_ctrl_values=ctrl_values_i1, + adjoint=adjoint, + ).results + + if name_str == "RY": + assert len(float_params) == 1, "RY takes one float parameter" + float_param = float_params[0] + return RYOp( + out_qubits=[qubit.type for qubit in qubits], + out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits], + theta=float_param, + in_qubits=qubits, + in_ctrl_qubits=ctrl_qubits, + in_ctrl_values=ctrl_values_i1, + adjoint=adjoint, + ).results + + if name_str == "RZ": + assert len(float_params) == 1, "RZ takes one float parameter" + float_param = float_params[0] + return RZOp( + out_qubits=[qubit.type for qubit in qubits], + out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits], + theta=float_param, + in_qubits=qubits, + in_ctrl_qubits=ctrl_qubits, + in_ctrl_values=ctrl_values_i1, + adjoint=adjoint, + ).results + if name_str == "MultiRZ": assert len(float_params) == 1, "MultiRZ takes one float parameter" float_param = float_params[0] diff --git a/mlir/include/Quantum/IR/QuantumOps.td b/mlir/include/Quantum/IR/QuantumOps.td index e9b46542e1..384320cfdc 100644 --- a/mlir/include/Quantum/IR/QuantumOps.td +++ b/mlir/include/Quantum/IR/QuantumOps.td @@ -564,6 +564,267 @@ def CustomOp : UnitaryGate_Op<"custom", [DifferentiableGate, NoMemoryEffect, let hasVerifier = 1; } +def PauliXOp : UnitaryGate_Op<"x", [NoMemoryEffect, + AttrSizedOperandSegments, AttrSizedResultSegments]> { + let summary = "PauliX"; + let description = [{ + }]; + + let arguments = (ins + Variadic:$in_qubits, + UnitAttr:$adjoint, + Variadic:$in_ctrl_qubits, + Variadic:$in_ctrl_values + ); + + let results = (outs + Variadic:$out_qubits, + Variadic:$out_ctrl_qubits + ); + + let assemblyFormat = [{ + $in_qubits (`adj` $adjoint^)? attr-dict ( `ctrls` `(` $in_ctrl_qubits^ `)` )? ( `ctrlvals` `(` $in_ctrl_values^ `)` )? `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )? + }]; +} + +def PauliYOp : UnitaryGate_Op<"y", [NoMemoryEffect, + AttrSizedOperandSegments, AttrSizedResultSegments]> { + let summary = "PauliY"; + let description = [{ + }]; + + let arguments = (ins + Variadic:$in_qubits, + UnitAttr:$adjoint, + Variadic:$in_ctrl_qubits, + Variadic:$in_ctrl_values + ); + + let results = (outs + Variadic:$out_qubits, + Variadic:$out_ctrl_qubits + ); + + let assemblyFormat = [{ + $in_qubits (`adj` $adjoint^)? attr-dict ( `ctrls` `(` $in_ctrl_qubits^ `)` )? ( `ctrlvals` `(` $in_ctrl_values^ `)` )? `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )? + }]; +} + +def PauliZOp : UnitaryGate_Op<"z", [NoMemoryEffect, + AttrSizedOperandSegments, AttrSizedResultSegments]> { + let summary = "PauliX"; + let description = [{ + }]; + + let arguments = (ins + Variadic:$in_qubits, + UnitAttr:$adjoint, + Variadic:$in_ctrl_qubits, + Variadic:$in_ctrl_values + ); + + let results = (outs + Variadic:$out_qubits, + Variadic:$out_ctrl_qubits + ); + + let assemblyFormat = [{ + $in_qubits (`adj` $adjoint^)? attr-dict ( `ctrls` `(` $in_ctrl_qubits^ `)` )? ( `ctrlvals` `(` $in_ctrl_values^ `)` )? `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )? + }]; +} + +def HadamardOp : UnitaryGate_Op<"h", [NoMemoryEffect, + AttrSizedOperandSegments, AttrSizedResultSegments]> { + let summary = "Hadamard"; + let description = [{ + }]; + + let arguments = (ins + Variadic:$in_qubits, + UnitAttr:$adjoint, + Variadic:$in_ctrl_qubits, + Variadic:$in_ctrl_values + ); + + let results = (outs + Variadic:$out_qubits, + Variadic:$out_ctrl_qubits + ); + + let assemblyFormat = [{ + $in_qubits (`adj` $adjoint^)? attr-dict ( `ctrls` `(` $in_ctrl_qubits^ `)` )? ( `ctrlvals` `(` $in_ctrl_values^ `)` )? `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )? + }]; +} + +def TGateOp : UnitaryGate_Op<"t", [NoMemoryEffect, + AttrSizedOperandSegments, AttrSizedResultSegments]> { + let summary = "T Gate"; + let description = [{ + }]; + + let arguments = (ins + Variadic:$in_qubits, + UnitAttr:$adjoint, + Variadic:$in_ctrl_qubits, + Variadic:$in_ctrl_values + ); + + let results = (outs + Variadic:$out_qubits, + Variadic:$out_ctrl_qubits + ); + + let assemblyFormat = [{ + $in_qubits (`adj` $adjoint^)? attr-dict ( `ctrls` `(` $in_ctrl_qubits^ `)` )? ( `ctrlvals` `(` $in_ctrl_values^ `)` )? `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )? + }]; +} + +def SGateOp : UnitaryGate_Op<"s", [NoMemoryEffect, + AttrSizedOperandSegments, AttrSizedResultSegments]> { + let summary = "S Gate"; + let description = [{ + }]; + + let arguments = (ins + Variadic:$in_qubits, + UnitAttr:$adjoint, + Variadic:$in_ctrl_qubits, + Variadic:$in_ctrl_values + ); + + let results = (outs + Variadic:$out_qubits, + Variadic:$out_ctrl_qubits + ); + + let assemblyFormat = [{ + $in_qubits (`adj` $adjoint^)? attr-dict ( `ctrls` `(` $in_ctrl_qubits^ `)` )? ( `ctrlvals` `(` $in_ctrl_values^ `)` )? `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )? + }]; +} + +def CNOTOp : UnitaryGate_Op<"cnot", [NoMemoryEffect, + AttrSizedOperandSegments, AttrSizedResultSegments]> { + let summary = "CNOT"; + let description = [{ + }]; + + let arguments = (ins + Variadic:$in_qubits, + UnitAttr:$adjoint, + Variadic:$in_ctrl_qubits, + Variadic:$in_ctrl_values + ); + + let results = (outs + Variadic:$out_qubits, + Variadic:$out_ctrl_qubits + ); + + let assemblyFormat = [{ + $in_qubits (`adj` $adjoint^)? attr-dict ( `ctrls` `(` $in_ctrl_qubits^ `)` )? ( `ctrlvals` `(` $in_ctrl_values^ `)` )? `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )? + }]; +} + + +def RXOp : UnitaryGate_Op<"rx", [DifferentiableGate, NoMemoryEffect, + AttrSizedOperandSegments, AttrSizedResultSegments]> { + let summary = "RX"; + let description = [{ + }]; + + let arguments = (ins + F64:$theta, + Variadic:$in_qubits, + UnitAttr:$adjoint, + Variadic:$in_ctrl_qubits, + Variadic:$in_ctrl_values + ); + + let results = (outs + Variadic:$out_qubits, + Variadic:$out_ctrl_qubits + ); + + let assemblyFormat = [{ + `(` $theta `)` $in_qubits (`adj` $adjoint^)? attr-dict ( `ctrls` `(` $in_ctrl_qubits^ `)` )? ( `ctrlvals` `(` $in_ctrl_values^ `)` )? `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )? + }]; + + let extraClassDeclaration = extraBaseClassDeclaration # [{ + mlir::ValueRange getAllParams() { + return getODSOperands(getParamOperandIdx()); + } + }]; + + let hasCanonicalizeMethod = 1; +} + + +def RYOp : UnitaryGate_Op<"ry", [DifferentiableGate, NoMemoryEffect, + AttrSizedOperandSegments, AttrSizedResultSegments]> { + let summary = "RY"; + let description = [{ + }]; + + let arguments = (ins + F64:$theta, + Variadic:$in_qubits, + UnitAttr:$adjoint, + Variadic:$in_ctrl_qubits, + Variadic:$in_ctrl_values + ); + + let results = (outs + Variadic:$out_qubits, + Variadic:$out_ctrl_qubits + ); + + let assemblyFormat = [{ + `(` $theta `)` $in_qubits (`adj` $adjoint^)? attr-dict ( `ctrls` `(` $in_ctrl_qubits^ `)` )? ( `ctrlvals` `(` $in_ctrl_values^ `)` )? `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )? + }]; + + let extraClassDeclaration = extraBaseClassDeclaration # [{ + mlir::ValueRange getAllParams() { + return getODSOperands(getParamOperandIdx()); + } + }]; + + let hasCanonicalizeMethod = 1; +} + + +def RZOp : UnitaryGate_Op<"rz", [DifferentiableGate, NoMemoryEffect, + AttrSizedOperandSegments, AttrSizedResultSegments]> { + let summary = "RZ"; + let description = [{ + }]; + + let arguments = (ins + F64:$theta, + Variadic:$in_qubits, + UnitAttr:$adjoint, + Variadic:$in_ctrl_qubits, + Variadic:$in_ctrl_values + ); + + let results = (outs + Variadic:$out_qubits, + Variadic:$out_ctrl_qubits + ); + + let assemblyFormat = [{ + `(` $theta `)` $in_qubits (`adj` $adjoint^)? attr-dict ( `ctrls` `(` $in_ctrl_qubits^ `)` )? ( `ctrlvals` `(` $in_ctrl_values^ `)` )? `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )? + }]; + + let extraClassDeclaration = extraBaseClassDeclaration # [{ + mlir::ValueRange getAllParams() { + return getODSOperands(getParamOperandIdx()); + } + }]; + + let hasCanonicalizeMethod = 1; +} + + def GlobalPhaseOp : UnitaryGate_Op<"gphase", [DifferentiableGate, AttrSizedOperandSegments]> { let summary = "Global Phase."; diff --git a/mlir/lib/Quantum/IR/QuantumOps.cpp b/mlir/lib/Quantum/IR/QuantumOps.cpp index 47686d99e2..96764c2944 100644 --- a/mlir/lib/Quantum/IR/QuantumOps.cpp +++ b/mlir/lib/Quantum/IR/QuantumOps.cpp @@ -70,6 +70,48 @@ LogicalResult CustomOp::canonicalize(CustomOp op, mlir::PatternRewriter &rewrite return failure(); } +LogicalResult RXOp::canonicalize(RXOp op, mlir::PatternRewriter &rewriter) +{ + if (op.getAdjoint()) { + auto paramNeg = rewriter.create(op.getLoc(), op.getTheta()); + + rewriter.replaceOpWithNewOp( + op, op.getOutQubits().getTypes(), op.getOutCtrlQubits().getTypes(), paramNeg, + op.getInQubits(), nullptr, op.getInCtrlQubits(), op.getInCtrlValues()); + + return success(); + }; + return failure(); +} + +LogicalResult RYOp::canonicalize(RYOp op, mlir::PatternRewriter &rewriter) +{ + if (op.getAdjoint()) { + auto paramNeg = rewriter.create(op.getLoc(), op.getTheta()); + + rewriter.replaceOpWithNewOp( + op, op.getOutQubits().getTypes(), op.getOutCtrlQubits().getTypes(), paramNeg, + op.getInQubits(), nullptr, op.getInCtrlQubits(), op.getInCtrlValues()); + + return success(); + }; + return failure(); +} + +LogicalResult RZOp::canonicalize(RZOp op, mlir::PatternRewriter &rewriter) +{ + if (op.getAdjoint()) { + auto paramNeg = rewriter.create(op.getLoc(), op.getTheta()); + + rewriter.replaceOpWithNewOp( + op, op.getOutQubits().getTypes(), op.getOutCtrlQubits().getTypes(), paramNeg, + op.getInQubits(), nullptr, op.getInCtrlQubits(), op.getInCtrlValues()); + + return success(); + }; + return failure(); +} + LogicalResult MultiRZOp::canonicalize(MultiRZOp op, mlir::PatternRewriter &rewriter) { if (op.getAdjoint()) { diff --git a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp index cd1d958395..ff61beaf4f 100644 --- a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp +++ b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp @@ -159,8 +159,12 @@ void populateSelfInversePatterns(RewritePatternSet &patterns) // but interfaces cannot be accepted by pattern matchers, since pattern // matchers require the target operations to have concrete names in the IR. patterns.add>(patterns.getContext(), 1); - patterns.add>(patterns.getContext(), 1); + // TODO: add other explicit unitary gate ops here as they are implemented + patterns.add>(patterns.getContext(), 1); + patterns.add>(patterns.getContext(), 1); + patterns.add>(patterns.getContext(), 1); patterns.add>(patterns.getContext(), 1); + patterns.add>(patterns.getContext(), 1); } } // namespace quantum diff --git a/mlir/lib/Quantum/Transforms/ConversionPatterns.cpp b/mlir/lib/Quantum/Transforms/ConversionPatterns.cpp index 83ff8b3da3..239d9e15d6 100644 --- a/mlir/lib/Quantum/Transforms/ConversionPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/ConversionPatterns.cpp @@ -523,6 +523,402 @@ struct CustomOpPattern : public OpConversionPattern { } }; +struct PauliXOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(PauliXOp op, PauliXOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + Location loc = op.getLoc(); + MLIRContext *ctx = getContext(); + const TypeConverter *conv = getTypeConverter(); + auto modifiersPtr = getModifiersPtr(loc, rewriter, conv, op.getAdjointFlag(), + adaptor.getInCtrlQubits(), adaptor.getInCtrlValues()); + + std::string qirName = "__catalyst__qis__PauliX"; + + SmallVector argTypes; + argTypes.insert(argTypes.end(), adaptor.getInQubits().getTypes().begin(), + adaptor.getInQubits().getTypes().end()); + argTypes.insert(argTypes.end(), modifiersPtr.getType()); + + Type qirSignature = LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx), argTypes, + /*isVarArg=*/false); + LLVM::LLVMFuncOp fnDecl = + catalyst::ensureFunctionDeclaration(rewriter, op, qirName, qirSignature); + + SmallVector args; + args.insert(args.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); + args.insert(args.end(), modifiersPtr); + + rewriter.create(loc, fnDecl, args); + SmallVector values; + values.insert(values.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); + values.insert(values.end(), adaptor.getInCtrlQubits().begin(), + adaptor.getInCtrlQubits().end()); + rewriter.replaceOp(op, values); + + return success(); + } +}; + +struct PauliYOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(PauliYOp op, PauliYOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + Location loc = op.getLoc(); + MLIRContext *ctx = getContext(); + const TypeConverter *conv = getTypeConverter(); + auto modifiersPtr = getModifiersPtr(loc, rewriter, conv, op.getAdjointFlag(), + adaptor.getInCtrlQubits(), adaptor.getInCtrlValues()); + + std::string qirName = "__catalyst__qis__PauliY"; + + SmallVector argTypes; + argTypes.insert(argTypes.end(), adaptor.getInQubits().getTypes().begin(), + adaptor.getInQubits().getTypes().end()); + argTypes.insert(argTypes.end(), modifiersPtr.getType()); + + Type qirSignature = LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx), argTypes, + /*isVarArg=*/false); + LLVM::LLVMFuncOp fnDecl = + catalyst::ensureFunctionDeclaration(rewriter, op, qirName, qirSignature); + + SmallVector args; + args.insert(args.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); + args.insert(args.end(), modifiersPtr); + + rewriter.create(loc, fnDecl, args); + SmallVector values; + values.insert(values.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); + values.insert(values.end(), adaptor.getInCtrlQubits().begin(), + adaptor.getInCtrlQubits().end()); + rewriter.replaceOp(op, values); + + return success(); + } +}; + +struct PauliZOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(PauliZOp op, PauliZOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + Location loc = op.getLoc(); + MLIRContext *ctx = getContext(); + const TypeConverter *conv = getTypeConverter(); + auto modifiersPtr = getModifiersPtr(loc, rewriter, conv, op.getAdjointFlag(), + adaptor.getInCtrlQubits(), adaptor.getInCtrlValues()); + + std::string qirName = "__catalyst__qis__PauliZ"; + + SmallVector argTypes; + argTypes.insert(argTypes.end(), adaptor.getInQubits().getTypes().begin(), + adaptor.getInQubits().getTypes().end()); + argTypes.insert(argTypes.end(), modifiersPtr.getType()); + + Type qirSignature = LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx), argTypes, + /*isVarArg=*/false); + LLVM::LLVMFuncOp fnDecl = + catalyst::ensureFunctionDeclaration(rewriter, op, qirName, qirSignature); + + SmallVector args; + args.insert(args.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); + args.insert(args.end(), modifiersPtr); + + rewriter.create(loc, fnDecl, args); + SmallVector values; + values.insert(values.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); + values.insert(values.end(), adaptor.getInCtrlQubits().begin(), + adaptor.getInCtrlQubits().end()); + rewriter.replaceOp(op, values); + + return success(); + } +}; + +struct HadamardOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(HadamardOp op, HadamardOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + Location loc = op.getLoc(); + MLIRContext *ctx = getContext(); + const TypeConverter *conv = getTypeConverter(); + auto modifiersPtr = getModifiersPtr(loc, rewriter, conv, op.getAdjointFlag(), + adaptor.getInCtrlQubits(), adaptor.getInCtrlValues()); + + std::string qirName = "__catalyst__qis__Hadamard"; + + SmallVector argTypes; + argTypes.insert(argTypes.end(), adaptor.getInQubits().getTypes().begin(), + adaptor.getInQubits().getTypes().end()); + argTypes.insert(argTypes.end(), modifiersPtr.getType()); + + Type qirSignature = LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx), argTypes, + /*isVarArg=*/false); + LLVM::LLVMFuncOp fnDecl = + catalyst::ensureFunctionDeclaration(rewriter, op, qirName, qirSignature); + + SmallVector args; + args.insert(args.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); + args.insert(args.end(), modifiersPtr); + + rewriter.create(loc, fnDecl, args); + SmallVector values; + values.insert(values.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); + values.insert(values.end(), adaptor.getInCtrlQubits().begin(), + adaptor.getInCtrlQubits().end()); + rewriter.replaceOp(op, values); + + return success(); + } +}; + +struct CNOTOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(CNOTOp op, CNOTOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + Location loc = op.getLoc(); + MLIRContext *ctx = getContext(); + const TypeConverter *conv = getTypeConverter(); + auto modifiersPtr = getModifiersPtr(loc, rewriter, conv, op.getAdjointFlag(), + adaptor.getInCtrlQubits(), adaptor.getInCtrlValues()); + + std::string qirName = "__catalyst__qis__CNOT"; + + SmallVector argTypes; + argTypes.insert(argTypes.end(), adaptor.getInQubits().getTypes().begin(), + adaptor.getInQubits().getTypes().end()); + argTypes.insert(argTypes.end(), modifiersPtr.getType()); + + Type qirSignature = LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx), argTypes, + /*isVarArg=*/false); + LLVM::LLVMFuncOp fnDecl = + catalyst::ensureFunctionDeclaration(rewriter, op, qirName, qirSignature); + + SmallVector args; + args.insert(args.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); + args.insert(args.end(), modifiersPtr); + + rewriter.create(loc, fnDecl, args); + SmallVector values; + values.insert(values.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); + values.insert(values.end(), adaptor.getInCtrlQubits().begin(), + adaptor.getInCtrlQubits().end()); + rewriter.replaceOp(op, values); + + return success(); + } +}; + +struct TGateOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(TGateOp op, TGateOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + Location loc = op.getLoc(); + MLIRContext *ctx = getContext(); + const TypeConverter *conv = getTypeConverter(); + auto modifiersPtr = getModifiersPtr(loc, rewriter, conv, op.getAdjointFlag(), + adaptor.getInCtrlQubits(), adaptor.getInCtrlValues()); + + std::string qirName = "__catalyst__qis__T"; + + SmallVector argTypes; + argTypes.insert(argTypes.end(), adaptor.getInQubits().getTypes().begin(), + adaptor.getInQubits().getTypes().end()); + argTypes.insert(argTypes.end(), modifiersPtr.getType()); + + Type qirSignature = LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx), argTypes, + /*isVarArg=*/false); + LLVM::LLVMFuncOp fnDecl = + catalyst::ensureFunctionDeclaration(rewriter, op, qirName, qirSignature); + + SmallVector args; + args.insert(args.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); + args.insert(args.end(), modifiersPtr); + + rewriter.create(loc, fnDecl, args); + SmallVector values; + values.insert(values.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); + values.insert(values.end(), adaptor.getInCtrlQubits().begin(), + adaptor.getInCtrlQubits().end()); + rewriter.replaceOp(op, values); + + return success(); + } +}; + +struct SGateOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(SGateOp op, SGateOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + Location loc = op.getLoc(); + MLIRContext *ctx = getContext(); + const TypeConverter *conv = getTypeConverter(); + auto modifiersPtr = getModifiersPtr(loc, rewriter, conv, op.getAdjointFlag(), + adaptor.getInCtrlQubits(), adaptor.getInCtrlValues()); + + std::string qirName = "__catalyst__qis__S"; + + SmallVector argTypes; + argTypes.insert(argTypes.end(), adaptor.getInQubits().getTypes().begin(), + adaptor.getInQubits().getTypes().end()); + argTypes.insert(argTypes.end(), modifiersPtr.getType()); + + Type qirSignature = LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx), argTypes, + /*isVarArg=*/false); + LLVM::LLVMFuncOp fnDecl = + catalyst::ensureFunctionDeclaration(rewriter, op, qirName, qirSignature); + + SmallVector args; + args.insert(args.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); + args.insert(args.end(), modifiersPtr); + + rewriter.create(loc, fnDecl, args); + SmallVector values; + values.insert(values.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); + values.insert(values.end(), adaptor.getInCtrlQubits().begin(), + adaptor.getInCtrlQubits().end()); + rewriter.replaceOp(op, values); + + return success(); + } +}; + +struct RXOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(RXOp op, RXOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + Location loc = op.getLoc(); + MLIRContext *ctx = getContext(); + const TypeConverter *conv = getTypeConverter(); + auto modifiersPtr = getModifiersPtr(loc, rewriter, conv, op.getAdjointFlag(), + adaptor.getInCtrlQubits(), adaptor.getInCtrlValues()); + + std::string qirName = "__catalyst__qis__RX"; + + SmallVector argTypes; + argTypes.insert(argTypes.end(), Float64Type::get(ctx)); + argTypes.insert(argTypes.end(), adaptor.getInQubits().getTypes().begin(), + adaptor.getInQubits().getTypes().end()); + argTypes.insert(argTypes.end(), modifiersPtr.getType()); + + Type qirSignature = LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx), argTypes, + /*isVarArg=*/false); + LLVM::LLVMFuncOp fnDecl = + catalyst::ensureFunctionDeclaration(rewriter, op, qirName, qirSignature); + + SmallVector args; + args.insert(args.end(), adaptor.getTheta()); + args.insert(args.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); + args.insert(args.end(), modifiersPtr); + + rewriter.create(loc, fnDecl, args); + SmallVector values; + values.insert(values.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); + values.insert(values.end(), adaptor.getInCtrlQubits().begin(), + adaptor.getInCtrlQubits().end()); + rewriter.replaceOp(op, values); + + return success(); + } +}; + +struct RYOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(RYOp op, RYOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + Location loc = op.getLoc(); + MLIRContext *ctx = getContext(); + const TypeConverter *conv = getTypeConverter(); + auto modifiersPtr = getModifiersPtr(loc, rewriter, conv, op.getAdjointFlag(), + adaptor.getInCtrlQubits(), adaptor.getInCtrlValues()); + + std::string qirName = "__catalyst__qis__RY"; + + SmallVector argTypes; + argTypes.insert(argTypes.end(), Float64Type::get(ctx)); + argTypes.insert(argTypes.end(), adaptor.getInQubits().getTypes().begin(), + adaptor.getInQubits().getTypes().end()); + argTypes.insert(argTypes.end(), modifiersPtr.getType()); + + Type qirSignature = LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx), argTypes, + /*isVarArg=*/false); + LLVM::LLVMFuncOp fnDecl = + catalyst::ensureFunctionDeclaration(rewriter, op, qirName, qirSignature); + + SmallVector args; + args.insert(args.end(), adaptor.getTheta()); + args.insert(args.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); + args.insert(args.end(), modifiersPtr); + + rewriter.create(loc, fnDecl, args); + SmallVector values; + values.insert(values.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); + values.insert(values.end(), adaptor.getInCtrlQubits().begin(), + adaptor.getInCtrlQubits().end()); + rewriter.replaceOp(op, values); + + return success(); + } +}; + +struct RZOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(RZOp op, RZOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + Location loc = op.getLoc(); + MLIRContext *ctx = getContext(); + const TypeConverter *conv = getTypeConverter(); + auto modifiersPtr = getModifiersPtr(loc, rewriter, conv, op.getAdjointFlag(), + adaptor.getInCtrlQubits(), adaptor.getInCtrlValues()); + + std::string qirName = "__catalyst__qis__RZ"; + + SmallVector argTypes; + argTypes.insert(argTypes.end(), Float64Type::get(ctx)); + argTypes.insert(argTypes.end(), adaptor.getInQubits().getTypes().begin(), + adaptor.getInQubits().getTypes().end()); + argTypes.insert(argTypes.end(), modifiersPtr.getType()); + + Type qirSignature = LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx), argTypes, + /*isVarArg=*/false); + LLVM::LLVMFuncOp fnDecl = + catalyst::ensureFunctionDeclaration(rewriter, op, qirName, qirSignature); + + SmallVector args; + args.insert(args.end(), adaptor.getTheta()); + args.insert(args.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); + args.insert(args.end(), modifiersPtr); + + rewriter.create(loc, fnDecl, args); + SmallVector values; + values.insert(values.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); + values.insert(values.end(), adaptor.getInCtrlQubits().begin(), + adaptor.getInCtrlQubits().end()); + rewriter.replaceOp(op, values); + + return success(); + } +}; + struct GlobalPhaseOpPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -1189,6 +1585,16 @@ void populateQIRConversionPatterns(TypeConverter &typeConverter, RewritePatternS patterns.add(typeConverter, patterns.getContext()); } patterns.add(typeConverter, patterns.getContext()); + patterns.add(typeConverter, patterns.getContext()); + patterns.add(typeConverter, patterns.getContext()); + patterns.add(typeConverter, patterns.getContext()); + patterns.add(typeConverter, patterns.getContext()); + patterns.add(typeConverter, patterns.getContext()); + patterns.add(typeConverter, patterns.getContext()); + patterns.add(typeConverter, patterns.getContext()); + patterns.add(typeConverter, patterns.getContext()); + patterns.add(typeConverter, patterns.getContext()); + patterns.add(typeConverter, patterns.getContext()); patterns.add(typeConverter, patterns.getContext()); patterns.add(typeConverter, patterns.getContext()); patterns.add(typeConverter, patterns.getContext()); From 6a508c55ecefb14414fd68a75fcc254772cead1d Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Tue, 18 Nov 2025 15:24:40 -0500 Subject: [PATCH 2/2] MLIR decompose-lowering works with fine-grained gate ops --- .../Transforms/DecomposeLoweringImpl.hpp | 39 ++++ .../Transforms/DecomposeLoweringPatterns.cpp | 219 ++++++++++++++++++ 2 files changed, 258 insertions(+) diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp index 6715808aae..c9cc935635 100644 --- a/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp +++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp @@ -431,6 +431,45 @@ class CustomOpSignatureAnalyzer : public BaseSignatureAnalyzer { } }; +class RXOpSignatureAnalyzer : public BaseSignatureAnalyzer { + public: + RXOpSignatureAnalyzer() = delete; + + RXOpSignatureAnalyzer(RXOp op, bool enableQregMode) + : BaseSignatureAnalyzer(op, op.getTheta(), op.getNonCtrlQubitOperands(), + op.getCtrlQubitOperands(), op.getCtrlValueOperands(), + op.getNonCtrlQubitResults(), op.getCtrlQubitResults(), + enableQregMode) + { + } +}; + +class RYOpSignatureAnalyzer : public BaseSignatureAnalyzer { + public: + RYOpSignatureAnalyzer() = delete; + + RYOpSignatureAnalyzer(RYOp op, bool enableQregMode) + : BaseSignatureAnalyzer(op, op.getTheta(), op.getNonCtrlQubitOperands(), + op.getCtrlQubitOperands(), op.getCtrlValueOperands(), + op.getNonCtrlQubitResults(), op.getCtrlQubitResults(), + enableQregMode) + { + } +}; + +class RZOpSignatureAnalyzer : public BaseSignatureAnalyzer { + public: + RZOpSignatureAnalyzer() = delete; + + RZOpSignatureAnalyzer(RZOp op, bool enableQregMode) + : BaseSignatureAnalyzer(op, op.getTheta(), op.getNonCtrlQubitOperands(), + op.getCtrlQubitOperands(), op.getCtrlValueOperands(), + op.getNonCtrlQubitResults(), op.getCtrlQubitResults(), + enableQregMode) + { + } +}; + class MultiRZOpSignatureAnalyzer : public BaseSignatureAnalyzer { public: MultiRZOpSignatureAnalyzer() = delete; diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp index 3d5ccc8775..c072ed368c 100644 --- a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp @@ -91,6 +91,222 @@ struct DLCustomOpPattern : public OpRewritePattern { } }; +struct DLRXOpPattern : public OpRewritePattern { + private: + const llvm::StringMap &decompositionRegistry; + const llvm::StringSet &targetGateSet; + + public: + DLRXOpPattern(MLIRContext *context, const llvm::StringMap ®istry, + const llvm::StringSet &gateSet) + : OpRewritePattern(context), decompositionRegistry(registry), targetGateSet(gateSet) + { + } + + LogicalResult matchAndRewrite(RXOp op, PatternRewriter &rewriter) const override + { + std::string gateName = "RX"; + + // Only decompose the op if it is not in the target gate set + if (targetGateSet.contains(gateName)) { + return failure(); + } + + // Find the corresponding decomposition function for the op + auto numQubits = op.getInQubits().size(); + + auto it = decompositionRegistry.find(gateName); + if (it == decompositionRegistry.end()) { + return failure(); + } + + func::FuncOp decompFunc = it->second; + // Here is the assumption that the decomposition function must have + // at least one input and one result + assert(decompFunc.getFunctionType().getNumInputs() > 0 && + "Decomposition function must have at least one input"); + assert(decompFunc.getFunctionType().getNumResults() >= 1 && + "Decomposition function must have at least one result"); + + rewriter.setInsertionPointAfter(op); + + auto enableQreg = isa(decompFunc.getFunctionType().getInput(0)); + auto numQbitsAttr = decompFunc->getAttrOfType("num_wires"); + if (!numQbitsAttr) { + op.emitError("Decomposition function missing 'num_wires' attribute"); + return failure(); + } + if (numQubits != static_cast(numQbitsAttr.getInt())) { + op.emitError("Mismatch in number of qubits: expected ") + << numQbitsAttr.getInt() << ", got " << numQubits; + return failure(); + } + + auto analyzer = RXOpSignatureAnalyzer(op, enableQreg); + assert(analyzer && "Analyzer should be valid"); + + auto callOperands = analyzer.prepareCallOperands(decompFunc, rewriter, op.getLoc()); + auto callOp = + rewriter.create(op.getLoc(), decompFunc.getFunctionType().getResults(), + decompFunc.getSymName(), callOperands); + + // Replace the op with the call op and adjust the insert ops for the qreg mode + if (callOp.getNumResults() == 1 && isa(callOp.getResult(0).getType())) { + auto results = analyzer.prepareCallResultForQreg(callOp, rewriter); + rewriter.replaceOp(op, results); + } + else { + rewriter.replaceOp(op, callOp->getResults()); + } + + return success(); + } +}; + +struct DLRYOpPattern : public OpRewritePattern { + private: + const llvm::StringMap &decompositionRegistry; + const llvm::StringSet &targetGateSet; + + public: + DLRYOpPattern(MLIRContext *context, const llvm::StringMap ®istry, + const llvm::StringSet &gateSet) + : OpRewritePattern(context), decompositionRegistry(registry), targetGateSet(gateSet) + { + } + + LogicalResult matchAndRewrite(RYOp op, PatternRewriter &rewriter) const override + { + std::string gateName = "RY"; + + // Only decompose the op if it is not in the target gate set + if (targetGateSet.contains(gateName)) { + return failure(); + } + + // Find the corresponding decomposition function for the op + auto numQubits = op.getInQubits().size(); + + auto it = decompositionRegistry.find(gateName); + if (it == decompositionRegistry.end()) { + return failure(); + } + + func::FuncOp decompFunc = it->second; + // Here is the assumption that the decomposition function must have + // at least one input and one result + assert(decompFunc.getFunctionType().getNumInputs() > 0 && + "Decomposition function must have at least one input"); + assert(decompFunc.getFunctionType().getNumResults() >= 1 && + "Decomposition function must have at least one result"); + + rewriter.setInsertionPointAfter(op); + + auto enableQreg = isa(decompFunc.getFunctionType().getInput(0)); + auto numQbitsAttr = decompFunc->getAttrOfType("num_wires"); + if (!numQbitsAttr) { + op.emitError("Decomposition function missing 'num_wires' attribute"); + return failure(); + } + if (numQubits != static_cast(numQbitsAttr.getInt())) { + op.emitError("Mismatch in number of qubits: expected ") + << numQbitsAttr.getInt() << ", got " << numQubits; + return failure(); + } + + auto analyzer = RYOpSignatureAnalyzer(op, enableQreg); + assert(analyzer && "Analyzer should be valid"); + + auto callOperands = analyzer.prepareCallOperands(decompFunc, rewriter, op.getLoc()); + auto callOp = + rewriter.create(op.getLoc(), decompFunc.getFunctionType().getResults(), + decompFunc.getSymName(), callOperands); + + // Replace the op with the call op and adjust the insert ops for the qreg mode + if (callOp.getNumResults() == 1 && isa(callOp.getResult(0).getType())) { + auto results = analyzer.prepareCallResultForQreg(callOp, rewriter); + rewriter.replaceOp(op, results); + } + else { + rewriter.replaceOp(op, callOp->getResults()); + } + + return success(); + } +}; + +struct DLRZOpPattern : public OpRewritePattern { + private: + const llvm::StringMap &decompositionRegistry; + const llvm::StringSet &targetGateSet; + + public: + DLRZOpPattern(MLIRContext *context, const llvm::StringMap ®istry, + const llvm::StringSet &gateSet) + : OpRewritePattern(context), decompositionRegistry(registry), targetGateSet(gateSet) + { + } + + LogicalResult matchAndRewrite(RZOp op, PatternRewriter &rewriter) const override + { + std::string gateName = "RZ"; + + // Only decompose the op if it is not in the target gate set + if (targetGateSet.contains(gateName)) { + return failure(); + } + + // Find the corresponding decomposition function for the op + auto numQubits = op.getInQubits().size(); + + auto it = decompositionRegistry.find(gateName); + if (it == decompositionRegistry.end()) { + return failure(); + } + + func::FuncOp decompFunc = it->second; + // Here is the assumption that the decomposition function must have + // at least one input and one result + assert(decompFunc.getFunctionType().getNumInputs() > 0 && + "Decomposition function must have at least one input"); + assert(decompFunc.getFunctionType().getNumResults() >= 1 && + "Decomposition function must have at least one result"); + + rewriter.setInsertionPointAfter(op); + + auto enableQreg = isa(decompFunc.getFunctionType().getInput(0)); + auto numQbitsAttr = decompFunc->getAttrOfType("num_wires"); + if (!numQbitsAttr) { + op.emitError("Decomposition function missing 'num_wires' attribute"); + return failure(); + } + if (numQubits != static_cast(numQbitsAttr.getInt())) { + op.emitError("Mismatch in number of qubits: expected ") + << numQbitsAttr.getInt() << ", got " << numQubits; + return failure(); + } + + auto analyzer = RZOpSignatureAnalyzer(op, enableQreg); + assert(analyzer && "Analyzer should be valid"); + + auto callOperands = analyzer.prepareCallOperands(decompFunc, rewriter, op.getLoc()); + auto callOp = + rewriter.create(op.getLoc(), decompFunc.getFunctionType().getResults(), + decompFunc.getSymName(), callOperands); + + // Replace the op with the call op and adjust the insert ops for the qreg mode + if (callOp.getNumResults() == 1 && isa(callOp.getResult(0).getType())) { + auto results = analyzer.prepareCallResultForQreg(callOp, rewriter); + rewriter.replaceOp(op, results); + } + else { + rewriter.replaceOp(op, callOp->getResults()); + } + + return success(); + } +}; + struct DLMultiRZOpPattern : public OpRewritePattern { private: const llvm::StringMap &decompositionRegistry; @@ -170,6 +386,9 @@ void populateDecomposeLoweringPatterns(RewritePatternSet &patterns, const llvm::StringSet &targetGateSet) { patterns.add(patterns.getContext(), decompositionRegistry, targetGateSet); + patterns.add(patterns.getContext(), decompositionRegistry, targetGateSet); + patterns.add(patterns.getContext(), decompositionRegistry, targetGateSet); + patterns.add(patterns.getContext(), decompositionRegistry, targetGateSet); patterns.add(patterns.getContext(), decompositionRegistry, targetGateSet); }