Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mlir/include/mlir-c/Pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,15 @@ typedef struct MlirExternalPassCallbacks MlirExternalPassCallbacks;
/// Creates an external `MlirPass` that calls the supplied `callbacks` using the
/// supplied `userData`. If `opName` is empty, the pass is a generic operation
/// pass. Otherwise it is an operation pass specific to the specified pass name.
MLIR_CAPI_EXPORTED MlirPass mlirCreateExternalPass(
MLIR_CAPI_EXPORTED MlirExternalPass mlirExternalPassCreate(
MlirTypeID passID, MlirStringRef name, MlirStringRef argument,
MlirStringRef description, MlirStringRef opName,
intptr_t nDependentDialects, MlirDialectHandle *dependentDialects,
MlirExternalPassCallbacks callbacks, void *userData);

// Static cast ExternalPass to Pass.
MLIR_CAPI_EXPORTED MlirPass mlirExternalPassGetPass(MlirExternalPass pass);

/// This signals that the pass has failed. This is only valid to call during
/// the `run` callback of `MlirExternalPassCallbacks`.
/// See Pass::signalPassFailure().
Expand Down
22 changes: 13 additions & 9 deletions mlir/lib/CAPI/IR/Pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,19 +193,23 @@ class ExternalPass : public Pass {
};
} // namespace mlir

MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name,
MlirStringRef argument,
MlirStringRef description, MlirStringRef opName,
intptr_t nDependentDialects,
MlirDialectHandle *dependentDialects,
MlirExternalPassCallbacks callbacks,
void *userData) {
return wrap(static_cast<mlir::Pass *>(new mlir::ExternalPass(
MlirExternalPass mlirExternalPassCreate(MlirTypeID passID, MlirStringRef name,
MlirStringRef argument,
MlirStringRef description, MlirStringRef opName,
intptr_t nDependentDialects,
MlirDialectHandle *dependentDialects,
MlirExternalPassCallbacks callbacks,
void *userData) {
return wrap(new mlir::ExternalPass(
unwrap(passID), unwrap(name), unwrap(argument), unwrap(description),
opName.length > 0 ? std::optional<StringRef>(unwrap(opName))
: std::nullopt,
{dependentDialects, static_cast<size_t>(nDependentDialects)}, callbacks,
userData)));
userData));
}

MlirPass mlirExternalPassGetPass(MlirExternalPass externalPass) {
return wrap(static_cast<mlir::Pass *>(&externalPass));
}

void mlirExternalPassSignalFailure(MlirExternalPass pass) {
Expand Down
30 changes: 20 additions & 10 deletions mlir/test/CAPI/pass.c
Original file line number Diff line number Diff line change
Expand Up @@ -367,17 +367,19 @@ void testExternalPass(void) {
mlirStringRefCreateFromCString("test-external-pass");
TestExternalPassUserData userData = {0};

MlirPass externalPass = mlirCreateExternalPass(
MlirExternalPass externalPass = mlirExternalPassCreate(
passID, name, argument, description, emptyOpName, 0, NULL,
makeTestExternalPassCallbacks(NULL, testRunExternalPass), &userData);

MlirPass pass = mlirExternalPassGetPass(externalPass);

if (userData.constructCallCount != 1) {
fprintf(stderr, "Expected constructCallCount to be 1\n");
exit(EXIT_FAILURE);
}

MlirPassManager pm = mlirPassManagerCreate(ctx);
mlirPassManagerAddOwnedPass(pm, externalPass);
mlirPassManagerAddOwnedPass(pm, pass);
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsFailure(success)) {
fprintf(stderr, "Unexpected failure running external pass.\n");
Expand Down Expand Up @@ -408,11 +410,13 @@ void testExternalPass(void) {
MlirDialectHandle funcHandle = mlirGetDialectHandle__func__();
MlirStringRef funcOpName = mlirStringRefCreateFromCString("func.func");

MlirPass externalPass = mlirCreateExternalPass(
MlirExternalPass externalPass = mlirExternalPassCreate(
passID, name, argument, description, funcOpName, 1, &funcHandle,
makeTestExternalPassCallbacks(NULL, testRunExternalFuncPass),
&userData);

MlirPass pass = mlirExternalPassGetPass(externalPass);

if (userData.constructCallCount != 1) {
fprintf(stderr, "Expected constructCallCount to be 1\n");
exit(EXIT_FAILURE);
Expand All @@ -421,7 +425,7 @@ void testExternalPass(void) {
MlirPassManager pm = mlirPassManagerCreate(ctx);
MlirOpPassManager nestedFuncPm =
mlirPassManagerGetNestedUnder(pm, funcOpName);
mlirOpPassManagerAddOwnedPass(nestedFuncPm, externalPass);
mlirOpPassManagerAddOwnedPass(nestedFuncPm, pass);
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsFailure(success)) {
fprintf(stderr, "Unexpected failure running external operation pass.\n");
Expand Down Expand Up @@ -457,19 +461,21 @@ void testExternalPass(void) {
mlirStringRefCreateFromCString("test-external-pass");
TestExternalPassUserData userData = {0};

MlirPass externalPass = mlirCreateExternalPass(
MlirExternalPass externalPass = mlirExternalPassCreate(
passID, name, argument, description, emptyOpName, 0, NULL,
makeTestExternalPassCallbacks(testInitializeExternalPass,
testRunExternalPass),
&userData);

MlirPass pass = mlirExternalPassGetPass(externalPass);

if (userData.constructCallCount != 1) {
fprintf(stderr, "Expected constructCallCount to be 1\n");
exit(EXIT_FAILURE);
}

MlirPassManager pm = mlirPassManagerCreate(ctx);
mlirPassManagerAddOwnedPass(pm, externalPass);
mlirPassManagerAddOwnedPass(pm, pass);
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsFailure(success)) {
fprintf(stderr, "Unexpected failure running external pass.\n");
Expand Down Expand Up @@ -504,19 +510,21 @@ void testExternalPass(void) {
mlirStringRefCreateFromCString("test-external-failing-pass");
TestExternalPassUserData userData = {0};

MlirPass externalPass = mlirCreateExternalPass(
MlirExternalPass externalPass = mlirExternalPassCreate(
passID, name, argument, description, emptyOpName, 0, NULL,
makeTestExternalPassCallbacks(testInitializeFailingExternalPass,
testRunExternalPass),
&userData);

MlirPass pass = mlirExternalPassGetPass(externalPass);

if (userData.constructCallCount != 1) {
fprintf(stderr, "Expected constructCallCount to be 1\n");
exit(EXIT_FAILURE);
}

MlirPassManager pm = mlirPassManagerCreate(ctx);
mlirPassManagerAddOwnedPass(pm, externalPass);
mlirPassManagerAddOwnedPass(pm, pass);
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsSuccess(success)) {
fprintf(
Expand Down Expand Up @@ -553,18 +561,20 @@ void testExternalPass(void) {
mlirStringRefCreateFromCString("test-external-failing-pass");
TestExternalPassUserData userData = {0};

MlirPass externalPass = mlirCreateExternalPass(
MlirExternalPass externalPass = mlirExternalPassCreate(
passID, name, argument, description, emptyOpName, 0, NULL,
makeTestExternalPassCallbacks(NULL, testRunFailingExternalPass),
&userData);

MlirPass pass = mlirExternalPassGetPass(externalPass);

if (userData.constructCallCount != 1) {
fprintf(stderr, "Expected constructCallCount to be 1\n");
exit(EXIT_FAILURE);
}

MlirPassManager pm = mlirPassManagerCreate(ctx);
mlirPassManagerAddOwnedPass(pm, externalPass);
mlirPassManagerAddOwnedPass(pm, pass);
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsSuccess(success)) {
fprintf(
Expand Down