From 46c39de8db577699fc0c0049d3c9300a6bd6ca47 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Mon, 10 Nov 2025 17:09:24 -0500 Subject: [PATCH 01/51] RTIO dialect is added to Catalyst --- mlir/include/CMakeLists.txt | 1 + mlir/include/RTIO/CMakeLists.txt | 1 + mlir/include/RTIO/IR/CMakeLists.txt | 8 + mlir/include/RTIO/IR/RTIODialect.h | 25 ++ mlir/include/RTIO/IR/RTIODialect.td | 164 ++++++++++++ mlir/include/RTIO/IR/RTIOOps.h | 26 ++ mlir/include/RTIO/IR/RTIOOps.td | 357 +++++++++++++++++++++++++ mlir/lib/CMakeLists.txt | 1 + mlir/lib/RTIO/CMakeLists.txt | 1 + mlir/lib/RTIO/IR/CMakeLists.txt | 11 + mlir/lib/RTIO/IR/RTIODialect.cpp | 150 +++++++++++ mlir/lib/RTIO/IR/RTIOOps.cpp | 36 +++ mlir/test/CMakeLists.txt | 1 + mlir/test/RTIO/VerifierTest.mlir | 192 +++++++++++++ mlir/tools/quantum-opt/CMakeLists.txt | 1 + mlir/tools/quantum-opt/quantum-opt.cpp | 2 + 16 files changed, 977 insertions(+) create mode 100644 mlir/include/RTIO/CMakeLists.txt create mode 100644 mlir/include/RTIO/IR/CMakeLists.txt create mode 100644 mlir/include/RTIO/IR/RTIODialect.h create mode 100644 mlir/include/RTIO/IR/RTIODialect.td create mode 100644 mlir/include/RTIO/IR/RTIOOps.h create mode 100644 mlir/include/RTIO/IR/RTIOOps.td create mode 100644 mlir/lib/RTIO/CMakeLists.txt create mode 100644 mlir/lib/RTIO/IR/CMakeLists.txt create mode 100644 mlir/lib/RTIO/IR/RTIODialect.cpp create mode 100644 mlir/lib/RTIO/IR/RTIOOps.cpp create mode 100644 mlir/test/RTIO/VerifierTest.mlir diff --git a/mlir/include/CMakeLists.txt b/mlir/include/CMakeLists.txt index 85180214a9..1fccb38f41 100644 --- a/mlir/include/CMakeLists.txt +++ b/mlir/include/CMakeLists.txt @@ -6,4 +6,5 @@ add_subdirectory(MBQC) add_subdirectory(Mitigation) add_subdirectory(QEC) add_subdirectory(Quantum) +add_subdirectory(RTIO) add_subdirectory(Test) diff --git a/mlir/include/RTIO/CMakeLists.txt b/mlir/include/RTIO/CMakeLists.txt new file mode 100644 index 0000000000..f33061b2d8 --- /dev/null +++ b/mlir/include/RTIO/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/include/RTIO/IR/CMakeLists.txt b/mlir/include/RTIO/IR/CMakeLists.txt new file mode 100644 index 0000000000..06b6d47a35 --- /dev/null +++ b/mlir/include/RTIO/IR/CMakeLists.txt @@ -0,0 +1,8 @@ +add_mlir_dialect(RTIOOps rtio) +add_mlir_doc(RTIODialect RTIODialect RTIO/ -gen-dialect-doc) +add_mlir_doc(RTIOOps RTIOOps RTIO/ -gen-op-doc) + +set(LLVM_TARGET_DEFINITIONS RTIOOps.td) +mlir_tablegen(RTIOAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=rtio) +mlir_tablegen(RTIOAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=rtio) +add_public_tablegen_target(MLIRRTIOAttributesIncGen) diff --git a/mlir/include/RTIO/IR/RTIODialect.h b/mlir/include/RTIO/IR/RTIODialect.h new file mode 100644 index 0000000000..38ca8c694b --- /dev/null +++ b/mlir/include/RTIO/IR/RTIODialect.h @@ -0,0 +1,25 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" + +#include "RTIO/IR/RTIOOpsDialect.h.inc" + +#define GET_TYPEDEF_CLASSES +#include "RTIO/IR/RTIOOpsTypes.h.inc" + diff --git a/mlir/include/RTIO/IR/RTIODialect.td b/mlir/include/RTIO/IR/RTIODialect.td new file mode 100644 index 0000000000..0de3669c13 --- /dev/null +++ b/mlir/include/RTIO/IR/RTIODialect.td @@ -0,0 +1,164 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef RTIO_DIALECT +#define RTIO_DIALECT + +include "mlir/IR/OpBase.td" +include "mlir/IR/DialectBase.td" +include "mlir/IR/AttrTypeBase.td" + +//===----------------------------------------------------------------------===// +// RTIO Dialect Definition +//===----------------------------------------------------------------------===// + +def RTIO_Dialect : Dialect { + let summary = "Real-Time I/O dialect for FPGA quantum control"; + let description = [{ + The RTIO dialect provides operations for precise timing control + and hardware signal generation on FPGAs for quantum computing. + + This dialect supports two levels of abstraction: + 1. Event-Based IR (high-level): Declarative operations with explicit event dependencies + + // TODO: Do we need the separate Timeline IR for artiq family? + 2. Timeline IR (low-level): Stateful operations with implicit time cursor + ``` + }]; + + let name = "rtio"; + let cppNamespace = "::catalyst::rtio"; + let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 0; + + let extraClassDeclaration = [{ + mlir::Attribute parseAttribute(mlir::DialectAsmParser &parser, + mlir::Type type) const override; + + void printAttribute(mlir::Attribute attr, + mlir::DialectAsmPrinter &printer) const override; + }]; +} + +//===----------------------------------------------------------------------===// +// RTIO dialect types. +//===----------------------------------------------------------------------===// + +class RTIO_Type traits = []> + : TypeDef { + let mnemonic = typeMnemonic; +} + +// Channel type with variadic parameters +def RTIOChannelType : RTIO_Type<"Channel", "channel"> { + let summary = "A hardware I/O channel with logical and physical identification"; + let description = [{ + Represents a virtual hardware channel for RTIO operations + + Syntax: + ``` + !rtio.channel // Dynamic: kind only + !rtio.channel // Static: kind + channel N + !rtio.channel // Dynamic: with qualifiers + !rtio.channel // Static: with qualifiers + channel N + ``` + + Examples: + ```mlir + + // Simple DDS, channel TBD + !rtio.channel<"dds", ?> + + // DDS for one two qualifiers 0 and "t0", channel TBD + !rtio.channel<"dds", [0, "transition_0"], ?> + + // DDS on hardware channel 0 + !rtio.channel<"dds", 0> + + // === Channel Resolution During Compilation === + + // Before channel resolution: Dynamic channels + !rtio.channel<"dds", [0], ?> + + // After channel resolution: Resolved to hardware channel 1 + !rtio.channel<"dds", [0], 1> + + // Note: The qualifiers are used provided addtional information to the channel type. + // And it will used to distinguish different channels with the same kind. + ``` + }]; + + let parameters = (ins + StringRefParameter<"channel kind">:$kind, + OptionalParameter<"mlir::ArrayAttr">:$qualifiers, + OptionalParameter<"mlir::IntegerAttr">:$channelId + ); + + let assemblyFormat = [{ + `<` custom($kind, $qualifiers, $channelId) `>` + }]; + + let extraClassDeclaration = [{ + bool hasQualifiers() const { + return getQualifiers() && !getQualifiers().empty(); + } + + size_t getNumQualifiers() const { + return getQualifiers() ? getQualifiers().size() : 0; + } + + mlir::Attribute getQualifier(size_t index) const { + if (!getQualifiers() || index >= getQualifiers().size()) + return nullptr; + return getQualifiers()[index]; + } + + bool isStatic() const { + auto channelId = getChannelId(); + if (!channelId) { + return false; + } + return channelId.getInt() >= 0; + } + + bool isDynamic() const { + return !isStatic(); + } + }]; +} + +// Event handle type +def RTIOEventType : RTIO_Type<"Event", "event"> { + let summary = "A handle to a pending RTIO event"; + let description = [{ + Represents a handle to a pending RTIO event (e.g., pulse, sync). + + Example: + ```mlir + %event0 = rtio.pulse %ch0 duration(%dur) frequency(%freq) phase(%phase) : !rtio.event + %event1 = rtio.pulse %ch1 duration(%dur) frequency(%freq) phase(%phase) : !rtio.event + %sync = rtio.sync %event0, %event1 : !rtio.event + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// RTIO Operation Base +//===----------------------------------------------------------------------===// + +class RTIO_Op traits = []> : + Op; + +#endif // RTIO_DIALECT + diff --git a/mlir/include/RTIO/IR/RTIOOps.h b/mlir/include/RTIO/IR/RTIOOps.h new file mode 100644 index 0000000000..0ce7643b83 --- /dev/null +++ b/mlir/include/RTIO/IR/RTIOOps.h @@ -0,0 +1,26 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "RTIO/IR/RTIODialect.h" +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#define GET_ATTRDEF_CLASSES +#include "RTIO/IR/RTIOAttributes.h.inc" +#define GET_OP_CLASSES +#include "RTIO/IR/RTIOOps.h.inc" diff --git a/mlir/include/RTIO/IR/RTIOOps.td b/mlir/include/RTIO/IR/RTIOOps.td new file mode 100644 index 0000000000..88fabbcb63 --- /dev/null +++ b/mlir/include/RTIO/IR/RTIOOps.td @@ -0,0 +1,357 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef RTIO_OPS +#define RTIO_OPS + +include "mlir/IR/OpBase.td" +include "mlir/IR/BuiltinAttributes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "RTIO/IR/RTIODialect.td" + +//===----------------------------------------------------------------------===// +// Event-Based API +//===----------------------------------------------------------------------===// + +def RTIOChannelOp : RTIO_Op<"channel"> { + let summary = "Define a static channel"; + let description = [{ + The channel's identity (kind, qualifiers, and channel id) is + encoded in its result type. + + The channel ID must be explicitly specified in the type as either: + - `?` for dynamic channel (to be resolved by channel resolution stage) + - Non-negative integer for static channel id + + Example: + ```mlir + // DDS on hardware channel 0 + %ch0_dds = rtio.channel : !rtio.channel<"dds", 0> + + // DDS on hardware channel 0 for qualifiers [0] + %ch0_t0 = rtio.channel : !rtio.channel<"dds", [0], 0> + + // DDS for dynamic channel with qualifiers [0] + %ch_t0 = rtio.channel : !rtio.channel<"dds", [0], ?> + ``` + }]; + + let arguments = (ins); + let results = (outs RTIOChannelType:$channel); + + let assemblyFormat = [{ + attr-dict `:` type(results) + }]; +} + +def RTIOQubitToChannelOp : RTIO_Op<"qubit_to_channel"> { + let summary = "Map a qubit to an RTIO channel"; + let description = [{ + It's a temporary operation that will be lowered to a static `rtio.channel` operation during + the channel resolution stage. The purpose of this operation is to allow the qubit from a + from a high-level dialect to be mapped to an `!rtio.channel`. + + Example: + ```mlir + // Map ion qubit to dds channel with dynamic channel id + %ch = rtio.qubit_to_channel %ion_qubit : !ion.qubit -> !rtio.channel<"dds", ?> + ``` + + During the channel resolution stage, this operation will be replaced by a + static `rtio.channel` operation once the qubit identity and channel mapping + are determined at compile time. + }]; + + let arguments = (ins AnyType:$qubit); + let results = (outs RTIOChannelType:$channel); + + let assemblyFormat = [{ + $qubit attr-dict `:` type(operands) `->` type(results) + }]; +} + +def RTIOPulseOp : RTIO_Op<"pulse"> { + let summary = "Generate a event-based pulse on a channel"; + let description = [{ + Generate a pulse on the channel. Returns an `!rtio.event` handle that can be used with + `rtio.sync` to wait for completion or to be used by other pulse operations with the `wait` + operand. + + Parameters: + - channel: Target output channel + - duration: Pulse duration in machine units (mu) + - frequency: frequency in Hz + - phase: Phase in radians + - wait: Optional events to wait for before starting this pulse (for sequencing) + + Returns: + - event: Handle to the pending pulse `!rtio.event` + + Example: + ```mlir + %dur = arith.constant ... : i64 + %freq = arith.constant ... : f64 + %phase = arith.constant ... : f64 + + %event0 = rtio.pulse %ch0 duration(%dur) frequency(%freq) phase(%phase) + : !rtio.channel<"dds", 0> -> !rtio.event + %event1 = rtio.pulse %ch1 duration(%dur) frequency(%freq) phase(%phase) + : !rtio.channel<"dds", 1> -> !rtio.event + + %sync = rtio.sync %event0, %event1 : !rtio.event + + %event2 = rtio.pulse %ch2 duration(%dur) frequency(%freq) phase(%phase) wait(%sync) + : !rtio.channel<"dds", 0> -> !rtio.event + ``` + }]; + + let arguments = (ins + RTIOChannelType:$channel, + I64:$duration, + F64:$frequency, + F64:$phase, + Optional:$wait + ); + + let results = (outs RTIOEventType:$event); + + let assemblyFormat = [{ + $channel `duration` `(` $duration `)` `frequency` `(` $frequency `)` `phase` `(` $phase `)` + (`wait` `(` $wait^ `)`)? attr-dict `:` type($channel) `->` type($event) + }]; +} + +def RTIOSyncOp : RTIO_Op<"sync"> { + let summary = "Synchronization barrier for event-based operations"; + let description = [{ + Wait for all specified events to complete before proceeding. + Returns a new event handle + + Example: + ```mlir + // Three pulses start in parallel + %event0 = rtio.pulse %ch0 duration(%dur0) frequency(%freq0) phase(%phase0) + : !rtio.channel<"dds", 0> -> !rtio.event + %event1 = rtio.pulse %ch1 duration(%dur1) frequency(%freq1) phase(%phase1) + : !rtio.channel<"dds", 1> -> !rtio.event + + // Wait for both to complete, get sync point + %sync = rtio.sync %event0, %event1 : !rtio.event + ``` + + The sync operation tracks dependencies between events, making it easier for the + event scheduling. + }]; + + let arguments = (ins Variadic:$events); + let results = (outs RTIOEventType:$sync_event); + let assemblyFormat = "$events attr-dict `:` type($sync_event)"; + + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// Timeline-Based IR (ARTIQ compatible) +//===----------------------------------------------------------------------===// + +def RTIONowOp : RTIO_Op<"now", [Pure]> { + let summary = "Read the current timeline cursor"; + let description = [{ + Returns the current value of the timeline cursor in machine units (mu). + + Example: + ```mlir + %t0 = rtio.now : i64 + rtio.delay 100 : i64 + %t1 = rtio.now : i64 + // %t1 = %t0 + 100 + ``` + }]; + + let results = (outs I64:$time); + let assemblyFormat = "attr-dict `:` type($time)"; +} + +def RTIODelayOp : RTIO_Op<"delay"> { + let summary = "Advance the timeline cursor by a relative duration"; + let description = [{ + Advances the timeline cursor by the specified duration in machine units. + Equivalent to `rtio.at(rtio.now() + duration)`. + + Example: + ```mlir + // case1 + rtio.delay 1000 : i64 + + // case2 + rtio.delay %dur : i64 + ``` + }]; + + let arguments = (ins I64:$duration); + let assemblyFormat = "$duration attr-dict `:` type($duration)"; +} + +def RTIOAtOp : RTIO_Op<"at"> { + let summary = "Set the timeline cursor to an absolute timestamp"; + let description = [{ + Moves the timeline cursor to a specific absolute timestamp. + + Pre-condition: The new time must satisfy `t >= rtio_counter_mu()` (hardware time), + otherwise an RTIO underflow error occurs at runtime. Note that `t` can be less than + the current timeline cursor value (from `rtio.now`), which is how parallel execution + is implemented. + + Example: + ```mlir + %t_start = rtio.now : i64 // e.g. t_start = 1000 mu + + // First pulse lane + rtio.on %ch0 + rtio.delay 500 : i64 + rtio.off %ch0 + // rtio.now would return 1500 mu (1000 mu + 500 mu) + %t0 = rtio.now : i64 + + // Second pulse lane (parallel with first lane) + + // Rewind to t_start = 1000 mu (< 1500 mu, but >= hardware counter) + rtio.at %t_start : i64 + rtio.on %ch1 + rtio.delay 300 : i64 + rtio.off %ch1 + // rtio.now would return 1300 mu (1000 mu + 300 mu) + %t1 = rtio.now : i64 + + // Sync: advance to the maximum end time (1500 mu) + %t_max = arith.maxui %t0, %t1 : i64 + rtio.at %t_max : i64 // advance to 1500 mu + ``` + }]; + + let arguments = (ins I64:$time); + let assemblyFormat = "$time attr-dict `:` type($time)"; +} + +def RTIOSetFrequencyOp : RTIO_Op<"set_frequency"> { + let summary = "Configure the frequency of a DDS channel"; + let description = [{ + Programs the DDS frequency tuning word (ftw) to generate a specific frequency. + The frequency is specified in Hz as a f64. + + Example: + ```mlir + %freq = arith.constant 1.266300000e10 : f64 // 12.663 GHz + rtio.set_frequency %ch0, %freq : !rtio.channel<"dds", 0>, f64 + ``` + + Assume phase is 0 and amplitude is 1. + It's equivalent to call the `AD9910.set` function with the following arguments: + + ```llvm + %set_func = load ptr, ptr @F.artiq.coredevice.ad9910.AD9910.set + call double %set_func( + ptr %env, + ptr %ch0, + double 1.266300000e+10, ; frequency + double 0.000000e+00, ; phase + double 1.000000e+00 ; amplitude + ) + ``` + }]; + + let arguments = (ins + RTIOChannelType:$channel, + F64:$frequency + ); + let assemblyFormat = "$channel `,` $frequency attr-dict `:` type($channel) `,` type($frequency)"; +} + +def RTIOSetPhaseOp : RTIO_Op<"set_phase"> { + let summary = "Configure the phase of a DDS channel"; + let description = [{ + Programs the DDS phase offset register to set the carrier phase. + The phase is specified in radians as a f64. + + Example: + ```mlir + // pi/2 phase shift + %pi_2 = arith.constant 1.5707963267948966 : f64 + rtio.set_phase %ch0, %pi_2 : !rtio.channel<"dds", 0>, f64 + ``` + }]; + + let arguments = (ins + RTIOChannelType:$channel, + F64:$phase + ); + let assemblyFormat = "$channel `,` $phase attr-dict `:` type($channel) `,` type($phase)"; +} + +def RTIOSetAmplitudeOp : RTIO_Op<"set_amplitude"> { + let summary = "Configure the amplitude of a DDS channel"; + let description = [{ + Programs the DDS amplitude + + Example: + ```mlir + %amp = arith.constant 1.0 : f64 + rtio.set_amplitude %ch0, %amp : !rtio.channel<"dds", 0>, f64 + ``` + }]; + + let arguments = (ins + RTIOChannelType:$channel, + F64:$amplitude + ); + let assemblyFormat = "$channel `,` $amplitude attr-dict `:` type($channel) `,` type($amplitude)"; +} + +def RTIOOnOp : RTIO_Op<"on"> { + let summary = "Turn on a channel output"; + let description = [{ + Activates the output of a channel at the current cursor. + For DDS channels, this typically enables the RF output switch. + + Example: + ```mlir + rtio.at %t_start : !rtio.time + rtio.on %ch + rtio.delay 1000 : i64 + rtio.off %ch + ``` + }]; + + let arguments = (ins RTIOChannelType:$channel); + let assemblyFormat = "$channel attr-dict `:` type($channel)"; +} + +def RTIOOffOp : RTIO_Op<"off"> { + let summary = "Turn off a channel output"; + let description = [{ + Deactivates the output of a channel at the current cursor. + For DDS channels, this typically disables the RF output switch. + + Example: + ```mlir + rtio.off %ch : !rtio.channel<"dds", 0> + ``` + }]; + + let arguments = (ins RTIOChannelType:$channel); + let assemblyFormat = "$channel attr-dict `:` type($channel)"; +} + +#endif // RTIO_OPS + diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt index 687046ef64..dadd487587 100644 --- a/mlir/lib/CMakeLists.txt +++ b/mlir/lib/CMakeLists.txt @@ -8,4 +8,5 @@ add_subdirectory(MBQC) add_subdirectory(Mitigation) add_subdirectory(QEC) add_subdirectory(Quantum) +add_subdirectory(RTIO) add_subdirectory(Test) diff --git a/mlir/lib/RTIO/CMakeLists.txt b/mlir/lib/RTIO/CMakeLists.txt new file mode 100644 index 0000000000..f33061b2d8 --- /dev/null +++ b/mlir/lib/RTIO/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/lib/RTIO/IR/CMakeLists.txt b/mlir/lib/RTIO/IR/CMakeLists.txt new file mode 100644 index 0000000000..ba921b2f67 --- /dev/null +++ b/mlir/lib/RTIO/IR/CMakeLists.txt @@ -0,0 +1,11 @@ +add_mlir_library(MLIRRTIO + RTIODialect.cpp + RTIOOps.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/RTIO + + DEPENDS + MLIRRTIOOpsIncGen + MLIRRTIOAttributesIncGen +) diff --git a/mlir/lib/RTIO/IR/RTIODialect.cpp b/mlir/lib/RTIO/IR/RTIODialect.cpp new file mode 100644 index 0000000000..5bb9fa4944 --- /dev/null +++ b/mlir/lib/RTIO/IR/RTIODialect.cpp @@ -0,0 +1,150 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "RTIO/IR/RTIODialect.h" +#include "RTIO/IR/RTIOOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace catalyst::rtio; + +//===----------------------------------------------------------------------===// +// RTIO Dialect +//===----------------------------------------------------------------------===// + +#include "RTIO/IR/RTIOOpsDialect.cpp.inc" + +static ParseResult parseChannelTypeBody(AsmParser &parser, std::string &kind, ArrayAttr &qualifiers, + IntegerAttr &channelId) +{ + // 1. Parse kind (string) + if (failed(parser.parseString(&kind))) + return failure(); + + // 2. Parse optional qualifiers: `, [...]` + qualifiers = nullptr; + if (succeeded(parser.parseOptionalComma())) { + if (succeeded(parser.parseOptionalLSquare())) { + SmallVector quals; + if (failed(parser.parseOptionalRSquare())) { + do { + Attribute attr; + if (failed(parser.parseAttribute(attr))) + return failure(); + quals.push_back(attr); + } while (succeeded(parser.parseOptionalComma())); + + if (failed(parser.parseRSquare())) + return failure(); + } + qualifiers = parser.getBuilder().getArrayAttr(quals); + + // After qualifiers, parse comma for channelId + if (failed(parser.parseOptionalComma())) { + channelId = parser.getBuilder().getI64IntegerAttr(-1); + return success(); + } + } + // Comma but no `[`, so this comma is for channelId + } + else { + // No comma at all, no qualifiers and no channelId + channelId = parser.getBuilder().getI64IntegerAttr(-1); + return success(); + } + + // 3. Parse channelId: `?` or non-negative integer + if (succeeded(parser.parseOptionalQuestion())) { + channelId = parser.getBuilder().getI64IntegerAttr(-1); + return success(); + } + + int64_t id; + if (failed(parser.parseInteger(id))) + return failure(); + + if (id < 0) + return parser.emitError(parser.getCurrentLocation(), + "static channel ID must be non-negative"); + + channelId = parser.getBuilder().getI64IntegerAttr(id); + return success(); +} + +// Custom printer for the entire channel type body +static void printChannelTypeBody(AsmPrinter &printer, StringRef kind, ArrayAttr qualifiers, + IntegerAttr channelId) +{ + // 1. Print kind + printer << "\"" << kind << "\""; + + // 2. Print qualifiers if present + if (qualifiers && !qualifiers.empty()) { + printer << ", ["; + llvm::interleaveComma(qualifiers, printer, + [&](Attribute attr) { printer.printAttribute(attr); }); + printer << "]"; + } + + // 3. Print channelId if present (and not default -1) + if (channelId) { + int64_t id = channelId.getInt(); + if (id >= 0 || (qualifiers && !qualifiers.empty())) { + printer << ", "; + if (id < 0) { + printer << "?"; + } + else { + printer << id; + } + } + } +} + +void catalyst::rtio::RTIODialect::initialize() +{ + addTypes< +#define GET_TYPEDEF_LIST +#include "RTIO/IR/RTIOOpsTypes.cpp.inc" + >(); + + addOperations< +#define GET_OP_LIST +#include "RTIO/IR/RTIOOps.cpp.inc" + >(); +} + +// Not support any custom attributes yet, might be supported in the future +Attribute catalyst::rtio::RTIODialect::parseAttribute(DialectAsmParser &parser, Type type) const +{ + parser.emitError(parser.getNameLoc(), "no dialect attributes are supported"); + return {}; +} + +void catalyst::rtio::RTIODialect::printAttribute(Attribute attr, DialectAsmPrinter &printer) const +{ + llvm_unreachable("no dialect attributes are supported"); +} + +//===----------------------------------------------------------------------===// +// RTIO Type Definitions +//===----------------------------------------------------------------------===// + +#define GET_TYPEDEF_CLASSES +#include "RTIO/IR/RTIOOpsTypes.cpp.inc" + +#define GET_ATTRDEF_CLASSES +#include "RTIO/IR/RTIOAttributes.cpp.inc" diff --git a/mlir/lib/RTIO/IR/RTIOOps.cpp b/mlir/lib/RTIO/IR/RTIOOps.cpp new file mode 100644 index 0000000000..19d5a147fc --- /dev/null +++ b/mlir/lib/RTIO/IR/RTIOOps.cpp @@ -0,0 +1,36 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "RTIO/IR/RTIOOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" + +using namespace mlir; +using namespace catalyst::rtio; + +//===----------------------------------------------------------------------===// +// RTIO Operations +//===----------------------------------------------------------------------===// + +LogicalResult RTIOSyncOp::verify() +{ + // Ensure at least one event is provided + if (getEvents().empty()) { + return emitOpError("requires at least one event to synchronize"); + } + return success(); +} + +#define GET_OP_CLASSES +#include "RTIO/IR/RTIOOps.cpp.inc" diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt index 06339bc7ec..417bbba640 100644 --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -27,6 +27,7 @@ set(TEST_SUITES MBQC cli QEC + RTIO ) diff --git a/mlir/test/RTIO/VerifierTest.mlir b/mlir/test/RTIO/VerifierTest.mlir new file mode 100644 index 0000000000..3cd171140b --- /dev/null +++ b/mlir/test/RTIO/VerifierTest.mlir @@ -0,0 +1,192 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RUN: quantum-opt %s --split-input-file --verify-diagnostics + +//////////////////////// +// Channel Operations // +//////////////////////// + +func.func @channel_good() { + // Smoke test for valid channel operations + %ch0 = rtio.channel : !rtio.channel<"dds", 0> + %ch0_q1 = rtio.channel : !rtio.channel<"dds", [0], 0> + %ch0_q2 = rtio.channel : !rtio.channel<"dds", [0, "t0"], 0> + %ch_explict_dyn = rtio.channel : !rtio.channel<"dds", ?> + %ch_explict_dyn_q1 = rtio.channel : !rtio.channel<"dds", [0], ?> + %ch_explict_dyn_q2 = rtio.channel : !rtio.channel<"dds", [0, "t0"], ?> + + %ch0_implicit_dyn = rtio.channel : !rtio.channel<"dds"> + %ch_implicit_dyn_q1 = rtio.channel : !rtio.channel<"dds", [0]> + %ch_implicit_dyn_q2 = rtio.channel : !rtio.channel<"dds", [0, "t0"]> + return +} + +// ----- + +func.func @channel_negative_id() { + // expected-error@+1 {{static channel ID must be non-negative}} + %ch0_negative = rtio.channel : !rtio.channel<"dds", -1> +} + +// ----- + +func.func @qubit_to_channel_good(%qubit: !ion.qubit) { + // Smoke test for qubit_to_channel + %ch0 = rtio.qubit_to_channel %qubit : !ion.qubit -> !rtio.channel<"dds", ["transition_0"], ?> + return +} + +// ----- + +//////////////////////////// +// Event-Based Operations // +//////////////////////////// + +func.func @pulse_basic(%dur: i64, %freq: f64, %phase: f64) { + %ch0 = rtio.channel : !rtio.channel<"dds", 0> + + %event = rtio.pulse %ch0 duration(%dur) frequency(%freq) phase(%phase) + : !rtio.channel<"dds", 0> -> !rtio.event + + return +} + +// ----- + +func.func @pulse_with_wait(%dur: i64, %freq: f64, %phase: f64) { + %ch0 = rtio.channel : !rtio.channel<"dds", 0> + + %event0 = rtio.pulse %ch0 duration(%dur) frequency(%freq) phase(%phase) + : !rtio.channel<"dds", 0> -> !rtio.event + + %event1 = rtio.pulse %ch0 duration(%dur) frequency(%freq) phase(%phase) wait(%event0) + : !rtio.channel<"dds", 0> -> !rtio.event + + return +} + + +// ----- + +func.func @sync_basic(%dur: i64, %freq: f64, %phase: f64) { + %ch0 = rtio.channel : !rtio.channel<"dds", 0> + %ch1 = rtio.channel : !rtio.channel<"dds", 1> + %ch2 = rtio.channel : !rtio.channel<"dds", 2> + %ch3 = rtio.channel : !rtio.channel<"dds", 3> + + // sync single event + %event0 = rtio.pulse %ch0 duration(%dur) frequency(%freq) phase(%phase) + : !rtio.channel<"dds", 0> -> !rtio.event + + %event1 = rtio.pulse %ch1 duration(%dur) frequency(%freq) phase(%phase) + : !rtio.channel<"dds", 1> -> !rtio.event + + // sync multiple events + %sync1 = rtio.sync %event0, %event1 : !rtio.event + + %event2 = rtio.pulse %ch2 duration(%dur) frequency(%freq) phase(%phase) + : !rtio.channel<"dds", 2> -> !rtio.event + + %event3 = rtio.pulse %ch3 duration(%dur) frequency(%freq) phase(%phase) + : !rtio.channel<"dds", 3> -> !rtio.event + + %sync2 = rtio.sync %sync1, %event2, %event3 : !rtio.event + + return +} + +// ----- + +func.func @sync_no_events() { + // expected-error@+1 {{requires at least one event to synchronize}} + %sync = rtio.sync : !rtio.event + + return +} + +// ----- + +/////////////////////////////// +// Timeline-Based Operations // +/////////////////////////////// + +func.func @timeline_now() { + %t = rtio.now : i64 + return +} + +// ----- + +func.func @timeline_at() { + %t = arith.constant 1000 : i64 + %delay = arith.constant 500 : i64 + %now = rtio.now : i64 + rtio.at %t : i64 + rtio.delay %delay : i64 + + // rewind to the start time + rtio.at %now : i64 + return +} + +// ----- + +func.func @timeline_dalay() { + %delay = arith.constant 500 : i64 + rtio.delay %delay : i64 + return +} + +// ----- + +func.func @set_frequency_dds_good(%freq: f64) { + %ch = rtio.channel : !rtio.channel<"dds", 0> + rtio.set_frequency %ch, %freq : !rtio.channel<"dds", 0>, f64 + return +} + + +// ----- + +func.func @set_phase_dds_good(%phase: f64) { + %ch = rtio.channel : !rtio.channel<"dds", 0> + rtio.set_phase %ch, %phase : !rtio.channel<"dds", 0>, f64 + + return +} + +// ----- + +func.func @set_amplitude_dds_good(%amp: f64) { + %ch = rtio.channel : !rtio.channel<"dds", 0> + rtio.set_amplitude %ch, %amp : !rtio.channel<"dds", 0>, f64 + return +} + +// ----- + +func.func @ttl_on_dds_good() { + %ch = rtio.channel : !rtio.channel<"dds", 0> + rtio.on %ch : !rtio.channel<"dds", 0> + return +} + +// ----- + +func.func @ttl_off_dds_good() { + %ch = rtio.channel : !rtio.channel<"dds", 0> + rtio.off %ch : !rtio.channel<"dds", 0> + return +} diff --git a/mlir/tools/quantum-opt/CMakeLists.txt b/mlir/tools/quantum-opt/CMakeLists.txt index 617398b03c..6739a31db2 100644 --- a/mlir/tools/quantum-opt/CMakeLists.txt +++ b/mlir/tools/quantum-opt/CMakeLists.txt @@ -24,6 +24,7 @@ set(LIBS mitigation-transforms MLIRIon ion-transforms + MLIRRTIO MLIRCatalystTest MLIRCatalystUtils MLIRTestDialect diff --git a/mlir/tools/quantum-opt/quantum-opt.cpp b/mlir/tools/quantum-opt/quantum-opt.cpp index a9b4140c9d..cef7ee29fa 100644 --- a/mlir/tools/quantum-opt/quantum-opt.cpp +++ b/mlir/tools/quantum-opt/quantum-opt.cpp @@ -43,6 +43,7 @@ #include "QEC/IR/QECDialect.h" #include "Quantum/IR/QuantumDialect.h" #include "Quantum/Transforms/BufferizableOpInterfaceImpl.h" +#include "RTIO/IR/RTIODialect.h" #include "RegisterAllPasses.h" namespace test { @@ -69,6 +70,7 @@ int main(int argc, char **argv) registry.insert(); registry.insert(); registry.insert(); + registry.insert(); registry.insert(); catalyst::registerBufferizableOpInterfaceExternalModels(registry); From f127e6a2e861b1b7b6960cd2a233c53b7eaf3328 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Mon, 10 Nov 2025 17:18:55 -0500 Subject: [PATCH 02/51] reformatting --- mlir/include/RTIO/IR/RTIODialect.h | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/include/RTIO/IR/RTIODialect.h b/mlir/include/RTIO/IR/RTIODialect.h index 38ca8c694b..4f0e7be685 100644 --- a/mlir/include/RTIO/IR/RTIODialect.h +++ b/mlir/include/RTIO/IR/RTIODialect.h @@ -22,4 +22,3 @@ #define GET_TYPEDEF_CLASSES #include "RTIO/IR/RTIOOpsTypes.h.inc" - From eb19fa3048f7961107b8a1f2d8f3d698d480b715 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Mon, 10 Nov 2025 23:18:26 -0500 Subject: [PATCH 03/51] pulse should use f64 for duration --- mlir/include/RTIO/IR/RTIOOps.td | 10 +++++----- mlir/test/RTIO/VerifierTest.mlir | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/mlir/include/RTIO/IR/RTIOOps.td b/mlir/include/RTIO/IR/RTIOOps.td index 88fabbcb63..9b5599f924 100644 --- a/mlir/include/RTIO/IR/RTIOOps.td +++ b/mlir/include/RTIO/IR/RTIOOps.td @@ -90,7 +90,7 @@ def RTIOPulseOp : RTIO_Op<"pulse"> { Parameters: - channel: Target output channel - - duration: Pulse duration in machine units (mu) + - duration: Pulse duration in seconds (f64) - frequency: frequency in Hz - phase: Phase in radians - wait: Optional events to wait for before starting this pulse (for sequencing) @@ -100,9 +100,9 @@ def RTIOPulseOp : RTIO_Op<"pulse"> { Example: ```mlir - %dur = arith.constant ... : i64 - %freq = arith.constant ... : f64 - %phase = arith.constant ... : f64 + %dur = arith.constant 1.66e-07 : f64 + %freq = arith.constant 1.266e10 : f64 + %phase = arith.constant 0.0 : f64 %event0 = rtio.pulse %ch0 duration(%dur) frequency(%freq) phase(%phase) : !rtio.channel<"dds", 0> -> !rtio.event @@ -118,7 +118,7 @@ def RTIOPulseOp : RTIO_Op<"pulse"> { let arguments = (ins RTIOChannelType:$channel, - I64:$duration, + F64:$duration, F64:$frequency, F64:$phase, Optional:$wait diff --git a/mlir/test/RTIO/VerifierTest.mlir b/mlir/test/RTIO/VerifierTest.mlir index 3cd171140b..c2f7875444 100644 --- a/mlir/test/RTIO/VerifierTest.mlir +++ b/mlir/test/RTIO/VerifierTest.mlir @@ -54,7 +54,7 @@ func.func @qubit_to_channel_good(%qubit: !ion.qubit) { // Event-Based Operations // //////////////////////////// -func.func @pulse_basic(%dur: i64, %freq: f64, %phase: f64) { +func.func @pulse_basic(%dur: f64, %freq: f64, %phase: f64) { %ch0 = rtio.channel : !rtio.channel<"dds", 0> %event = rtio.pulse %ch0 duration(%dur) frequency(%freq) phase(%phase) @@ -65,7 +65,7 @@ func.func @pulse_basic(%dur: i64, %freq: f64, %phase: f64) { // ----- -func.func @pulse_with_wait(%dur: i64, %freq: f64, %phase: f64) { +func.func @pulse_with_wait(%dur: f64, %freq: f64, %phase: f64) { %ch0 = rtio.channel : !rtio.channel<"dds", 0> %event0 = rtio.pulse %ch0 duration(%dur) frequency(%freq) phase(%phase) @@ -80,7 +80,7 @@ func.func @pulse_with_wait(%dur: i64, %freq: f64, %phase: f64) { // ----- -func.func @sync_basic(%dur: i64, %freq: f64, %phase: f64) { +func.func @sync_basic(%dur: f64, %freq: f64, %phase: f64) { %ch0 = rtio.channel : !rtio.channel<"dds", 0> %ch1 = rtio.channel : !rtio.channel<"dds", 1> %ch2 = rtio.channel : !rtio.channel<"dds", 2> From e5b170f29fdf7d3d15ae4fc340564f1953e0781e Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Mon, 17 Nov 2025 09:55:16 -0500 Subject: [PATCH 04/51] Update mlir/include/RTIO/IR/RTIODialect.td Co-authored-by: Mehrdad Malek <39844030+mehrdad2m@users.noreply.github.com> --- mlir/include/RTIO/IR/RTIODialect.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/RTIO/IR/RTIODialect.td b/mlir/include/RTIO/IR/RTIODialect.td index 0de3669c13..dda62321b4 100644 --- a/mlir/include/RTIO/IR/RTIODialect.td +++ b/mlir/include/RTIO/IR/RTIODialect.td @@ -80,7 +80,7 @@ def RTIOChannelType : RTIO_Type<"Channel", "channel"> { // Simple DDS, channel TBD !rtio.channel<"dds", ?> - // DDS for one two qualifiers 0 and "t0", channel TBD + // DDS channel with two qualifiers 0 and "t0", channel ID TBD !rtio.channel<"dds", [0, "transition_0"], ?> // DDS on hardware channel 0 From 6c9e90dd555b823853bc4b3221a2e480de5d978d Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Mon, 17 Nov 2025 09:55:38 -0500 Subject: [PATCH 05/51] Update mlir/include/RTIO/IR/RTIODialect.td Co-authored-by: Mehrdad Malek <39844030+mehrdad2m@users.noreply.github.com> --- mlir/include/RTIO/IR/RTIODialect.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/RTIO/IR/RTIODialect.td b/mlir/include/RTIO/IR/RTIODialect.td index dda62321b4..8c5ca32aaa 100644 --- a/mlir/include/RTIO/IR/RTIODialect.td +++ b/mlir/include/RTIO/IR/RTIODialect.td @@ -69,7 +69,7 @@ def RTIOChannelType : RTIO_Type<"Channel", "channel"> { Syntax: ``` !rtio.channel // Dynamic: kind only - !rtio.channel // Static: kind + channel N + !rtio.channel // Static: kind + channel ID !rtio.channel // Dynamic: with qualifiers !rtio.channel // Static: with qualifiers + channel N ``` From a50294489d358283223fe061d114ba389349597f Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Mon, 17 Nov 2025 09:55:48 -0500 Subject: [PATCH 06/51] Update mlir/include/RTIO/IR/RTIODialect.td Co-authored-by: Mehrdad Malek <39844030+mehrdad2m@users.noreply.github.com> --- mlir/include/RTIO/IR/RTIODialect.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/RTIO/IR/RTIODialect.td b/mlir/include/RTIO/IR/RTIODialect.td index 8c5ca32aaa..d3c6cf2403 100644 --- a/mlir/include/RTIO/IR/RTIODialect.td +++ b/mlir/include/RTIO/IR/RTIODialect.td @@ -71,7 +71,7 @@ def RTIOChannelType : RTIO_Type<"Channel", "channel"> { !rtio.channel // Dynamic: kind only !rtio.channel // Static: kind + channel ID !rtio.channel // Dynamic: with qualifiers - !rtio.channel // Static: with qualifiers + channel N + !rtio.channel // Static: with qualifiers + channel ID ``` Examples: From 94927c3c429a45e988c6732a269af9738eebb362 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Mon, 17 Nov 2025 09:56:02 -0500 Subject: [PATCH 07/51] Update mlir/include/RTIO/IR/RTIODialect.td Co-authored-by: Mehrdad Malek <39844030+mehrdad2m@users.noreply.github.com> --- mlir/include/RTIO/IR/RTIODialect.td | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/include/RTIO/IR/RTIODialect.td b/mlir/include/RTIO/IR/RTIODialect.td index d3c6cf2403..af55265dd6 100644 --- a/mlir/include/RTIO/IR/RTIODialect.td +++ b/mlir/include/RTIO/IR/RTIODialect.td @@ -94,8 +94,8 @@ def RTIOChannelType : RTIO_Type<"Channel", "channel"> { // After channel resolution: Resolved to hardware channel 1 !rtio.channel<"dds", [0], 1> - // Note: The qualifiers are used provided addtional information to the channel type. - // And it will used to distinguish different channels with the same kind. + // Note: The qualifiers are used to provide additional information to the channel type. + // And they will be used to distinguish different channels with the same kind. ``` }]; From e9b1e8447467c900639bb1e6ac099fb2235edd5b Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Mon, 17 Nov 2025 09:56:09 -0500 Subject: [PATCH 08/51] Update mlir/include/RTIO/IR/RTIOOps.td Co-authored-by: Mehrdad Malek <39844030+mehrdad2m@users.noreply.github.com> --- mlir/include/RTIO/IR/RTIOOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/RTIO/IR/RTIOOps.td b/mlir/include/RTIO/IR/RTIOOps.td index 9b5599f924..b7ad983963 100644 --- a/mlir/include/RTIO/IR/RTIOOps.td +++ b/mlir/include/RTIO/IR/RTIOOps.td @@ -82,7 +82,7 @@ def RTIOQubitToChannelOp : RTIO_Op<"qubit_to_channel"> { } def RTIOPulseOp : RTIO_Op<"pulse"> { - let summary = "Generate a event-based pulse on a channel"; + let summary = "Generate an event-based pulse on a channel"; let description = [{ Generate a pulse on the channel. Returns an `!rtio.event` handle that can be used with `rtio.sync` to wait for completion or to be used by other pulse operations with the `wait` From 1452e7bc4cb6976c25dfa78921e89ddbd5218ae2 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Mon, 17 Nov 2025 09:56:20 -0500 Subject: [PATCH 09/51] Update mlir/include/RTIO/IR/RTIODialect.td Co-authored-by: Mehrdad Malek <39844030+mehrdad2m@users.noreply.github.com> --- mlir/include/RTIO/IR/RTIODialect.td | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/include/RTIO/IR/RTIODialect.td b/mlir/include/RTIO/IR/RTIODialect.td index af55265dd6..f19dc267bf 100644 --- a/mlir/include/RTIO/IR/RTIODialect.td +++ b/mlir/include/RTIO/IR/RTIODialect.td @@ -34,7 +34,6 @@ def RTIO_Dialect : Dialect { // TODO: Do we need the separate Timeline IR for artiq family? 2. Timeline IR (low-level): Stateful operations with implicit time cursor - ``` }]; let name = "rtio"; From 38242f5b730549172542224314d1d66cf7c8c642 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Mon, 17 Nov 2025 09:56:30 -0500 Subject: [PATCH 10/51] Update mlir/include/RTIO/IR/RTIOOps.td Co-authored-by: Mehrdad Malek <39844030+mehrdad2m@users.noreply.github.com> --- mlir/include/RTIO/IR/RTIOOps.td | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/include/RTIO/IR/RTIOOps.td b/mlir/include/RTIO/IR/RTIOOps.td index b7ad983963..eec63bd746 100644 --- a/mlir/include/RTIO/IR/RTIOOps.td +++ b/mlir/include/RTIO/IR/RTIOOps.td @@ -59,8 +59,8 @@ def RTIOQubitToChannelOp : RTIO_Op<"qubit_to_channel"> { let summary = "Map a qubit to an RTIO channel"; let description = [{ It's a temporary operation that will be lowered to a static `rtio.channel` operation during - the channel resolution stage. The purpose of this operation is to allow the qubit from a - from a high-level dialect to be mapped to an `!rtio.channel`. + the channel resolution stage. The purpose of this operation is to allow a qubit from a + high-level dialect to be mapped to an `!rtio.channel`. Example: ```mlir From 6c00116b8858098e78201be3e61003bd98432806 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 18 Nov 2025 00:05:54 -0500 Subject: [PATCH 11/51] add empty op --- mlir/include/RTIO/IR/RTIOOps.td | 41 +++++++++++++++++++++++++++++++- mlir/lib/RTIO/IR/RTIODialect.cpp | 5 ++++ mlir/test/RTIO/VerifierTest.mlir | 7 ++++++ 3 files changed, 52 insertions(+), 1 deletion(-) diff --git a/mlir/include/RTIO/IR/RTIOOps.td b/mlir/include/RTIO/IR/RTIOOps.td index 9b5599f924..9b02ad15fb 100644 --- a/mlir/include/RTIO/IR/RTIOOps.td +++ b/mlir/include/RTIO/IR/RTIOOps.td @@ -65,7 +65,7 @@ def RTIOQubitToChannelOp : RTIO_Op<"qubit_to_channel"> { Example: ```mlir // Map ion qubit to dds channel with dynamic channel id - %ch = rtio.qubit_to_channel %ion_qubit : !ion.qubit -> !rtio.channel<"dds", ?> + %ch = rtio.qubit_tos_channel %ion_qubit : !ion.qubit -> !rtio.channel<"dds", ?> ``` During the channel resolution stage, this operation will be replaced by a @@ -130,6 +130,16 @@ def RTIOPulseOp : RTIO_Op<"pulse"> { $channel `duration` `(` $duration `)` `frequency` `(` $frequency `)` `phase` `(` $phase `)` (`wait` `(` $wait^ `)`)? attr-dict `:` type($channel) `->` type($event) }]; + + let extraClassDeclaration = [{ + void setWait(mlir::Value waitEvent) { + if (waitEvent) { + getWaitMutable().assign(waitEvent); + } else { + getWaitMutable().clear(); + } + } + }]; } def RTIOSyncOp : RTIO_Op<"sync"> { @@ -161,6 +171,35 @@ def RTIOSyncOp : RTIO_Op<"sync"> { let hasVerifier = 1; } +def RTIOEmptyOp : RTIO_Op<"empty", [Pure]> { + let summary = "Create an empty event for sequencing"; + let description = [{ + Creates an empty event that can be used for initializing the first event + or when you need an event handle without performing any actual hardware operation. + + The empty event acts as a no-op synchronization point that completes immediately, + allowing to establish event chains without triggering actions. + + Example: + ```mlir + // Initialize with empty event for loop + %init_event = rtio.empty : !rtio.event + %result = scf.for %i = %c0 to %c10 step %c1 iter_args(%event = %init_event) -> (!rtio.event) { + // First pulse waits on event from previous iteration (or init_event in first iteration) + %e0 = rtio.pulse %ch duration(%dur) frequency(%freq) phase(%phase) wait(%event) + : !rtio.channel<"dds", 0> -> !rtio.event + %e1 = rtio.pulse %ch duration(%dur) frequency(%freq) phase(%phase) wait(%event) + : !rtio.channel<"dds", 0> -> !rtio.event + %sync = rtio.sync %e0, %e1 : !rtio.event + scf.yield %sync : !rtio.event + } + ``` + }]; + + let results = (outs RTIOEventType:$event); + let assemblyFormat = "attr-dict `:` type($event)"; +} + //===----------------------------------------------------------------------===// // Timeline-Based IR (ARTIQ compatible) //===----------------------------------------------------------------------===// diff --git a/mlir/lib/RTIO/IR/RTIODialect.cpp b/mlir/lib/RTIO/IR/RTIODialect.cpp index 5bb9fa4944..cf7c3f3e1e 100644 --- a/mlir/lib/RTIO/IR/RTIODialect.cpp +++ b/mlir/lib/RTIO/IR/RTIODialect.cpp @@ -111,6 +111,11 @@ static void printChannelTypeBody(AsmPrinter &printer, StringRef kind, ArrayAttr printer << id; } } + } else { + if (qualifiers && !qualifiers.empty()) { + printer << ", "; + } + printer << "?"; } } diff --git a/mlir/test/RTIO/VerifierTest.mlir b/mlir/test/RTIO/VerifierTest.mlir index c2f7875444..1ebd3db529 100644 --- a/mlir/test/RTIO/VerifierTest.mlir +++ b/mlir/test/RTIO/VerifierTest.mlir @@ -190,3 +190,10 @@ func.func @ttl_off_dds_good() { rtio.off %ch : !rtio.channel<"dds", 0> return } + +// ----- + +func.func @empty_good() { + %empty = rtio.empty : !rtio.event + return +} From ef5c3a61f54b2b95c0e73e6075e375e6ad1d2791 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 18 Nov 2025 00:10:07 -0500 Subject: [PATCH 12/51] fix formatting --- mlir/lib/RTIO/IR/RTIODialect.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/lib/RTIO/IR/RTIODialect.cpp b/mlir/lib/RTIO/IR/RTIODialect.cpp index cf7c3f3e1e..64e60ac774 100644 --- a/mlir/lib/RTIO/IR/RTIODialect.cpp +++ b/mlir/lib/RTIO/IR/RTIODialect.cpp @@ -111,7 +111,8 @@ static void printChannelTypeBody(AsmPrinter &printer, StringRef kind, ArrayAttr printer << id; } } - } else { + } + else { if (qualifiers && !qualifiers.empty()) { printer << ", "; } From c8193334f0d4118d47d126be989ddeff6e4e6287 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 18 Nov 2025 00:17:05 -0500 Subject: [PATCH 13/51] update td --- mlir/include/RTIO/IR/RTIODialect.td | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mlir/include/RTIO/IR/RTIODialect.td b/mlir/include/RTIO/IR/RTIODialect.td index f19dc267bf..8b3609c664 100644 --- a/mlir/include/RTIO/IR/RTIODialect.td +++ b/mlir/include/RTIO/IR/RTIODialect.td @@ -76,10 +76,15 @@ def RTIOChannelType : RTIO_Type<"Channel", "channel"> { Examples: ```mlir - // Simple DDS, channel TBD + // Simple DDS, channel ID TBD (to be resolved during channel resolution) !rtio.channel<"dds", ?> - // DDS channel with two qualifiers 0 and "t0", channel ID TBD + // DDS channel with qualifiers: + // Qualifiers distinguish different logical channels of the same kind. + // Example: ion 0, transition 0 -> will be distinguished from other transitions of + // the same ion. The mapping of the logical channel to the hardware channel will + // be resolved during the lowering pass from given dialect to RTIO dialect. + // User need to specifiy the mapping logic. !rtio.channel<"dds", [0, "transition_0"], ?> // DDS on hardware channel 0 From a26864c360bfdb8565ce78e2bb9b6845591d1108 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 18 Nov 2025 00:53:19 -0500 Subject: [PATCH 14/51] Add pass to lower ion dialect to rtio dialect --- mlir/include/Ion/Transforms/Passes.td | 21 + mlir/lib/Ion/Transforms/CMakeLists.txt | 3 + mlir/lib/Ion/Transforms/ion-to-rtio.cpp | 937 ++++++++++++++++++++++++ 3 files changed, 961 insertions(+) create mode 100644 mlir/lib/Ion/Transforms/ion-to-rtio.cpp diff --git a/mlir/include/Ion/Transforms/Passes.td b/mlir/include/Ion/Transforms/Passes.td index 7395214414..ef14d16b23 100644 --- a/mlir/include/Ion/Transforms/Passes.td +++ b/mlir/include/Ion/Transforms/Passes.td @@ -52,4 +52,25 @@ def IonConversionPass : Pass<"convert-ion-to-llvm"> { ]; } +def IonToRTIOPass : Pass<"convert-ion-to-rtio", "mlir::ModuleOp"> { + let summary = "Convert Ion dialect operations to RTIO dialect"; + + let dependentDialects = [ + "rtio::RTIODialect", + "arith::ArithDialect", + "func::FuncDialect", + "memref::MemRefDialect", + "scf::SCFDialect", + "quantum::QuantumDialect", + "ion::IonDialect", + "linalg::LinalgDialect", + ]; + + let options = [ + Option<"kernelName", "kernel-name", + "std::string", /*default=*/"\"__kernel__\"", + "Name of the generated kernel function"> + ]; +} + #endif // ION_PASSES diff --git a/mlir/lib/Ion/Transforms/CMakeLists.txt b/mlir/lib/Ion/Transforms/CMakeLists.txt index aea11a3376..45cf7be159 100644 --- a/mlir/lib/Ion/Transforms/CMakeLists.txt +++ b/mlir/lib/Ion/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ set(LIBRARY_NAME ion-transforms) file(GLOB SRC ion-to-llvm.cpp + ion-to-rtio.cpp ConversionPatterns.cpp gates_to_pulses.cpp GatesToPulsesPatterns.cpp @@ -13,6 +14,7 @@ set(LIBS ${dialect_libs} ${conversion_libs} MLIRIon + MLIRRTIO ) set(DEPENDS @@ -36,6 +38,7 @@ target_link_libraries(${LIBRARY_NAME} PRIVATE if(CMAKE_CXX_COMPILER_ID MATCHES ".*Clang") set_source_files_properties( ion-to-llvm.cpp + ion-to-rtio.cpp ConversionPatterns.cpp gates_to_pulses.cpp GatesToPulsesPatterns.cpp diff --git a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp new file mode 100644 index 0000000000..8b93a2b5b1 --- /dev/null +++ b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp @@ -0,0 +1,937 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "Ion/IR/IonDialect.h" +#include "Ion/IR/IonOps.h" +#include "Ion/Transforms/Passes.h" +#include "Quantum/IR/QuantumDialect.h" +#include "Quantum/IR/QuantumOps.h" +#include "RTIO/IR/RTIODialect.h" +#include "RTIO/IR/RTIOOps.h" +#include + +using namespace mlir; +using namespace catalyst; + +namespace catalyst { +namespace ion { + +namespace { + +//===----------------------------------------------------------------------===// +// Helper functions +//===----------------------------------------------------------------------===// + +enum class TraceMode { + Qreg = 0, + Event = 1, +}; + +/// Traces a Value backward through the IR by tracing its dataflow dependencies +/// across control flow and specific quantum operations. +/// +/// Template Parameters: +/// - ModeT: TraceMode enum (Qreg or Event) that controls how quantum.insert +/// operations are handled +/// Qreg mode: Trace to find the source qreg of the given value +/// Event mode: Trace to find all events that contribute to the given value +/// - CallbackT: Callable type that will be invoked for each visited value. +/// May optionally return WalkResult for early termination. +/// +/// Supported Operations: +/// - scf.for +/// - scf.if +/// - ion.parallelprotocol +/// - unrealized_conversion_cast +/// - quantum.extract +/// - quantum.insert +template +auto traceValueWithCallback(Value value, CallbackT &&callback) +{ + WalkResult walkResult = WalkResult::advance(); + std::queue visited; + visited.push(value); + + while (!visited.empty()) { + Value value = visited.front(); + visited.pop(); + + if constexpr (std::is_same_v, WalkResult>) { + if (callback(value).wasInterrupted()) { + walkResult = WalkResult::interrupt(); + continue; + } + } + else { + callback(value); + } + + if (auto arg = mlir::dyn_cast(value)) { + Block *block = arg.getOwner(); + Operation *parentOp = block->getParentOp(); + + if (auto forOp = dyn_cast(parentOp)) { + unsigned argIndex = arg.getArgNumber(); + Value iterArg = forOp.getInitArgs()[argIndex - 1]; + visited.push(iterArg); + continue; + } + else if (auto parallelProtocolOp = dyn_cast(parentOp)) { + unsigned argIndex = arg.getArgNumber(); + Value inQubit = parallelProtocolOp.getInQubits()[argIndex]; + visited.push(inQubit); + continue; + } + parentOp->emitError("Unsupported parent operation for block argument: ") << value; + llvm::reportFatalInternalError("Unsupported block argument"); + } + + Operation *defOp = value.getDefiningOp(); + if (defOp == nullptr) { + continue; + } + + if (auto forOp = dyn_cast(defOp)) { + unsigned resultIdx = llvm::cast(value).getResultNumber(); + BlockArgument iterArg = forOp.getRegionIterArg(resultIdx); + visited.push(iterArg); + } + else if (auto ifOp = dyn_cast(defOp)) { + unsigned resultIdx = llvm::cast(value).getResultNumber(); + Value thenValue = ifOp.thenYield().getOperand(resultIdx); + Value elseValue = ifOp.elseYield().getOperand(resultIdx); + visited.push(thenValue); + visited.push(elseValue); + } + else if (auto parallelProtocolOp = dyn_cast(defOp)) { + unsigned resultIdx = llvm::cast(value).getResultNumber(); + Value inQubit = parallelProtocolOp.getInQubits()[resultIdx]; + visited.push(inQubit); + } + else if (auto op = dyn_cast(defOp)) { + visited.push(op.getInputs().front()); + } + else if (auto op = dyn_cast(defOp)) { + visited.push(op.getQreg()); + } + else if (auto op = dyn_cast(defOp)) { + Value inQreg = op.getInQreg(); + Value qubit = op.getQubit(); + if constexpr (ModeT == TraceMode::Qreg) { + visited.push(inQreg); + } + else if constexpr (ModeT == TraceMode::Event) { + visited.push(qubit); + // only trace qreg if it defined op is also come from insert op + if (llvm::isa_and_present(inQreg.getDefiningOp())) { + visited.push(inQreg); + } + } + } + } + + if constexpr (std::is_same_v, WalkResult>) { + return walkResult; + } +} + +Value createSyncEvent(ArrayRef events, PatternRewriter &rewriter) +{ + if (events.size() == 1) { + return events.front(); + } + auto eventType = rtio::EventType::get(rewriter.getContext()); + return rewriter.create(rewriter.getUnknownLoc(), eventType, events); +} + +// Helper class to store ion information +class IonInfo { + private: + llvm::StringMap levelEnergyMap; + + struct TransitionInfo { + std::string level0; + std::string level1; + double einstein_a; + std::string multipole; + }; + SmallVector transitions; + + public: + IonInfo(ion::IonOp op) + { + auto levelAttrs = op.getLevels(); + auto transitionsAttr = op.getTransitions(); + + // Map from Level label to Energy value + for (auto levelAttr : levelAttrs) { + auto level = cast(levelAttr); + std::string label = level.getLabel().getValue().str(); + double energy = level.getEnergy().getValueAsDouble(); + levelEnergyMap[label] = energy; + } + + // Store transition information + for (auto transitionAttr : transitionsAttr) { + auto transition = cast(transitionAttr); + TransitionInfo info; + info.level0 = transition.getLevel_0().getValue().str(); + info.level1 = transition.getLevel_1().getValue().str(); + info.einstein_a = transition.getEinsteinA().getValueAsDouble(); + info.multipole = transition.getMultipole().getValue().str(); + transitions.push_back(info); + } + } + + // Get energy of a level by label + std::optional getLevelEnergy(StringRef label) const + { + auto it = levelEnergyMap.find(label.str()); + if (it != levelEnergyMap.end()) { + return it->second; + } + return std::nullopt; + } + + // Get level label of a transition by index + template + std::optional getTransitionLevelEnergy(size_t transitionIndex) const + { + static_assert(IndexT == 0 || IndexT == 1, "IndexT must be 0 or 1"); + + if (transitionIndex >= transitions.size()) { + return std::nullopt; + } + + const auto &transition = transitions[transitionIndex]; + if constexpr (IndexT == 0) { + return getLevelEnergy(transition.level0); + } + else { + return getLevelEnergy(transition.level1); + } + } + + // Get energy difference of a transition (level1 energy - level0 energy) + std::optional getTransitionEnergyDiff(size_t index) const + { + if (index >= transitions.size()) { + return std::nullopt; + } + + auto energy0 = getTransitionLevelEnergy<0>(index); + auto energy1 = getTransitionLevelEnergy<1>(index); + + if (energy0.has_value() && energy1.has_value()) { + return energy1.value() - energy0.value(); + } + + return std::nullopt; + } + + // Get number of transitions + size_t getNumTransitions() const { return transitions.size(); } + + // Get transition info by index + std::optional getTransition(size_t index) const + { + if (index < transitions.size()) { + return transitions[index]; + } + return std::nullopt; + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Conversion Patterns +//===----------------------------------------------------------------------===// + +/// Convert ion.parallelprotocol and introduce rtio.sync to ensure the order +/// +/// Example: +/// ``` +/// %0, %1 = ion.parallelprotocol(%q0, %q1) { +/// ^bb0(%arg0, %arg1): +/// %p0 = rtio.pulse(...) : !rtio.event +/// %p1 = rtio.pulse(...) : !rtio.event +/// ion.yield %arg0, %arg1 +/// } +/// ``` +/// will be converted to: +/// ``` +/// %event0 = unrealized_conversion_cast %q0 : !ion.qubit -> !rtio.event +/// %event1 = unrealized_conversion_cast %q1 : !ion.qubit -> !rtio.event +/// %p0 = rtio.pulse(..., wait = %event0) : !rtio.event +/// %p1 = rtio.pulse(..., wait = %event1) : !rtio.event +/// %sync = rtio.sync %p0, %p1 : !rtio.event +/// %0 = unrealized_conversion_cast %sync : !rtio.event -> !ion.qubit +/// %1 = unrealized_conversion_cast %sync : !rtio.event -> !ion.qubit +/// ``` +/// Those unrealized conversion casts are used to establish the dependency but will be +/// resolved by the subsequent stages. +struct ParallelProtocolToRTIOPattern : public OpConversionPattern { + ParallelProtocolToRTIOPattern(TypeConverter &typeConverter, MLIRContext *ctx) + : OpConversionPattern(typeConverter, ctx) + { + } + + LogicalResult matchAndRewrite(ion::ParallelProtocolOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + MLIRContext *ctx = rewriter.getContext(); + Location loc = op.getLoc(); + + Block *regionBlock = &op.getBodyRegion().front(); + IRMapping irMapping; + SmallVector inQubits; + for (auto [blockArg, operand] : + llvm::zip(regionBlock->getArguments(), adaptor.getOperands())) { + irMapping.map(blockArg, operand); + + // collect qubits to trace the events + if (isa(operand.getType())) { + inQubits.push_back(operand); + } + } + + rewriter.setInsertionPointAfter(op); + + // create events for each qubit + auto events = llvm::map_range(inQubits, [&](Value qubit) { + auto eventType = rtio::EventType::get(ctx); + return rewriter.create(loc, eventType, qubit).getResult(0); + }); + + Value inputSyncEvent = createSyncEvent(llvm::to_vector(events), rewriter); + + // Clone operations from the region to outside + SmallVector pulseEvents; + for (auto ®ionOp : regionBlock->without_terminator()) { + auto *clonedOp = rewriter.clone(regionOp, irMapping); + if (auto pulseOp = dyn_cast(clonedOp)) { + // set wait event for the pulse operation + pulseOp.setWait(inputSyncEvent); + pulseEvents.push_back(pulseOp.getEvent()); + } + irMapping.map(regionOp.getResults(), clonedOp->getResults()); + } + + // Create sync operation from pulse events (must have at least one after Phase 1) + assert(pulseEvents.size() > 0 && + "must have at least one pulse operation after parallel protocol conversion"); + + Value outputSyncEvent = createSyncEvent(llvm::to_vector(pulseEvents), rewriter); + + SmallVector results; + for (Value result : op.getResults()) { + // unrealized conversion cast sync event to result type + auto event = + rewriter.create(loc, result.getType(), outputSyncEvent); + results.push_back(event.getResult(0)); + } + + rewriter.replaceOp(op, results); + return success(); + } +}; + +/// Convert ion.pulse to rtio.pulse +/// +/// Example: +/// ``` +/// %pulse = ion.pulse(%duration) %qubit { +/// beam = #ion.beam<...> +/// } : !ion.pulse +/// ``` +/// will be converted to: +/// ``` +/// %ch = rtio.qubit_to_channel %qubit : !ion.qubit -> !rtio.channel<"dds", ?> +/// ... // other pulse parameters settings +/// %event = rtio.pulse %ch duration(%duration) frequency(%freq) phase(%phase) +/// : !rtio.channel<"dds", ?> -> !rtio.event +/// ``` +struct PulseToRTIOPattern : public OpConversionPattern { + IonInfo ionInfo; + DenseMap &qextractToMemrefMap; + PulseToRTIOPattern(TypeConverter &typeConverter, MLIRContext *ctx, IonInfo ionInfo, + DenseMap &qextractToMemrefMap) + : OpConversionPattern(typeConverter, ctx), ionInfo(ionInfo), + qextractToMemrefMap(qextractToMemrefMap) + { + } + + double calculateFrequency(int64_t transitionIndex, double detuning, + const IonInfo &ionInfo) const + { + // TODO: raman1_frequency can be passed as a pass option for extensibility + double raman1_frequency = 2 * llvm::numbers::pi * 844.485e12 - + 2 * llvm::numbers::pi * 12.643e9 - 2 * llvm::numbers::pi * 20e6; + + auto energyDiff = ionInfo.getTransitionEnergyDiff(transitionIndex); + assert(energyDiff.has_value() && "energyDiff must have a value"); + + double reference_energy = energyDiff.value(); + double frequency = + (reference_energy + detuning - raman1_frequency) / (2.0 * llvm::numbers::pi); + return frequency; + } + + LogicalResult matchAndRewrite(ion::PulseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + Location loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + + // Get pulse parameters + Value duration = op.getTime(); + auto beamAttr = op.getBeam(); + auto phaseAttr = op.getPhase(); + + // Extract beam parameters + double detuning = beamAttr.getDetuning().getValueAsDouble(); + double phase = phaseAttr.getValueAsDouble(); + int64_t transitionIndex = beamAttr.getTransitionIndex().getInt(); + double frequency = calculateFrequency(transitionIndex, detuning, ionInfo); + Value freqValue = + rewriter.create(loc, rewriter.getF64FloatAttr(frequency)); + Value phaseValue = rewriter.create(loc, rewriter.getF64FloatAttr(phase)); + + // Convert the qubit to a channel + ArrayAttr qualifiers = rewriter.getArrayAttr({rewriter.getI64IntegerAttr(transitionIndex)}); + auto channelType = rtio::ChannelType::get(ctx, "dds", qualifiers, nullptr); + + Value memrefLoadValue = nullptr; + traceValueWithCallback(op.getInQubit(), [&](Value value) -> WalkResult { + if (qextractToMemrefMap.count(value)) { + memrefLoadValue = qextractToMemrefMap[value]; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + assert(memrefLoadValue != nullptr && "memrefLoadValue must not be null"); + + Value channel = + rewriter.create(loc, channelType, memrefLoadValue); + + // Create rtio.pulse + auto eventType = rtio::EventType::get(ctx); + Value event = rewriter.create(loc, eventType, channel, duration, + freqValue, phaseValue, nullptr); + rewriter.replaceOp(op, event); + + return success(); + } +}; + +/// Propagates RTIO events from chain of operations to event types. +/// +/// Steps: +/// 1. Traces backward to find all events that contribute to the current event value +/// 2. Creates a sync event from all collected events +/// 3. Replaces the cast operation with the sync event +struct PropagateEventsPattern : public OpRewritePattern { + MLIRContext *ctx; + + PropagateEventsPattern(MLIRContext *ctx) + : OpRewritePattern(ctx), ctx(ctx) + { + } + + LogicalResult matchAndRewrite(UnrealizedConversionCastOp op, + PatternRewriter &rewriter) const override + { + if (op.getNumOperands() != 1 || op.getNumResults() != 1) + return failure(); + + Type srcType = op.getInputs()[0].getType(); + Type dstType = op.getResult(0).getType(); + + // Only match casts from quantum/ion types to event type + // quantum.qreg -> event, quantum.qubit -> event, ion.qubit -> event + bool validSrcType = + llvm::isa(srcType); + bool validDstType = llvm::isa(dstType); + if (!validSrcType || !validDstType) + return failure(); + + Value input = op.getInputs()[0]; + + // Find associated events + // Skip over intermediate cast/extract/insert operations to collect events + bool reachedAllocOp = false; + SetVector events; + traceValueWithCallback(input, [&](Value value) -> WalkResult { + auto defOp = value.getDefiningOp(); + if (defOp && + isa(defOp)) { + return WalkResult::advance(); + } + + // collect event and stop tracing this path + if (isa(value.getType())) { + events.insert(value); + return WalkResult::interrupt(); + } + + if (isa(defOp)) { + reachedAllocOp = true; + return WalkResult::interrupt(); + } + + return WalkResult::advance(); + }); + + if (reachedAllocOp && events.empty()) { + auto eventType = rtio::EventType::get(ctx); + Value emptyEvent = rewriter.create(op.getLoc(), eventType); + rewriter.replaceOp(op, emptyEvent); + return success(); + } + + if (events.empty()) { + op.emitError("No events found for cast op"); + llvm::reportFatalInternalError("No events found for cast op"); + } + + // Create a sync event from all collected events + // TODO: check domination, so that we can avoid creating a sync event if events are + // already dominated by one of the events + Value syncEvent = createSyncEvent(events.getArrayRef(), rewriter); + rewriter.replaceOp(op, syncEvent); + return success(); + } +}; + +/// Clean up quantum/ion related ops that are not needed after conversion +struct CleanQuantumOpsPattern : public RewritePattern { + CleanQuantumOpsPattern(MLIRContext *ctx) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) + { + } + + LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override + { + Dialect *dialect = op->getDialect(); + if (!dialect || !isa(dialect)) + return failure(); + + if (!op->use_empty()) { + return failure(); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +LogicalResult CleanQuantumOps(func::FuncOp funcOp, MLIRContext *ctx) +{ + RewritePatternSet patterns(ctx); + patterns.add(ctx); + if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { + return failure(); + } + return success(); +} + +LogicalResult CanonicalizeKernelFunction(func::FuncOp funcOp, MLIRContext *ctx) +{ + RewritePatternSet patterns(ctx); + for (auto *dialect : ctx->getLoadedDialects()) { + dialect->getCanonicalizationPatterns(patterns); + } + for (RegisteredOperationName op : ctx->getRegisteredOperations()) { + op.getCanonicalizationPatterns(patterns, ctx); + } + if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { + return failure(); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// Pass Implementation +//===----------------------------------------------------------------------===// + +#define GEN_PASS_DEF_IONTORTIOPASS +#include "Ion/Transforms/Passes.h.inc" + +struct IonToRTIOPass : public impl::IonToRTIOPassBase { + using impl::IonToRTIOPassBase::IonToRTIOPassBase; + + LogicalResult IonPulseConversion(func::FuncOp funcOp, ConversionTarget &baseTarget, + TypeConverter &typeConverter, IonInfo ionInfo, + DenseMap &qextractToMemrefMap, MLIRContext *ctx) + { + ConversionTarget target(baseTarget); + target.addIllegalOp(); + + RewritePatternSet patterns(ctx); + patterns.add(typeConverter, ctx, ionInfo, qextractToMemrefMap); + if (failed(applyPartialConversion(funcOp, target, std::move(patterns)))) { + return failure(); + } + return success(); + } + + LogicalResult ParallelProtocolConversion(func::FuncOp funcOp, ConversionTarget &baseTarget, + TypeConverter &typeConverter, MLIRContext *ctx) + { + ConversionTarget target(baseTarget); + target.addIllegalOp(); + + RewritePatternSet patterns(ctx); + patterns.add(typeConverter, ctx); + if (failed(applyPartialConversion(funcOp, target, std::move(patterns)))) { + return failure(); + } + return success(); + } + + LogicalResult SCFStructuralConversion(func::FuncOp funcOp, ConversionTarget &target, + TypeConverter &typeConverter, MLIRContext *ctx) + { + TypeConverter scfTypeConverter(typeConverter); + scfTypeConverter.addConversion( + [ctx](quantum::QubitType) -> Type { return rtio::EventType::get(ctx); }); + scfTypeConverter.addConversion( + [ctx](quantum::QuregType) -> Type { return rtio::EventType::get(ctx); }); + scfTypeConverter.addConversion( + [ctx](ion::QubitType) -> Type { return rtio::EventType::get(ctx); }); + // Add materialization for quantum/ion -> event + scfTypeConverter.addSourceMaterialization( + [](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) -> Value { + if (inputs.size() != 1) + return nullptr; + Type inputType = inputs.front().getType(); + if (inputType != resultType) { + return builder.create(loc, resultType, inputs) + .getResult(0); + } + return inputs[0]; + }); + // Add target materialization for event -> quantum/ion + scfTypeConverter.addTargetMaterialization( + [](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) -> Value { + if (inputs.size() != 1) + return nullptr; + Type inputType = inputs.front().getType(); + if (inputType != resultType) { + return builder.create(loc, resultType, inputs) + .getResult(0); + } + return inputs[0]; + }); + + ConversionTarget scfTarget(getContext()); + scfTarget.addLegalDialect(); + + // Mark SCF ops as illegal only if they use quantum/ion types + scfTarget.addDynamicallyLegalOp([&](scf::ForOp op) { + for (auto arg : op.getRegionIterArgs()) { + Type type = arg.getType(); + if (llvm::isa(type)) { + return false; + } + } + for (auto result : op.getResults()) { + Type type = result.getType(); + if (llvm::isa(type)) { + return false; + } + } + return true; + }); + + scfTarget.addDynamicallyLegalOp([&](scf::IfOp op) { + for (auto result : op.getResults()) { + Type type = result.getType(); + if (llvm::isa(type)) { + return false; + } + } + return true; + }); + + scfTarget.addLegalOp(); + + // restructure SCF Operations + RewritePatternSet scfPatterns(&getContext()); + mlir::scf::populateSCFStructuralTypeConversionsAndLegality(scfTypeConverter, scfPatterns, + scfTarget); + + if (failed(applyPartialConversion(funcOp, scfTarget, std::move(scfPatterns)))) { + return failure(); + } + + return success(); + } + + LogicalResult PropagateEvents(func::FuncOp funcOp, MLIRContext *ctx) + { + RewritePatternSet patterns(ctx); + patterns.add(ctx); + if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { + return failure(); + } + return success(); + } + + SmallVector getIonInfos() + { + SmallVector ionInfos; + getOperation()->walk([&](ion::IonOp ionOp) { ionInfos.emplace_back(IonInfo(ionOp)); }); + return ionInfos; + } + + func::FuncOp createKernelFunction(func::FuncOp qnodeFunc, std::string kernelName, + OpBuilder &builder) + { + MLIRContext *ctx = builder.getContext(); + + auto newQnodeFunc = qnodeFunc.clone(); + newQnodeFunc.setName(kernelName); + auto oldFuncType = qnodeFunc.getFunctionType(); + // create new function type with empty results + auto newFuncType = FunctionType::get(ctx, oldFuncType.getInputs(), {}); + newQnodeFunc.setFunctionType(newFuncType); + + // Replace all return ops with empty returns + SmallVector returnsToReplace; + newQnodeFunc.walk([&](func::ReturnOp returnOp) { returnsToReplace.push_back(returnOp); }); + + for (auto returnOp : returnsToReplace) { + builder.setInsertionPoint(returnOp); + builder.create(returnOp.getLoc()); + returnOp.erase(); + } + + return newQnodeFunc; + } + + void initializeMemrefMap(func::FuncOp funcOp, ModuleOp module, + DenseMap &qregToMemrefMap, + DenseMap &qextractToMemrefMap, MLIRContext *ctx) + { + OpBuilder builder(ctx); + + int globalCounter = 0; + funcOp.walk([&](quantum::AllocOp allocOp) { + size_t numQubits = allocOp.getNqubitsAttr().value(); + auto memrefType = + MemRefType::get({static_cast(numQubits)}, builder.getIndexType()); + + // Create a unique symbol name for this global + std::string globalNameStr = "__qubit_map_" + std::to_string(globalCounter++); + StringRef globalName = globalNameStr; + + // Create dense attribute with values [0, 1, 2, ..., numQubits-1] * 2 + auto tensorType = + RankedTensorType::get({static_cast(numQubits)}, builder.getIndexType()); + SmallVector values; + // Use IndexType::kInternalStorageBitWidth for index type + unsigned indexWidth = IndexType::kInternalStorageBitWidth; + for (size_t i = 0; i < numQubits; i++) { + values.push_back(APInt(indexWidth, i * 2)); + } + auto denseAttr = DenseIntElementsAttr::get(tensorType, values); + + // Create global memref at module level + builder.setInsertionPointToStart(module.getBody()); + auto globalOp = memref::GlobalOp::create( + builder, allocOp.getLoc(), + builder.getStringAttr(globalName), // sym_name + builder.getStringAttr("private"), // sym_visibility + TypeAttr::get(memrefType), // type + denseAttr, // initial_value + builder.getUnitAttr(), // constant + IntegerAttr()); // alignment + + // Get the global memref in the function + builder.setInsertionPointAfter(allocOp); + Value qubitMap = builder.create(allocOp.getLoc(), memrefType, + globalOp.getSymName()); + + qregToMemrefMap[allocOp.getResult()] = qubitMap; + }); + + funcOp.walk([&](quantum::ExtractOp extractOp) { + traceValueWithCallback( + extractOp.getQreg(), [&](Value value) -> WalkResult { + if (qregToMemrefMap.count(value)) { + builder.setInsertionPointAfter(extractOp); + auto memref = qregToMemrefMap[value]; + + Value memrefLoadValue = nullptr; + if (Value idx = extractOp.getIdx()) { + // idx is an operand (i64), need to cast to index + Value indexValue = builder.create( + extractOp.getLoc(), builder.getIndexType(), idx); + memrefLoadValue = builder.create( + extractOp.getLoc(), memref, ValueRange{indexValue}); + } + else if (IntegerAttr idxAttr = extractOp.getIdxAttrAttr()) { + Value indexValue = builder.create( + extractOp.getLoc(), idxAttr.getInt()); + memrefLoadValue = builder.create( + extractOp.getLoc(), memref, ValueRange{indexValue}); + } + if (memrefLoadValue) { + qextractToMemrefMap[extractOp.getResult()] = memrefLoadValue; + } + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + }); + } + + LogicalResult updateEntryFunction(func::FuncOp entryFunc, func::FuncOp newQnodeFunc, + MLIRContext *ctx) + { + // Update entry function return type to empty + auto oldEntryFuncType = entryFunc.getFunctionType(); + auto newEntryFuncType = FunctionType::get(ctx, oldEntryFuncType.getInputs(), {}); + entryFunc.setFunctionType(newEntryFuncType); + + // Clear the function body + Block *entryBlock = &entryFunc.getBody().front(); + SmallVector opsToErase; + for (Operation &op : entryBlock->getOperations()) { + opsToErase.push_back(&op); + } + for (auto op : opsToErase) { + op->dropAllUses(); + op->erase(); + } + + // Create call to kernel function + OpBuilder entryBuilder(ctx); + entryBuilder.setInsertionPointToStart(entryBlock); + + SmallVector kernelArgs(entryFunc.getArguments().begin(), + entryFunc.getArguments().end()); + + // compare args type with kernel function arguments + if (kernelArgs.size() != newQnodeFunc.getArguments().size()) { + entryFunc->emitError("Failed to update entry function: number of arguments mismatch"); + return failure(); + } + for (size_t i = 0; i < kernelArgs.size(); i++) { + if (kernelArgs[i].getType() != newQnodeFunc.getArguments()[i].getType()) { + entryFunc->emitError("Failed to update entry function: argument type mismatch"); + return failure(); + } + } + + entryBuilder.create(entryFunc.getLoc(), newQnodeFunc.getName(), TypeRange{}, + kernelArgs); + + entryBuilder.setInsertionPointToEnd(entryBlock); + entryBuilder.create(entryFunc.getLoc()); + return success(); + } + + void runOnOperation() override + { + MLIRContext *ctx = &getContext(); + auto module = cast(getOperation()); + + // check if there is only one qnode function + func::FuncOp qnodeFunc = nullptr; + int qnodeCounts = 0; + module.walk([&](func::FuncOp funcOp) { + if (funcOp->hasAttr("qnode")) { + qnodeFunc = funcOp; + qnodeCounts++; + } + }); + assert(qnodeCounts == 1 && "only one qnode function is allowed"); + + // collect all ion information for calculating frequency when converting ion.pulse + SmallVector ionInfos = getIonInfos(); + if (ionInfos.empty()) { + getOperation()->emitError("Failed to get ion information"); + return signalPassFailure(); + } + + // currently, we only support one ion information + assert(ionInfos.size() == 1 && "only one ion information is allowed"); + IonInfo &ionInfo = ionInfos.front(); + + // clone qnode function as new kernel function + OpBuilder builder(ctx); + func::FuncOp newQnodeFunc = createKernelFunction(qnodeFunc, kernelName, builder); + module.insert(qnodeFunc, newQnodeFunc); + + // Construct mapping from qreg alloc and qreg extract to memref + // In the later conversion, we use the mapping to construct the channel for rtio.pulse + DenseMap qregToMemrefMap; + DenseMap qextractToMemrefMap; + initializeMemrefMap(newQnodeFunc, module, qregToMemrefMap, qextractToMemrefMap, ctx); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addConversion( + [&](ion::PulseType type) -> Type { return rtio::EventType::get(ctx); }); + + ConversionTarget target(*ctx); + target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); + + // prepare kernel function + if (failed(IonPulseConversion(newQnodeFunc, target, typeConverter, ionInfo, + qextractToMemrefMap, ctx)) || + failed(ParallelProtocolConversion(newQnodeFunc, target, typeConverter, ctx)) || + failed(SCFStructuralConversion(newQnodeFunc, target, typeConverter, ctx)) || + failed(PropagateEvents(newQnodeFunc, ctx)) || + failed(CleanQuantumOps(newQnodeFunc, ctx)) || + failed(CanonicalizeKernelFunction(newQnodeFunc, ctx))) { + newQnodeFunc->emitError("Failed to convert to rtio dialect"); + return signalPassFailure(); + } + + // TODO: Channel Mapping (qubit N, transition M) -> channel N * NUM_TRANSITIONS + M + // TODO: Naive scheduling to generate the simple Timeline RTIO IR + + // remove body of entry function and add call to kernel function + auto entryFunc = module.lookupSymbol("jit_circuit"); + if (!entryFunc) { + module.emitError("Cannot find entry function 'jit_circuit'"); + return signalPassFailure(); + } + + if (failed(updateEntryFunction(entryFunc, newQnodeFunc, ctx))) { + module.emitError("Failed to update entry function"); + return signalPassFailure(); + } + } +}; + +} // namespace ion +} // namespace catalyst From c0e7234fe9b7bc0e08c2a02f644049d9d0bf32ab Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 18 Nov 2025 00:55:09 -0500 Subject: [PATCH 15/51] update --- mlir/include/RTIO/IR/RTIOOps.td | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/include/RTIO/IR/RTIOOps.td b/mlir/include/RTIO/IR/RTIOOps.td index c00d0bf639..7fdca59ced 100644 --- a/mlir/include/RTIO/IR/RTIOOps.td +++ b/mlir/include/RTIO/IR/RTIOOps.td @@ -25,7 +25,7 @@ include "RTIO/IR/RTIODialect.td" //===----------------------------------------------------------------------===// def RTIOChannelOp : RTIO_Op<"channel"> { - let summary = "Define a static channel"; + let summary = "Define a channel"; let description = [{ The channel's identity (kind, qualifiers, and channel id) is encoded in its result type. @@ -365,7 +365,7 @@ def RTIOOnOp : RTIO_Op<"on"> { Example: ```mlir - rtio.at %t_start : !rtio.time + rtio.at %t_start : i64 rtio.on %ch rtio.delay 1000 : i64 rtio.off %ch From f1ad990259c254909df1cfb0298aabae6031e992 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 18 Nov 2025 02:01:25 -0500 Subject: [PATCH 16/51] add missing include --- mlir/lib/Ion/Transforms/ion-to-rtio.cpp | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp index 8b93a2b5b1..ab5ae99124 100644 --- a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp +++ b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -764,14 +766,14 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { // Create global memref at module level builder.setInsertionPointToStart(module.getBody()); - auto globalOp = memref::GlobalOp::create( - builder, allocOp.getLoc(), - builder.getStringAttr(globalName), // sym_name - builder.getStringAttr("private"), // sym_visibility - TypeAttr::get(memrefType), // type - denseAttr, // initial_value - builder.getUnitAttr(), // constant - IntegerAttr()); // alignment + auto globalOp = + memref::GlobalOp::create(builder, allocOp.getLoc(), + builder.getStringAttr(globalName), // sym_name + builder.getStringAttr("private"), // sym_visibility + TypeAttr::get(memrefType), // type + denseAttr, // initial_value + builder.getUnitAttr(), // constant + IntegerAttr()); // alignment // Get the global memref in the function builder.setInsertionPointAfter(allocOp); From 4cf83fec25261de44932a16a34ff00403af64db7 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 18 Nov 2025 09:58:06 -0500 Subject: [PATCH 17/51] update --- mlir/lib/Driver/CompilerDriver.cpp | 2 ++ mlir/tools/quantum-lsp-server/quantum-lsp-server.cpp | 2 ++ 2 files changed, 4 insertions(+) diff --git a/mlir/lib/Driver/CompilerDriver.cpp b/mlir/lib/Driver/CompilerDriver.cpp index 28d34d0e66..f3ac3569e4 100644 --- a/mlir/lib/Driver/CompilerDriver.cpp +++ b/mlir/lib/Driver/CompilerDriver.cpp @@ -73,6 +73,7 @@ #include "QEC/IR/QECDialect.h" #include "Quantum/IR/QuantumDialect.h" #include "Quantum/Transforms/BufferizableOpInterfaceImpl.h" +#include "RTIO/IR/RTIODialect.h" #include "RegisterAllPasses.h" #include "Enzyme.h" @@ -302,6 +303,7 @@ void registerAllCatalystDialects(DialectRegistry ®istry) registry.insert(); registry.insert(); registry.insert(); + registry.insert(); registry.insert(); registry.insert(); } diff --git a/mlir/tools/quantum-lsp-server/quantum-lsp-server.cpp b/mlir/tools/quantum-lsp-server/quantum-lsp-server.cpp index a98de69942..1578f350cf 100644 --- a/mlir/tools/quantum-lsp-server/quantum-lsp-server.cpp +++ b/mlir/tools/quantum-lsp-server/quantum-lsp-server.cpp @@ -23,6 +23,7 @@ #include "Mitigation/IR/MitigationDialect.h" #include "QEC/IR/QECDialect.h" #include "Quantum/IR/QuantumDialect.h" +#include "RTIO/IR/RTIODialect.h" #include "stablehlo/dialect/Register.h" @@ -37,6 +38,7 @@ int main(int argc, char **argv) registry.insert(); registry.insert(); registry.insert(); + registry.insert(); mlir::stablehlo::registerAllDialects(registry); From 27a4b981e011d17d71c04ec81a5d9495ef03a1ca Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 18 Nov 2025 11:44:19 -0500 Subject: [PATCH 18/51] add missing lib --- mlir/lib/Driver/CMakeLists.txt | 1 + mlir/tools/catalyst-cli/CMakeLists.txt | 1 + mlir/tools/quantum-lsp-server/CMakeLists.txt | 1 + 3 files changed, 3 insertions(+) diff --git a/mlir/lib/Driver/CMakeLists.txt b/mlir/lib/Driver/CMakeLists.txt index 5ec6857426..3ffd467987 100644 --- a/mlir/lib/Driver/CMakeLists.txt +++ b/mlir/lib/Driver/CMakeLists.txt @@ -44,6 +44,7 @@ set(LIBS mitigation-transforms MLIRIon ion-transforms + MLIRRTIO MLIRCatalystTest ${ENZYME_LIB} ) diff --git a/mlir/tools/catalyst-cli/CMakeLists.txt b/mlir/tools/catalyst-cli/CMakeLists.txt index 1dd3d9693d..46fd64cc05 100644 --- a/mlir/tools/catalyst-cli/CMakeLists.txt +++ b/mlir/tools/catalyst-cli/CMakeLists.txt @@ -40,6 +40,7 @@ set(LIBS mitigation-transforms MLIRIon ion-transforms + MLIRRTIO MLIRCatalystTest ${ENZYME_LIB} CatalystCompilerDriver diff --git a/mlir/tools/quantum-lsp-server/CMakeLists.txt b/mlir/tools/quantum-lsp-server/CMakeLists.txt index 507480ef00..7fbbcf8c3d 100644 --- a/mlir/tools/quantum-lsp-server/CMakeLists.txt +++ b/mlir/tools/quantum-lsp-server/CMakeLists.txt @@ -13,6 +13,7 @@ set(LIBS MLIRMBQC MLIRMitigation MLIRIon + MLIRRTIO ) add_llvm_executable(quantum-lsp-server quantum-lsp-server.cpp) From 29ffc7002749c53018c83625bdac81bad9f4d22c Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 18 Nov 2025 12:12:22 -0500 Subject: [PATCH 19/51] remove redundant include --- mlir/lib/Ion/Transforms/ion-to-rtio.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp index ab5ae99124..85f4e972e7 100644 --- a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp +++ b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp @@ -31,7 +31,6 @@ #include "Quantum/IR/QuantumOps.h" #include "RTIO/IR/RTIODialect.h" #include "RTIO/IR/RTIOOps.h" -#include using namespace mlir; using namespace catalyst; From fcd540d87570ea3cfc818d6191a257cf7ed67028 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Wed, 19 Nov 2025 10:00:41 -0500 Subject: [PATCH 20/51] update --- mlir/include/RTIO/IR/RTIOOps.h | 1 + mlir/include/RTIO/IR/RTIOOps.td | 8 +- mlir/lib/Ion/Transforms/ion-to-rtio.cpp | 141 +++++++++++++++++++++++- mlir/lib/RTIO/IR/RTIOOps.cpp | 42 ++++++- 4 files changed, 188 insertions(+), 4 deletions(-) diff --git a/mlir/include/RTIO/IR/RTIOOps.h b/mlir/include/RTIO/IR/RTIOOps.h index 0ce7643b83..04cd66d41e 100644 --- a/mlir/include/RTIO/IR/RTIOOps.h +++ b/mlir/include/RTIO/IR/RTIOOps.h @@ -18,6 +18,7 @@ #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #define GET_ATTRDEF_CLASSES diff --git a/mlir/include/RTIO/IR/RTIOOps.td b/mlir/include/RTIO/IR/RTIOOps.td index 7fdca59ced..896a13259b 100644 --- a/mlir/include/RTIO/IR/RTIOOps.td +++ b/mlir/include/RTIO/IR/RTIOOps.td @@ -24,7 +24,7 @@ include "RTIO/IR/RTIODialect.td" // Event-Based API //===----------------------------------------------------------------------===// -def RTIOChannelOp : RTIO_Op<"channel"> { +def RTIOChannelOp : RTIO_Op<"channel", [Pure]> { let summary = "Define a channel"; let description = [{ The channel's identity (kind, qualifiers, and channel id) is @@ -50,12 +50,14 @@ def RTIOChannelOp : RTIO_Op<"channel"> { let arguments = (ins); let results = (outs RTIOChannelType:$channel); + let hasCanonicalizeMethod = 1; + let assemblyFormat = [{ attr-dict `:` type(results) }]; } -def RTIOQubitToChannelOp : RTIO_Op<"qubit_to_channel"> { +def RTIOQubitToChannelOp : RTIO_Op<"qubit_to_channel", [Pure]> { let summary = "Map a qubit to an RTIO channel"; let description = [{ It's a temporary operation that will be lowered to a static `rtio.channel` operation during @@ -76,6 +78,8 @@ def RTIOQubitToChannelOp : RTIO_Op<"qubit_to_channel"> { let arguments = (ins AnyType:$qubit); let results = (outs RTIOChannelType:$channel); + let hasCanonicalizeMethod = 1; + let assemblyFormat = [{ $qubit attr-dict `:` type(operands) `->` type(results) }]; diff --git a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp index 85f4e972e7..7a26c07891 100644 --- a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp +++ b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include "mlir/Dialect/Arith/IR/Arith.h" @@ -21,6 +22,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/IR/Matchers.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -447,6 +449,109 @@ struct PulseToRTIOPattern : public OpConversionPattern { } }; +/// Resolve the static channel mapping for the rtio.qubit_to_channel operation +/// +/// It's expecting `qubit_to_channel` has the following def-use chain: +/// memref.global w/ constants -> memref.get_global -> memref.load -> qubit_to_channel +/// +/// Example: +/// ``` +/// %ch = rtio.qubit_to_channel %qubit : !ion.qubit -> !rtio.channel<"dds", ?> +/// ``` +/// will be converted to: +/// ``` +/// %ch = rtio.channel "dds" { channel_id = 0 } : !rtio.channel<"dds"> +/// ``` +struct ResolveChannelMappingPattern : public OpRewritePattern { + ResolveChannelMappingPattern(MLIRContext *ctx) + : OpRewritePattern(ctx) + { + } + + LogicalResult matchAndRewrite(rtio::RTIOQubitToChannelOp op, + PatternRewriter &rewriter) const override + { + Location loc = op.getLoc(); + Value qubit = op.getQubit(); + + auto loadOp = qubit.getDefiningOp(); + if (!loadOp) { + return failure(); + } + + Value memref = loadOp.getMemRef(); + auto getGlobalOp = memref.getDefiningOp(); + if (!getGlobalOp) { + return failure(); + } + + StringRef globalName = getGlobalOp.getName(); + ModuleOp module = op->getParentOfType(); + if (!module) { + return failure(); + } + auto globalOp = module.lookupSymbol(globalName); + if (!globalOp) { + return failure(); + } + + auto initialValue = globalOp.getInitialValue(); + if (!initialValue) { + return failure(); + } + + auto denseAttr = llvm::dyn_cast(*initialValue); + if (!denseAttr) { + return failure(); + } + + ValueRange indices = loadOp.getIndices(); + if (indices.size() != 1) { + return failure(); + } + + IntegerAttr indexAttr; + if (!matchPattern(indices[0], m_Constant(&indexAttr))) { + return failure(); + } + + int64_t index = indexAttr.getInt(); + + size_t denseSize = denseAttr.size(); + if (index < 0 || static_cast(index) >= denseSize) { + return failure(); + } + + APInt channelIdValue = denseAttr.getValues()[index]; + + auto originalChannelType = llvm::dyn_cast(op.getChannel().getType()); + if (!originalChannelType) { + return failure(); + } + StringRef kind = originalChannelType.getKind(); + ArrayAttr qualifiers = originalChannelType.getQualifiers(); + + int offset = 0; + // If the qualifiers is not empty, get the first qualifier and check if it is 0 or 1 + if (qualifiers.size() >= 1) { + IntegerAttr qualifier0 = llvm::dyn_cast(qualifiers[0]); + offset = qualifier0.getInt() == 0 ? 0 : 1; + } + + IntegerAttr channelIdAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), + channelIdValue.getSExtValue() + offset); + + auto resolvedChannelType = + rtio::ChannelType::get(rewriter.getContext(), kind, qualifiers, channelIdAttr); + + Value channel = rewriter.create(loc, resolvedChannelType); + + rewriter.replaceOp(op, channel); + + return success(); + } +}; + /// Propagates RTIO events from chain of operations to event types. /// /// Steps: @@ -574,6 +679,16 @@ LogicalResult CanonicalizeKernelFunction(func::FuncOp funcOp, MLIRContext *ctx) return success(); } +LogicalResult ResolveChannelMapping(func::FuncOp funcOp, MLIRContext *ctx) +{ + RewritePatternSet patterns(ctx); + patterns.add(ctx); + if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { + return failure(); + } + return success(); +} + //===----------------------------------------------------------------------===// // Pass Implementation //===----------------------------------------------------------------------===// @@ -891,6 +1006,16 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { func::FuncOp newQnodeFunc = createKernelFunction(qnodeFunc, kernelName, builder); module.insert(qnodeFunc, newQnodeFunc); + // drop one of the pulse from the certain protocol + // the way we handle the dropped pulse will be updated in the future + newQnodeFunc.walk([&](ion::ParallelProtocolOp parallelProtocolOp) { + parallelProtocolOp.walk([&](ion::PulseOp pulseOp) { + if (pulseOp.getBeamAttr().getTransitionIndex().getInt() == 0) { + pulseOp.erase(); + } + }); + }); + // Construct mapping from qreg alloc and qreg extract to memref // In the later conversion, we use the mapping to construct the channel for rtio.pulse DenseMap qregToMemrefMap; @@ -917,8 +1042,20 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { return signalPassFailure(); } - // TODO: Channel Mapping (qubit N, transition M) -> channel N * NUM_TRANSITIONS + M + // Resolve the static channel, the dynamic channel will be remained as `?` + if (failed(ResolveChannelMapping(newQnodeFunc, ctx))) { + newQnodeFunc->emitError("Failed to resolve channel mapping"); + return signalPassFailure(); + } + // TODO: Naive scheduling to generate the simple Timeline RTIO IR + // To shorten the `timeline`: `list scheduling`, `graph scheduling`, ... etc. can also be + // used to schedule the operations. Here we just linearly mapping the operation based on the + // order in the function without caring any latency issues. + // if (failed(NaiveScheduling(newQnodeFunc, ctx))) { + // newQnodeFunc->emitError("Failed to schedule"); + // return signalPassFailure(); + // } // remove body of entry function and add call to kernel function auto entryFunc = module.lookupSymbol("jit_circuit"); @@ -931,6 +1068,8 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { module.emitError("Failed to update entry function"); return signalPassFailure(); } + + qnodeFunc->erase(); } }; diff --git a/mlir/lib/RTIO/IR/RTIOOps.cpp b/mlir/lib/RTIO/IR/RTIOOps.cpp index 19d5a147fc..4bf12fb4cc 100644 --- a/mlir/lib/RTIO/IR/RTIOOps.cpp +++ b/mlir/lib/RTIO/IR/RTIOOps.cpp @@ -12,9 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "RTIO/IR/RTIOOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" + +#include "RTIO/IR/RTIOOps.h" using namespace mlir; using namespace catalyst::rtio; @@ -32,5 +34,43 @@ LogicalResult RTIOSyncOp::verify() return success(); } +LogicalResult RTIOQubitToChannelOp::canonicalize(RTIOQubitToChannelOp op, + mlir::PatternRewriter &rewriter) +{ + Block *currentBlock = op->getBlock(); + Value qubit = op.getQubit(); + Type channelType = op.getChannel().getType(); + + // Try to find the same qubit_to_channel operation between [block->begin, op) + for (Operation &prevOp : llvm::make_range(currentBlock->begin(), op->getIterator())) { + if (auto prevQubitToChannel = dyn_cast(&prevOp)) { + if (prevQubitToChannel.getQubit() == qubit && + prevQubitToChannel.getChannel().getType() == channelType) { + rewriter.replaceOp(op, prevQubitToChannel.getChannel()); + return success(); + } + } + } + + return failure(); +} + +LogicalResult RTIOChannelOp::canonicalize(RTIOChannelOp op, mlir::PatternRewriter &rewriter) +{ + Block *currentBlock = op->getBlock(); + Value channel = op.getChannel(); + Type channelType = channel.getType(); + + for (Operation &prevOp : llvm::make_range(currentBlock->begin(), op->getIterator())) { + if (auto prevChannelOp = dyn_cast(&prevOp)) { + if (prevChannelOp.getChannel().getType() == channelType) { + rewriter.replaceOp(op, prevChannelOp.getChannel()); + return success(); + } + } + } + return failure(); +} + #define GET_OP_CLASSES #include "RTIO/IR/RTIOOps.cpp.inc" From 82bb3aa937aff2f3ebc4049a15e2b3708d5182da Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Thu, 20 Nov 2025 13:51:54 -0500 Subject: [PATCH 21/51] remove other redundant functions --- mlir/lib/Ion/Transforms/ion-to-rtio.cpp | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp index 7a26c07891..ba7dbb75ed 100644 --- a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp +++ b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp @@ -12,9 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include - #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -971,6 +968,16 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { entryBuilder.setInsertionPointToEnd(entryBlock); entryBuilder.create(entryFunc.getLoc()); + + // Remove other functions + // We currently just lower to the kernel function + auto module = entryFunc->getParentOfType(); + module.walk([&](func::FuncOp funcOp) { + if (funcOp.getName().str() == "teardown" || funcOp.getName().str() == "setup") { + funcOp.erase(); + } + }); + return success(); } From 4a7430a68bc0748369534010a06ad6d7218c2172 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Thu, 27 Nov 2025 15:13:28 -0500 Subject: [PATCH 22/51] update pass --- mlir/lib/Ion/Transforms/ion-to-rtio.cpp | 116 ++++++++++++++---------- 1 file changed, 67 insertions(+), 49 deletions(-) diff --git a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp index ba7dbb75ed..fd50470bf1 100644 --- a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp +++ b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp @@ -19,7 +19,9 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/Matchers.h" +#include "mlir/Transforms/CSE.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -135,6 +137,9 @@ auto traceValueWithCallback(Value value, CallbackT &&callback) else if (auto op = dyn_cast(defOp)) { visited.push(op.getQreg()); } + else if (auto op = dyn_cast(defOp)) { + visited.push(op.getQubit()); + } else if (auto op = dyn_cast(defOp)) { Value inQreg = op.getInQreg(); Value qubit = op.getQubit(); @@ -156,7 +161,7 @@ auto traceValueWithCallback(Value value, CallbackT &&callback) } } -Value createSyncEvent(ArrayRef events, PatternRewriter &rewriter) +Value awaitEvents(ArrayRef events, PatternRewriter &rewriter) { if (events.size() == 1) { return events.front(); @@ -317,23 +322,59 @@ struct ParallelProtocolToRTIOPattern : public OpConversionPattern(loc, eventType, qubit).getResult(0); }); - Value inputSyncEvent = createSyncEvent(llvm::to_vector(events), rewriter); + Value inputSyncEvent = awaitEvents(llvm::to_vector(events), rewriter); // Clone operations from the region to outside SmallVector pulseEvents; + DenseMap qubitToOffset; + + // we cache the channel to index mapping to avoid multiple lookups + DenseMap cache; for (auto ®ionOp : regionBlock->without_terminator()) { auto *clonedOp = rewriter.clone(regionOp, irMapping); if (auto pulseOp = dyn_cast(clonedOp)) { // set wait event for the pulse operation pulseOp.setWait(inputSyncEvent); + + Value index = nullptr; + + SmallVector chain; + traceValueWithCallback( + pulseOp.getChannel(), [&](Value value) -> WalkResult { + if (cache.count(value)) { + index = cache[value]; + return WalkResult::interrupt(); + } + chain.push_back(value); + if (auto loadOp = + llvm::dyn_cast_if_present(value.getDefiningOp())) { + index = loadOp.getIndices()[0]; + + // cache the channel to index mapping + cache[pulseOp.getChannel()] = index; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + + assert(index != nullptr && "index must not be null"); + + // update cache + for (Value value : chain) { + cache[value] = index; + } + pulseOp->setAttr("offset", rewriter.getI64IntegerAttr(qubitToOffset[index])); + + // the same qubit may appear multiple times in the parallel protocol + // so we need to increment the offset for each appearance + qubitToOffset[index]++; + pulseEvents.push_back(pulseOp.getEvent()); } irMapping.map(regionOp.getResults(), clonedOp->getResults()); @@ -343,7 +384,7 @@ struct ParallelProtocolToRTIOPattern : public OpConversionPattern 0 && "must have at least one pulse operation after parallel protocol conversion"); - Value outputSyncEvent = createSyncEvent(llvm::to_vector(pulseEvents), rewriter); + Value outputSyncEvent = awaitEvents(llvm::to_vector(pulseEvents), rewriter); SmallVector results; for (Value result : op.getResults()) { @@ -528,15 +569,14 @@ struct ResolveChannelMappingPattern : public OpRewritePattern= 1) { - IntegerAttr qualifier0 = llvm::dyn_cast(qualifiers[0]); - offset = qualifier0.getInt() == 0 ? 0 : 1; - } + // channel should have exactly one use before lowering to channel op + assert(op.getChannel().hasOneUse() && "channel should have exactly one use"); + + auto pulseOp = cast(*op.getChannel().getUsers().begin()); + int64_t offset = cast(pulseOp->getAttr("offset")).getInt(); - IntegerAttr channelIdAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), - channelIdValue.getSExtValue() + offset); + IntegerAttr channelIdAttr = rewriter.getIntegerAttr( + rewriter.getIndexType(), (channelIdValue.getSExtValue() + offset)); auto resolvedChannelType = rtio::ChannelType::get(rewriter.getContext(), kind, qualifiers, channelIdAttr); @@ -622,7 +662,7 @@ struct PropagateEventsPattern : public OpRewritePattern { auto newFuncType = FunctionType::get(ctx, oldFuncType.getInputs(), {}); newQnodeFunc.setFunctionType(newFuncType); - // Replace all return ops with empty returns - SmallVector returnsToReplace; - newQnodeFunc.walk([&](func::ReturnOp returnOp) { returnsToReplace.push_back(returnOp); }); - - for (auto returnOp : returnsToReplace) { - builder.setInsertionPoint(returnOp); - builder.create(returnOp.getLoc()); - returnOp.erase(); - } + // Clear operands from all return ops (make them return nothing) + newQnodeFunc.walk([](func::ReturnOp returnOp) { returnOp.getOperandsMutable().clear(); }); return newQnodeFunc; } @@ -864,14 +902,14 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { std::string globalNameStr = "__qubit_map_" + std::to_string(globalCounter++); StringRef globalName = globalNameStr; - // Create dense attribute with values [0, 1, 2, ..., numQubits-1] * 2 + // Create dense attribute with values [0, 1, 2, ..., numQubits-1] auto tensorType = RankedTensorType::get({static_cast(numQubits)}, builder.getIndexType()); SmallVector values; // Use IndexType::kInternalStorageBitWidth for index type unsigned indexWidth = IndexType::kInternalStorageBitWidth; for (size_t i = 0; i < numQubits; i++) { - values.push_back(APInt(indexWidth, i * 2)); + values.push_back(APInt(indexWidth, i)); } auto denseAttr = DenseIntElementsAttr::get(tensorType, values); @@ -933,16 +971,10 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { auto newEntryFuncType = FunctionType::get(ctx, oldEntryFuncType.getInputs(), {}); entryFunc.setFunctionType(newEntryFuncType); - // Clear the function body + // Clear the function body (reverse order so uses are erased before defs) Block *entryBlock = &entryFunc.getBody().front(); - SmallVector opsToErase; - for (Operation &op : entryBlock->getOperations()) { - opsToErase.push_back(&op); - } - for (auto op : opsToErase) { - op->dropAllUses(); - op->erase(); - } + for (auto &op : llvm::make_early_inc_range(llvm::reverse(*entryBlock))) + op.erase(); // Create call to kernel function OpBuilder entryBuilder(ctx); @@ -1044,26 +1076,12 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { failed(SCFStructuralConversion(newQnodeFunc, target, typeConverter, ctx)) || failed(PropagateEvents(newQnodeFunc, ctx)) || failed(CleanQuantumOps(newQnodeFunc, ctx)) || + failed(ResolveChannelMapping(newQnodeFunc, ctx)) || failed(CanonicalizeKernelFunction(newQnodeFunc, ctx))) { newQnodeFunc->emitError("Failed to convert to rtio dialect"); return signalPassFailure(); } - // Resolve the static channel, the dynamic channel will be remained as `?` - if (failed(ResolveChannelMapping(newQnodeFunc, ctx))) { - newQnodeFunc->emitError("Failed to resolve channel mapping"); - return signalPassFailure(); - } - - // TODO: Naive scheduling to generate the simple Timeline RTIO IR - // To shorten the `timeline`: `list scheduling`, `graph scheduling`, ... etc. can also be - // used to schedule the operations. Here we just linearly mapping the operation based on the - // order in the function without caring any latency issues. - // if (failed(NaiveScheduling(newQnodeFunc, ctx))) { - // newQnodeFunc->emitError("Failed to schedule"); - // return signalPassFailure(); - // } - // remove body of entry function and add call to kernel function auto entryFunc = module.lookupSymbol("jit_circuit"); if (!entryFunc) { From a349b1ce12f2f813d6dc8cd3960c991bd7f85224 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Thu, 27 Nov 2025 15:39:31 -0500 Subject: [PATCH 23/51] remove redundant code --- mlir/lib/Ion/Transforms/ion-to-rtio.cpp | 81 ++++--------------------- 1 file changed, 13 insertions(+), 68 deletions(-) diff --git a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp index fd50470bf1..1c68448886 100644 --- a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp +++ b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp @@ -880,6 +880,9 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { auto newFuncType = FunctionType::get(ctx, oldFuncType.getInputs(), {}); newQnodeFunc.setFunctionType(newFuncType); + // set public visibility for kernel function + newQnodeFunc.setPublic(); + // Clear operands from all return ops (make them return nothing) newQnodeFunc.walk([](func::ReturnOp returnOp) { returnOp.getOperandsMutable().clear(); }); @@ -963,56 +966,6 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { }); } - LogicalResult updateEntryFunction(func::FuncOp entryFunc, func::FuncOp newQnodeFunc, - MLIRContext *ctx) - { - // Update entry function return type to empty - auto oldEntryFuncType = entryFunc.getFunctionType(); - auto newEntryFuncType = FunctionType::get(ctx, oldEntryFuncType.getInputs(), {}); - entryFunc.setFunctionType(newEntryFuncType); - - // Clear the function body (reverse order so uses are erased before defs) - Block *entryBlock = &entryFunc.getBody().front(); - for (auto &op : llvm::make_early_inc_range(llvm::reverse(*entryBlock))) - op.erase(); - - // Create call to kernel function - OpBuilder entryBuilder(ctx); - entryBuilder.setInsertionPointToStart(entryBlock); - - SmallVector kernelArgs(entryFunc.getArguments().begin(), - entryFunc.getArguments().end()); - - // compare args type with kernel function arguments - if (kernelArgs.size() != newQnodeFunc.getArguments().size()) { - entryFunc->emitError("Failed to update entry function: number of arguments mismatch"); - return failure(); - } - for (size_t i = 0; i < kernelArgs.size(); i++) { - if (kernelArgs[i].getType() != newQnodeFunc.getArguments()[i].getType()) { - entryFunc->emitError("Failed to update entry function: argument type mismatch"); - return failure(); - } - } - - entryBuilder.create(entryFunc.getLoc(), newQnodeFunc.getName(), TypeRange{}, - kernelArgs); - - entryBuilder.setInsertionPointToEnd(entryBlock); - entryBuilder.create(entryFunc.getLoc()); - - // Remove other functions - // We currently just lower to the kernel function - auto module = entryFunc->getParentOfType(); - module.walk([&](func::FuncOp funcOp) { - if (funcOp.getName().str() == "teardown" || funcOp.getName().str() == "setup") { - funcOp.erase(); - } - }); - - return success(); - } - void runOnOperation() override { MLIRContext *ctx = &getContext(); @@ -1047,13 +1000,13 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { // drop one of the pulse from the certain protocol // the way we handle the dropped pulse will be updated in the future - newQnodeFunc.walk([&](ion::ParallelProtocolOp parallelProtocolOp) { - parallelProtocolOp.walk([&](ion::PulseOp pulseOp) { - if (pulseOp.getBeamAttr().getTransitionIndex().getInt() == 0) { - pulseOp.erase(); - } - }); + SmallVector pulsesToErase; + newQnodeFunc.walk([&](ion::PulseOp pulseOp) { + if (pulseOp.getBeamAttr().getTransitionIndex().getInt() == 0) + pulsesToErase.push_back(pulseOp); }); + for (auto pulseOp : pulsesToErase) + pulseOp.erase(); // Construct mapping from qreg alloc and qreg extract to memref // In the later conversion, we use the mapping to construct the channel for rtio.pulse @@ -1082,19 +1035,11 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { return signalPassFailure(); } - // remove body of entry function and add call to kernel function - auto entryFunc = module.lookupSymbol("jit_circuit"); - if (!entryFunc) { - module.emitError("Cannot find entry function 'jit_circuit'"); - return signalPassFailure(); - } - - if (failed(updateEntryFunction(entryFunc, newQnodeFunc, ctx))) { - module.emitError("Failed to update entry function"); - return signalPassFailure(); + for (auto funcOp : llvm::make_early_inc_range(module.getOps())) { + if (funcOp.getName().str() != newQnodeFunc.getName().str()) { + funcOp.erase(); + } } - - qnodeFunc->erase(); } }; From cc814da35de5fef24adf6d2fbe6b55216b4d7026 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Fri, 28 Nov 2025 00:07:46 -0500 Subject: [PATCH 24/51] update linkage --- mlir/lib/Ion/Transforms/ion-to-rtio.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp index 1c68448886..b70d2ad598 100644 --- a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp +++ b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp @@ -576,7 +576,7 @@ struct ResolveChannelMappingPattern : public OpRewritePattern(pulseOp->getAttr("offset")).getInt(); IntegerAttr channelIdAttr = rewriter.getIntegerAttr( - rewriter.getIndexType(), (channelIdValue.getSExtValue() + offset)); + rewriter.getIndexType(), (channelIdValue.getSExtValue() * 2 + offset)); auto resolvedChannelType = rtio::ChannelType::get(rewriter.getContext(), kind, qualifiers, channelIdAttr); @@ -880,8 +880,9 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { auto newFuncType = FunctionType::get(ctx, oldFuncType.getInputs(), {}); newQnodeFunc.setFunctionType(newFuncType); - // set public visibility for kernel function + // set public visibility and remove internal linkage for kernel function newQnodeFunc.setPublic(); + newQnodeFunc->removeAttr("llvm.linkage"); // Clear operands from all return ops (make them return nothing) newQnodeFunc.walk([](func::ReturnOp returnOp) { returnOp.getOperandsMutable().clear(); }); From 5681bb30d30785d316510313462002145a7250cc Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 2 Dec 2025 14:38:07 -0500 Subject: [PATCH 25/51] add rtio.config --- mlir/include/RTIO/IR/RTIODialect.td | 83 +- mlir/lib/Ion/Transforms/ion-to-rtio.cpp | 1137 +++++++++++++++++++++++ mlir/lib/RTIO/IR/RTIODialect.cpp | 17 +- mlir/test/RTIO/VerifierTest.mlir | 24 + 4 files changed, 1240 insertions(+), 21 deletions(-) create mode 100644 mlir/lib/Ion/Transforms/ion-to-rtio.cpp diff --git a/mlir/include/RTIO/IR/RTIODialect.td b/mlir/include/RTIO/IR/RTIODialect.td index 8b3609c664..e211140a52 100644 --- a/mlir/include/RTIO/IR/RTIODialect.td +++ b/mlir/include/RTIO/IR/RTIODialect.td @@ -39,15 +39,7 @@ def RTIO_Dialect : Dialect { let name = "rtio"; let cppNamespace = "::catalyst::rtio"; let useDefaultTypePrinterParser = 1; - let useDefaultAttributePrinterParser = 0; - - let extraClassDeclaration = [{ - mlir::Attribute parseAttribute(mlir::DialectAsmParser &parser, - mlir::Type type) const override; - - void printAttribute(mlir::Attribute attr, - mlir::DialectAsmPrinter &printer) const override; - }]; + let useDefaultAttributePrinterParser = 1; } //===----------------------------------------------------------------------===// @@ -164,5 +156,78 @@ def RTIOEventType : RTIO_Type<"Event", "event"> { class RTIO_Op traits = []> : Op; +//===----------------------------------------------------------------------===// +// RTIO Attributes +//===----------------------------------------------------------------------===// + +class RTIO_Attr traits = []> + : AttrDef { + let mnemonic = attrMnemonic; +} + +def RTIOConfigAttr : RTIO_Attr<"Config", "config"> { + let summary = "A dictionary attribute for RTIO configuration"; + let description = [{ + A configuration attribute that wraps a DictionaryAttr for RTIO-specific metadata. + Can be used as a module-level attribute. + + Example: + ```mlir + module @my_module attributes { + rtio.config = #rtio.config<{ + config1 = 1 : i32, + config2 = "test", + }> + } { + // ... + } + ``` + }]; + + let parameters = (ins + "mlir::DictionaryAttr":$dict + ); + + let assemblyFormat = "`<` $dict `>`"; + + let extraClassDeclaration = [{ + /// The canonical attribute name for module-level config. + static llvm::StringRef getModuleAttrName() { + return "rtio.config"; + } + + /// Return the value for the given key, or null if not found. + mlir::Attribute get(llvm::StringRef key) const { + return getDict().get(key); + } + mlir::Attribute get(mlir::StringAttr key) const { + return getDict().get(key); + } + + /// Return whether the config contains the given key. + bool contains(llvm::StringRef key) const { + return getDict().contains(key); + } + bool contains(mlir::StringAttr key) const { + return getDict().contains(key); + } + + /// Return the specified named attribute if present. + std::optional getNamed(llvm::StringRef key) const { + return getDict().getNamed(key); + } + std::optional getNamed(mlir::StringAttr key) const { + return getDict().getNamed(key); + } + + /// Support range iteration (delegates to DictionaryAttr). + using iterator = mlir::DictionaryAttr::iterator; + iterator begin() const { return getDict().begin(); } + iterator end() const { return getDict().end(); } + bool empty() const { return getDict().empty(); } + size_t size() const { return getDict().size(); } + }]; +} + #endif // RTIO_DIALECT diff --git a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp new file mode 100644 index 0000000000..77765b424e --- /dev/null +++ b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp @@ -0,0 +1,1137 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "llvm/Support/JSON.h" +#include "llvm/Support/MemoryBuffer.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Transforms/CSE.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "Ion/IR/IonDialect.h" +#include "Ion/IR/IonOps.h" +#include "Ion/Transforms/Passes.h" +#include "Quantum/IR/QuantumDialect.h" +#include "Quantum/IR/QuantumOps.h" +#include "RTIO/IR/RTIODialect.h" +#include "RTIO/IR/RTIOOps.h" + +using namespace mlir; +using namespace catalyst; + +namespace catalyst { +namespace ion { + +namespace { + +//===----------------------------------------------------------------------===// +// Helper functions +//===----------------------------------------------------------------------===// + +enum class TraceMode { + Qreg = 0, + Event = 1, +}; + +/// Traces a Value backward through the IR by tracing its dataflow dependencies +/// across control flow and specific quantum operations. +/// +/// Template Parameters: +/// - ModeT: TraceMode enum (Qreg or Event) that controls how quantum.insert +/// operations are handled +/// Qreg mode: Trace to find the source qreg of the given value +/// Event mode: Trace to find all events that contribute to the given value +/// - CallbackT: Callable type that will be invoked for each visited value. +/// May optionally return WalkResult for early termination. +/// +/// Supported Operations: +/// - scf.for +/// - scf.if +/// - ion.parallelprotocol +/// - unrealized_conversion_cast +/// - quantum.extract +/// - quantum.insert +template +auto traceValueWithCallback(Value value, CallbackT &&callback) +{ + WalkResult walkResult = WalkResult::advance(); + std::queue visited; + visited.push(value); + + while (!visited.empty()) { + Value value = visited.front(); + visited.pop(); + + if constexpr (std::is_same_v, WalkResult>) { + if (callback(value).wasInterrupted()) { + walkResult = WalkResult::interrupt(); + continue; + } + } + else { + callback(value); + } + + if (auto arg = mlir::dyn_cast(value)) { + Block *block = arg.getOwner(); + Operation *parentOp = block->getParentOp(); + + if (auto forOp = dyn_cast(parentOp)) { + unsigned argIndex = arg.getArgNumber(); + Value iterArg = forOp.getInitArgs()[argIndex - 1]; + visited.push(iterArg); + continue; + } + else if (auto parallelProtocolOp = dyn_cast(parentOp)) { + unsigned argIndex = arg.getArgNumber(); + Value inQubit = parallelProtocolOp.getInQubits()[argIndex]; + visited.push(inQubit); + continue; + } + parentOp->emitError("Unsupported parent operation for block argument: ") << value; + llvm::reportFatalInternalError("Unsupported block argument"); + } + + Operation *defOp = value.getDefiningOp(); + if (defOp == nullptr) { + continue; + } + + if (auto forOp = dyn_cast(defOp)) { + unsigned resultIdx = llvm::cast(value).getResultNumber(); + BlockArgument iterArg = forOp.getRegionIterArg(resultIdx); + visited.push(iterArg); + } + else if (auto ifOp = dyn_cast(defOp)) { + unsigned resultIdx = llvm::cast(value).getResultNumber(); + Value thenValue = ifOp.thenYield().getOperand(resultIdx); + Value elseValue = ifOp.elseYield().getOperand(resultIdx); + visited.push(thenValue); + visited.push(elseValue); + } + else if (auto parallelProtocolOp = dyn_cast(defOp)) { + unsigned resultIdx = llvm::cast(value).getResultNumber(); + Value inQubit = parallelProtocolOp.getInQubits()[resultIdx]; + visited.push(inQubit); + } + else if (auto op = dyn_cast(defOp)) { + visited.push(op.getInputs().front()); + } + else if (auto op = dyn_cast(defOp)) { + visited.push(op.getQreg()); + } + else if (auto op = dyn_cast(defOp)) { + visited.push(op.getQubit()); + } + else if (auto op = dyn_cast(defOp)) { + Value inQreg = op.getInQreg(); + Value qubit = op.getQubit(); + if constexpr (ModeT == TraceMode::Qreg) { + visited.push(inQreg); + } + else if constexpr (ModeT == TraceMode::Event) { + visited.push(qubit); + // only trace qreg if it defined op is also come from insert op + if (llvm::isa_and_present(inQreg.getDefiningOp())) { + visited.push(inQreg); + } + } + } + } + + if constexpr (std::is_same_v, WalkResult>) { + return walkResult; + } +} + +Value awaitEvents(ArrayRef events, PatternRewriter &rewriter) +{ + if (events.size() == 1) { + return events.front(); + } + auto eventType = rtio::EventType::get(rewriter.getContext()); + return rewriter.create(rewriter.getUnknownLoc(), eventType, events); +} + +// Helper class to store ion information +class IonInfo { + private: + llvm::StringMap levelEnergyMap; + + struct TransitionInfo { + std::string level0; + std::string level1; + double einstein_a; + std::string multipole; + }; + SmallVector transitions; + + public: + IonInfo(ion::IonOp op) + { + auto levelAttrs = op.getLevels(); + auto transitionsAttr = op.getTransitions(); + + // Map from Level label to Energy value + for (auto levelAttr : levelAttrs) { + auto level = cast(levelAttr); + std::string label = level.getLabel().getValue().str(); + double energy = level.getEnergy().getValueAsDouble(); + levelEnergyMap[label] = energy; + } + + // Store transition information + for (auto transitionAttr : transitionsAttr) { + auto transition = cast(transitionAttr); + TransitionInfo info; + info.level0 = transition.getLevel_0().getValue().str(); + info.level1 = transition.getLevel_1().getValue().str(); + info.einstein_a = transition.getEinsteinA().getValueAsDouble(); + info.multipole = transition.getMultipole().getValue().str(); + transitions.push_back(info); + } + } + + // Get energy of a level by label + std::optional getLevelEnergy(StringRef label) const + { + auto it = levelEnergyMap.find(label.str()); + if (it != levelEnergyMap.end()) { + return it->second; + } + return std::nullopt; + } + + // Get level label of a transition by index + template + std::optional getTransitionLevelEnergy(size_t transitionIndex) const + { + static_assert(IndexT == 0 || IndexT == 1, "IndexT must be 0 or 1"); + + if (transitionIndex >= transitions.size()) { + return std::nullopt; + } + + const auto &transition = transitions[transitionIndex]; + if constexpr (IndexT == 0) { + return getLevelEnergy(transition.level0); + } + else { + return getLevelEnergy(transition.level1); + } + } + + // Get energy difference of a transition (level1 energy - level0 energy) + std::optional getTransitionEnergyDiff(size_t index) const + { + if (index >= transitions.size()) { + return std::nullopt; + } + + auto energy0 = getTransitionLevelEnergy<0>(index); + auto energy1 = getTransitionLevelEnergy<1>(index); + + if (energy0.has_value() && energy1.has_value()) { + return energy1.value() - energy0.value(); + } + + return std::nullopt; + } + + // Get number of transitions + size_t getNumTransitions() const { return transitions.size(); } + + // Get transition info by index + std::optional getTransition(size_t index) const + { + if (index < transitions.size()) { + return transitions[index]; + } + return std::nullopt; + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Conversion Patterns +//===----------------------------------------------------------------------===// + +/// Convert ion.parallelprotocol and introduce rtio.sync to ensure the order +/// +/// Example: +/// ``` +/// %0, %1 = ion.parallelprotocol(%q0, %q1) { +/// ^bb0(%arg0, %arg1): +/// %p0 = rtio.pulse(...) : !rtio.event +/// %p1 = rtio.pulse(...) : !rtio.event +/// ion.yield %arg0, %arg1 +/// } +/// ``` +/// will be converted to: +/// ``` +/// %event0 = unrealized_conversion_cast %q0 : !ion.qubit -> !rtio.event +/// %event1 = unrealized_conversion_cast %q1 : !ion.qubit -> !rtio.event +/// %p0 = rtio.pulse(..., wait = %event0) : !rtio.event +/// %p1 = rtio.pulse(..., wait = %event1) : !rtio.event +/// %sync = rtio.sync %p0, %p1 : !rtio.event +/// %0 = unrealized_conversion_cast %sync : !rtio.event -> !ion.qubit +/// %1 = unrealized_conversion_cast %sync : !rtio.event -> !ion.qubit +/// ``` +/// Those unrealized conversion casts are used to establish the dependency but will be +/// resolved by the subsequent stages. +struct ParallelProtocolToRTIOPattern : public OpConversionPattern { + ParallelProtocolToRTIOPattern(TypeConverter &typeConverter, MLIRContext *ctx) + : OpConversionPattern(typeConverter, ctx) + { + } + + LogicalResult matchAndRewrite(ion::ParallelProtocolOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + MLIRContext *ctx = rewriter.getContext(); + Location loc = op.getLoc(); + + Block *regionBlock = &op.getBodyRegion().front(); + IRMapping irMapping; + SmallVector inQubits; + for (auto [blockArg, operand] : + llvm::zip(regionBlock->getArguments(), adaptor.getOperands())) { + irMapping.map(blockArg, operand); + + // collect qubits to trace the events + if (isa(operand.getType())) { + inQubits.push_back(operand); + } + } + + // create events for each qubit + auto events = llvm::map_range(inQubits, [&](Value qubit) { + auto eventType = rtio::EventType::get(ctx); + return rewriter.create(loc, eventType, qubit).getResult(0); + }); + + Value inputSyncEvent = awaitEvents(llvm::to_vector(events), rewriter); + + // Clone operations from the region to outside + SmallVector pulseEvents; + DenseMap qubitToOffset; + + // we cache the channel to index mapping to avoid multiple lookups + DenseMap cache; + for (auto ®ionOp : regionBlock->without_terminator()) { + auto *clonedOp = rewriter.clone(regionOp, irMapping); + if (auto pulseOp = dyn_cast(clonedOp)) { + // set wait event for the pulse operation + pulseOp.setWait(inputSyncEvent); + + Value index = nullptr; + + SmallVector chain; + traceValueWithCallback( + pulseOp.getChannel(), [&](Value value) -> WalkResult { + if (cache.count(value)) { + index = cache[value]; + return WalkResult::interrupt(); + } + chain.push_back(value); + if (auto loadOp = + llvm::dyn_cast_if_present(value.getDefiningOp())) { + index = loadOp.getIndices()[0]; + + // cache the channel to index mapping + cache[pulseOp.getChannel()] = index; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + + assert(index != nullptr && "index must not be null"); + + // update cache + for (Value value : chain) { + cache[value] = index; + } + pulseOp->setAttr("offset", rewriter.getI64IntegerAttr(qubitToOffset[index])); + + // the same qubit may appear multiple times in the parallel protocol + // so we need to increment the offset for each appearance + qubitToOffset[index]++; + + pulseEvents.push_back(pulseOp.getEvent()); + } + irMapping.map(regionOp.getResults(), clonedOp->getResults()); + } + + // Create sync operation from pulse events (must have at least one after Phase 1) + assert(pulseEvents.size() > 0 && + "must have at least one pulse operation after parallel protocol conversion"); + + Value outputSyncEvent = awaitEvents(llvm::to_vector(pulseEvents), rewriter); + + SmallVector results; + for (Value result : op.getResults()) { + // unrealized conversion cast sync event to result type + auto event = + rewriter.create(loc, result.getType(), outputSyncEvent); + results.push_back(event.getResult(0)); + } + + rewriter.replaceOp(op, results); + return success(); + } +}; + +/// Convert ion.pulse to rtio.pulse +/// +/// Example: +/// ``` +/// %pulse = ion.pulse(%duration) %qubit { +/// beam = #ion.beam<...> +/// } : !ion.pulse +/// ``` +/// will be converted to: +/// ``` +/// %ch = rtio.qubit_to_channel %qubit : !ion.qubit -> !rtio.channel<"dds", ?> +/// ... // other pulse parameters settings +/// %event = rtio.pulse %ch duration(%duration) frequency(%freq) phase(%phase) +/// : !rtio.channel<"dds", ?> -> !rtio.event +/// ``` +struct PulseToRTIOPattern : public OpConversionPattern { + IonInfo ionInfo; + DenseMap &qextractToMemrefMap; + PulseToRTIOPattern(TypeConverter &typeConverter, MLIRContext *ctx, IonInfo ionInfo, + DenseMap &qextractToMemrefMap) + : OpConversionPattern(typeConverter, ctx), ionInfo(ionInfo), + qextractToMemrefMap(qextractToMemrefMap) + { + } + + double calculateFrequency(int64_t transitionIndex, double detuning, + const IonInfo &ionInfo) const + { + // TODO: raman1_frequency can be passed as a pass option for extensibility + double raman1_frequency = 2 * llvm::numbers::pi * 844.485e12 - + 2 * llvm::numbers::pi * 12.643e9 - 2 * llvm::numbers::pi * 20e6; + + auto energyDiff = ionInfo.getTransitionEnergyDiff(transitionIndex); + assert(energyDiff.has_value() && "energyDiff must have a value"); + + double reference_energy = energyDiff.value(); + double frequency = + (reference_energy + detuning - raman1_frequency) / (2.0 * llvm::numbers::pi); + return frequency; + } + + LogicalResult matchAndRewrite(ion::PulseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + Location loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + + // Get pulse parameters + Value duration = op.getTime(); + auto beamAttr = op.getBeam(); + auto phaseAttr = op.getPhase(); + + // Extract beam parameters + double detuning = beamAttr.getDetuning().getValueAsDouble(); + double phase = phaseAttr.getValueAsDouble(); + int64_t transitionIndex = beamAttr.getTransitionIndex().getInt(); + double frequency = calculateFrequency(transitionIndex, detuning, ionInfo); + Value freqValue = + rewriter.create(loc, rewriter.getF64FloatAttr(frequency)); + Value phaseValue = rewriter.create(loc, rewriter.getF64FloatAttr(phase)); + + // Convert the qubit to a channel + ArrayAttr qualifiers = rewriter.getArrayAttr({rewriter.getI64IntegerAttr(transitionIndex)}); + auto channelType = rtio::ChannelType::get(ctx, "dds", qualifiers, nullptr); + + Value memrefLoadValue = nullptr; + traceValueWithCallback(op.getInQubit(), [&](Value value) -> WalkResult { + if (qextractToMemrefMap.count(value)) { + memrefLoadValue = qextractToMemrefMap[value]; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + assert(memrefLoadValue != nullptr && "memrefLoadValue must not be null"); + + Value channel = + rewriter.create(loc, channelType, memrefLoadValue); + + // Create rtio.pulse + auto eventType = rtio::EventType::get(ctx); + Value event = rewriter.create(loc, eventType, channel, duration, + freqValue, phaseValue, nullptr); + rewriter.replaceOp(op, event); + + return success(); + } +}; + +/// Resolve the static channel mapping for the rtio.qubit_to_channel operation +/// +/// It's expecting `qubit_to_channel` has the following def-use chain: +/// memref.global w/ constants -> memref.get_global -> memref.load -> qubit_to_channel +/// +/// Example: +/// ``` +/// %ch = rtio.qubit_to_channel %qubit : !ion.qubit -> !rtio.channel<"dds", ?> +/// ``` +/// will be converted to: +/// ``` +/// %ch = rtio.channel "dds" { channel_id = 0 } : !rtio.channel<"dds"> +/// ``` +struct ResolveChannelMappingPattern : public OpRewritePattern { + ResolveChannelMappingPattern(MLIRContext *ctx) + : OpRewritePattern(ctx) + { + } + + LogicalResult matchAndRewrite(rtio::RTIOQubitToChannelOp op, + PatternRewriter &rewriter) const override + { + Location loc = op.getLoc(); + Value qubit = op.getQubit(); + + auto loadOp = qubit.getDefiningOp(); + if (!loadOp) { + return failure(); + } + + Value memref = loadOp.getMemRef(); + auto getGlobalOp = memref.getDefiningOp(); + if (!getGlobalOp) { + return failure(); + } + + StringRef globalName = getGlobalOp.getName(); + ModuleOp module = op->getParentOfType(); + if (!module) { + return failure(); + } + auto globalOp = module.lookupSymbol(globalName); + if (!globalOp) { + return failure(); + } + + auto initialValue = globalOp.getInitialValue(); + if (!initialValue) { + return failure(); + } + + auto denseAttr = llvm::dyn_cast(*initialValue); + if (!denseAttr) { + return failure(); + } + + ValueRange indices = loadOp.getIndices(); + if (indices.size() != 1) { + return failure(); + } + + IntegerAttr indexAttr; + if (!matchPattern(indices[0], m_Constant(&indexAttr))) { + return failure(); + } + + int64_t index = indexAttr.getInt(); + + size_t denseSize = denseAttr.size(); + if (index < 0 || static_cast(index) >= denseSize) { + return failure(); + } + + APInt channelIdValue = denseAttr.getValues()[index]; + + auto originalChannelType = llvm::dyn_cast(op.getChannel().getType()); + if (!originalChannelType) { + return failure(); + } + StringRef kind = originalChannelType.getKind(); + ArrayAttr qualifiers = originalChannelType.getQualifiers(); + + // channel should have exactly one use before lowering to channel op + assert(op.getChannel().hasOneUse() && "channel should have exactly one use"); + + auto pulseOp = cast(*op.getChannel().getUsers().begin()); + int64_t offset = cast(pulseOp->getAttr("offset")).getInt(); + + IntegerAttr channelIdAttr = rewriter.getIntegerAttr( + rewriter.getIndexType(), (channelIdValue.getSExtValue() * 2 + offset)); + + auto resolvedChannelType = + rtio::ChannelType::get(rewriter.getContext(), kind, qualifiers, channelIdAttr); + + Value channel = rewriter.create(loc, resolvedChannelType); + + rewriter.replaceOp(op, channel); + + return success(); + } +}; + +/// Propagates RTIO events from chain of operations to event types. +/// +/// Steps: +/// 1. Traces backward to find all events that contribute to the current event value +/// 2. Creates a sync event from all collected events +/// 3. Replaces the cast operation with the sync event +struct PropagateEventsPattern : public OpRewritePattern { + MLIRContext *ctx; + + PropagateEventsPattern(MLIRContext *ctx) + : OpRewritePattern(ctx), ctx(ctx) + { + } + + LogicalResult matchAndRewrite(UnrealizedConversionCastOp op, + PatternRewriter &rewriter) const override + { + if (op.getNumOperands() != 1 || op.getNumResults() != 1) + return failure(); + + Type srcType = op.getInputs()[0].getType(); + Type dstType = op.getResult(0).getType(); + + // Only match casts from quantum/ion types to event type + // quantum.qreg -> event, quantum.qubit -> event, ion.qubit -> event + bool validSrcType = + llvm::isa(srcType); + bool validDstType = llvm::isa(dstType); + if (!validSrcType || !validDstType) + return failure(); + + Value input = op.getInputs()[0]; + + // Find associated events + // Skip over intermediate cast/extract/insert operations to collect events + bool reachedAllocOp = false; + SetVector events; + traceValueWithCallback(input, [&](Value value) -> WalkResult { + auto defOp = value.getDefiningOp(); + if (defOp && + isa(defOp)) { + return WalkResult::advance(); + } + + // collect event and stop tracing this path + if (isa(value.getType())) { + events.insert(value); + return WalkResult::interrupt(); + } + + if (isa(defOp)) { + reachedAllocOp = true; + return WalkResult::interrupt(); + } + + return WalkResult::advance(); + }); + + if (reachedAllocOp && events.empty()) { + auto eventType = rtio::EventType::get(ctx); + Value emptyEvent = rewriter.create(op.getLoc(), eventType); + rewriter.replaceOp(op, emptyEvent); + return success(); + } + + if (events.empty()) { + op.emitError("No events found for cast op"); + llvm::reportFatalInternalError("No events found for cast op"); + } + + // Create a sync event from all collected events + // TODO: check domination, so that we can avoid creating a sync event if events are + // already dominated by one of the events + Value syncEvent = awaitEvents(events.getArrayRef(), rewriter); + rewriter.replaceOp(op, syncEvent); + return success(); + } +}; + +/// Clean up quantum/ion related ops that are not needed after conversion +struct CleanQuantumOpsPattern : public RewritePattern { + CleanQuantumOpsPattern(MLIRContext *ctx) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) + { + } + + LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override + { + Dialect *dialect = op->getDialect(); + if (!dialect || !isa(dialect)) + return failure(); + + if (!op->use_empty()) { + return failure(); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +LogicalResult CleanQuantumOps(func::FuncOp funcOp, MLIRContext *ctx) +{ + RewritePatternSet patterns(ctx); + patterns.add(ctx); + if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { + return failure(); + } + return success(); +} + +LogicalResult CanonicalizeKernelFunction(func::FuncOp funcOp, MLIRContext *ctx) +{ + RewritePatternSet patterns(ctx); + for (auto *dialect : ctx->getLoadedDialects()) { + dialect->getCanonicalizationPatterns(patterns); + } + for (RegisteredOperationName op : ctx->getRegisteredOperations()) { + op.getCanonicalizationPatterns(patterns, ctx); + } + if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { + return failure(); + } + + IRRewriter rewriter(ctx); + DominanceInfo domInfo(funcOp); + eliminateCommonSubExpressions(rewriter, domInfo, funcOp); + + return success(); +} + +LogicalResult ResolveChannelMapping(func::FuncOp funcOp, MLIRContext *ctx) +{ + RewritePatternSet patterns(ctx); + patterns.add(ctx); + if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { + return failure(); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// JSON to MLIR Attribute Conversion +//===----------------------------------------------------------------------===// + +/// Convert a JSON value to an MLIR Attribute +Attribute jsonToAttribute(MLIRContext *ctx, const llvm::json::Value &json) +{ + if (auto str = json.getAsString()) { + return StringAttr::get(ctx, *str); + } + if (auto num = json.getAsInteger()) { + return IntegerAttr::get(IntegerType::get(ctx, 64), *num); + } + if (auto num = json.getAsNumber()) { + return FloatAttr::get(Float64Type::get(ctx), *num); + } + if (auto b = json.getAsBoolean()) { + return BoolAttr::get(ctx, *b); + } + if (auto arr = json.getAsArray()) { + SmallVector attrs; + for (const auto &elem : *arr) { + attrs.push_back(jsonToAttribute(ctx, elem)); + } + return ArrayAttr::get(ctx, attrs); + } + if (auto *obj = json.getAsObject()) { + SmallVector entries; + for (const auto &kv : *obj) { + StringRef key = kv.first; + entries.emplace_back(StringAttr::get(ctx, key), jsonToAttribute(ctx, kv.second)); + } + // Sort entries by name for DictionaryAttr + llvm::sort(entries, [](const NamedAttribute &lhs, const NamedAttribute &rhs) { + return lhs.getName().getValue() < rhs.getName().getValue(); + }); + return DictionaryAttr::get(ctx, entries); + } + // null + return UnitAttr::get(ctx); +} + +/// Load a JSON file and convert it to an rtio.config attribute +FailureOr loadDeviceDbAsConfig(MLIRContext *ctx, StringRef filePath) +{ + auto fileOrErr = llvm::MemoryBuffer::getFile(filePath); + if (!fileOrErr) { + return failure(); + } + + auto json = llvm::json::parse((*fileOrErr)->getBuffer()); + if (!json) { + llvm::errs() << "Failed to parse JSON: " << llvm::toString(json.takeError()) << "\n"; + return failure(); + } + + auto *obj = json->getAsObject(); + if (!obj) { + llvm::errs() << "Device DB JSON must be an object\n"; + return failure(); + } + + // Convert JSON object to DictionaryAttr + SmallVector entries; + for (const auto &kv : *obj) { + StringRef key = kv.first; + entries.emplace_back(StringAttr::get(ctx, key), jsonToAttribute(ctx, kv.second)); + } + llvm::sort(entries, [](const NamedAttribute &lhs, const NamedAttribute &rhs) { + return lhs.getName().getValue() < rhs.getName().getValue(); + }); + + auto dictAttr = DictionaryAttr::get(ctx, entries); + return rtio::ConfigAttr::get(ctx, dictAttr); +} + +//===----------------------------------------------------------------------===// +// Pass Implementation +//===----------------------------------------------------------------------===// + +#define GEN_PASS_DEF_IONTORTIOPASS +#include "Ion/Transforms/Passes.h.inc" + +struct IonToRTIOPass : public impl::IonToRTIOPassBase { + using impl::IonToRTIOPassBase::IonToRTIOPassBase; + + LogicalResult IonPulseConversion(func::FuncOp funcOp, ConversionTarget &baseTarget, + TypeConverter &typeConverter, IonInfo ionInfo, + DenseMap &qextractToMemrefMap, MLIRContext *ctx) + { + ConversionTarget target(baseTarget); + target.addIllegalOp(); + + RewritePatternSet patterns(ctx); + patterns.add(typeConverter, ctx, ionInfo, qextractToMemrefMap); + if (failed(applyPartialConversion(funcOp, target, std::move(patterns)))) { + return failure(); + } + return success(); + } + + LogicalResult ParallelProtocolConversion(func::FuncOp funcOp, ConversionTarget &baseTarget, + TypeConverter &typeConverter, MLIRContext *ctx) + { + ConversionTarget target(baseTarget); + target.addIllegalOp(); + + RewritePatternSet patterns(ctx); + patterns.add(typeConverter, ctx); + if (failed(applyPartialConversion(funcOp, target, std::move(patterns)))) { + return failure(); + } + return success(); + } + + LogicalResult SCFStructuralConversion(func::FuncOp funcOp, ConversionTarget &target, + TypeConverter &typeConverter, MLIRContext *ctx) + { + TypeConverter scfTypeConverter(typeConverter); + scfTypeConverter.addConversion( + [ctx](quantum::QubitType) -> Type { return rtio::EventType::get(ctx); }); + scfTypeConverter.addConversion( + [ctx](quantum::QuregType) -> Type { return rtio::EventType::get(ctx); }); + scfTypeConverter.addConversion( + [ctx](ion::QubitType) -> Type { return rtio::EventType::get(ctx); }); + // Add materialization for quantum/ion -> event + scfTypeConverter.addSourceMaterialization( + [](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) -> Value { + if (inputs.size() != 1) + return nullptr; + Type inputType = inputs.front().getType(); + if (inputType != resultType) { + return builder.create(loc, resultType, inputs) + .getResult(0); + } + return inputs[0]; + }); + // Add target materialization for event -> quantum/ion + scfTypeConverter.addTargetMaterialization( + [](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) -> Value { + if (inputs.size() != 1) + return nullptr; + Type inputType = inputs.front().getType(); + if (inputType != resultType) { + return builder.create(loc, resultType, inputs) + .getResult(0); + } + return inputs[0]; + }); + + ConversionTarget scfTarget(getContext()); + scfTarget.addLegalDialect(); + + // Mark SCF ops as illegal only if they use quantum/ion types + scfTarget.addDynamicallyLegalOp([&](scf::ForOp op) { + for (auto arg : op.getRegionIterArgs()) { + Type type = arg.getType(); + if (llvm::isa(type)) { + return false; + } + } + for (auto result : op.getResults()) { + Type type = result.getType(); + if (llvm::isa(type)) { + return false; + } + } + return true; + }); + + scfTarget.addDynamicallyLegalOp([&](scf::IfOp op) { + for (auto result : op.getResults()) { + Type type = result.getType(); + if (llvm::isa(type)) { + return false; + } + } + return true; + }); + + scfTarget.addLegalOp(); + + // restructure SCF Operations + RewritePatternSet scfPatterns(&getContext()); + mlir::scf::populateSCFStructuralTypeConversionsAndLegality(scfTypeConverter, scfPatterns, + scfTarget); + + if (failed(applyPartialConversion(funcOp, scfTarget, std::move(scfPatterns)))) { + return failure(); + } + + return success(); + } + + LogicalResult PropagateEvents(func::FuncOp funcOp, MLIRContext *ctx) + { + RewritePatternSet patterns(ctx); + patterns.add(ctx); + if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { + return failure(); + } + return success(); + } + + SmallVector getIonInfos() + { + SmallVector ionInfos; + getOperation()->walk([&](ion::IonOp ionOp) { ionInfos.emplace_back(IonInfo(ionOp)); }); + return ionInfos; + } + + func::FuncOp createKernelFunction(func::FuncOp qnodeFunc, std::string kernelName, + OpBuilder &builder) + { + MLIRContext *ctx = builder.getContext(); + + auto newQnodeFunc = qnodeFunc.clone(); + newQnodeFunc.setName(kernelName); + auto oldFuncType = qnodeFunc.getFunctionType(); + // create new function type with empty results + auto newFuncType = FunctionType::get(ctx, oldFuncType.getInputs(), {}); + newQnodeFunc.setFunctionType(newFuncType); + + // set public visibility and remove internal linkage for kernel function + newQnodeFunc.setPublic(); + newQnodeFunc->removeAttr("llvm.linkage"); + + // Clear operands from all return ops (make them return nothing) + newQnodeFunc.walk([](func::ReturnOp returnOp) { returnOp.getOperandsMutable().clear(); }); + + return newQnodeFunc; + } + + void initializeMemrefMap(func::FuncOp funcOp, ModuleOp module, + DenseMap &qregToMemrefMap, + DenseMap &qextractToMemrefMap, MLIRContext *ctx) + { + OpBuilder builder(ctx); + + int globalCounter = 0; + funcOp.walk([&](quantum::AllocOp allocOp) { + size_t numQubits = allocOp.getNqubitsAttr().value(); + auto memrefType = + MemRefType::get({static_cast(numQubits)}, builder.getIndexType()); + + // Create a unique symbol name for this global + std::string globalNameStr = "__qubit_map_" + std::to_string(globalCounter++); + StringRef globalName = globalNameStr; + + // Create dense attribute with values [0, 1, 2, ..., numQubits-1] + auto tensorType = + RankedTensorType::get({static_cast(numQubits)}, builder.getIndexType()); + SmallVector values; + // Use IndexType::kInternalStorageBitWidth for index type + unsigned indexWidth = IndexType::kInternalStorageBitWidth; + for (size_t i = 0; i < numQubits; i++) { + values.push_back(APInt(indexWidth, i)); + } + auto denseAttr = DenseIntElementsAttr::get(tensorType, values); + + // Create global memref at module level + builder.setInsertionPointToStart(module.getBody()); + auto globalOp = + memref::GlobalOp::create(builder, allocOp.getLoc(), + builder.getStringAttr(globalName), // sym_name + builder.getStringAttr("private"), // sym_visibility + TypeAttr::get(memrefType), // type + denseAttr, // initial_value + builder.getUnitAttr(), // constant + IntegerAttr()); // alignment + + // Get the global memref in the function + builder.setInsertionPointAfter(allocOp); + Value qubitMap = builder.create(allocOp.getLoc(), memrefType, + globalOp.getSymName()); + + qregToMemrefMap[allocOp.getResult()] = qubitMap; + }); + + funcOp.walk([&](quantum::ExtractOp extractOp) { + traceValueWithCallback( + extractOp.getQreg(), [&](Value value) -> WalkResult { + if (qregToMemrefMap.count(value)) { + builder.setInsertionPointAfter(extractOp); + auto memref = qregToMemrefMap[value]; + + Value memrefLoadValue = nullptr; + if (Value idx = extractOp.getIdx()) { + // idx is an operand (i64), need to cast to index + Value indexValue = builder.create( + extractOp.getLoc(), builder.getIndexType(), idx); + memrefLoadValue = builder.create( + extractOp.getLoc(), memref, ValueRange{indexValue}); + } + else if (IntegerAttr idxAttr = extractOp.getIdxAttrAttr()) { + Value indexValue = builder.create( + extractOp.getLoc(), idxAttr.getInt()); + memrefLoadValue = builder.create( + extractOp.getLoc(), memref, ValueRange{indexValue}); + } + if (memrefLoadValue) { + qextractToMemrefMap[extractOp.getResult()] = memrefLoadValue; + } + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + }); + } + + void runOnOperation() override + { + MLIRContext *ctx = &getContext(); + auto module = cast(getOperation()); + + // Load device_db JSON file and set rtio.config attribute on module + if (!deviceDb.empty()) { + auto configOrErr = loadDeviceDbAsConfig(ctx, deviceDb); + if (failed(configOrErr)) { + module->emitError("Failed to load device database from: ") << deviceDb; + return signalPassFailure(); + } + module->setAttr(rtio::ConfigAttr::getModuleAttrName(), *configOrErr); + } + + // check if there is only one qnode function + func::FuncOp qnodeFunc = nullptr; + int qnodeCounts = 0; + module.walk([&](func::FuncOp funcOp) { + if (funcOp->hasAttr("qnode")) { + qnodeFunc = funcOp; + qnodeCounts++; + } + }); + assert(qnodeCounts == 1 && "only one qnode function is allowed"); + + // collect all ion information for calculating frequency when converting ion.pulse + SmallVector ionInfos = getIonInfos(); + if (ionInfos.empty()) { + getOperation()->emitError("Failed to get ion information"); + return signalPassFailure(); + } + + // currently, we only support one ion information + assert(ionInfos.size() == 1 && "only one ion information is allowed"); + IonInfo &ionInfo = ionInfos.front(); + + // clone qnode function as new kernel function + OpBuilder builder(ctx); + func::FuncOp newQnodeFunc = createKernelFunction(qnodeFunc, kernelName, builder); + module.insert(qnodeFunc, newQnodeFunc); + + // drop one of the pulse from the certain protocol + // the way we handle the dropped pulse will be updated in the future + SmallVector pulsesToErase; + newQnodeFunc.walk([&](ion::PulseOp pulseOp) { + if (pulseOp.getBeamAttr().getTransitionIndex().getInt() == 0) + pulsesToErase.push_back(pulseOp); + }); + for (auto pulseOp : pulsesToErase) + pulseOp.erase(); + + // Construct mapping from qreg alloc and qreg extract to memref + // In the later conversion, we use the mapping to construct the channel for rtio.pulse + DenseMap qregToMemrefMap; + DenseMap qextractToMemrefMap; + initializeMemrefMap(newQnodeFunc, module, qregToMemrefMap, qextractToMemrefMap, ctx); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addConversion( + [&](ion::PulseType type) -> Type { return rtio::EventType::get(ctx); }); + + ConversionTarget target(*ctx); + target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); + + // prepare kernel function + if (failed(IonPulseConversion(newQnodeFunc, target, typeConverter, ionInfo, + qextractToMemrefMap, ctx)) || + failed(ParallelProtocolConversion(newQnodeFunc, target, typeConverter, ctx)) || + failed(SCFStructuralConversion(newQnodeFunc, target, typeConverter, ctx)) || + failed(PropagateEvents(newQnodeFunc, ctx)) || + failed(CleanQuantumOps(newQnodeFunc, ctx)) || + failed(ResolveChannelMapping(newQnodeFunc, ctx)) || + failed(CanonicalizeKernelFunction(newQnodeFunc, ctx))) { + newQnodeFunc->emitError("Failed to convert to rtio dialect"); + return signalPassFailure(); + } + + for (auto funcOp : llvm::make_early_inc_range(module.getOps())) { + if (funcOp.getName().str() != newQnodeFunc.getName().str()) { + funcOp.erase(); + } + } + } +}; + +} // namespace ion +} // namespace catalyst diff --git a/mlir/lib/RTIO/IR/RTIODialect.cpp b/mlir/lib/RTIO/IR/RTIODialect.cpp index 64e60ac774..a08a9e377e 100644 --- a/mlir/lib/RTIO/IR/RTIODialect.cpp +++ b/mlir/lib/RTIO/IR/RTIODialect.cpp @@ -127,24 +127,17 @@ void catalyst::rtio::RTIODialect::initialize() #include "RTIO/IR/RTIOOpsTypes.cpp.inc" >(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "RTIO/IR/RTIOAttributes.cpp.inc" + >(); + addOperations< #define GET_OP_LIST #include "RTIO/IR/RTIOOps.cpp.inc" >(); } -// Not support any custom attributes yet, might be supported in the future -Attribute catalyst::rtio::RTIODialect::parseAttribute(DialectAsmParser &parser, Type type) const -{ - parser.emitError(parser.getNameLoc(), "no dialect attributes are supported"); - return {}; -} - -void catalyst::rtio::RTIODialect::printAttribute(Attribute attr, DialectAsmPrinter &printer) const -{ - llvm_unreachable("no dialect attributes are supported"); -} - //===----------------------------------------------------------------------===// // RTIO Type Definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/RTIO/VerifierTest.mlir b/mlir/test/RTIO/VerifierTest.mlir index 1ebd3db529..f34ba09d7f 100644 --- a/mlir/test/RTIO/VerifierTest.mlir +++ b/mlir/test/RTIO/VerifierTest.mlir @@ -197,3 +197,27 @@ func.func @empty_good() { %empty = rtio.empty : !rtio.event return } + +// ----- + +module @config_test attributes { + rtio.config = #rtio.config<{ + config1 = 1 : i32, + config2 = "test", + nested = {a = 0 : i32, b = "test"} + }> +} { + func.func @kernel() { + return + } +} + +// ----- + +module @empty_config attributes { + rtio.config = #rtio.config<{}> +} { + func.func @kernel() { + return + } +} From f326bf23eb360bd00cfca151cc9c635c2d017706 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 2 Dec 2025 14:46:40 -0500 Subject: [PATCH 26/51] drop timeline IR --- mlir/include/RTIO/IR/RTIOOps.td | 196 ------------------------------- mlir/test/RTIO/VerifierTest.mlir | 75 ------------ 2 files changed, 271 deletions(-) diff --git a/mlir/include/RTIO/IR/RTIOOps.td b/mlir/include/RTIO/IR/RTIOOps.td index 7fdca59ced..91b93a3743 100644 --- a/mlir/include/RTIO/IR/RTIOOps.td +++ b/mlir/include/RTIO/IR/RTIOOps.td @@ -20,10 +20,6 @@ include "mlir/IR/BuiltinAttributes.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "RTIO/IR/RTIODialect.td" -//===----------------------------------------------------------------------===// -// Event-Based API -//===----------------------------------------------------------------------===// - def RTIOChannelOp : RTIO_Op<"channel"> { let summary = "Define a channel"; let description = [{ @@ -200,197 +196,5 @@ def RTIOEmptyOp : RTIO_Op<"empty", [Pure]> { let assemblyFormat = "attr-dict `:` type($event)"; } -//===----------------------------------------------------------------------===// -// Timeline-Based IR (ARTIQ compatible) -//===----------------------------------------------------------------------===// - -def RTIONowOp : RTIO_Op<"now", [Pure]> { - let summary = "Read the current timeline cursor"; - let description = [{ - Returns the current value of the timeline cursor in machine units (mu). - - Example: - ```mlir - %t0 = rtio.now : i64 - rtio.delay 100 : i64 - %t1 = rtio.now : i64 - // %t1 = %t0 + 100 - ``` - }]; - - let results = (outs I64:$time); - let assemblyFormat = "attr-dict `:` type($time)"; -} - -def RTIODelayOp : RTIO_Op<"delay"> { - let summary = "Advance the timeline cursor by a relative duration"; - let description = [{ - Advances the timeline cursor by the specified duration in machine units. - Equivalent to `rtio.at(rtio.now() + duration)`. - - Example: - ```mlir - // case1 - rtio.delay 1000 : i64 - - // case2 - rtio.delay %dur : i64 - ``` - }]; - - let arguments = (ins I64:$duration); - let assemblyFormat = "$duration attr-dict `:` type($duration)"; -} - -def RTIOAtOp : RTIO_Op<"at"> { - let summary = "Set the timeline cursor to an absolute timestamp"; - let description = [{ - Moves the timeline cursor to a specific absolute timestamp. - - Pre-condition: The new time must satisfy `t >= rtio_counter_mu()` (hardware time), - otherwise an RTIO underflow error occurs at runtime. Note that `t` can be less than - the current timeline cursor value (from `rtio.now`), which is how parallel execution - is implemented. - - Example: - ```mlir - %t_start = rtio.now : i64 // e.g. t_start = 1000 mu - - // First pulse lane - rtio.on %ch0 - rtio.delay 500 : i64 - rtio.off %ch0 - // rtio.now would return 1500 mu (1000 mu + 500 mu) - %t0 = rtio.now : i64 - - // Second pulse lane (parallel with first lane) - - // Rewind to t_start = 1000 mu (< 1500 mu, but >= hardware counter) - rtio.at %t_start : i64 - rtio.on %ch1 - rtio.delay 300 : i64 - rtio.off %ch1 - // rtio.now would return 1300 mu (1000 mu + 300 mu) - %t1 = rtio.now : i64 - - // Sync: advance to the maximum end time (1500 mu) - %t_max = arith.maxui %t0, %t1 : i64 - rtio.at %t_max : i64 // advance to 1500 mu - ``` - }]; - - let arguments = (ins I64:$time); - let assemblyFormat = "$time attr-dict `:` type($time)"; -} - -def RTIOSetFrequencyOp : RTIO_Op<"set_frequency"> { - let summary = "Configure the frequency of a DDS channel"; - let description = [{ - Programs the DDS frequency tuning word (ftw) to generate a specific frequency. - The frequency is specified in Hz as a f64. - - Example: - ```mlir - %freq = arith.constant 1.266300000e10 : f64 // 12.663 GHz - rtio.set_frequency %ch0, %freq : !rtio.channel<"dds", 0>, f64 - ``` - - Assume phase is 0 and amplitude is 1. - It's equivalent to call the `AD9910.set` function with the following arguments: - - ```llvm - %set_func = load ptr, ptr @F.artiq.coredevice.ad9910.AD9910.set - call double %set_func( - ptr %env, - ptr %ch0, - double 1.266300000e+10, ; frequency - double 0.000000e+00, ; phase - double 1.000000e+00 ; amplitude - ) - ``` - }]; - - let arguments = (ins - RTIOChannelType:$channel, - F64:$frequency - ); - let assemblyFormat = "$channel `,` $frequency attr-dict `:` type($channel) `,` type($frequency)"; -} - -def RTIOSetPhaseOp : RTIO_Op<"set_phase"> { - let summary = "Configure the phase of a DDS channel"; - let description = [{ - Programs the DDS phase offset register to set the carrier phase. - The phase is specified in radians as a f64. - - Example: - ```mlir - // pi/2 phase shift - %pi_2 = arith.constant 1.5707963267948966 : f64 - rtio.set_phase %ch0, %pi_2 : !rtio.channel<"dds", 0>, f64 - ``` - }]; - - let arguments = (ins - RTIOChannelType:$channel, - F64:$phase - ); - let assemblyFormat = "$channel `,` $phase attr-dict `:` type($channel) `,` type($phase)"; -} - -def RTIOSetAmplitudeOp : RTIO_Op<"set_amplitude"> { - let summary = "Configure the amplitude of a DDS channel"; - let description = [{ - Programs the DDS amplitude - - Example: - ```mlir - %amp = arith.constant 1.0 : f64 - rtio.set_amplitude %ch0, %amp : !rtio.channel<"dds", 0>, f64 - ``` - }]; - - let arguments = (ins - RTIOChannelType:$channel, - F64:$amplitude - ); - let assemblyFormat = "$channel `,` $amplitude attr-dict `:` type($channel) `,` type($amplitude)"; -} - -def RTIOOnOp : RTIO_Op<"on"> { - let summary = "Turn on a channel output"; - let description = [{ - Activates the output of a channel at the current cursor. - For DDS channels, this typically enables the RF output switch. - - Example: - ```mlir - rtio.at %t_start : i64 - rtio.on %ch - rtio.delay 1000 : i64 - rtio.off %ch - ``` - }]; - - let arguments = (ins RTIOChannelType:$channel); - let assemblyFormat = "$channel attr-dict `:` type($channel)"; -} - -def RTIOOffOp : RTIO_Op<"off"> { - let summary = "Turn off a channel output"; - let description = [{ - Deactivates the output of a channel at the current cursor. - For DDS channels, this typically disables the RF output switch. - - Example: - ```mlir - rtio.off %ch : !rtio.channel<"dds", 0> - ``` - }]; - - let arguments = (ins RTIOChannelType:$channel); - let assemblyFormat = "$channel attr-dict `:` type($channel)"; -} - #endif // RTIO_OPS diff --git a/mlir/test/RTIO/VerifierTest.mlir b/mlir/test/RTIO/VerifierTest.mlir index f34ba09d7f..9b4a39549c 100644 --- a/mlir/test/RTIO/VerifierTest.mlir +++ b/mlir/test/RTIO/VerifierTest.mlir @@ -118,81 +118,6 @@ func.func @sync_no_events() { // ----- -/////////////////////////////// -// Timeline-Based Operations // -/////////////////////////////// - -func.func @timeline_now() { - %t = rtio.now : i64 - return -} - -// ----- - -func.func @timeline_at() { - %t = arith.constant 1000 : i64 - %delay = arith.constant 500 : i64 - %now = rtio.now : i64 - rtio.at %t : i64 - rtio.delay %delay : i64 - - // rewind to the start time - rtio.at %now : i64 - return -} - -// ----- - -func.func @timeline_dalay() { - %delay = arith.constant 500 : i64 - rtio.delay %delay : i64 - return -} - -// ----- - -func.func @set_frequency_dds_good(%freq: f64) { - %ch = rtio.channel : !rtio.channel<"dds", 0> - rtio.set_frequency %ch, %freq : !rtio.channel<"dds", 0>, f64 - return -} - - -// ----- - -func.func @set_phase_dds_good(%phase: f64) { - %ch = rtio.channel : !rtio.channel<"dds", 0> - rtio.set_phase %ch, %phase : !rtio.channel<"dds", 0>, f64 - - return -} - -// ----- - -func.func @set_amplitude_dds_good(%amp: f64) { - %ch = rtio.channel : !rtio.channel<"dds", 0> - rtio.set_amplitude %ch, %amp : !rtio.channel<"dds", 0>, f64 - return -} - -// ----- - -func.func @ttl_on_dds_good() { - %ch = rtio.channel : !rtio.channel<"dds", 0> - rtio.on %ch : !rtio.channel<"dds", 0> - return -} - -// ----- - -func.func @ttl_off_dds_good() { - %ch = rtio.channel : !rtio.channel<"dds", 0> - rtio.off %ch : !rtio.channel<"dds", 0> - return -} - -// ----- - func.func @empty_good() { %empty = rtio.empty : !rtio.event return From 0e6b14f4ffd1cebf55e8cbf8806a09eb57acc2b7 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 2 Dec 2025 14:55:35 -0500 Subject: [PATCH 27/51] remove unrelated thing --- mlir/lib/Ion/Transforms/ion-to-rtio.cpp | 1137 ----------------------- 1 file changed, 1137 deletions(-) delete mode 100644 mlir/lib/Ion/Transforms/ion-to-rtio.cpp diff --git a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp deleted file mode 100644 index 77765b424e..0000000000 --- a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp +++ /dev/null @@ -1,1137 +0,0 @@ -// Copyright 2025 Xanadu Quantum Technologies Inc. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "llvm/Support/JSON.h" -#include "llvm/Support/MemoryBuffer.h" - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/SCF/Transforms/Patterns.h" -#include "mlir/IR/Dominance.h" -#include "mlir/IR/Matchers.h" -#include "mlir/Transforms/CSE.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#include "Ion/IR/IonDialect.h" -#include "Ion/IR/IonOps.h" -#include "Ion/Transforms/Passes.h" -#include "Quantum/IR/QuantumDialect.h" -#include "Quantum/IR/QuantumOps.h" -#include "RTIO/IR/RTIODialect.h" -#include "RTIO/IR/RTIOOps.h" - -using namespace mlir; -using namespace catalyst; - -namespace catalyst { -namespace ion { - -namespace { - -//===----------------------------------------------------------------------===// -// Helper functions -//===----------------------------------------------------------------------===// - -enum class TraceMode { - Qreg = 0, - Event = 1, -}; - -/// Traces a Value backward through the IR by tracing its dataflow dependencies -/// across control flow and specific quantum operations. -/// -/// Template Parameters: -/// - ModeT: TraceMode enum (Qreg or Event) that controls how quantum.insert -/// operations are handled -/// Qreg mode: Trace to find the source qreg of the given value -/// Event mode: Trace to find all events that contribute to the given value -/// - CallbackT: Callable type that will be invoked for each visited value. -/// May optionally return WalkResult for early termination. -/// -/// Supported Operations: -/// - scf.for -/// - scf.if -/// - ion.parallelprotocol -/// - unrealized_conversion_cast -/// - quantum.extract -/// - quantum.insert -template -auto traceValueWithCallback(Value value, CallbackT &&callback) -{ - WalkResult walkResult = WalkResult::advance(); - std::queue visited; - visited.push(value); - - while (!visited.empty()) { - Value value = visited.front(); - visited.pop(); - - if constexpr (std::is_same_v, WalkResult>) { - if (callback(value).wasInterrupted()) { - walkResult = WalkResult::interrupt(); - continue; - } - } - else { - callback(value); - } - - if (auto arg = mlir::dyn_cast(value)) { - Block *block = arg.getOwner(); - Operation *parentOp = block->getParentOp(); - - if (auto forOp = dyn_cast(parentOp)) { - unsigned argIndex = arg.getArgNumber(); - Value iterArg = forOp.getInitArgs()[argIndex - 1]; - visited.push(iterArg); - continue; - } - else if (auto parallelProtocolOp = dyn_cast(parentOp)) { - unsigned argIndex = arg.getArgNumber(); - Value inQubit = parallelProtocolOp.getInQubits()[argIndex]; - visited.push(inQubit); - continue; - } - parentOp->emitError("Unsupported parent operation for block argument: ") << value; - llvm::reportFatalInternalError("Unsupported block argument"); - } - - Operation *defOp = value.getDefiningOp(); - if (defOp == nullptr) { - continue; - } - - if (auto forOp = dyn_cast(defOp)) { - unsigned resultIdx = llvm::cast(value).getResultNumber(); - BlockArgument iterArg = forOp.getRegionIterArg(resultIdx); - visited.push(iterArg); - } - else if (auto ifOp = dyn_cast(defOp)) { - unsigned resultIdx = llvm::cast(value).getResultNumber(); - Value thenValue = ifOp.thenYield().getOperand(resultIdx); - Value elseValue = ifOp.elseYield().getOperand(resultIdx); - visited.push(thenValue); - visited.push(elseValue); - } - else if (auto parallelProtocolOp = dyn_cast(defOp)) { - unsigned resultIdx = llvm::cast(value).getResultNumber(); - Value inQubit = parallelProtocolOp.getInQubits()[resultIdx]; - visited.push(inQubit); - } - else if (auto op = dyn_cast(defOp)) { - visited.push(op.getInputs().front()); - } - else if (auto op = dyn_cast(defOp)) { - visited.push(op.getQreg()); - } - else if (auto op = dyn_cast(defOp)) { - visited.push(op.getQubit()); - } - else if (auto op = dyn_cast(defOp)) { - Value inQreg = op.getInQreg(); - Value qubit = op.getQubit(); - if constexpr (ModeT == TraceMode::Qreg) { - visited.push(inQreg); - } - else if constexpr (ModeT == TraceMode::Event) { - visited.push(qubit); - // only trace qreg if it defined op is also come from insert op - if (llvm::isa_and_present(inQreg.getDefiningOp())) { - visited.push(inQreg); - } - } - } - } - - if constexpr (std::is_same_v, WalkResult>) { - return walkResult; - } -} - -Value awaitEvents(ArrayRef events, PatternRewriter &rewriter) -{ - if (events.size() == 1) { - return events.front(); - } - auto eventType = rtio::EventType::get(rewriter.getContext()); - return rewriter.create(rewriter.getUnknownLoc(), eventType, events); -} - -// Helper class to store ion information -class IonInfo { - private: - llvm::StringMap levelEnergyMap; - - struct TransitionInfo { - std::string level0; - std::string level1; - double einstein_a; - std::string multipole; - }; - SmallVector transitions; - - public: - IonInfo(ion::IonOp op) - { - auto levelAttrs = op.getLevels(); - auto transitionsAttr = op.getTransitions(); - - // Map from Level label to Energy value - for (auto levelAttr : levelAttrs) { - auto level = cast(levelAttr); - std::string label = level.getLabel().getValue().str(); - double energy = level.getEnergy().getValueAsDouble(); - levelEnergyMap[label] = energy; - } - - // Store transition information - for (auto transitionAttr : transitionsAttr) { - auto transition = cast(transitionAttr); - TransitionInfo info; - info.level0 = transition.getLevel_0().getValue().str(); - info.level1 = transition.getLevel_1().getValue().str(); - info.einstein_a = transition.getEinsteinA().getValueAsDouble(); - info.multipole = transition.getMultipole().getValue().str(); - transitions.push_back(info); - } - } - - // Get energy of a level by label - std::optional getLevelEnergy(StringRef label) const - { - auto it = levelEnergyMap.find(label.str()); - if (it != levelEnergyMap.end()) { - return it->second; - } - return std::nullopt; - } - - // Get level label of a transition by index - template - std::optional getTransitionLevelEnergy(size_t transitionIndex) const - { - static_assert(IndexT == 0 || IndexT == 1, "IndexT must be 0 or 1"); - - if (transitionIndex >= transitions.size()) { - return std::nullopt; - } - - const auto &transition = transitions[transitionIndex]; - if constexpr (IndexT == 0) { - return getLevelEnergy(transition.level0); - } - else { - return getLevelEnergy(transition.level1); - } - } - - // Get energy difference of a transition (level1 energy - level0 energy) - std::optional getTransitionEnergyDiff(size_t index) const - { - if (index >= transitions.size()) { - return std::nullopt; - } - - auto energy0 = getTransitionLevelEnergy<0>(index); - auto energy1 = getTransitionLevelEnergy<1>(index); - - if (energy0.has_value() && energy1.has_value()) { - return energy1.value() - energy0.value(); - } - - return std::nullopt; - } - - // Get number of transitions - size_t getNumTransitions() const { return transitions.size(); } - - // Get transition info by index - std::optional getTransition(size_t index) const - { - if (index < transitions.size()) { - return transitions[index]; - } - return std::nullopt; - } -}; - -} // namespace - -//===----------------------------------------------------------------------===// -// Conversion Patterns -//===----------------------------------------------------------------------===// - -/// Convert ion.parallelprotocol and introduce rtio.sync to ensure the order -/// -/// Example: -/// ``` -/// %0, %1 = ion.parallelprotocol(%q0, %q1) { -/// ^bb0(%arg0, %arg1): -/// %p0 = rtio.pulse(...) : !rtio.event -/// %p1 = rtio.pulse(...) : !rtio.event -/// ion.yield %arg0, %arg1 -/// } -/// ``` -/// will be converted to: -/// ``` -/// %event0 = unrealized_conversion_cast %q0 : !ion.qubit -> !rtio.event -/// %event1 = unrealized_conversion_cast %q1 : !ion.qubit -> !rtio.event -/// %p0 = rtio.pulse(..., wait = %event0) : !rtio.event -/// %p1 = rtio.pulse(..., wait = %event1) : !rtio.event -/// %sync = rtio.sync %p0, %p1 : !rtio.event -/// %0 = unrealized_conversion_cast %sync : !rtio.event -> !ion.qubit -/// %1 = unrealized_conversion_cast %sync : !rtio.event -> !ion.qubit -/// ``` -/// Those unrealized conversion casts are used to establish the dependency but will be -/// resolved by the subsequent stages. -struct ParallelProtocolToRTIOPattern : public OpConversionPattern { - ParallelProtocolToRTIOPattern(TypeConverter &typeConverter, MLIRContext *ctx) - : OpConversionPattern(typeConverter, ctx) - { - } - - LogicalResult matchAndRewrite(ion::ParallelProtocolOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { - MLIRContext *ctx = rewriter.getContext(); - Location loc = op.getLoc(); - - Block *regionBlock = &op.getBodyRegion().front(); - IRMapping irMapping; - SmallVector inQubits; - for (auto [blockArg, operand] : - llvm::zip(regionBlock->getArguments(), adaptor.getOperands())) { - irMapping.map(blockArg, operand); - - // collect qubits to trace the events - if (isa(operand.getType())) { - inQubits.push_back(operand); - } - } - - // create events for each qubit - auto events = llvm::map_range(inQubits, [&](Value qubit) { - auto eventType = rtio::EventType::get(ctx); - return rewriter.create(loc, eventType, qubit).getResult(0); - }); - - Value inputSyncEvent = awaitEvents(llvm::to_vector(events), rewriter); - - // Clone operations from the region to outside - SmallVector pulseEvents; - DenseMap qubitToOffset; - - // we cache the channel to index mapping to avoid multiple lookups - DenseMap cache; - for (auto ®ionOp : regionBlock->without_terminator()) { - auto *clonedOp = rewriter.clone(regionOp, irMapping); - if (auto pulseOp = dyn_cast(clonedOp)) { - // set wait event for the pulse operation - pulseOp.setWait(inputSyncEvent); - - Value index = nullptr; - - SmallVector chain; - traceValueWithCallback( - pulseOp.getChannel(), [&](Value value) -> WalkResult { - if (cache.count(value)) { - index = cache[value]; - return WalkResult::interrupt(); - } - chain.push_back(value); - if (auto loadOp = - llvm::dyn_cast_if_present(value.getDefiningOp())) { - index = loadOp.getIndices()[0]; - - // cache the channel to index mapping - cache[pulseOp.getChannel()] = index; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - - assert(index != nullptr && "index must not be null"); - - // update cache - for (Value value : chain) { - cache[value] = index; - } - pulseOp->setAttr("offset", rewriter.getI64IntegerAttr(qubitToOffset[index])); - - // the same qubit may appear multiple times in the parallel protocol - // so we need to increment the offset for each appearance - qubitToOffset[index]++; - - pulseEvents.push_back(pulseOp.getEvent()); - } - irMapping.map(regionOp.getResults(), clonedOp->getResults()); - } - - // Create sync operation from pulse events (must have at least one after Phase 1) - assert(pulseEvents.size() > 0 && - "must have at least one pulse operation after parallel protocol conversion"); - - Value outputSyncEvent = awaitEvents(llvm::to_vector(pulseEvents), rewriter); - - SmallVector results; - for (Value result : op.getResults()) { - // unrealized conversion cast sync event to result type - auto event = - rewriter.create(loc, result.getType(), outputSyncEvent); - results.push_back(event.getResult(0)); - } - - rewriter.replaceOp(op, results); - return success(); - } -}; - -/// Convert ion.pulse to rtio.pulse -/// -/// Example: -/// ``` -/// %pulse = ion.pulse(%duration) %qubit { -/// beam = #ion.beam<...> -/// } : !ion.pulse -/// ``` -/// will be converted to: -/// ``` -/// %ch = rtio.qubit_to_channel %qubit : !ion.qubit -> !rtio.channel<"dds", ?> -/// ... // other pulse parameters settings -/// %event = rtio.pulse %ch duration(%duration) frequency(%freq) phase(%phase) -/// : !rtio.channel<"dds", ?> -> !rtio.event -/// ``` -struct PulseToRTIOPattern : public OpConversionPattern { - IonInfo ionInfo; - DenseMap &qextractToMemrefMap; - PulseToRTIOPattern(TypeConverter &typeConverter, MLIRContext *ctx, IonInfo ionInfo, - DenseMap &qextractToMemrefMap) - : OpConversionPattern(typeConverter, ctx), ionInfo(ionInfo), - qextractToMemrefMap(qextractToMemrefMap) - { - } - - double calculateFrequency(int64_t transitionIndex, double detuning, - const IonInfo &ionInfo) const - { - // TODO: raman1_frequency can be passed as a pass option for extensibility - double raman1_frequency = 2 * llvm::numbers::pi * 844.485e12 - - 2 * llvm::numbers::pi * 12.643e9 - 2 * llvm::numbers::pi * 20e6; - - auto energyDiff = ionInfo.getTransitionEnergyDiff(transitionIndex); - assert(energyDiff.has_value() && "energyDiff must have a value"); - - double reference_energy = energyDiff.value(); - double frequency = - (reference_energy + detuning - raman1_frequency) / (2.0 * llvm::numbers::pi); - return frequency; - } - - LogicalResult matchAndRewrite(ion::PulseOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { - Location loc = op.getLoc(); - MLIRContext *ctx = rewriter.getContext(); - - // Get pulse parameters - Value duration = op.getTime(); - auto beamAttr = op.getBeam(); - auto phaseAttr = op.getPhase(); - - // Extract beam parameters - double detuning = beamAttr.getDetuning().getValueAsDouble(); - double phase = phaseAttr.getValueAsDouble(); - int64_t transitionIndex = beamAttr.getTransitionIndex().getInt(); - double frequency = calculateFrequency(transitionIndex, detuning, ionInfo); - Value freqValue = - rewriter.create(loc, rewriter.getF64FloatAttr(frequency)); - Value phaseValue = rewriter.create(loc, rewriter.getF64FloatAttr(phase)); - - // Convert the qubit to a channel - ArrayAttr qualifiers = rewriter.getArrayAttr({rewriter.getI64IntegerAttr(transitionIndex)}); - auto channelType = rtio::ChannelType::get(ctx, "dds", qualifiers, nullptr); - - Value memrefLoadValue = nullptr; - traceValueWithCallback(op.getInQubit(), [&](Value value) -> WalkResult { - if (qextractToMemrefMap.count(value)) { - memrefLoadValue = qextractToMemrefMap[value]; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - assert(memrefLoadValue != nullptr && "memrefLoadValue must not be null"); - - Value channel = - rewriter.create(loc, channelType, memrefLoadValue); - - // Create rtio.pulse - auto eventType = rtio::EventType::get(ctx); - Value event = rewriter.create(loc, eventType, channel, duration, - freqValue, phaseValue, nullptr); - rewriter.replaceOp(op, event); - - return success(); - } -}; - -/// Resolve the static channel mapping for the rtio.qubit_to_channel operation -/// -/// It's expecting `qubit_to_channel` has the following def-use chain: -/// memref.global w/ constants -> memref.get_global -> memref.load -> qubit_to_channel -/// -/// Example: -/// ``` -/// %ch = rtio.qubit_to_channel %qubit : !ion.qubit -> !rtio.channel<"dds", ?> -/// ``` -/// will be converted to: -/// ``` -/// %ch = rtio.channel "dds" { channel_id = 0 } : !rtio.channel<"dds"> -/// ``` -struct ResolveChannelMappingPattern : public OpRewritePattern { - ResolveChannelMappingPattern(MLIRContext *ctx) - : OpRewritePattern(ctx) - { - } - - LogicalResult matchAndRewrite(rtio::RTIOQubitToChannelOp op, - PatternRewriter &rewriter) const override - { - Location loc = op.getLoc(); - Value qubit = op.getQubit(); - - auto loadOp = qubit.getDefiningOp(); - if (!loadOp) { - return failure(); - } - - Value memref = loadOp.getMemRef(); - auto getGlobalOp = memref.getDefiningOp(); - if (!getGlobalOp) { - return failure(); - } - - StringRef globalName = getGlobalOp.getName(); - ModuleOp module = op->getParentOfType(); - if (!module) { - return failure(); - } - auto globalOp = module.lookupSymbol(globalName); - if (!globalOp) { - return failure(); - } - - auto initialValue = globalOp.getInitialValue(); - if (!initialValue) { - return failure(); - } - - auto denseAttr = llvm::dyn_cast(*initialValue); - if (!denseAttr) { - return failure(); - } - - ValueRange indices = loadOp.getIndices(); - if (indices.size() != 1) { - return failure(); - } - - IntegerAttr indexAttr; - if (!matchPattern(indices[0], m_Constant(&indexAttr))) { - return failure(); - } - - int64_t index = indexAttr.getInt(); - - size_t denseSize = denseAttr.size(); - if (index < 0 || static_cast(index) >= denseSize) { - return failure(); - } - - APInt channelIdValue = denseAttr.getValues()[index]; - - auto originalChannelType = llvm::dyn_cast(op.getChannel().getType()); - if (!originalChannelType) { - return failure(); - } - StringRef kind = originalChannelType.getKind(); - ArrayAttr qualifiers = originalChannelType.getQualifiers(); - - // channel should have exactly one use before lowering to channel op - assert(op.getChannel().hasOneUse() && "channel should have exactly one use"); - - auto pulseOp = cast(*op.getChannel().getUsers().begin()); - int64_t offset = cast(pulseOp->getAttr("offset")).getInt(); - - IntegerAttr channelIdAttr = rewriter.getIntegerAttr( - rewriter.getIndexType(), (channelIdValue.getSExtValue() * 2 + offset)); - - auto resolvedChannelType = - rtio::ChannelType::get(rewriter.getContext(), kind, qualifiers, channelIdAttr); - - Value channel = rewriter.create(loc, resolvedChannelType); - - rewriter.replaceOp(op, channel); - - return success(); - } -}; - -/// Propagates RTIO events from chain of operations to event types. -/// -/// Steps: -/// 1. Traces backward to find all events that contribute to the current event value -/// 2. Creates a sync event from all collected events -/// 3. Replaces the cast operation with the sync event -struct PropagateEventsPattern : public OpRewritePattern { - MLIRContext *ctx; - - PropagateEventsPattern(MLIRContext *ctx) - : OpRewritePattern(ctx), ctx(ctx) - { - } - - LogicalResult matchAndRewrite(UnrealizedConversionCastOp op, - PatternRewriter &rewriter) const override - { - if (op.getNumOperands() != 1 || op.getNumResults() != 1) - return failure(); - - Type srcType = op.getInputs()[0].getType(); - Type dstType = op.getResult(0).getType(); - - // Only match casts from quantum/ion types to event type - // quantum.qreg -> event, quantum.qubit -> event, ion.qubit -> event - bool validSrcType = - llvm::isa(srcType); - bool validDstType = llvm::isa(dstType); - if (!validSrcType || !validDstType) - return failure(); - - Value input = op.getInputs()[0]; - - // Find associated events - // Skip over intermediate cast/extract/insert operations to collect events - bool reachedAllocOp = false; - SetVector events; - traceValueWithCallback(input, [&](Value value) -> WalkResult { - auto defOp = value.getDefiningOp(); - if (defOp && - isa(defOp)) { - return WalkResult::advance(); - } - - // collect event and stop tracing this path - if (isa(value.getType())) { - events.insert(value); - return WalkResult::interrupt(); - } - - if (isa(defOp)) { - reachedAllocOp = true; - return WalkResult::interrupt(); - } - - return WalkResult::advance(); - }); - - if (reachedAllocOp && events.empty()) { - auto eventType = rtio::EventType::get(ctx); - Value emptyEvent = rewriter.create(op.getLoc(), eventType); - rewriter.replaceOp(op, emptyEvent); - return success(); - } - - if (events.empty()) { - op.emitError("No events found for cast op"); - llvm::reportFatalInternalError("No events found for cast op"); - } - - // Create a sync event from all collected events - // TODO: check domination, so that we can avoid creating a sync event if events are - // already dominated by one of the events - Value syncEvent = awaitEvents(events.getArrayRef(), rewriter); - rewriter.replaceOp(op, syncEvent); - return success(); - } -}; - -/// Clean up quantum/ion related ops that are not needed after conversion -struct CleanQuantumOpsPattern : public RewritePattern { - CleanQuantumOpsPattern(MLIRContext *ctx) - : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) - { - } - - LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override - { - Dialect *dialect = op->getDialect(); - if (!dialect || !isa(dialect)) - return failure(); - - if (!op->use_empty()) { - return failure(); - } - - rewriter.eraseOp(op); - return success(); - } -}; - -LogicalResult CleanQuantumOps(func::FuncOp funcOp, MLIRContext *ctx) -{ - RewritePatternSet patterns(ctx); - patterns.add(ctx); - if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { - return failure(); - } - return success(); -} - -LogicalResult CanonicalizeKernelFunction(func::FuncOp funcOp, MLIRContext *ctx) -{ - RewritePatternSet patterns(ctx); - for (auto *dialect : ctx->getLoadedDialects()) { - dialect->getCanonicalizationPatterns(patterns); - } - for (RegisteredOperationName op : ctx->getRegisteredOperations()) { - op.getCanonicalizationPatterns(patterns, ctx); - } - if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { - return failure(); - } - - IRRewriter rewriter(ctx); - DominanceInfo domInfo(funcOp); - eliminateCommonSubExpressions(rewriter, domInfo, funcOp); - - return success(); -} - -LogicalResult ResolveChannelMapping(func::FuncOp funcOp, MLIRContext *ctx) -{ - RewritePatternSet patterns(ctx); - patterns.add(ctx); - if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { - return failure(); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// JSON to MLIR Attribute Conversion -//===----------------------------------------------------------------------===// - -/// Convert a JSON value to an MLIR Attribute -Attribute jsonToAttribute(MLIRContext *ctx, const llvm::json::Value &json) -{ - if (auto str = json.getAsString()) { - return StringAttr::get(ctx, *str); - } - if (auto num = json.getAsInteger()) { - return IntegerAttr::get(IntegerType::get(ctx, 64), *num); - } - if (auto num = json.getAsNumber()) { - return FloatAttr::get(Float64Type::get(ctx), *num); - } - if (auto b = json.getAsBoolean()) { - return BoolAttr::get(ctx, *b); - } - if (auto arr = json.getAsArray()) { - SmallVector attrs; - for (const auto &elem : *arr) { - attrs.push_back(jsonToAttribute(ctx, elem)); - } - return ArrayAttr::get(ctx, attrs); - } - if (auto *obj = json.getAsObject()) { - SmallVector entries; - for (const auto &kv : *obj) { - StringRef key = kv.first; - entries.emplace_back(StringAttr::get(ctx, key), jsonToAttribute(ctx, kv.second)); - } - // Sort entries by name for DictionaryAttr - llvm::sort(entries, [](const NamedAttribute &lhs, const NamedAttribute &rhs) { - return lhs.getName().getValue() < rhs.getName().getValue(); - }); - return DictionaryAttr::get(ctx, entries); - } - // null - return UnitAttr::get(ctx); -} - -/// Load a JSON file and convert it to an rtio.config attribute -FailureOr loadDeviceDbAsConfig(MLIRContext *ctx, StringRef filePath) -{ - auto fileOrErr = llvm::MemoryBuffer::getFile(filePath); - if (!fileOrErr) { - return failure(); - } - - auto json = llvm::json::parse((*fileOrErr)->getBuffer()); - if (!json) { - llvm::errs() << "Failed to parse JSON: " << llvm::toString(json.takeError()) << "\n"; - return failure(); - } - - auto *obj = json->getAsObject(); - if (!obj) { - llvm::errs() << "Device DB JSON must be an object\n"; - return failure(); - } - - // Convert JSON object to DictionaryAttr - SmallVector entries; - for (const auto &kv : *obj) { - StringRef key = kv.first; - entries.emplace_back(StringAttr::get(ctx, key), jsonToAttribute(ctx, kv.second)); - } - llvm::sort(entries, [](const NamedAttribute &lhs, const NamedAttribute &rhs) { - return lhs.getName().getValue() < rhs.getName().getValue(); - }); - - auto dictAttr = DictionaryAttr::get(ctx, entries); - return rtio::ConfigAttr::get(ctx, dictAttr); -} - -//===----------------------------------------------------------------------===// -// Pass Implementation -//===----------------------------------------------------------------------===// - -#define GEN_PASS_DEF_IONTORTIOPASS -#include "Ion/Transforms/Passes.h.inc" - -struct IonToRTIOPass : public impl::IonToRTIOPassBase { - using impl::IonToRTIOPassBase::IonToRTIOPassBase; - - LogicalResult IonPulseConversion(func::FuncOp funcOp, ConversionTarget &baseTarget, - TypeConverter &typeConverter, IonInfo ionInfo, - DenseMap &qextractToMemrefMap, MLIRContext *ctx) - { - ConversionTarget target(baseTarget); - target.addIllegalOp(); - - RewritePatternSet patterns(ctx); - patterns.add(typeConverter, ctx, ionInfo, qextractToMemrefMap); - if (failed(applyPartialConversion(funcOp, target, std::move(patterns)))) { - return failure(); - } - return success(); - } - - LogicalResult ParallelProtocolConversion(func::FuncOp funcOp, ConversionTarget &baseTarget, - TypeConverter &typeConverter, MLIRContext *ctx) - { - ConversionTarget target(baseTarget); - target.addIllegalOp(); - - RewritePatternSet patterns(ctx); - patterns.add(typeConverter, ctx); - if (failed(applyPartialConversion(funcOp, target, std::move(patterns)))) { - return failure(); - } - return success(); - } - - LogicalResult SCFStructuralConversion(func::FuncOp funcOp, ConversionTarget &target, - TypeConverter &typeConverter, MLIRContext *ctx) - { - TypeConverter scfTypeConverter(typeConverter); - scfTypeConverter.addConversion( - [ctx](quantum::QubitType) -> Type { return rtio::EventType::get(ctx); }); - scfTypeConverter.addConversion( - [ctx](quantum::QuregType) -> Type { return rtio::EventType::get(ctx); }); - scfTypeConverter.addConversion( - [ctx](ion::QubitType) -> Type { return rtio::EventType::get(ctx); }); - // Add materialization for quantum/ion -> event - scfTypeConverter.addSourceMaterialization( - [](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) -> Value { - if (inputs.size() != 1) - return nullptr; - Type inputType = inputs.front().getType(); - if (inputType != resultType) { - return builder.create(loc, resultType, inputs) - .getResult(0); - } - return inputs[0]; - }); - // Add target materialization for event -> quantum/ion - scfTypeConverter.addTargetMaterialization( - [](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) -> Value { - if (inputs.size() != 1) - return nullptr; - Type inputType = inputs.front().getType(); - if (inputType != resultType) { - return builder.create(loc, resultType, inputs) - .getResult(0); - } - return inputs[0]; - }); - - ConversionTarget scfTarget(getContext()); - scfTarget.addLegalDialect(); - - // Mark SCF ops as illegal only if they use quantum/ion types - scfTarget.addDynamicallyLegalOp([&](scf::ForOp op) { - for (auto arg : op.getRegionIterArgs()) { - Type type = arg.getType(); - if (llvm::isa(type)) { - return false; - } - } - for (auto result : op.getResults()) { - Type type = result.getType(); - if (llvm::isa(type)) { - return false; - } - } - return true; - }); - - scfTarget.addDynamicallyLegalOp([&](scf::IfOp op) { - for (auto result : op.getResults()) { - Type type = result.getType(); - if (llvm::isa(type)) { - return false; - } - } - return true; - }); - - scfTarget.addLegalOp(); - - // restructure SCF Operations - RewritePatternSet scfPatterns(&getContext()); - mlir::scf::populateSCFStructuralTypeConversionsAndLegality(scfTypeConverter, scfPatterns, - scfTarget); - - if (failed(applyPartialConversion(funcOp, scfTarget, std::move(scfPatterns)))) { - return failure(); - } - - return success(); - } - - LogicalResult PropagateEvents(func::FuncOp funcOp, MLIRContext *ctx) - { - RewritePatternSet patterns(ctx); - patterns.add(ctx); - if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { - return failure(); - } - return success(); - } - - SmallVector getIonInfos() - { - SmallVector ionInfos; - getOperation()->walk([&](ion::IonOp ionOp) { ionInfos.emplace_back(IonInfo(ionOp)); }); - return ionInfos; - } - - func::FuncOp createKernelFunction(func::FuncOp qnodeFunc, std::string kernelName, - OpBuilder &builder) - { - MLIRContext *ctx = builder.getContext(); - - auto newQnodeFunc = qnodeFunc.clone(); - newQnodeFunc.setName(kernelName); - auto oldFuncType = qnodeFunc.getFunctionType(); - // create new function type with empty results - auto newFuncType = FunctionType::get(ctx, oldFuncType.getInputs(), {}); - newQnodeFunc.setFunctionType(newFuncType); - - // set public visibility and remove internal linkage for kernel function - newQnodeFunc.setPublic(); - newQnodeFunc->removeAttr("llvm.linkage"); - - // Clear operands from all return ops (make them return nothing) - newQnodeFunc.walk([](func::ReturnOp returnOp) { returnOp.getOperandsMutable().clear(); }); - - return newQnodeFunc; - } - - void initializeMemrefMap(func::FuncOp funcOp, ModuleOp module, - DenseMap &qregToMemrefMap, - DenseMap &qextractToMemrefMap, MLIRContext *ctx) - { - OpBuilder builder(ctx); - - int globalCounter = 0; - funcOp.walk([&](quantum::AllocOp allocOp) { - size_t numQubits = allocOp.getNqubitsAttr().value(); - auto memrefType = - MemRefType::get({static_cast(numQubits)}, builder.getIndexType()); - - // Create a unique symbol name for this global - std::string globalNameStr = "__qubit_map_" + std::to_string(globalCounter++); - StringRef globalName = globalNameStr; - - // Create dense attribute with values [0, 1, 2, ..., numQubits-1] - auto tensorType = - RankedTensorType::get({static_cast(numQubits)}, builder.getIndexType()); - SmallVector values; - // Use IndexType::kInternalStorageBitWidth for index type - unsigned indexWidth = IndexType::kInternalStorageBitWidth; - for (size_t i = 0; i < numQubits; i++) { - values.push_back(APInt(indexWidth, i)); - } - auto denseAttr = DenseIntElementsAttr::get(tensorType, values); - - // Create global memref at module level - builder.setInsertionPointToStart(module.getBody()); - auto globalOp = - memref::GlobalOp::create(builder, allocOp.getLoc(), - builder.getStringAttr(globalName), // sym_name - builder.getStringAttr("private"), // sym_visibility - TypeAttr::get(memrefType), // type - denseAttr, // initial_value - builder.getUnitAttr(), // constant - IntegerAttr()); // alignment - - // Get the global memref in the function - builder.setInsertionPointAfter(allocOp); - Value qubitMap = builder.create(allocOp.getLoc(), memrefType, - globalOp.getSymName()); - - qregToMemrefMap[allocOp.getResult()] = qubitMap; - }); - - funcOp.walk([&](quantum::ExtractOp extractOp) { - traceValueWithCallback( - extractOp.getQreg(), [&](Value value) -> WalkResult { - if (qregToMemrefMap.count(value)) { - builder.setInsertionPointAfter(extractOp); - auto memref = qregToMemrefMap[value]; - - Value memrefLoadValue = nullptr; - if (Value idx = extractOp.getIdx()) { - // idx is an operand (i64), need to cast to index - Value indexValue = builder.create( - extractOp.getLoc(), builder.getIndexType(), idx); - memrefLoadValue = builder.create( - extractOp.getLoc(), memref, ValueRange{indexValue}); - } - else if (IntegerAttr idxAttr = extractOp.getIdxAttrAttr()) { - Value indexValue = builder.create( - extractOp.getLoc(), idxAttr.getInt()); - memrefLoadValue = builder.create( - extractOp.getLoc(), memref, ValueRange{indexValue}); - } - if (memrefLoadValue) { - qextractToMemrefMap[extractOp.getResult()] = memrefLoadValue; - } - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - }); - } - - void runOnOperation() override - { - MLIRContext *ctx = &getContext(); - auto module = cast(getOperation()); - - // Load device_db JSON file and set rtio.config attribute on module - if (!deviceDb.empty()) { - auto configOrErr = loadDeviceDbAsConfig(ctx, deviceDb); - if (failed(configOrErr)) { - module->emitError("Failed to load device database from: ") << deviceDb; - return signalPassFailure(); - } - module->setAttr(rtio::ConfigAttr::getModuleAttrName(), *configOrErr); - } - - // check if there is only one qnode function - func::FuncOp qnodeFunc = nullptr; - int qnodeCounts = 0; - module.walk([&](func::FuncOp funcOp) { - if (funcOp->hasAttr("qnode")) { - qnodeFunc = funcOp; - qnodeCounts++; - } - }); - assert(qnodeCounts == 1 && "only one qnode function is allowed"); - - // collect all ion information for calculating frequency when converting ion.pulse - SmallVector ionInfos = getIonInfos(); - if (ionInfos.empty()) { - getOperation()->emitError("Failed to get ion information"); - return signalPassFailure(); - } - - // currently, we only support one ion information - assert(ionInfos.size() == 1 && "only one ion information is allowed"); - IonInfo &ionInfo = ionInfos.front(); - - // clone qnode function as new kernel function - OpBuilder builder(ctx); - func::FuncOp newQnodeFunc = createKernelFunction(qnodeFunc, kernelName, builder); - module.insert(qnodeFunc, newQnodeFunc); - - // drop one of the pulse from the certain protocol - // the way we handle the dropped pulse will be updated in the future - SmallVector pulsesToErase; - newQnodeFunc.walk([&](ion::PulseOp pulseOp) { - if (pulseOp.getBeamAttr().getTransitionIndex().getInt() == 0) - pulsesToErase.push_back(pulseOp); - }); - for (auto pulseOp : pulsesToErase) - pulseOp.erase(); - - // Construct mapping from qreg alloc and qreg extract to memref - // In the later conversion, we use the mapping to construct the channel for rtio.pulse - DenseMap qregToMemrefMap; - DenseMap qextractToMemrefMap; - initializeMemrefMap(newQnodeFunc, module, qregToMemrefMap, qextractToMemrefMap, ctx); - - TypeConverter typeConverter; - typeConverter.addConversion([](Type type) { return type; }); - typeConverter.addConversion( - [&](ion::PulseType type) -> Type { return rtio::EventType::get(ctx); }); - - ConversionTarget target(*ctx); - target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); - - // prepare kernel function - if (failed(IonPulseConversion(newQnodeFunc, target, typeConverter, ionInfo, - qextractToMemrefMap, ctx)) || - failed(ParallelProtocolConversion(newQnodeFunc, target, typeConverter, ctx)) || - failed(SCFStructuralConversion(newQnodeFunc, target, typeConverter, ctx)) || - failed(PropagateEvents(newQnodeFunc, ctx)) || - failed(CleanQuantumOps(newQnodeFunc, ctx)) || - failed(ResolveChannelMapping(newQnodeFunc, ctx)) || - failed(CanonicalizeKernelFunction(newQnodeFunc, ctx))) { - newQnodeFunc->emitError("Failed to convert to rtio dialect"); - return signalPassFailure(); - } - - for (auto funcOp : llvm::make_early_inc_range(module.getOps())) { - if (funcOp.getName().str() != newQnodeFunc.getName().str()) { - funcOp.erase(); - } - } - } -}; - -} // namespace ion -} // namespace catalyst From bd1c4c4c1ec1ce4116349c00b7ace09abe84f35f Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 2 Dec 2025 15:01:38 -0500 Subject: [PATCH 28/51] update channel ID underlying number --- mlir/include/RTIO/IR/RTIODialect.td | 3 ++- mlir/lib/RTIO/IR/RTIODialect.cpp | 9 +++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/mlir/include/RTIO/IR/RTIODialect.td b/mlir/include/RTIO/IR/RTIODialect.td index e211140a52..a6208bd6d8 100644 --- a/mlir/include/RTIO/IR/RTIODialect.td +++ b/mlir/include/RTIO/IR/RTIODialect.td @@ -18,6 +18,7 @@ include "mlir/IR/OpBase.td" include "mlir/IR/DialectBase.td" include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" //===----------------------------------------------------------------------===// // RTIO Dialect Definition @@ -125,7 +126,7 @@ def RTIOChannelType : RTIO_Type<"Channel", "channel"> { if (!channelId) { return false; } - return channelId.getInt() >= 0; + return channelId.getInt() != ShapedType::kDynamic; } bool isDynamic() const { diff --git a/mlir/lib/RTIO/IR/RTIODialect.cpp b/mlir/lib/RTIO/IR/RTIODialect.cpp index a08a9e377e..f7dc068f06 100644 --- a/mlir/lib/RTIO/IR/RTIODialect.cpp +++ b/mlir/lib/RTIO/IR/RTIODialect.cpp @@ -15,6 +15,7 @@ #include "RTIO/IR/RTIODialect.h" #include "RTIO/IR/RTIOOps.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/DialectImplementation.h" #include "llvm/ADT/TypeSwitch.h" @@ -54,7 +55,7 @@ static ParseResult parseChannelTypeBody(AsmParser &parser, std::string &kind, Ar // After qualifiers, parse comma for channelId if (failed(parser.parseOptionalComma())) { - channelId = parser.getBuilder().getI64IntegerAttr(-1); + channelId = parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic); return success(); } } @@ -62,13 +63,13 @@ static ParseResult parseChannelTypeBody(AsmParser &parser, std::string &kind, Ar } else { // No comma at all, no qualifiers and no channelId - channelId = parser.getBuilder().getI64IntegerAttr(-1); + channelId = parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic); return success(); } // 3. Parse channelId: `?` or non-negative integer if (succeeded(parser.parseOptionalQuestion())) { - channelId = parser.getBuilder().getI64IntegerAttr(-1); + channelId = parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic); return success(); } @@ -99,7 +100,7 @@ static void printChannelTypeBody(AsmPrinter &printer, StringRef kind, ArrayAttr printer << "]"; } - // 3. Print channelId if present (and not default -1) + // 3. Print channelId if present (and not default ShapedType::kDynamic) if (channelId) { int64_t id = channelId.getInt(); if (id >= 0 || (qualifiers && !qualifiers.empty())) { From 6d57469419acb8d292874f02644655d9fb857878 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 2 Dec 2025 15:07:14 -0500 Subject: [PATCH 29/51] add bracket for if --- mlir/lib/RTIO/IR/RTIODialect.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mlir/lib/RTIO/IR/RTIODialect.cpp b/mlir/lib/RTIO/IR/RTIODialect.cpp index f7dc068f06..c30b010d25 100644 --- a/mlir/lib/RTIO/IR/RTIODialect.cpp +++ b/mlir/lib/RTIO/IR/RTIODialect.cpp @@ -74,12 +74,14 @@ static ParseResult parseChannelTypeBody(AsmParser &parser, std::string &kind, Ar } int64_t id; - if (failed(parser.parseInteger(id))) + if (failed(parser.parseInteger(id))) { return failure(); + } - if (id < 0) + if (id < 0) { return parser.emitError(parser.getCurrentLocation(), "static channel ID must be non-negative"); + } channelId = parser.getBuilder().getI64IntegerAttr(id); return success(); From cf5874e989b3f98b49fad4ab90a1adad6a3310db Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 2 Dec 2025 15:10:36 -0500 Subject: [PATCH 30/51] fix --- mlir/include/RTIO/IR/RTIODialect.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/RTIO/IR/RTIODialect.td b/mlir/include/RTIO/IR/RTIODialect.td index a6208bd6d8..396d262861 100644 --- a/mlir/include/RTIO/IR/RTIODialect.td +++ b/mlir/include/RTIO/IR/RTIODialect.td @@ -126,7 +126,7 @@ def RTIOChannelType : RTIO_Type<"Channel", "channel"> { if (!channelId) { return false; } - return channelId.getInt() != ShapedType::kDynamic; + return channelId.getInt() != mlir::ShapedType::kDynamic; } bool isDynamic() const { From c4a3afa2f8ae640d26e43a232101ba2487432ca3 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 2 Dec 2025 15:13:30 -0500 Subject: [PATCH 31/51] add assert for oob --- mlir/include/RTIO/IR/RTIODialect.td | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/include/RTIO/IR/RTIODialect.td b/mlir/include/RTIO/IR/RTIODialect.td index 396d262861..313470db99 100644 --- a/mlir/include/RTIO/IR/RTIODialect.td +++ b/mlir/include/RTIO/IR/RTIODialect.td @@ -116,8 +116,8 @@ def RTIOChannelType : RTIO_Type<"Channel", "channel"> { } mlir::Attribute getQualifier(size_t index) const { - if (!getQualifiers() || index >= getQualifiers().size()) - return nullptr; + assert(getQualifiers() && "qualifiers are not present"); + assert(index < getQualifiers().size() && "index out of bounds"); return getQualifiers()[index]; } From 5e136eb769b38d773ee0644b891cee799fd25915 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 2 Dec 2025 15:26:00 -0500 Subject: [PATCH 32/51] update comment --- mlir/include/RTIO/IR/RTIODialect.td | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/mlir/include/RTIO/IR/RTIODialect.td b/mlir/include/RTIO/IR/RTIODialect.td index 313470db99..c9b3dc0aaa 100644 --- a/mlir/include/RTIO/IR/RTIODialect.td +++ b/mlir/include/RTIO/IR/RTIODialect.td @@ -30,11 +30,7 @@ def RTIO_Dialect : Dialect { The RTIO dialect provides operations for precise timing control and hardware signal generation on FPGAs for quantum computing. - This dialect supports two levels of abstraction: - 1. Event-Based IR (high-level): Declarative operations with explicit event dependencies - - // TODO: Do we need the separate Timeline IR for artiq family? - 2. Timeline IR (low-level): Stateful operations with implicit time cursor + It provides declarative operations with explicit event dependencies for hardware control. }]; let name = "rtio"; From b623d0d76d9d6c9b1e78253bbcfd4f6c2047e681 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 2 Dec 2025 15:36:47 -0500 Subject: [PATCH 33/51] supports importing json to rtio.config --- mlir/lib/Ion/Transforms/ion-to-rtio.cpp | 89 +++++++++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp index b70d2ad598..77765b424e 100644 --- a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp +++ b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "llvm/Support/JSON.h" +#include "llvm/Support/MemoryBuffer.h" + #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -731,6 +734,82 @@ LogicalResult ResolveChannelMapping(func::FuncOp funcOp, MLIRContext *ctx) return success(); } +//===----------------------------------------------------------------------===// +// JSON to MLIR Attribute Conversion +//===----------------------------------------------------------------------===// + +/// Convert a JSON value to an MLIR Attribute +Attribute jsonToAttribute(MLIRContext *ctx, const llvm::json::Value &json) +{ + if (auto str = json.getAsString()) { + return StringAttr::get(ctx, *str); + } + if (auto num = json.getAsInteger()) { + return IntegerAttr::get(IntegerType::get(ctx, 64), *num); + } + if (auto num = json.getAsNumber()) { + return FloatAttr::get(Float64Type::get(ctx), *num); + } + if (auto b = json.getAsBoolean()) { + return BoolAttr::get(ctx, *b); + } + if (auto arr = json.getAsArray()) { + SmallVector attrs; + for (const auto &elem : *arr) { + attrs.push_back(jsonToAttribute(ctx, elem)); + } + return ArrayAttr::get(ctx, attrs); + } + if (auto *obj = json.getAsObject()) { + SmallVector entries; + for (const auto &kv : *obj) { + StringRef key = kv.first; + entries.emplace_back(StringAttr::get(ctx, key), jsonToAttribute(ctx, kv.second)); + } + // Sort entries by name for DictionaryAttr + llvm::sort(entries, [](const NamedAttribute &lhs, const NamedAttribute &rhs) { + return lhs.getName().getValue() < rhs.getName().getValue(); + }); + return DictionaryAttr::get(ctx, entries); + } + // null + return UnitAttr::get(ctx); +} + +/// Load a JSON file and convert it to an rtio.config attribute +FailureOr loadDeviceDbAsConfig(MLIRContext *ctx, StringRef filePath) +{ + auto fileOrErr = llvm::MemoryBuffer::getFile(filePath); + if (!fileOrErr) { + return failure(); + } + + auto json = llvm::json::parse((*fileOrErr)->getBuffer()); + if (!json) { + llvm::errs() << "Failed to parse JSON: " << llvm::toString(json.takeError()) << "\n"; + return failure(); + } + + auto *obj = json->getAsObject(); + if (!obj) { + llvm::errs() << "Device DB JSON must be an object\n"; + return failure(); + } + + // Convert JSON object to DictionaryAttr + SmallVector entries; + for (const auto &kv : *obj) { + StringRef key = kv.first; + entries.emplace_back(StringAttr::get(ctx, key), jsonToAttribute(ctx, kv.second)); + } + llvm::sort(entries, [](const NamedAttribute &lhs, const NamedAttribute &rhs) { + return lhs.getName().getValue() < rhs.getName().getValue(); + }); + + auto dictAttr = DictionaryAttr::get(ctx, entries); + return rtio::ConfigAttr::get(ctx, dictAttr); +} + //===----------------------------------------------------------------------===// // Pass Implementation //===----------------------------------------------------------------------===// @@ -972,6 +1051,16 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { MLIRContext *ctx = &getContext(); auto module = cast(getOperation()); + // Load device_db JSON file and set rtio.config attribute on module + if (!deviceDb.empty()) { + auto configOrErr = loadDeviceDbAsConfig(ctx, deviceDb); + if (failed(configOrErr)) { + module->emitError("Failed to load device database from: ") << deviceDb; + return signalPassFailure(); + } + module->setAttr(rtio::ConfigAttr::getModuleAttrName(), *configOrErr); + } + // check if there is only one qnode function func::FuncOp qnodeFunc = nullptr; int qnodeCounts = 0; From b127e9ecef2db6e786c58863fd97a6f9e0dd73ec Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 2 Dec 2025 15:45:12 -0500 Subject: [PATCH 34/51] update changelog --- doc/releases/changelog-dev.md | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index f4def776b0..dcd9091e42 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -70,7 +70,7 @@

