Skip to content

Commit 5cbe5e1

Browse files
committed
[mlir][capi] make MLIR Pass C-API type safe
- change C-API name mlirCreateExternalPass to mlirExternalPassCreate for aligning the C-API naming convension; - make mlirExternalPassCreate to return MlirExternalPass for type safety; - create new C-API MlirExternalPassGetPass to cast MlirExternalPass to MlirPass;
1 parent 1557eed commit 5cbe5e1

File tree

3 files changed

+37
-20
lines changed

3 files changed

+37
-20
lines changed

mlir/include/mlir-c/Pass.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,15 @@ typedef struct MlirExternalPassCallbacks MlirExternalPassCallbacks;
174174
/// Creates an external `MlirPass` that calls the supplied `callbacks` using the
175175
/// supplied `userData`. If `opName` is empty, the pass is a generic operation
176176
/// pass. Otherwise it is an operation pass specific to the specified pass name.
177-
MLIR_CAPI_EXPORTED MlirPass mlirCreateExternalPass(
177+
MLIR_CAPI_EXPORTED MlirExternalPass mlirExternalPassCreate(
178178
MlirTypeID passID, MlirStringRef name, MlirStringRef argument,
179179
MlirStringRef description, MlirStringRef opName,
180180
intptr_t nDependentDialects, MlirDialectHandle *dependentDialects,
181181
MlirExternalPassCallbacks callbacks, void *userData);
182182

183+
// Static cast ExternalPass to Pass.
184+
MLIR_CAPI_EXPORTED MlirPass mlirExternalPassGetPass(MlirExternalPass pass);
185+
183186
/// This signals that the pass has failed. This is only valid to call during
184187
/// the `run` callback of `MlirExternalPassCallbacks`.
185188
/// See Pass::signalPassFailure().

mlir/lib/CAPI/IR/Pass.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -193,19 +193,23 @@ class ExternalPass : public Pass {
193193
};
194194
} // namespace mlir
195195

