Skip to content

Commit 60304f3

Browse files
bcdonovanreza-j
andauthored
Add LabelPlayOpDurationsPass (#231)
This PR adds a LabelPlayOpDurationPass which labels all `pulse.play` operations with the duration of the waveform that the operation plays. This pass is to be used to provide duration information which is assumed to be present in the `SchedulePortPass`. --------- Co-authored-by: reza-j <[email protected]>
1 parent 79f55b4 commit 60304f3

File tree

6 files changed

+136
-0
lines changed

6 files changed

+136
-0
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
//===- LabelPlayOpDurations.cpp - Label PlayOps with Durations --*- C++ -*-===//
2+
//
3+
// (C) Copyright IBM 2024.
4+
//
5+
// This code is part of Qiskit.
6+
//
7+
// This code is licensed under the Apache License, Version 2.0 with LLVM
8+
// Exceptions. You may obtain a copy of this license in the LICENSE.txt
9+
// file in the root directory of this source tree.
10+
//
11+
// Any modifications or derivative works of this code must retain this
12+
// copyright notice, and modified files need to carry a notice indicating
13+
// that they have been altered from the originals.
14+
//
15+
//===----------------------------------------------------------------------===//
16+
///
17+
/// This file defines the pass for labeling pulse.play operations with the
18+
/// duration of the waveform being played.
19+
//===----------------------------------------------------------------------===//
20+
21+
#ifndef PULSE_LABEL_PLAY_DURATION_H
22+
#define PULSE_LABEL_PLAY_DURATION_H
23+
24+
#include "mlir/IR/BuiltinOps.h"
25+
#include "mlir/Pass/Pass.h"
26+
27+
namespace mlir::pulse {
28+
29+
class LabelPlayOpDurationsPass
30+
: public PassWrapper<LabelPlayOpDurationsPass, OperationPass<ModuleOp>> {
31+
public:
32+
void runOnOperation() override;
33+
34+
llvm::StringRef getArgument() const override;
35+
llvm::StringRef getDescription() const override;
36+
llvm::StringRef getName() const override;
37+
};
38+
} // namespace mlir::pulse
39+
40+
#endif // PULSE_LABEL_PLAY_DURATION_H

include/Dialect/Pulse/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "Conversion/QUIRToPulse/LoadPulseCals.h"
2525
#include "Conversion/QUIRToPulse/QUIRToPulse.h"
2626
#include "Dialect/Pulse/Transforms/ClassicalOnlyDetection.h"
27+
#include "Dialect/Pulse/Transforms/LabelPlayOpDurations.h"
2728
#include "Dialect/Pulse/Transforms/MergeDelays.h"
2829
#include "Dialect/Pulse/Transforms/RemoveUnusedArguments.h"
2930
#include "Dialect/Pulse/Transforms/SchedulePort.h"

lib/Dialect/Pulse/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
add_mlir_dialect_library(MLIRPulseTransforms
1414
ClassicalOnlyDetection.cpp
1515
InlineRegion.cpp
16+
LabelPlayOpDurations.cpp
1617
MergeDelays.cpp
1718
Passes.cpp
1819
RemoveUnusedArguments.cpp
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
//===- LabelPlayOpDurations.cpp - Label PlayOps with Durations --*- C++ -*-===//
2+
//
3+
// (C) Copyright IBM 2024.
4+
//
5+
// This code is part of Qiskit.
6+
//
7+
// This code is licensed under the Apache License, Version 2.0 with LLVM
8+
// Exceptions. You may obtain a copy of this license in the LICENSE.txt
9+
// file in the root directory of this source tree.
10+
//
11+
// Any modifications or derivative works of this code must retain this
12+
// copyright notice, and modified files need to carry a notice indicating
13+
// that they have been altered from the originals.
14+
//
15+
//===----------------------------------------------------------------------===//
16+
///
17+
/// This file implements the pass for labeling pulse.play operations with the
18+
/// duration of the waveform being played.
19+
///
20+
//===----------------------------------------------------------------------===//
21+
22+
#include "Dialect/Pulse/Transforms/LabelPlayOpDurations.h"
23+
#include "Dialect/Pulse/IR/PulseInterfaces.h"
24+
#include "Dialect/Pulse/IR/PulseOps.h"
25+
26+
#include "mlir/IR/Value.h"
27+
#include "mlir/Support/LLVM.h"
28+
29+
#include "llvm/ADT/StringRef.h"
30+
31+
#include <cstdint>
32+
#include <string>
33+
#include <unordered_map>
34+
#include <vector>
35+
36+
using namespace mlir;
37+
using namespace mlir::pulse;
38+
39+
void LabelPlayOpDurationsPass::runOnOperation() {
40+
41+
// all PlayOps are assumed to be inside of a pulse.sequence
42+
// pass builds a mapping of sequence name , argument number to duration
43+
// for all play operations using call_sequences
44+
//
45+
// pass then searches for all play operations and assigns the durations using
46+
// the mapping
47+
48+
Operation *module = getOperation();
49+
50+
std::unordered_map<std::string, std::vector<uint64_t>> argumentToDuration;
51+
52+
module->walk([&](CallSequenceOp callSequenceOp) {
53+
auto callee = callSequenceOp.getCallee().str();
54+
55+
for (const auto &operand : callSequenceOp->getOperands()) {
56+
57+
uint64_t duration = 0;
58+
auto *defOp = operand.getDefiningOp();
59+
if (defOp)
60+
if (auto waveformOp = dyn_cast<Waveform_CreateOp>(defOp))
61+
duration = waveformOp.getDuration(nullptr /*callSequenceOp*/).get();
62+
63+
argumentToDuration[callee].push_back(duration);
64+
}
65+
});
66+
67+
module->walk([&](PlayOp playOp) {
68+
auto sequenceOp = playOp->getParentOfType<mlir::pulse::SequenceOp>();
69+
auto sequenceStr = sequenceOp.getSymName().str();
70+
auto wfArgNumber = playOp.getWfr().dyn_cast<BlockArgument>().getArgNumber();
71+
auto duration = argumentToDuration[sequenceStr][wfArgNumber];
72+
mlir::pulse::PulseOpSchedulingInterface::setDuration(playOp, duration);
73+
});
74+
75+
} // runOnOperation
76+
77+
llvm::StringRef LabelPlayOpDurationsPass::getArgument() const {
78+
return "pulse-label-play-op-duration";
79+
}
80+
81+
llvm::StringRef LabelPlayOpDurationsPass::getDescription() const {
82+
return "Label PlayOps with duration attributes";
83+
}
84+
85+
llvm::StringRef LabelPlayOpDurationsPass::getName() const {
86+
return "Label PlayOp Durations Pass";
87+
}

lib/Dialect/Pulse/Transforms/Passes.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "Conversion/QUIRToPulse/QUIRToPulse.h"
2525

2626
#include "Dialect/Pulse/Transforms/ClassicalOnlyDetection.h"
27+
#include "Dialect/Pulse/Transforms/LabelPlayOpDurations.h"
2728
#include "Dialect/Pulse/Transforms/MergeDelays.h"
2829
#include "Dialect/Pulse/Transforms/RemoveUnusedArguments.h"
2930
#include "Dialect/Pulse/Transforms/SchedulePort.h"
@@ -37,6 +38,7 @@ namespace mlir::pulse {
3738
void pulsePassPipelineBuilder(OpPassManager &pm) {}
3839

3940
void registerPulsePasses() {
41+
PassRegistration<LabelPlayOpDurationsPass>();
4042
PassRegistration<LoadPulseCalsPass>();
4143
PassRegistration<QUIRToPulsePass>();
4244
PassRegistration<MergeDelayPass>();
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
other:
3+
- |
4+
Add LabelPlayOpsDurationPass to add a pulse.duration label to pulse.play
5+
operations based on the duration of the waveform being played.

0 commit comments

Comments
 (0)