@@ -1042,6 +1042,65 @@ struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
10421042 }
10431043};
10441044
1045+ // / Remove empty acc.kernel_environment operations. If the operation has wait
1046+ // / operands, create a acc.wait operation to preserve synchronization.
1047+ struct RemoveEmptyKernelEnvironment
1048+ : public OpRewritePattern<acc::KernelEnvironmentOp> {
1049+ using OpRewritePattern<acc::KernelEnvironmentOp>::OpRewritePattern;
1050+
1051+ LogicalResult matchAndRewrite (acc::KernelEnvironmentOp op,
1052+ PatternRewriter &rewriter) const override {
1053+ assert (op->getNumRegions () == 1 && " expected op to have one region" );
1054+
1055+ Block &block = op.getRegion ().front ();
1056+ if (!block.empty ())
1057+ return failure ();
1058+
1059+ // Conservatively disable canonicalization of empty acc.kernel_environment
1060+ // operations if the wait operands in the kernel_environment cannot be fully
1061+ // represented by acc.wait operation.
1062+
1063+ // Disable canonicalization if device type is not the default
1064+ if (auto deviceTypeAttr = op.getWaitOperandsDeviceTypeAttr ()) {
1065+ for (auto attr : deviceTypeAttr) {
1066+ if (auto dtAttr = mlir::dyn_cast<acc::DeviceTypeAttr>(attr)) {
1067+ if (dtAttr.getValue () != mlir::acc::DeviceType::None)
1068+ return failure ();
1069+ }
1070+ }
1071+ }
1072+
1073+ // Disable canonicalization if any wait segment has a devnum
1074+ if (auto hasDevnumAttr = op.getHasWaitDevnumAttr ()) {
1075+ for (auto attr : hasDevnumAttr) {
1076+ if (auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>(attr)) {
1077+ if (boolAttr.getValue ())
1078+ return failure ();
1079+ }
1080+ }
1081+ }
1082+
1083+ // Disable canonicalization if there are multiple wait segments
1084+ if (auto segmentsAttr = op.getWaitOperandsSegmentsAttr ()) {
1085+ if (segmentsAttr.size () > 1 )
1086+ return failure ();
1087+ }
1088+
1089+ // Remove empty kernel environment.
1090+ // Preserve synchronization by creating acc.wait operation if needed.
1091+ if (!op.getWaitOperands ().empty () || op.getWaitOnlyAttr ())
1092+ rewriter.replaceOpWithNewOp <acc::WaitOp>(op, op.getWaitOperands (),
1093+ /* asyncOperand=*/ Value (),
1094+ /* waitDevnum=*/ Value (),
1095+ /* async=*/ nullptr ,
1096+ /* ifCond=*/ Value ());
1097+ else
1098+ rewriter.eraseOp (op);
1099+
1100+ return success ();
1101+ }
1102+ };
1103+
10451104// ===----------------------------------------------------------------------===//
10461105// Recipe Region Helpers
10471106// ===----------------------------------------------------------------------===//
@@ -2690,6 +2749,15 @@ void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
26902749 results.add <RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
26912750}
26922751
2752+ // ===----------------------------------------------------------------------===//
2753+ // KernelEnvironmentOp
2754+ // ===----------------------------------------------------------------------===//
2755+
2756+ void acc::KernelEnvironmentOp::getCanonicalizationPatterns (
2757+ RewritePatternSet &results, MLIRContext *context) {
2758+ results.add <RemoveEmptyKernelEnvironment>(context);
2759+ }
2760+
26932761// ===----------------------------------------------------------------------===//
26942762// LoopOp
26952763// ===----------------------------------------------------------------------===//
0 commit comments