196-
MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name,
197-
MlirStringRef argument,
198-
MlirStringRef description, MlirStringRef opName,
199-
intptr_t nDependentDialects,
200-
MlirDialectHandle *dependentDialects,
201-
MlirExternalPassCallbacks callbacks,
202-
void *userData) {
203-
return wrap(static_cast<mlir::Pass *>(new mlir::ExternalPass(
196+
MlirExternalPass mlirExternalPassCreate(MlirTypeID passID, MlirStringRef name,
197+
MlirStringRef argument,
198+
MlirStringRef description, MlirStringRef opName,
199+
intptr_t nDependentDialects,
200+
MlirDialectHandle *dependentDialects,
201+
MlirExternalPassCallbacks callbacks,
202+
void *userData) {
203+
return wrap(new mlir::ExternalPass(
204204
unwrap(passID), unwrap(name), unwrap(argument), unwrap(description),
205205
opName.length > 0 ? std::optional<StringRef>(unwrap(opName))
206206
: std::nullopt,
207207
{dependentDialects, static_cast<size_t>(nDependentDialects)}, callbacks,
208-
userData)));
208+
userData));
209+
}
210+
211+
MlirPass mlirExternalPassGetPass(MlirExternalPass externalPass) {
212+
return wrap(static_cast<mlir::Pass>(externalPass));
209213
}
210214

211215
void mlirExternalPassSignalFailure(MlirExternalPass pass) {

mlir/test/CAPI/pass.c

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -367,17 +367,19 @@ void testExternalPass(void) {
367367
mlirStringRefCreateFromCString("test-external-pass");
368368
TestExternalPassUserData userData = {0};
369369

370-
MlirPass externalPass = mlirCreateExternalPass(
370+
MlirExternalPass externalPass = mlirExternalPassCreate(
371371
passID, name, argument, description, emptyOpName, 0, NULL,
372372
makeTestExternalPassCallbacks(NULL, testRunExternalPass), &userData);
373373

374+
MlirPass pass = mlirExternalPassGetPass(externalPass);
375+
374376
if (userData.constructCallCount != 1) {
375377
fprintf(stderr, "Expected constructCallCount to be 1\n");
376378
exit(EXIT_FAILURE);
377379
}
378380

379381
MlirPassManager pm = mlirPassManagerCreate(ctx);
380-
mlirPassManagerAddOwnedPass(pm, externalPass);
382+
mlirPassManagerAddOwnedPass(pm, pass);
381383
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
382384
if (mlirLogicalResultIsFailure(success)) {
383385
fprintf(stderr, "Unexpected failure running external pass.\n");
@@ -408,11 +410,13 @@ void testExternalPass(void) {
408410
MlirDialectHandle funcHandle = mlirGetDialectHandle__func__();
409411
MlirStringRef funcOpName = mlirStringRefCreateFromCString("func.func");
410412

411-
MlirPass externalPass = mlirCreateExternalPass(
413+
MlirExternalPass externalPass = mlirExternalPassCreate(
412414
passID, name, argument, description, funcOpName, 1, &funcHandle,
413415
makeTestExternalPassCallbacks(NULL, testRunExternalFuncPass),
414416
&userData);
415417

418+
MlirPass pass = mlirExternalPassGetPass(externalPass);
419+
416420
if (userData.constructCallCount != 1) {
417421
fprintf(stderr, "Expected constructCallCount to be 1\n");
418422
exit(EXIT_FAILURE);
@@ -421,7 +425,7 @@ void testExternalPass(void) {
421425
MlirPassManager pm = mlirPassManagerCreate(ctx);
422426
MlirOpPassManager nestedFuncPm =
423427
mlirPassManagerGetNestedUnder(pm, funcOpName);
424-
mlirOpPassManagerAddOwnedPass(nestedFuncPm, externalPass);
428+
mlirOpPassManagerAddOwnedPass(nestedFuncPm, pass);
425429
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
426430
if (mlirLogicalResultIsFailure(success)) {
427431
fprintf(stderr, "Unexpected failure running external operation pass.\n");
@@ -457,19 +461,21 @@ void testExternalPass(void) {
457461
mlirStringRefCreateFromCString("test-external-pass");
458462
TestExternalPassUserData userData = {0};
459463

460-
MlirPass externalPass = mlirCreateExternalPass(
464+
MlirExternalPass externalPass = mlirExternalPassCreate(
461465
passID, name, argument, description, emptyOpName, 0, NULL,
462466
makeTestExternalPassCallbacks(testInitializeExternalPass,
463467
testRunExternalPass),
464468
&userData);
465469

470+
MlirPass pass = mlirExternalPassGetPass(externalPass);
471+
466472
if (userData.constructCallCount != 1) {
467473
fprintf(stderr, "Expected constructCallCount to be 1\n");
468474
exit(EXIT_FAILURE);
469475
}
470476

471477
MlirPassManager pm = mlirPassManagerCreate(ctx);
472-
mlirPassManagerAddOwnedPass(pm, externalPass);
478+
mlirPassManagerAddOwnedPass(pm, pass);
473479
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
474480
if (mlirLogicalResultIsFailure(success)) {
475481
fprintf(stderr, "Unexpected failure running external pass.\n");
@@ -504,19 +510,21 @@ void testExternalPass(void) {
504510
mlirStringRefCreateFromCString("test-external-failing-pass");
505511
TestExternalPassUserData userData = {0};
506512

507-
MlirPass externalPass = mlirCreateExternalPass(
513+
MlirExternalPass externalPass = mlirExternalPassCreate(
508514
passID, name, argument, description, emptyOpName, 0, NULL,
509515
makeTestExternalPassCallbacks(testInitializeFailingExternalPass,
510516
testRunExternalPass),
511517
&userData);
512518

519+
MlirPass pass = mlirExternalPassGetPass(externalPass);
520+
513521
if (userData.constructCallCount != 1) {
514522
fprintf(stderr, "Expected constructCallCount to be 1\n");
515523
exit(EXIT_FAILURE);
516524
}
517525

518526
MlirPassManager pm = mlirPassManagerCreate(ctx);
519-
mlirPassManagerAddOwnedPass(pm, externalPass);
527+
mlirPassManagerAddOwnedPass(pm, pass);
520528
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
521529
if (mlirLogicalResultIsSuccess(success)) {
522530
fprintf(
@@ -553,18 +561,20 @@ void testExternalPass(void) {
553561
mlirStringRefCreateFromCString("test-external-failing-pass");
554562
TestExternalPassUserData userData = {0};
555563

556-
MlirPass externalPass = mlirCreateExternalPass(
564+
MlirExternalPass externalPass = mlirExternalPassCreate(
557565
passID, name, argument, description, emptyOpName, 0, NULL,
558566
makeTestExternalPassCallbacks(NULL, testRunFailingExternalPass),
559567
&userData);
560568

569+
MlirPass pass = mlirExternalPassGetPass(externalPass);
570+
561571
if (userData.constructCallCount != 1) {
562572
fprintf(stderr, "Expected constructCallCount to be 1\n");
563573
exit(EXIT_FAILURE);
564574
}
565575

566576
MlirPassManager pm = mlirPassManagerCreate(ctx);
567-
mlirPassManagerAddOwnedPass(pm, externalPass);
577+
mlirPassManagerAddOwnedPass(pm, pass);
568578
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
569579
if (mlirLogicalResultIsSuccess(success)) {
570580
fprintf(

0 commit comments

Comments
 (0)