diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h index 8fd8e9956a65a..acac29a513a38 100644 --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -174,12 +174,15 @@ typedef struct MlirExternalPassCallbacks MlirExternalPassCallbacks; /// Creates an external `MlirPass` that calls the supplied `callbacks` using the /// supplied `userData`. If `opName` is empty, the pass is a generic operation /// pass. Otherwise it is an operation pass specific to the specified pass name. -MLIR_CAPI_EXPORTED MlirPass mlirCreateExternalPass( +MLIR_CAPI_EXPORTED MlirExternalPass mlirExternalPassCreate( MlirTypeID passID, MlirStringRef name, MlirStringRef argument, MlirStringRef description, MlirStringRef opName, intptr_t nDependentDialects, MlirDialectHandle *dependentDialects, MlirExternalPassCallbacks callbacks, void *userData); +// Static cast ExternalPass to Pass. +MLIR_CAPI_EXPORTED MlirPass mlirExternalPassGetPass(MlirExternalPass pass); + /// This signals that the pass has failed. This is only valid to call during /// the `run` callback of `MlirExternalPassCallbacks`. /// See Pass::signalPassFailure(). diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp index 883b7e8bb832d..c751d5c467e4c 100644 --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -193,19 +193,23 @@ class ExternalPass : public Pass { }; } // namespace mlir -MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name, - MlirStringRef argument, - MlirStringRef description, MlirStringRef opName, - intptr_t nDependentDialects, - MlirDialectHandle *dependentDialects, - MlirExternalPassCallbacks callbacks, - void *userData) { - return wrap(static_cast(new mlir::ExternalPass( +MlirExternalPass mlirExternalPassCreate(MlirTypeID passID, MlirStringRef name, + MlirStringRef argument, + MlirStringRef description, MlirStringRef opName, + intptr_t nDependentDialects, + MlirDialectHandle *dependentDialects, + MlirExternalPassCallbacks callbacks, + void *userData) { + return wrap(new mlir::ExternalPass( unwrap(passID), unwrap(name), unwrap(argument), unwrap(description), opName.length > 0 ? std::optional(unwrap(opName)) : std::nullopt, {dependentDialects, static_cast(nDependentDialects)}, callbacks, - userData))); + userData)); +} + +MlirPass mlirExternalPassGetPass(MlirExternalPass externalPass) { + return wrap(static_cast(&externalPass)); } void mlirExternalPassSignalFailure(MlirExternalPass pass) { diff --git a/mlir/test/CAPI/pass.c b/mlir/test/CAPI/pass.c index 3aad0016b393c..8778c945ab259 100644 --- a/mlir/test/CAPI/pass.c +++ b/mlir/test/CAPI/pass.c @@ -367,17 +367,19 @@ void testExternalPass(void) { mlirStringRefCreateFromCString("test-external-pass"); TestExternalPassUserData userData = {0}; - MlirPass externalPass = mlirCreateExternalPass( + MlirExternalPass externalPass = mlirExternalPassCreate( passID, name, argument, description, emptyOpName, 0, NULL, makeTestExternalPassCallbacks(NULL, testRunExternalPass), &userData); + MlirPass pass = mlirExternalPassGetPass(externalPass); + if (userData.constructCallCount != 1) { fprintf(stderr, "Expected constructCallCount to be 1\n"); exit(EXIT_FAILURE); } MlirPassManager pm = mlirPassManagerCreate(ctx); - mlirPassManagerAddOwnedPass(pm, externalPass); + mlirPassManagerAddOwnedPass(pm, pass); MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module); if (mlirLogicalResultIsFailure(success)) { fprintf(stderr, "Unexpected failure running external pass.\n"); @@ -408,11 +410,13 @@ void testExternalPass(void) { MlirDialectHandle funcHandle = mlirGetDialectHandle__func__(); MlirStringRef funcOpName = mlirStringRefCreateFromCString("func.func"); - MlirPass externalPass = mlirCreateExternalPass( + MlirExternalPass externalPass = mlirExternalPassCreate( passID, name, argument, description, funcOpName, 1, &funcHandle, makeTestExternalPassCallbacks(NULL, testRunExternalFuncPass), &userData); + MlirPass pass = mlirExternalPassGetPass(externalPass); + if (userData.constructCallCount != 1) { fprintf(stderr, "Expected constructCallCount to be 1\n"); exit(EXIT_FAILURE); @@ -421,7 +425,7 @@ void testExternalPass(void) { MlirPassManager pm = mlirPassManagerCreate(ctx); MlirOpPassManager nestedFuncPm = mlirPassManagerGetNestedUnder(pm, funcOpName); - mlirOpPassManagerAddOwnedPass(nestedFuncPm, externalPass); + mlirOpPassManagerAddOwnedPass(nestedFuncPm, pass); MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module); if (mlirLogicalResultIsFailure(success)) { fprintf(stderr, "Unexpected failure running external operation pass.\n"); @@ -457,19 +461,21 @@ void testExternalPass(void) { mlirStringRefCreateFromCString("test-external-pass"); TestExternalPassUserData userData = {0}; - MlirPass externalPass = mlirCreateExternalPass( + MlirExternalPass externalPass = mlirExternalPassCreate( passID, name, argument, description, emptyOpName, 0, NULL, makeTestExternalPassCallbacks(testInitializeExternalPass, testRunExternalPass), &userData); + MlirPass pass = mlirExternalPassGetPass(externalPass); + if (userData.constructCallCount != 1) { fprintf(stderr, "Expected constructCallCount to be 1\n"); exit(EXIT_FAILURE); } MlirPassManager pm = mlirPassManagerCreate(ctx); - mlirPassManagerAddOwnedPass(pm, externalPass); + mlirPassManagerAddOwnedPass(pm, pass); MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module); if (mlirLogicalResultIsFailure(success)) { fprintf(stderr, "Unexpected failure running external pass.\n"); @@ -504,19 +510,21 @@ void testExternalPass(void) { mlirStringRefCreateFromCString("test-external-failing-pass"); TestExternalPassUserData userData = {0}; - MlirPass externalPass = mlirCreateExternalPass( + MlirExternalPass externalPass = mlirExternalPassCreate( passID, name, argument, description, emptyOpName, 0, NULL, makeTestExternalPassCallbacks(testInitializeFailingExternalPass, testRunExternalPass), &userData); + MlirPass pass = mlirExternalPassGetPass(externalPass); + if (userData.constructCallCount != 1) { fprintf(stderr, "Expected constructCallCount to be 1\n"); exit(EXIT_FAILURE); } MlirPassManager pm = mlirPassManagerCreate(ctx); - mlirPassManagerAddOwnedPass(pm, externalPass); + mlirPassManagerAddOwnedPass(pm, pass); MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module); if (mlirLogicalResultIsSuccess(success)) { fprintf( @@ -553,18 +561,20 @@ void testExternalPass(void) { mlirStringRefCreateFromCString("test-external-failing-pass"); TestExternalPassUserData userData = {0}; - MlirPass externalPass = mlirCreateExternalPass( + MlirExternalPass externalPass = mlirExternalPassCreate( passID, name, argument, description, emptyOpName, 0, NULL, makeTestExternalPassCallbacks(NULL, testRunFailingExternalPass), &userData); + MlirPass pass = mlirExternalPassGetPass(externalPass); + if (userData.constructCallCount != 1) { fprintf(stderr, "Expected constructCallCount to be 1\n"); exit(EXIT_FAILURE); } MlirPassManager pm = mlirPassManagerCreate(ctx); - mlirPassManagerAddOwnedPass(pm, externalPass); + mlirPassManagerAddOwnedPass(pm, pass); MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module); if (mlirLogicalResultIsSuccess(success)) { fprintf(