Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 129 additions & 10 deletions mlir/lib/Catalyst/Transforms/ApplyTransformSequencePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,129 @@

#include "llvm/Support/Debug.h"

#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/AnalysisManager.h"
#include "mlir/Pass/Pass.h"

#include "Catalyst/IR/CatalystDialect.h"
#include "QEC/IR/QECDialect.h"
#include "mlir/Pass/PassInstrumentation.h"
#include "mlir/Pass/PassManager.h"

using namespace llvm;
using namespace mlir;
using namespace catalyst;

namespace catalyst {

/// Generate a meaningful name for a transform operation for pass instrumentation
static std::string getTransformOpName(mlir::transform::TransformOpInterface transformOp)
{
std::string baseName = transformOp->getName().getStringRef().str();

llvm::errs() << "Transform op: " << baseName << "\n";

if (auto applyPassOp =
dyn_cast<mlir::transform::ApplyRegisteredPassOp>(transformOp.getOperation())) {
if (auto passName = applyPassOp.getPassName(); !passName.empty()) {
return "transform_" + passName.str();
}
}

// convert "." to "_"
std::replace(baseName.begin(), baseName.end(), '.', '_');
return baseName;
}

/// A fake pass wrapper that represents a single transform operation. Allowing it to be tracked by
/// pass instrumentation.
class TransformOpSubPass : public OperationPass<> {
private:
mlir::transform::TransformOpInterface transformOp;
std::string opNameStr;

public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TransformOpSubPass)

TransformOpSubPass(mlir::transform::TransformOpInterface op)
: OperationPass(TypeID::get<TransformOpSubPass>()), transformOp(op),
opNameStr(getTransformOpName(op))
{
}

void runOnOperation() override
{
llvm_unreachable("TransformOpSubPass should not be executed");
}

StringRef getName() const override { return opNameStr; }
StringRef getArgument() const override { return opNameStr; }
StringRef getDescription() const override { return "Transform dialect operation"; }

std::unique_ptr<Pass> clonePass() const override
{
return std::make_unique<TransformOpSubPass>(transformOp);
}

mlir::transform::TransformOpInterface getTransformOp() const { return transformOp; }
};

/// Apply transforms with individual subpass tracking by executing each transform operation
/// individually with instrumentation hooks. This implements a custom sequence
/// execution that mirrors the logic in NamedSequenceOp::apply but with instrumentation.
LogicalResult applyTransformsWithSubpassTracking(Operation *payload, Operation *namedSequence,
PassInstrumentor *passInstrumentor)
{
auto namedSeqOp = dyn_cast<mlir::transform::NamedSequenceOp>(namedSequence);
if (!namedSeqOp || namedSeqOp.getBody().empty()) {
return success();
}

Block &sequenceBlock = namedSeqOp.getBody().front();
if (sequenceBlock.without_terminator().empty()) {
return success();
}

mlir::transform::TransformState state = mlir::transform::detail::makeTransformStateForTesting(
namedSequence->getParentRegion(), payload);

// Map the entry block argument to the list of operations.
// Note: this is the same implementation as PossibleTopLevelTransformOp but
// without attaching the interface / trait since that is tailored to a
// dangling top-level op that does not get "called".
auto scope = state.make_region_scope(namedSeqOp.getBody());
if (failed(mlir::transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
state, namedSequence, namedSeqOp.getBody()))) {
return failure();
}

for (Operation &transformOp : sequenceBlock.without_terminator()) {
if (auto transformInterface =
dyn_cast<mlir::transform::TransformOpInterface>(transformOp)) {
auto subPass = std::make_unique<TransformOpSubPass>(transformInterface);

// hook before pass
passInstrumentor->runBeforePass(subPass.get(), payload);

DiagnosedSilenceableFailure result = state.applyTransform(transformInterface);

if (result.isDefiniteFailure()) {
// hook after pass failed
passInstrumentor->runAfterPassFailed(subPass.get(), payload);
return failure();
}

if (result.isSilenceableFailure()) {
(void)result.silence();
}

// hook after pass
passInstrumentor->runAfterPass(subPass.get(), payload);
}
}

return success();
}

#define GEN_PASS_DEF_APPLYTRANSFORMSEQUENCEPASS
#include "Catalyst/Transforms/Passes.h.inc"

Expand Down Expand Up @@ -79,12 +190,20 @@ struct ApplyTransformSequencePass
}
});

// Perform the transform
if (failed(mlir::transform::applyTransforms(
payload, cast<mlir::transform::TransformOpInterface>(transformer_main_sequence), {},
mlir::transform::TransformOptions(), false))) {
return signalPassFailure();
};
if (auto *passInstrumentor = getAnalysisManager().getPassInstrumentor(); passInstrumentor) {
// Manually execute the transform sequence with individual subpass tracking
if (failed(applyTransformsWithSubpassTracking(payload, transformer_main_sequence,
passInstrumentor))) {
return signalPassFailure();
}
}
else {
if (failed(mlir::transform::applyTransforms(
payload, cast<mlir::transform::TransformOpInterface>(transformer_main_sequence),
{}, mlir::transform::TransformOptions(), false))) {
return signalPassFailure();
}
}

transformer.erase();
}
Expand Down