@@ -155,13 +155,15 @@ struct ExtractSliceOpInterface
155155 RankedTensorType sourceType = extractSliceOp.getSource ().getType ();
156156
157157 // For each dimension, assert that:
158- // 0 <= offset < dim_size
159- // 0 <= offset + (size - 1) * stride < dim_size
158+ // For empty slices (size == 0) : 0 <= offset <= dim_size
159+ // For non-empty slices (size > 0): 0 <= offset < dim_size
160+ // 0 <= offset + (size - 1) * stride <
161+ // dim_size
160162 Value zero = arith::ConstantIndexOp::create (builder, loc, 0 );
161163 Value one = arith::ConstantIndexOp::create (builder, loc, 1 );
162164
163165 for (int64_t i : llvm::seq<int64_t >(0 , sourceType.getRank ())) {
164- // Reset insertion point to before the operation for each dimension
166+
165167 builder.setInsertionPoint (extractSliceOp);
166168
167169 Value offset = getValueOrCreateConstantIndexOp (
@@ -170,46 +172,63 @@ struct ExtractSliceOpInterface
170172 builder, loc, extractSliceOp.getMixedSizes ()[i]);
171173 Value stride = getValueOrCreateConstantIndexOp (
172174 builder, loc, extractSliceOp.getMixedStrides ()[i]);
173-
174- // Verify that offset is in-bounds.
175175 Value dimSize = builder.createOrFold <tensor::DimOp>(
176176 loc, extractSliceOp.getSource (), i);
177- Value offsetInBounds =
178- generateInBoundsCheck (builder, loc, offset, zero, dimSize);
179- cf::AssertOp::create (builder, loc, offsetInBounds,
177+
178+ // Verify that offset is in-bounds (conditional on slice size).
179+ Value sizeIsZero = arith::CmpIOp::create (
180+ builder, loc, arith::CmpIPredicate::eq, size, zero);
181+ auto offsetCheckIf = scf::IfOp::create (
182+ builder, loc, sizeIsZero,
183+ [&](OpBuilder &b, Location loc) {
184+ // For empty slices, offset can be at the boundary: 0 <= offset <=
185+ // dimSize.
186+ Value offsetGEZero = arith::CmpIOp::create (
187+ b, loc, arith::CmpIPredicate::sge, offset, zero);
188+ Value offsetLEDimSize = arith::CmpIOp::create (
189+ b, loc, arith::CmpIPredicate::sle, offset, dimSize);
190+ Value emptyOffsetValid =
191+ arith::AndIOp::create (b, loc, offsetGEZero, offsetLEDimSize);
192+ scf::YieldOp::create (b, loc, emptyOffsetValid);
193+ },
194+ [&](OpBuilder &b, Location loc) {
195+ // For non-empty slices, offset must be a valid index: 0 <= offset <
196+ // dimSize.
197+ Value offsetInBounds =
198+ generateInBoundsCheck (b, loc, offset, zero, dimSize);
199+ scf::YieldOp::create (b, loc, offsetInBounds);
200+ });
201+
202+ Value offsetCondition = offsetCheckIf.getResult (0 );
203+ cf::AssertOp::create (builder, loc, offsetCondition,
180204 generateErrorMessage (op, " offset " +
181205 std::to_string (i) +
182206 " is out-of-bounds" ));
183207
184- // Only verify if size > 0
208+ // Verify that the slice endpoint is in-bounds (only for non-empty
209+ // slices).
185210 Value sizeIsNonZero = arith::CmpIOp::create (
186211 builder, loc, arith::CmpIPredicate::sgt, size, zero);
212+ auto ifOp = scf::IfOp::create (
213+ builder, loc, sizeIsNonZero,
214+ [&](OpBuilder &b, Location loc) {
215+ // Verify that slice does not run out-of-bounds.
216+ Value sizeMinusOne = arith::SubIOp::create (b, loc, size, one);
217+ Value sizeMinusOneTimesStride =
218+ arith::MulIOp::create (b, loc, sizeMinusOne, stride);
219+ Value lastPos =
220+ arith::AddIOp::create (b, loc, offset, sizeMinusOneTimesStride);
221+ Value lastPosInBounds =
222+ generateInBoundsCheck (b, loc, lastPos, zero, dimSize);
223+ scf::YieldOp::create (b, loc, lastPosInBounds);
224+ },
225+ [&](OpBuilder &b, Location loc) {
226+ Value trueVal =
227+ arith::ConstantOp::create (b, loc, b.getBoolAttr (true ));
228+ scf::YieldOp::create (b, loc, trueVal);
229+ });
187230
188- auto ifOp = scf::IfOp::create (builder, loc, builder.getI1Type (),
189- sizeIsNonZero, /* withElseRegion=*/ true );
190-
191- // Populate the "then" region (for size > 0).
192- builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
193-
194- // Verify that slice does not run out-of-bounds.
195- Value sizeMinusOne = arith::SubIOp::create (builder, loc, size, one);
196- Value sizeMinusOneTimesStride =
197- arith::MulIOp::create (builder, loc, sizeMinusOne, stride);
198- Value lastPos =
199- arith::AddIOp::create (builder, loc, offset, sizeMinusOneTimesStride);
200- Value lastPosInBounds =
201- generateInBoundsCheck (builder, loc, lastPos, zero, dimSize);
202- scf::YieldOp::create (builder, loc, lastPosInBounds);
203-
204- // Populate the "else" region (for size == 0).
205- builder.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
206- Value trueVal =
207- arith::ConstantOp::create (builder, loc, builder.getBoolAttr (true ));
208- scf::YieldOp::create (builder, loc, trueVal);
209-
210- builder.setInsertionPointAfter (ifOp);
211231 Value finalCondition = ifOp.getResult (0 );
212-
213232 cf::AssertOp::create (
214233 builder, loc, finalCondition,
215234 generateErrorMessage (
0 commit comments