Skip to content

Commit 63aca2f

Browse files
authored
De-tensorize function boundary pass (#1904)
**Context:** Scalar tensors across function boundaries are prevalent due to JAX, which results in extra instructions required to extract elements from tensor arguments, and construct tensors from scalar elements to pass to the functions. **Description of the Change:** A new MLIR pass is created to detensorize function boundaries. Extra tensor extract and from_element ops are added then folded. **Benefits:** Reduce instructions count ### Time score |Workflow |Catalyst (qjit+exec)| |Catalyst (exec)| | | :--- | ---: | :---: | ---: | :---: | |QPE |0.99 | |0.99 | | |QSVT |0.99 | |1.00 | | |XAS |1.40 |🟢|0.99 | | |shor |0.98 | |0.98 | | |molecular_hamiltonian| - | | - | | |sampling |1.00 | |1.02 | | |stateprep |0.98 | |0.68 |🔴| |grover |1.02 | |0.96 |🔴| |QAOA_layers_scaling |1.00 | |1.01 | | |QML |0.98 | |1.00 | | |QML_jaxjit | - | | - | | |UCCSD |0.98 | |1.02 | | |VQE | - | |1.00 | | ## Detailed Per-Workflow Results - The assumed noise level for runtime improvements/regressions is 3.0% - ⚠️ marks workflows with runtime fluctuations greater than 5.0% (std/mean) ### Catalyst (compilation + execution) |Workflow |time [s]|std/mean| |time score| |virt mem [MB]|virt mem score| |phys mem [MB]|phys mem score| | | :--- | ---: | ---: | :---: | ---: | :---: | ---: | ---: | :---: | ---: | ---: | :---: | |QPE[11-12] |1.196 |7.1% |⚠️|0.97 | |365.729 | - | |116.785 | - | | |QPE[12-12] |2.222 |6.5% |⚠️|1.00 | |499.972 | - | |106.713 | - | | |QSVT[9] |26.464 |18.0% |⚠️|0.99 | |4,279.710 | - | |2,471.539 | - | | |XAS[2-1-9] |28.783 |0.8% | |1.46 |🟢|4,424.084 | - | |2,501.321 | - | | |XAS[2-2-9] |38.110 |0.6% | |1.33 |🟢|4,424.159 | - | |2,500.624 | - | | |shor[15] |1.288 |7.7% |⚠️|0.98 | |237.171 | - | |138.236 | - | | |shor[33] |3.168 |3.2% | |0.99 | |237.761 | - | |141.140 | - | | |sampling[24-2] |1.563 |5.0% | |1.02 | |1,046.111 | - | |116.871 | - | | |sampling[25-2] |3.411 |3.1% | |0.99 | |1,852.645 | - | |106.881 | - | | |stateprep[12-MottonenStatePreparation]|7.099 |2.8% | |0.98 | |404.347 | - | |229.323 | - | | |grover[18] |9.678 |0.9% | |1.02 | |459.678 | - | |291.979 | - | | |QAOA_layers_scaling[19-4] |111.156 |0.4% | |1.00 | |584.816 | - | |5,387.448 | - | | |QML[IQPKernelClassifier-12-10] |121.986 |1.1% | |0.98 | |2,122.109 | - | |912.646 | - | | |UCCSD[H2O-STO-3G] |12.956 |1.4% | |0.99 | |532.305 | - | |286.573 | - | | |UCCSD[NH3-STO-3G] |33.076 |2.3% | |0.97 | |1,025.818 | - | |537.739 | - | | |VQE[H2O-STO-3G] |21.477 |0.8% | |0.99 | |475.394 | - | |271.380 | - | | |VQE[NH3-STO-3G] |72.532 |0.7% | | - | |1,224.448 | - | |785.674 | - | | ### Catalyst (execution only) |Workflow |time [s]|std/mean| |time score| |virt mem [MB]|virt mem score| |phys mem [MB]|phys mem score| | | :--- | ---: | ---: | :---: | ---: | :---: | ---: | ---: | :---: | ---: | ---: | :---: | |QPE[11-12] |0.896 |1.0% | |0.98 | |365.954 | - | |117.281 | - | | |QPE[12-12] |2.060 |0.7% | |0.99 | |500.201 | - | |116.912 | - | | |QSVT[9] |0.143 |6.2% |⚠️|1.00 | |4,279.643 | - | |2,471.780 | - | | |XAS[2-1-9] |9.320 |0.7% | |1.00 | |4,424.198 | - | |2,502.304 | - | | |XAS[2-2-9] |18.713 |0.4% | |0.99 | |4,424.159 | - | |2,496.918 | - | | |shor[15] |0.051 |6.1% |⚠️|0.96 |🔴|237.228 | - | |138.162 | - | | |shor[33] |1.817 |1.5% | |1.00 | |237.699 | - | |140.599 | - | | |sampling[24-2] |1.368 |1.6% | |1.03 |🟢|1,046.060 | - | |103.055 | - | | |sampling[25-2] |3.186 |1.3% | |1.00 | |1,856.906 | - | |123.036 | - | | |stateprep[12-MottonenStatePreparation]|0.033 |19.5% |⚠️|0.68 |🔴|404.091 | - | |242.000 | - | | |grover[18] |1.561 |1.6% | |0.96 |🔴|468.432 | - | |308.101 | - | | |QAOA_layers_scaling[19-4] |106.051 |0.5% | |1.01 | |584.730 | - | |5,745.373 | - | | |QML[IQPKernelClassifier-12-10] |97.194 |3.2% | |1.00 | |1,184.518 | - | |623.976 | - | | |UCCSD[H2O-STO-3G] |0.113 |1.9% | |1.05 |🟢|530.752 | - | |292.098 | - | | |UCCSD[NH3-STO-3G] |1.153 |4.2% | |0.99 | |1,025.717 | - | |550.011 | - | | |VQE[H2O-STO-3G] |2.067 |1.8% | |1.00 | |475.332 | - | |272.044 | - | | |VQE[NH3-STO-3G] |16.754 |1.0% | |1.01 | |1,224.450 | - | |785.883 | - | | **Possible Drawbacks:** **Related GitHub Issues:** [sc-95476]
1 parent 6b2db96 commit 63aca2f

File tree

9 files changed

+402
-1
lines changed

9 files changed

+402
-1
lines changed

doc/releases/changelog-dev.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44

55
<h3>Improvements 🛠</h3>
66

7+
* Added `detensorizefunctionboundary` pass to remove scalar tensors across function boundaries and enabled `symbol-dce` pass to remove dead functions, reducing the number of instructions for compilation.
8+
[(#1904)](https://github.com/PennyLaneAI/catalyst/pull/1904)
9+
710
* Workflows `for_loop`, `while_loop` and `cond` now error out if `qml.capture` is enabled.
811
[(#1945)](https://github.com/PennyLaneAI/catalyst/pull/1945)
912

frontend/catalyst/pipelines.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,9 @@ def get_hlo_lowering_stage(_options: CompileOptions) -> List[str]:
239239
"cse",
240240
"func.func(linalg-detensorize{aggressive-mode})",
241241
"detensorize-scf",
242+
"detensorize-function-boundary",
242243
"canonicalize",
244+
"symbol-dce",
243245
]
244246
return hlo_lowering
245247

mlir/include/Catalyst/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ std::unique_ptr<mlir::Pass> createArrayListToMemRefPass();
2626
std::unique_ptr<mlir::Pass> createBufferDeallocationPass();
2727
std::unique_ptr<mlir::Pass> createCatalystBufferizationPass();
2828
std::unique_ptr<mlir::Pass> createCatalystConversionPass();
29+
std::unique_ptr<mlir::Pass> createDetensorizeFunctionBoundaryPass();
2930
std::unique_ptr<mlir::Pass> createDetensorizeSCFPass();
3031
std::unique_ptr<mlir::Pass> createDisableAssertionPass();
3132
std::unique_ptr<mlir::Pass> createGEPInboundsPass();

mlir/include/Catalyst/Transforms/Passes.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,16 @@
1717

1818
include "mlir/Pass/PassBase.td"
1919

20+
def DetensorizeFunctionBoundaryPass : Pass<"detensorize-function-boundary"> {
21+
let summary = "Detensorize across function boundary.";
22+
23+
let dependentDialects = [
24+
"tensor::TensorDialect",
25+
];
26+
27+
let constructor = "catalyst::createDetensorizeFunctionBoundaryPass()";
28+
}
29+
2030
def DetensorizeSCFPass : Pass<"detensorize-scf"> {
2131
let summary = "Detensorize for, if, while operations from the SCF dialect.";
2232

mlir/lib/Catalyst/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ file(GLOB SRC
88
BufferizableOpInterfaceImpl.cpp
99
catalyst_to_llvm.cpp
1010
DetectQNodes.cpp
11+
DetensorizeFunctionBoundaryPass.cpp
1112
DetensorizeSCFPass.cpp
1213
disable_assertion.cpp
1314
DisableAssertionPatterns.cpp
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
#define DEBUG_TYPE "detensorize-func-boundary"
2+
3+
#include "mlir/Dialect/Func/IR/FuncOps.h"
4+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
5+
#include "mlir/IR/IRMapping.h"
6+
#include "mlir/IR/PatternMatch.h"
7+
#include "mlir/Pass/Pass.h"
8+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
9+
10+
#include "Catalyst/IR/CatalystDialect.h"
11+
12+
using namespace llvm;
13+
using namespace mlir;
14+
using namespace catalyst;
15+
16+
namespace {
17+
bool isScalarTensor(Type type)
18+
{
19+
if (auto rankedType = dyn_cast<RankedTensorType>(type)) {
20+
return rankedType.getRank() == 0;
21+
}
22+
return false;
23+
}
24+
25+
Type getScalarOrOriginalType(Type type)
26+
{
27+
if (isScalarTensor(type)) {
28+
return dyn_cast<RankedTensorType>(type).getElementType();
29+
}
30+
else {
31+
return type;
32+
}
33+
}
34+
35+
bool hasScalarTensorSignature(func::FuncOp funcOp)
36+
{
37+
for (Type type : funcOp.getFunctionType().getInputs()) {
38+
if (isScalarTensor(type)) {
39+
return true;
40+
}
41+
}
42+
for (Type type : funcOp.getFunctionType().getResults()) {
43+
if (isScalarTensor(type)) {
44+
return true;
45+
}
46+
}
47+
return false;
48+
}
49+
50+
struct DetensorizeCallSitePattern : public OpRewritePattern<func::CallOp> {
51+
using OpRewritePattern<func::CallOp>::OpRewritePattern;
52+
53+
LogicalResult matchAndRewrite(func::CallOp callOp, PatternRewriter &rewriter) const override
54+
{
55+
auto funcOp =
56+
SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(callOp, callOp.getCalleeAttr());
57+
58+
// Skip for main function
59+
if (!funcOp || funcOp->hasAttr("llvm.emit_c_interface")) {
60+
return failure();
61+
}
62+
63+
if (!hasScalarTensorSignature(funcOp)) {
64+
return failure();
65+
}
66+
67+
// Skip for QNodes
68+
// Some Gradient boundaries only work for Tensor signatures
69+
// and not scalar ones, hence we skip them here.
70+
if (funcOp->hasAttr("qnode")) {
71+
return failure();
72+
}
73+
74+
// Create detensorized FuncOp if it does not already exist
75+
auto module = callOp->getParentOfType<ModuleOp>();
76+
std::string newFuncName = funcOp.getName().str() + ".detensorized";
77+
auto newFuncOp = module.lookupSymbol<func::FuncOp>(newFuncName);
78+
79+
if (!newFuncOp) {
80+
OpBuilder::InsertionGuard guard(rewriter);
81+
rewriter.setInsertionPointToEnd(module.getBody());
82+
83+
// Create the new function with a detensorized signature
84+
FunctionType funcType = funcOp.getFunctionType();
85+
SmallVector<Type> newArgTypes, newResultTypes;
86+
SmallVector<NamedAttribute> newAttrs;
87+
extractDetensorizedOpSignature(funcType, funcOp, newArgTypes, newResultTypes, newAttrs);
88+
89+
// Create the new function, passing the collected signature
90+
auto newFuncType = FunctionType::get(getContext(), newArgTypes, newResultTypes);
91+
newFuncOp =
92+
rewriter.create<func::FuncOp>(funcOp.getLoc(), newFuncName, newFuncType, newAttrs);
93+
94+
// Map FuncOp body and return operation
95+
Block *newEntryBlock = newFuncOp.addEntryBlock();
96+
IRMapping mapper;
97+
mapFuncOpBodyAndReturnOp(rewriter, newEntryBlock, funcOp, mapper);
98+
}
99+
100+
// Rewrite the original call site to use the new detensorized function
101+
replaceCallOp(rewriter, callOp, newFuncOp);
102+
return success();
103+
}
104+
105+
void extractDetensorizedOpSignature(FunctionType &funcType, func::FuncOp &funcOp,
106+
SmallVector<Type> &newArgTypes,
107+
SmallVector<Type> &newResultTypes,
108+
SmallVector<NamedAttribute> &newAttrs) const
109+
{
110+
for (Type type : funcType.getInputs()) {
111+
newArgTypes.push_back(getScalarOrOriginalType(type));
112+
}
113+
for (Type type : funcType.getResults()) {
114+
newResultTypes.push_back(getScalarOrOriginalType(type));
115+
}
116+
117+
// Collect all attributes from the original function
118+
for (const NamedAttribute &attr : funcOp->getAttrs()) {
119+
if (attr.getName() == funcOp.getSymNameAttrName() ||
120+
attr.getName() == funcOp.getFunctionTypeAttrName()) {
121+
continue;
122+
}
123+
newAttrs.push_back(attr);
124+
}
125+
}
126+
127+
void mapFuncOpBodyAndReturnOp(PatternRewriter &rewriter, Block *newEntryBlock,
128+
func::FuncOp &funcOp, IRMapping &mapper) const
129+
{
130+
rewriter.setInsertionPointToStart(newEntryBlock);
131+
for (const auto &it : llvm::enumerate(funcOp.getArguments())) {
132+
Value oldArg = it.value();
133+
Value newArg = newEntryBlock->getArgument(it.index());
134+
135+
if (isScalarTensor(oldArg.getType())) {
136+
// Insert a FromElementsOp if the old argument is a scalar tensor
137+
auto fromElementsOp = rewriter.create<tensor::FromElementsOp>(
138+
newArg.getLoc(), oldArg.getType(), newArg);
139+
mapper.map(oldArg, fromElementsOp.getResult());
140+
}
141+
else {
142+
mapper.map(oldArg, newArg);
143+
}
144+
}
145+
146+
// Clone the operations from the body of old function (excluding the old return)
147+
rewriter.setInsertionPointToEnd(newEntryBlock);
148+
for (Operation &op : funcOp.front().without_terminator()) {
149+
rewriter.clone(op, mapper);
150+
}
151+
152+
// Create a new return operation with the mapped results
153+
auto oldReturnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
154+
SmallVector<Value> newReturnOperands;
155+
newReturnOperands.reserve(oldReturnOp.getNumOperands());
156+
for (Value operand : oldReturnOp.getOperands()) {
157+
Value newOperand = mapper.lookup(operand);
158+
if (isScalarTensor(newOperand.getType())) {
159+
// Insert ExtractOp if the operand is a scalar tensor
160+
auto extractOp = rewriter.create<tensor::ExtractOp>(oldReturnOp.getLoc(),
161+
newOperand, ValueRange{});
162+
newReturnOperands.push_back(extractOp.getResult());
163+
}
164+
else {
165+
newReturnOperands.push_back(newOperand);
166+
}
167+
}
168+
rewriter.create<func::ReturnOp>(oldReturnOp.getLoc(), newReturnOperands);
169+
}
170+
171+
void replaceCallOp(PatternRewriter &rewriter, func::CallOp &callOp,
172+
func::FuncOp &newFuncOp) const
173+
{
174+
rewriter.setInsertionPoint(callOp);
175+
SmallVector<Value> newOperands;
176+
for (Value operand : callOp.getOperands()) {
177+
// Insert ExtractOp if the old operand is a scalar tensor to bridge the detensorized
178+
// function
179+
if (isScalarTensor(operand.getType())) {
180+
auto extractOp =
181+
rewriter.create<tensor::ExtractOp>(callOp.getLoc(), operand, ValueRange{});
182+
newOperands.push_back(extractOp.getResult());
183+
}
184+
else {
185+
newOperands.push_back(operand);
186+
}
187+
}
188+
189+
auto newCallOp = rewriter.create<func::CallOp>(callOp.getLoc(), newFuncOp, newOperands);
190+
191+
SmallVector<Value> newResults;
192+
for (size_t i = 0; i < callOp.getNumResults(); ++i) {
193+
Value oldResult = callOp.getResult(i);
194+
Value newResult = newCallOp.getResult(i);
195+
if (isScalarTensor(oldResult.getType())) {
196+
// Insert a FromElementsOp if the old result is a scalar tensor to bridge the
197+
// detensorized function
198+
auto fromElementsOp = rewriter.create<tensor::FromElementsOp>(
199+
callOp.getLoc(), oldResult.getType(), newResult);
200+
newResults.push_back(fromElementsOp.getResult());
201+
}
202+
else {
203+
newResults.push_back(newResult);
204+
}
205+
}
206+
207+
rewriter.replaceOp(callOp, newResults);
208+
}
209+
};
210+
} // namespace
211+
212+
namespace catalyst {
213+
#define GEN_PASS_DEF_DETENSORIZEFUNCTIONBOUNDARYPASS
214+
#include "Catalyst/Transforms/Passes.h.inc"
215+
216+
struct DetensorizeFunctionBoundaryPass
217+
: public impl::DetensorizeFunctionBoundaryPassBase<DetensorizeFunctionBoundaryPass> {
218+
using impl::DetensorizeFunctionBoundaryPassBase<
219+
DetensorizeFunctionBoundaryPass>::DetensorizeFunctionBoundaryPassBase;
220+
void runOnOperation() override
221+
{
222+
MLIRContext *context = &getContext();
223+
RewritePatternSet patterns(context);
224+
225+
patterns.add<DetensorizeCallSitePattern>(context);
226+
227+
GreedyRewriteConfig config;
228+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns), config))) {
229+
signalPassFailure();
230+
}
231+
}
232+
};
233+
234+
std::unique_ptr<Pass> createDetensorizeFunctionBoundaryPass()
235+
{
236+
return std::make_unique<DetensorizeFunctionBoundaryPass>();
237+
}
238+
239+
} // namespace catalyst

mlir/lib/Catalyst/Transforms/DetensorizeSCFPass.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#define DEBUG_TYPE "myhelloworld"
1+
#define DEBUG_TYPE "detensorize-scf"
22

33
#include "mlir/Dialect/SCF/IR/SCF.h"
44
#include "mlir/Dialect/Tensor/IR/Tensor.h"

mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ void catalyst::registerAllCatalystPasses()
4040
mlir::registerPass(catalyst::createDecomposeNonCliffordPPRPass);
4141
mlir::registerPass(catalyst::createDecomposeCliffordPPRPass);
4242
mlir::registerPass(catalyst::createCountPPMSpecsPass);
43+
mlir::registerPass(catalyst::createDetensorizeFunctionBoundaryPass);
4344
mlir::registerPass(catalyst::createDetensorizeSCFPass);
4445
mlir::registerPass(catalyst::createDisableAssertionPass);
4546
mlir::registerPass(catalyst::createDisentangleCNOTPass);

0 commit comments

Comments
 (0)