@@ -235,7 +235,8 @@ LogicalResult MaskState::minStateScalar(const MaskState &lhsState,
235235 }
236236 } else {
237237 InFlightDiagnostic diag =
238- emitError (loc) << " Unexpected case where both lhs and rhs are not scalars" ;
238+ emitError (loc)
239+ << " Unexpected case where both lhs and rhs are not scalars" ;
239240 return failure ();
240241 }
241242 return success ();
@@ -329,7 +330,7 @@ LogicalResult MaskState::parseAnd(arith::AndIOp andOp, const Location loc,
329330 if (failed (rhsState.parse (andOp.getRhs (), loc, builder)))
330331 return failure ();
331332
332- if (!lhsState.isMask () || !rhsState.isMask ()) {
333+ if (!lhsState.isMask () || !rhsState.isMask ()) {
333334 return this ->minStateScalar (lhsState, rhsState, loc, builder);
334335 }
335336 return this ->minStates (lhsState, rhsState, loc, builder);
@@ -363,8 +364,8 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc,
363364 // We only support sge against 0 for lower bounds. Dims already has an
364365 // implicit assumption that the lower bound is 0, so if we see this, assume
365366 // the comparison evaluates to true.
366- if (cmpOp.getPredicate () == arith::CmpIPredicate::sge
367- && !(rhsState.scalar && hasConstZero (rhsState.scalar ))) {
367+ if (cmpOp.getPredicate () == arith::CmpIPredicate::sge &&
368+ !(rhsState.scalar && hasConstZero (rhsState.scalar ))) {
368369 InFlightDiagnostic diag = emitError (loc)
369370 << " Unsupported cmpi with rhs not equal to 0" ;
370371 return failure ();
@@ -383,8 +384,11 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc,
383384 cmpDim = i;
384385 }
385386 }
386- assert (cmpDim != -1 &&
387- " Unexpected case where no dimension has size larger than 1" );
387+ assert (
388+ cmpDim != -1 ||
389+ (!lhsState.scalar && cmpOp.getPredicate () == arith::CmpIPredicate::slt ||
390+ cmpOp.getPredicate () == arith::CmpIPredicate::ult) &&
391+ " Unexpected case where no dimension has size larger than 1" );
388392
389393 OpFoldResult newDim;
390394 if (lhsState.scalar ) {
@@ -397,10 +401,10 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc,
397401 // should be loaded/stored by inserting a comparison + select:
398402 // dim = lhs < rhs ? lhs.dim : 0
399403 newDim = compareOFRs (lhsState.scalar , rhsState.scalar , cmpOp.getPredicate (),
400- lhsState.dims [cmpDim], builder.getIndexAttr (0 ),
401- loc, builder);
404+ lhsState.dims [cmpDim], builder.getIndexAttr (0 ), loc ,
405+ builder);
402406 } else if (cmpOp.getPredicate () == arith::CmpIPredicate::slt ||
403- cmpOp.getPredicate () == arith::CmpIPredicate::ult) {
407+ cmpOp.getPredicate () == arith::CmpIPredicate::ult) {
404408 // Important:
405409 // In the case where the values we are loading are entirely masked off like
406410 // the following:
@@ -418,8 +422,8 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc,
418422 newEnd = maxOFRs (newEnd, lhsState.start , loc, builder);
419423 newDim = subOFRs (newEnd, lhsState.start , loc, builder);
420424 } else {
421- assert (cmpOp.getPredicate () == arith::CmpIPredicate::sge && rhsState. scalar
422- && hasConstZero (rhsState.scalar ));
425+ assert (cmpOp.getPredicate () == arith::CmpIPredicate::sge &&
426+ rhsState. scalar && hasConstZero (rhsState.scalar ));
423427 newDim = lhsState.dims [cmpDim];
424428 }
425429
@@ -507,6 +511,12 @@ LogicalResult MaskState::parseLoopIterArg(Value v, const Location loc,
507511 }
508512 }
509513
514+ if (!lhsState.start && !lhsState.end ) {
515+ assert (lhsState.scalar && " MaskState must have a scalar" );
516+ lhsState.start = builder.getIndexAttr (0 );
517+ lhsState.end = lhsState.scalar ;
518+ }
519+
510520 auto dist = subOFRs (lhsState.end , lhsState.start , loc, builder);
511521 this ->start = forOp.getRegionIterArg (argIndex + 1 );
512522 this ->end = addOFRs (this ->start , dist, loc, builder);
0 commit comments