@@ -891,20 +891,25 @@ struct TestCreateIllegalBlock : public RewritePattern {
891891 }
892892};
893893
894- // / A simple pattern that tests the undo mechanism when replacing the uses of a
895- // / block argument.
896- struct TestUndoBlockArgReplace : public ConversionPattern {
897- TestUndoBlockArgReplace (MLIRContext *ctx)
898- : ConversionPattern( " test.undo_block_arg_replace " , /* benefit= */ 1 , ctx) {}
894+ // / A simple pattern that tests the "replaceUsesOfBlockArgument" API.
895+ struct TestBlockArgReplace : public ConversionPattern {
896+ TestBlockArgReplace (MLIRContext *ctx, const TypeConverter &converter)
897+ : ConversionPattern(converter, " test.block_arg_replace " , /* benefit= */ 1 ,
898+ ctx) {}
899899
900900 LogicalResult
901901 matchAndRewrite (Operation *op, ArrayRef<Value> operands,
902902 ConversionPatternRewriter &rewriter) const final {
903- auto illegalOp =
904- rewriter. create <ILLegalOpF>( op->getLoc (), rewriter. getF32Type () );
903+ // Replace the first block argument with 2x the second block argument.
904+ Value repl = op->getRegion ( 0 ). getArgument ( 1 );
905905 rewriter.replaceUsesOfBlockArgument (op->getRegion (0 ).getArgument (0 ),
906- illegalOp->getResult (0 ));
907- rewriter.modifyOpInPlace (op, [] {});
906+ {repl, repl});
907+ rewriter.modifyOpInPlace (op, [&] {
908+ // If the "trigger_rollback" attribute is set, keep the op illegal, so
909+ // that a rollback is triggered.
910+ if (!op->hasAttr (" trigger_rollback" ))
911+ op->setAttr (" is_legal" , rewriter.getUnitAttr ());
912+ });
908913 return success ();
909914 }
910915};
@@ -1375,20 +1380,19 @@ struct TestLegalizePatternDriver
13751380 TestTypeConverter converter;
13761381 mlir::RewritePatternSet patterns (&getContext ());
13771382 populateWithGenerated (patterns);
1378- patterns
1379- .add <TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
1380- TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
1381- TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
1382- TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
1383- TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
1384- TestNonRootReplacement, TestBoundedRecursiveRewrite,
1385- TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
1386- TestCreateUnregisteredOp, TestUndoMoveOpBefore,
1387- TestUndoPropertiesModification, TestEraseOp,
1388- TestRepetitive1ToNConsumer>(&getContext ());
1383+ patterns.add <
1384+ TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
1385+ TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
1386+ TestUndoBlockErase, TestSplitReturnType, TestChangeProducerTypeI32ToF32,
1387+ TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
1388+ TestUpdateConsumerType, TestNonRootReplacement,
1389+ TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
1390+ TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
1391+ TestUndoPropertiesModification, TestEraseOp,
1392+ TestRepetitive1ToNConsumer>(&getContext ());
13891393 patterns.add <TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
1390- TestPassthroughInvalidOp, TestMultiple1ToNReplacement>(
1391- &getContext (), converter);
1394+ TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
1395+ TestBlockArgReplace>( &getContext (), converter);
13921396 patterns.add <TestConvertBlockArgs>(converter, &getContext ());
13931397 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern (patterns,
13941398 converter);
@@ -1413,6 +1417,9 @@ struct TestLegalizePatternDriver
14131417 });
14141418 target.addDynamicallyLegalOp <func::CallOp>(
14151419 [&](func::CallOp op) { return converter.isLegal (op); });
1420+ target.addDynamicallyLegalOp (
1421+ OperationName (" test.block_arg_replace" , &getContext ()),
1422+ [](Operation *op) { return op->hasAttr (" is_legal" ); });
14161423
14171424 // TestCreateUnregisteredOp creates `arith.constant` operation,
14181425 // which was not added to target intentionally to test
0 commit comments