Improvements 🛠

-* A new ``"changed"`` option has been added to the ``keep_intermediate`` parameter of +* A new ``"changed"`` option has been added to the ``keep_intermediate`` parameter of :func:`~.qjit`. This option saves intermediate IR files after each pass, but only when the IR is actually modified by the pass. [(#2186)](https://github.com/PennyLaneAI/catalyst/pull/2186) @@ -132,9 +132,9 @@ * The `--adjoint-lowering` pass can now handle PPR operations. [(#2227)](https://github.com/PennyLaneAI/catalyst/pull/2227) -* Catalyst now supports Pauli product rotations with arbitrary or dynamic angles in the - QEC dialect. This will allow :class:`qml.PauliRot` with arbitrary or dynamic angles, - angles not known at compile time, to be lowered to the QEC dialect. This is implemented +* Catalyst now supports Pauli product rotations with arbitrary or dynamic angles in the + QEC dialect. This will allow :class:`qml.PauliRot` with arbitrary or dynamic angles, + angles not known at compile time, to be lowered to the QEC dialect. This is implemented as a new `qec.ppr.arbitrary` operation, which takes a Pauli product and an arbitrary or dynamic angle as input. The arbitrary angles are specified as a double in terms of radian. [(#2232)](https://github.com/PennyLaneAI/catalyst/pull/2232) @@ -151,7 +151,7 @@ * The MLIR pipeline ``enforce-runtime-invariants-pipeline`` has been renamed to ``quantum-compilation-pipeline`` and the old ``quantum-compilation-pipeline`` has been renamed to - ``gradient-lowering-pipeline``. Users who referenced these pipeline names directly would need to + ``gradient-lowering-pipeline``. Users who referenced these pipeline names directly would need to update their code to use the new names. [(#2186)](https://github.com/PennyLaneAI/catalyst/pull/2186) @@ -277,7 +277,9 @@ * Decouple the ion dialect from the quantum dialect to support the new RTIO compilation flow. The ion dialect now uses its own `!ion.qubit` type instead of depending on `!quantum.bit`. Conversion between qubits of quantum and ion dialects is handled via unrealized conversion casts. + And we support the compiling from ION dialect to RTIO dilalect. [(#2163)](https://github.com/PennyLaneAI/catalyst/pull/2163) + [(#2204)](https://github.com/PennyLaneAI/catalyst/pull/2204) For an example, quantum qubits are converted to ion qubits as follows: ```mlir @@ -302,15 +304,15 @@ of identities. [(#2192)](https://github.com/PennyLaneAI/catalyst/pull/2192) - * Renamed `annotate-function` pass to `annotate-invalid-gradient-functions` and move it to the + * Renamed `annotate-function` pass to `annotate-invalid-gradient-functions` and move it to the gradient dialect and the `lower-gradients` compilation stage. [(#2241)](https://github.com/PennyLaneAI/catalyst/pull/2241) * Added support for PPRs to the :func:`~.passes.merge_rotations` pass to merge PPRs with equivalent angles, and cancelling of PPRs with opposite angles, or angles - that sum to identity. Also supports conditions on PPRs, merging when conditions are + that sum to identity. Also supports conditions on PPRs, merging when conditions are identical and not merging otherwise. - [(#2224)](https://github.com/PennyLaneAI/catalyst/pull/2224) + [(#2224)](https://github.com/PennyLaneAI/catalyst/pull/2224) [(#2245)](https://github.com/PennyLaneAI/catalyst/pull/2245) [(#2254)](https://github.com/PennyLaneAI/catalyst/pull/2254) From 1ce02ba413cfca7157c9b083e6519211a6e99dfa Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 2 Dec 2025 15:54:40 -0500 Subject: [PATCH 35/51] add deviceDB option --- mlir/include/Ion/Transforms/Passes.td | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mlir/include/Ion/Transforms/Passes.td b/mlir/include/Ion/Transforms/Passes.td index ef14d16b23..742d654656 100644 --- a/mlir/include/Ion/Transforms/Passes.td +++ b/mlir/include/Ion/Transforms/Passes.td @@ -70,6 +70,9 @@ def IonToRTIOPass : Pass<"convert-ion-to-rtio", "mlir::ModuleOp"> { Option<"kernelName", "kernel-name", "std::string", /*default=*/"\"__kernel__\"", "Name of the generated kernel function"> + Option<"deviceDb", "device_db", + "std::string", /*default=*/"\"\"", + "Path to the device database JSON file">, ]; } From a379528b4638478c051ff244be0176e254163cb8 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 2 Dec 2025 15:59:25 -0500 Subject: [PATCH 36/51] add missing comma --- mlir/include/Ion/Transforms/Passes.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/Ion/Transforms/Passes.td b/mlir/include/Ion/Transforms/Passes.td index 742d654656..ea76ee780d 100644 --- a/mlir/include/Ion/Transforms/Passes.td +++ b/mlir/include/Ion/Transforms/Passes.td @@ -69,7 +69,7 @@ def IonToRTIOPass : Pass<"convert-ion-to-rtio", "mlir::ModuleOp"> { let options = [ Option<"kernelName", "kernel-name", "std::string", /*default=*/"\"__kernel__\"", - "Name of the generated kernel function"> + "Name of the generated kernel function">, Option<"deviceDb", "device_db", "std::string", /*default=*/"\"\"", "Path to the device database JSON file">, From 4f851ef3c2554602c1acc61c8dc4932b1834e255 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 2 Dec 2025 16:18:46 -0500 Subject: [PATCH 37/51] add missing include --- mlir/lib/Ion/Transforms/ion-to-rtio.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp index 77765b424e..7f78a166ff 100644 --- a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp +++ b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "llvm/Support/JSON.h" #include "llvm/Support/MemoryBuffer.h" From 2ff88f9ee67edc1f61076dfe16fbce736dfea45b Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Wed, 3 Dec 2025 14:08:24 -0500 Subject: [PATCH 38/51] Update mlir/lib/RTIO/IR/RTIODialect.cpp Co-authored-by: Mehrdad Malek <39844030+mehrdad2m@users.noreply.github.com> --- mlir/lib/RTIO/IR/RTIODialect.cpp | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/mlir/lib/RTIO/IR/RTIODialect.cpp b/mlir/lib/RTIO/IR/RTIODialect.cpp index c30b010d25..283fb41f30 100644 --- a/mlir/lib/RTIO/IR/RTIODialect.cpp +++ b/mlir/lib/RTIO/IR/RTIODialect.cpp @@ -105,21 +105,16 @@ static void printChannelTypeBody(AsmPrinter &printer, StringRef kind, ArrayAttr // 3. Print channelId if present (and not default ShapedType::kDynamic) if (channelId) { int64_t id = channelId.getInt(); - if (id >= 0 || (qualifiers && !qualifiers.empty())) { - printer << ", "; - if (id < 0) { - printer << "?"; - } - else { - printer << id; - } + printer << ", "; + if (id >= 0) { + printer << id; } - } - else { - if (qualifiers && !qualifiers.empty()) { - printer << ", "; + else { + printer << "?"; } - printer << "?"; + } + else if (qualifiers && !qualifiers.empty()) { + printer << ", ?"; } } From 2486c74fd8a03be1eeaaeae42a993390ec2b1cc1 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Wed, 3 Dec 2025 14:08:40 -0500 Subject: [PATCH 39/51] Update mlir/include/RTIO/IR/RTIOOps.td Co-authored-by: Mehrdad Malek <39844030+mehrdad2m@users.noreply.github.com> --- mlir/include/RTIO/IR/RTIOOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/RTIO/IR/RTIOOps.td b/mlir/include/RTIO/IR/RTIOOps.td index 91b93a3743..e826d56714 100644 --- a/mlir/include/RTIO/IR/RTIOOps.td +++ b/mlir/include/RTIO/IR/RTIOOps.td @@ -61,7 +61,7 @@ def RTIOQubitToChannelOp : RTIO_Op<"qubit_to_channel"> { Example: ```mlir // Map ion qubit to dds channel with dynamic channel id - %ch = rtio.qubit_tos_channel %ion_qubit : !ion.qubit -> !rtio.channel<"dds", ?> + %ch = rtio.qubit_to_channel %ion_qubit : !ion.qubit -> !rtio.channel<"dds", ?> ``` During the channel resolution stage, this operation will be replaced by a From e6d0c7c65829d13dc8e7a2aca40909bc6484eaa5 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Wed, 3 Dec 2025 14:22:51 -0500 Subject: [PATCH 40/51] Add filecheck line --- mlir/test/RTIO/VerifierTest.mlir | 39 +++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/mlir/test/RTIO/VerifierTest.mlir b/mlir/test/RTIO/VerifierTest.mlir index 9b4a39549c..fdb01b3a88 100644 --- a/mlir/test/RTIO/VerifierTest.mlir +++ b/mlir/test/RTIO/VerifierTest.mlir @@ -12,23 +12,34 @@ // See the License for the specific language governing permissions and // limitations under the License. -// RUN: quantum-opt %s --split-input-file --verify-diagnostics +// RUN: quantum-opt %s --split-input-file --verify-diagnostics | FileCheck %s //////////////////////// // Channel Operations // //////////////////////// +// CHECK-LABEL: func.func @channel_good() func.func @channel_good() { // Smoke test for valid channel operations + // CHECK: rtio.channel : !rtio.channel<"dds", 0> %ch0 = rtio.channel : !rtio.channel<"dds", 0> + // CHECK: rtio.channel : !rtio.channel<"dds", [0 : i64], 0> %ch0_q1 = rtio.channel : !rtio.channel<"dds", [0], 0> + // CHECK: rtio.channel : !rtio.channel<"dds", [0 : i64, "t0"], 0> %ch0_q2 = rtio.channel : !rtio.channel<"dds", [0, "t0"], 0> + // CHECK: rtio.channel : !rtio.channel<"dds", ?> %ch_explict_dyn = rtio.channel : !rtio.channel<"dds", ?> + // CHECK: rtio.channel : !rtio.channel<"dds", [0 : i64], ?> %ch_explict_dyn_q1 = rtio.channel : !rtio.channel<"dds", [0], ?> + // CHECK: rtio.channel : !rtio.channel<"dds", [0 : i64, "t0"], ?> %ch_explict_dyn_q2 = rtio.channel : !rtio.channel<"dds", [0, "t0"], ?> + // Implicit dynamic channel ID (omitted = dynamic) + // CHECK: rtio.channel : !rtio.channel<"dds", ?> %ch0_implicit_dyn = rtio.channel : !rtio.channel<"dds"> + // CHECK: rtio.channel : !rtio.channel<"dds", [0 : i64], ?> %ch_implicit_dyn_q1 = rtio.channel : !rtio.channel<"dds", [0]> + // CHECK: rtio.channel : !rtio.channel<"dds", [0 : i64, "t0"], ?> %ch_implicit_dyn_q2 = rtio.channel : !rtio.channel<"dds", [0, "t0"]> return } @@ -42,8 +53,10 @@ func.func @channel_negative_id() { // ----- +// CHECK-LABEL: func.func @qubit_to_channel_good func.func @qubit_to_channel_good(%qubit: !ion.qubit) { // Smoke test for qubit_to_channel + // CHECK: rtio.qubit_to_channel %{{.*}} : !ion.qubit -> !rtio.channel<"dds", ["transition_0"], ?> %ch0 = rtio.qubit_to_channel %qubit : !ion.qubit -> !rtio.channel<"dds", ["transition_0"], ?> return } @@ -54,9 +67,12 @@ func.func @qubit_to_channel_good(%qubit: !ion.qubit) { // Event-Based Operations // //////////////////////////// +// CHECK-LABEL: func.func @pulse_basic func.func @pulse_basic(%dur: f64, %freq: f64, %phase: f64) { + // CHECK: %[[CH:.*]] = rtio.channel : !rtio.channel<"dds", 0> %ch0 = rtio.channel : !rtio.channel<"dds", 0> + // CHECK: rtio.pulse %[[CH]] duration(%{{.*}}) frequency(%{{.*}}) phase(%{{.*}}) : <"dds", 0> -> !rtio.event %event = rtio.pulse %ch0 duration(%dur) frequency(%freq) phase(%phase) : !rtio.channel<"dds", 0> -> !rtio.event @@ -65,12 +81,16 @@ func.func @pulse_basic(%dur: f64, %freq: f64, %phase: f64) { // ----- +// CHECK-LABEL: func.func @pulse_with_wait func.func @pulse_with_wait(%dur: f64, %freq: f64, %phase: f64) { + // CHECK: %[[CH:.*]] = rtio.channel : !rtio.channel<"dds", 0> %ch0 = rtio.channel : !rtio.channel<"dds", 0> + // CHECK: %[[E0:.*]] = rtio.pulse %[[CH]] duration(%{{.*}}) frequency(%{{.*}}) phase(%{{.*}}) : <"dds", 0> -> !rtio.event %event0 = rtio.pulse %ch0 duration(%dur) frequency(%freq) phase(%phase) : !rtio.channel<"dds", 0> -> !rtio.event + // CHECK: rtio.pulse %[[CH]] duration(%{{.*}}) frequency(%{{.*}}) phase(%{{.*}}) wait(%[[E0]]) : <"dds", 0> -> !rtio.event %event1 = rtio.pulse %ch0 duration(%dur) frequency(%freq) phase(%phase) wait(%event0) : !rtio.channel<"dds", 0> -> !rtio.event @@ -80,28 +100,39 @@ func.func @pulse_with_wait(%dur: f64, %freq: f64, %phase: f64) { // ----- +// CHECK-LABEL: func.func @sync_basic func.func @sync_basic(%dur: f64, %freq: f64, %phase: f64) { + // CHECK: %[[CH0:.*]] = rtio.channel : !rtio.channel<"dds", 0> %ch0 = rtio.channel : !rtio.channel<"dds", 0> + // CHECK: %[[CH1:.*]] = rtio.channel : !rtio.channel<"dds", 1> %ch1 = rtio.channel : !rtio.channel<"dds", 1> + // CHECK: %[[CH2:.*]] = rtio.channel : !rtio.channel<"dds", 2> %ch2 = rtio.channel : !rtio.channel<"dds", 2> + // CHECK: %[[CH3:.*]] = rtio.channel : !rtio.channel<"dds", 3> %ch3 = rtio.channel : !rtio.channel<"dds", 3> // sync single event + // CHECK: %[[E0:.*]] = rtio.pulse %[[CH0]] duration(%{{.*}}) frequency(%{{.*}}) phase(%{{.*}}) : <"dds", 0> -> !rtio.event %event0 = rtio.pulse %ch0 duration(%dur) frequency(%freq) phase(%phase) : !rtio.channel<"dds", 0> -> !rtio.event + // CHECK: %[[E1:.*]] = rtio.pulse %[[CH1]] duration(%{{.*}}) frequency(%{{.*}}) phase(%{{.*}}) : <"dds", 1> -> !rtio.event %event1 = rtio.pulse %ch1 duration(%dur) frequency(%freq) phase(%phase) : !rtio.channel<"dds", 1> -> !rtio.event // sync multiple events + // CHECK: %[[SYNC1:.*]] = rtio.sync %[[E0]], %[[E1]] : !rtio.event %sync1 = rtio.sync %event0, %event1 : !rtio.event + // CHECK: %[[E2:.*]] = rtio.pulse %[[CH2]] duration(%{{.*}}) frequency(%{{.*}}) phase(%{{.*}}) : <"dds", 2> -> !rtio.event %event2 = rtio.pulse %ch2 duration(%dur) frequency(%freq) phase(%phase) : !rtio.channel<"dds", 2> -> !rtio.event + // CHECK: %[[E3:.*]] = rtio.pulse %[[CH3]] duration(%{{.*}}) frequency(%{{.*}}) phase(%{{.*}}) : <"dds", 3> -> !rtio.event %event3 = rtio.pulse %ch3 duration(%dur) frequency(%freq) phase(%phase) : !rtio.channel<"dds", 3> -> !rtio.event + // CHECK: rtio.sync %[[SYNC1]], %[[E2]], %[[E3]] : !rtio.event %sync2 = rtio.sync %sync1, %event2, %event3 : !rtio.event return @@ -118,13 +149,16 @@ func.func @sync_no_events() { // ----- +// CHECK-LABEL: func.func @empty_good() func.func @empty_good() { + // CHECK: rtio.empty : !rtio.event %empty = rtio.empty : !rtio.event return } // ----- +// CHECK: module @config_test attributes {rtio.config = #rtio.config<{config1 = 1 : i32, config2 = "test", nested = {a = 0 : i32, b = "test"}}>} module @config_test attributes { rtio.config = #rtio.config<{ config1 = 1 : i32, @@ -132,6 +166,7 @@ module @config_test attributes { nested = {a = 0 : i32, b = "test"} }> } { + // CHECK: func.func @kernel() func.func @kernel() { return } @@ -139,9 +174,11 @@ module @config_test attributes { // ----- +// CHECK: module @empty_config attributes {rtio.config = #rtio.config<{}>} module @empty_config attributes { rtio.config = #rtio.config<{}> } { + // CHECK: func.func @kernel() func.func @kernel() { return } From 3937ecf6da31d07e0227a4b842d6b88eec2606f0 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Wed, 3 Dec 2025 16:13:29 -0500 Subject: [PATCH 41/51] revert --- doc/releases/changelog-dev.md | 7 ------- 1 file changed, 7 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index eda413fd22..a39b93344e 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -70,13 +70,6 @@

Improvements 🛠

-* Catalyst can now use the new `pass_name` property of pennylane transform objects. Passes can now - be created using `qml.transform(pass_name=pass_name)` instead of `PassPipelineWrapper`. - [(#2149](https://github.com/PennyLaneAI/catalyst/pull/2149) - -* An error is now raised if a transform is applied inside a QNode when program capture is enabled. - [(#2256)](https://github.com/PennyLaneAI/catalyst/pull/2256) - * A new ``"changed"`` option has been added to the ``keep_intermediate`` parameter of :func:`~.qjit`. This option saves intermediate IR files after each pass, but only when the IR is actually modified by the pass. From 5559870038af8d16c6a1c98e686a336d6d841bd9 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Wed, 3 Dec 2025 16:14:13 -0500 Subject: [PATCH 42/51] merge --- doc/releases/changelog-dev.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index a39b93344e..eda413fd22 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -70,6 +70,13 @@

Improvements 🛠

+* Catalyst can now use the new `pass_name` property of pennylane transform objects. Passes can now + be created using `qml.transform(pass_name=pass_name)` instead of `PassPipelineWrapper`. + [(#2149](https://github.com/PennyLaneAI/catalyst/pull/2149) + +* An error is now raised if a transform is applied inside a QNode when program capture is enabled. + [(#2256)](https://github.com/PennyLaneAI/catalyst/pull/2256) + * A new ``"changed"`` option has been added to the ``keep_intermediate`` parameter of :func:`~.qjit`. This option saves intermediate IR files after each pass, but only when the IR is actually modified by the pass. From e92d58ed99aa0c44effb801ccbec51407b93d7fa Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Wed, 3 Dec 2025 16:19:03 -0500 Subject: [PATCH 43/51] Add changelog --- doc/releases/changelog-dev.md | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index eba451c271..f97ea81ee4 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -2,6 +2,9 @@

New features since last release

+* RTIO dialect is added to bypass the compilation flow from OpenAPL to ARTIQ’s LLVM IR. It is introduced to bridge the gap between ION dialect and ARTIQ’s LLVM IR. The design philosophy of RTIO dialect is primarily event-based. Every operation is asynchronous; sync behaviour occurs only via `rtio.sync` or `wait operand` in event operation. + [(#2185)](https://github.com/PennyLaneAI/catalyst/pull/2185) + * Added ``catalyst.switch``, a qjit compatible, index-switch style control flow decorator. [(#2171)](https://github.com/PennyLaneAI/catalyst/pull/2171) @@ -77,7 +80,7 @@ * An error is now raised if a transform is applied inside a QNode when program capture is enabled. [(#2256)](https://github.com/PennyLaneAI/catalyst/pull/2256) -* A new ``"changed"`` option has been added to the ``keep_intermediate`` parameter of +* A new ``"changed"`` option has been added to the ``keep_intermediate`` parameter of :func:`~.qjit`. This option saves intermediate IR files after each pass, but only when the IR is actually modified by the pass. [(#2186)](https://github.com/PennyLaneAI/catalyst/pull/2186) @@ -139,9 +142,9 @@ * The `--adjoint-lowering` pass can now handle PPR operations. [(#2227)](https://github.com/PennyLaneAI/catalyst/pull/2227) -* Catalyst now supports Pauli product rotations with arbitrary or dynamic angles in the - QEC dialect. This will allow :class:`qml.PauliRot` with arbitrary or dynamic angles, - angles not known at compile time, to be lowered to the QEC dialect. This is implemented +* Catalyst now supports Pauli product rotations with arbitrary or dynamic angles in the + QEC dialect. This will allow :class:`qml.PauliRot` with arbitrary or dynamic angles, + angles not known at compile time, to be lowered to the QEC dialect. This is implemented as a new `qec.ppr.arbitrary` operation, which takes a Pauli product and an arbitrary or dynamic angle as input. The arbitrary angles are specified as a double in terms of radian. [(#2232)](https://github.com/PennyLaneAI/catalyst/pull/2232) @@ -158,7 +161,7 @@ * The MLIR pipeline ``enforce-runtime-invariants-pipeline`` has been renamed to ``quantum-compilation-pipeline`` and the old ``quantum-compilation-pipeline`` has been renamed to - ``gradient-lowering-pipeline``. Users who referenced these pipeline names directly would need to + ``gradient-lowering-pipeline``. Users who referenced these pipeline names directly would need to update their code to use the new names. [(#2186)](https://github.com/PennyLaneAI/catalyst/pull/2186) @@ -309,24 +312,24 @@ of identities. [(#2192)](https://github.com/PennyLaneAI/catalyst/pull/2192) - * Renamed `annotate-function` pass to `annotate-invalid-gradient-functions` and move it to the + * Renamed `annotate-function` pass to `annotate-invalid-gradient-functions` and move it to the gradient dialect and the `lower-gradients` compilation stage. [(#2241)](https://github.com/PennyLaneAI/catalyst/pull/2241) * Added support for PPRs to the :func:`~.passes.merge_rotations` pass to merge PPRs with equivalent angles, and cancelling of PPRs with opposite angles, or angles - that sum to identity. Also supports conditions on PPRs, merging when conditions are + that sum to identity. Also supports conditions on PPRs, merging when conditions are identical and not merging otherwise. - [(#2224)](https://github.com/PennyLaneAI/catalyst/pull/2224) + [(#2224)](https://github.com/PennyLaneAI/catalyst/pull/2224) [(#2245)](https://github.com/PennyLaneAI/catalyst/pull/2245) [(#2254)](https://github.com/PennyLaneAI/catalyst/pull/2254) * Refactor QEC tablegen files to separate QEC operations into a new `QECOp.td` file - [(#2253](https://github.com/PennyLaneAI/catalyst/pull/2253) + [(#2253](https://github.com/PennyLaneAI/catalyst/pull/2253) - * Removed the `getRotationKind` and `setRotationKind` methods from + * Removed the `getRotationKind` and `setRotationKind` methods from the QEC interface `QECOpInterface` to simplify the interface. [(#2250)](https://github.com/PennyLaneAI/catalyst/pull/2250) From f77a12b859136610d7ecf0c66d8f6c0e237ed9b6 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Wed, 3 Dec 2025 16:20:40 -0500 Subject: [PATCH 44/51] update changelog --- doc/releases/changelog-dev.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 348a6baf21..6a4deded7f 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -2,8 +2,9 @@

New features since last release

-* RTIO dialect is added to bypass the compilation flow from OpenAPL to ARTIQ’s LLVM IR. It is introduced to bridge the gap between ION dialect and ARTIQ’s LLVM IR. The design philosophy of RTIO dialect is primarily event-based. Every operation is asynchronous; sync behaviour occurs only via `rtio.sync` or `wait operand` in event operation. +* RTIO dialect is added to bypass the compilation flow from OpenAPL to ARTIQ’s LLVM IR. It is introduced to bridge the gap between ION dialect and ARTIQ’s LLVM IR. The design philosophy of RTIO dialect is primarily event-based. Every operation is asynchronous; sync behaviour occurs only via `rtio.sync` or `wait operand` in event operation. And we now support the compiling from ION dialect to RTIO dilalect. [(#2185)](https://github.com/PennyLaneAI/catalyst/pull/2185) + [(#2204)](https://github.com/PennyLaneAI/catalyst/pull/2204) * Added ``catalyst.switch``, a qjit compatible, index-switch style control flow decorator. [(#2171)](https://github.com/PennyLaneAI/catalyst/pull/2171) @@ -287,9 +288,7 @@ * Decouple the ion dialect from the quantum dialect to support the new RTIO compilation flow. The ion dialect now uses its own `!ion.qubit` type instead of depending on `!quantum.bit`. Conversion between qubits of quantum and ion dialects is handled via unrealized conversion casts. - And we support the compiling from ION dialect to RTIO dilalect. [(#2163)](https://github.com/PennyLaneAI/catalyst/pull/2163) - [(#2204)](https://github.com/PennyLaneAI/catalyst/pull/2204) For an example, quantum qubits are converted to ion qubits as follows: ```mlir From d0eba303215db95d16bfa084408bd7e94902c2c4 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Thu, 4 Dec 2025 16:58:45 -0500 Subject: [PATCH 45/51] udpate --- mlir/lib/Ion/Transforms/ion-to-rtio.cpp | 74 +++++++++++-------------- 1 file changed, 32 insertions(+), 42 deletions(-) diff --git a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp index 7f78a166ff..7fe5748c62 100644 --- a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp +++ b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp @@ -706,36 +706,6 @@ LogicalResult CleanQuantumOps(func::FuncOp funcOp, MLIRContext *ctx) return success(); } -LogicalResult CanonicalizeKernelFunction(func::FuncOp funcOp, MLIRContext *ctx) -{ - RewritePatternSet patterns(ctx); - for (auto *dialect : ctx->getLoadedDialects()) { - dialect->getCanonicalizationPatterns(patterns); - } - for (RegisteredOperationName op : ctx->getRegisteredOperations()) { - op.getCanonicalizationPatterns(patterns, ctx); - } - if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { - return failure(); - } - - IRRewriter rewriter(ctx); - DominanceInfo domInfo(funcOp); - eliminateCommonSubExpressions(rewriter, domInfo, funcOp); - - return success(); -} - -LogicalResult ResolveChannelMapping(func::FuncOp funcOp, MLIRContext *ctx) -{ - RewritePatternSet patterns(ctx); - patterns.add(ctx); - if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { - return failure(); - } - return success(); -} - //===----------------------------------------------------------------------===// // JSON to MLIR Attribute Conversion //===----------------------------------------------------------------------===// @@ -932,10 +902,17 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { return success(); } - LogicalResult PropagateEvents(func::FuncOp funcOp, MLIRContext *ctx) + LogicalResult FinalizeKernelFunction(func::FuncOp funcOp, MLIRContext *ctx) { RewritePatternSet patterns(ctx); + for (auto *dialect : ctx->getLoadedDialects()) { + dialect->getCanonicalizationPatterns(patterns); + } + for (RegisteredOperationName op : ctx->getRegisteredOperations()) { + op.getCanonicalizationPatterns(patterns, ctx); + } patterns.add(ctx); + patterns.add(ctx); if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { return failure(); } @@ -978,6 +955,8 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { OpBuilder builder(ctx); int globalCounter = 0; + + // create a global memref for each quantum.alloc op funcOp.walk([&](quantum::AllocOp allocOp) { size_t numQubits = allocOp.getNqubitsAttr().value(); auto memrefType = @@ -1048,6 +1027,20 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { }); } + // In ARTIQ's compilation flow, we need to drop the pulse with transition 0 from the protocol + void dropOnePulseFromProtocol(func::FuncOp funcOp) + { + SmallVector pulsesToErase; + funcOp.walk([&](ion::PulseOp pulseOp) { + if (pulseOp.getBeamAttr().getTransitionIndex().getInt() == 0) { + pulsesToErase.push_back(pulseOp); + } + }); + for (auto pulseOp : pulsesToErase) { + pulseOp.erase(); + } + } + void runOnOperation() override { MLIRContext *ctx = &getContext(); @@ -1092,13 +1085,7 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { // drop one of the pulse from the certain protocol // the way we handle the dropped pulse will be updated in the future - SmallVector pulsesToErase; - newQnodeFunc.walk([&](ion::PulseOp pulseOp) { - if (pulseOp.getBeamAttr().getTransitionIndex().getInt() == 0) - pulsesToErase.push_back(pulseOp); - }); - for (auto pulseOp : pulsesToErase) - pulseOp.erase(); + dropOnePulseFromProtocol(newQnodeFunc); // Construct mapping from qreg alloc and qreg extract to memref // In the later conversion, we use the mapping to construct the channel for rtio.pulse @@ -1119,14 +1106,17 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { qextractToMemrefMap, ctx)) || failed(ParallelProtocolConversion(newQnodeFunc, target, typeConverter, ctx)) || failed(SCFStructuralConversion(newQnodeFunc, target, typeConverter, ctx)) || - failed(PropagateEvents(newQnodeFunc, ctx)) || - failed(CleanQuantumOps(newQnodeFunc, ctx)) || - failed(ResolveChannelMapping(newQnodeFunc, ctx)) || - failed(CanonicalizeKernelFunction(newQnodeFunc, ctx))) { + failed(FinalizeKernelFunction(newQnodeFunc, ctx))) { newQnodeFunc->emitError("Failed to convert to rtio dialect"); return signalPassFailure(); } + if (failed(CleanQuantumOps(newQnodeFunc, ctx))) { + newQnodeFunc->emitError("Failed to clean quantum ops"); + return signalPassFailure(); + } + + // remove other unused functions, only keep the kernel function for (auto funcOp : llvm::make_early_inc_range(module.getOps())) { if (funcOp.getName().str() != newQnodeFunc.getName().str()) { funcOp.erase(); From 983733420eccd3808f36aa0268af3f2bb698140d Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Thu, 4 Dec 2025 17:12:08 -0500 Subject: [PATCH 46/51] update --- mlir/lib/Ion/Transforms/ion-to-rtio.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp index 7fe5748c62..7ce57fc252 100644 --- a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp +++ b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp @@ -916,6 +916,11 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { return failure(); } + + IRRewriter rewriter(ctx); + DominanceInfo domInfo(funcOp); + eliminateCommonSubExpressions(rewriter, domInfo, funcOp); + return success(); } From 6314072373fc36d7b3c301aada25130be4d76fdd Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Fri, 5 Dec 2025 11:07:41 -0500 Subject: [PATCH 47/51] add test --- mlir/lib/Ion/Transforms/ion-to-rtio.cpp | 3 +- mlir/test/Ion/IonToRTIO.mlir | 319 ++++++++++++++++++++++++ 2 files changed, 321 insertions(+), 1 deletion(-) create mode 100644 mlir/test/Ion/IonToRTIO.mlir diff --git a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp index 7ce57fc252..7943bcf223 100644 --- a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp +++ b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp @@ -26,13 +26,13 @@ #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/CSE.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "Ion/IR/IonDialect.h" #include "Ion/IR/IonOps.h" -#include "Ion/Transforms/Passes.h" #include "Quantum/IR/QuantumDialect.h" #include "Quantum/IR/QuantumOps.h" #include "RTIO/IR/RTIODialect.h" @@ -786,6 +786,7 @@ FailureOr loadDeviceDbAsConfig(MLIRContext *ctx, StringRef fil // Pass Implementation //===----------------------------------------------------------------------===// +#define GEN_PASS_DECL_IONTORTIOPASS #define GEN_PASS_DEF_IONTORTIOPASS #include "Ion/Transforms/Passes.h.inc" diff --git a/mlir/test/Ion/IonToRTIO.mlir b/mlir/test/Ion/IonToRTIO.mlir new file mode 100644 index 0000000000..e7bc5ccb1e --- /dev/null +++ b/mlir/test/Ion/IonToRTIO.mlir @@ -0,0 +1,319 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RUN: quantum-opt %s --convert-ion-to-rtio --split-input-file -verify-diagnostics | FileCheck %s + +// RX(1) + +// CHECK: memref.global "private" constant @__qubit_map_0 : memref<2xindex> = dense<[0, 1]> +// CHECK-LABEL: func.func @__kernel__() +// CHECK-SAME: attributes {diff_method = "parameter-shift", qnode} +module @circuit { + func.func public @circuit_0() -> (memref<4xi64>, memref<4xi64>) attributes {diff_method = "parameter-shift", llvm.linkage = #llvm.linkage, qnode} { + %0 = ion.ion {charge = -1.000000e+00 : f64, levels = [#ion.level