1515#include " mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
1616#include " mlir/Dialect/MemRef/IR/MemRef.h"
1717#include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
18+ #include " mlir/Dialect/SCF/IR/SCF.h"
1819#include " mlir/Interfaces/RuntimeVerifiableOpInterface.h"
1920
2021using namespace mlir ;
@@ -296,37 +297,11 @@ struct SubViewOpInterface
296297 Value sizeIsNonZero = arith::CmpIOp::create (
297298 builder, loc, arith::CmpIPredicate::sgt, size, zero);
298299
299- /*
300- * Split the current block to create the below control flow structure:
301- *
302- * ^preCondBlock:
303- * ... // offset check already done above
304- * %size_nonzero = arith.cmpi sgt, %size, %zero
305- * cf.cond_br %size_nonzero, ^sizeBoundsCheckBlock, ^afterCheckBlock
306- *
307- * ^sizeBoundsCheckBlock:
308- * %last_pos = ... // compute offset + (size-1) * stride
309- * %last_pos_ok = ... // last position bounds check
310- * cf.assert %last_pos_ok, "extract_slice runs out-of-bounds"
311- * cf.br ^afterCheckBlock
312- *
313- * ^afterCheckBlock:
314- * tensor.extract_slice ... // the original operation
315- */
316- Block *preCondBlock = builder.getBlock ();
317- Block *afterCheckBlock = preCondBlock->splitBlock (subView);
318-
319- // Create the block for conditional size bounds verification.
320- Block *sizeBoundsCheckBlock = builder.createBlock (
321- preCondBlock->getParent (), Region::iterator (afterCheckBlock));
322-
323- // Terminate the pre-condition block with the conditional branch.
324- builder.setInsertionPointToEnd (preCondBlock);
325- cf::CondBranchOp::create (builder, loc, sizeIsNonZero,
326- sizeBoundsCheckBlock, afterCheckBlock);
327-
328- // Populate the size bounds check block with lastPos verification.
329- builder.setInsertionPointToStart (sizeBoundsCheckBlock);
300+ auto ifOp = scf::IfOp::create (builder, loc, builder.getI1Type (),
301+ sizeIsNonZero, /* withElseRegion=*/ true );
302+
303+ // Populate the "then" region (for size > 0).
304+ builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
330305
331306 // Verify that slice does not run out-of-bounds.
332307 Value sizeMinusOne = arith::SubIOp::create (builder, loc, size, one);
@@ -336,12 +311,23 @@ struct SubViewOpInterface
336311 arith::AddIOp::create (builder, loc, offset, sizeMinusOneTimesStride);
337312 Value lastPosInBounds =
338313 generateInBoundsCheck (builder, loc, lastPos, zero, dimSize);
314+
315+ scf::YieldOp::create (builder, loc, lastPosInBounds);
316+
317+ // Populate the "else" region (for size == 0).
318+ builder.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
319+ Value trueVal =
320+ arith::ConstantOp::create (builder, loc, builder.getBoolAttr (true ));
321+ scf::YieldOp::create (builder, loc, trueVal);
322+
323+ builder.setInsertionPointAfter (ifOp);
324+ Value finalCondition = ifOp.getResult (0 );
325+
339326 cf::AssertOp::create (
340- builder, loc, lastPosInBounds ,
327+ builder, loc, finalCondition ,
341328 generateErrorMessage (op,
342329 " subview runs out-of-bounds along dimension " +
343330 std::to_string (i)));
344- cf::BranchOp::create (builder, loc, afterCheckBlock);
345331 }
346332 }
347333};
0 commit comments