@@ -320,6 +320,11 @@ void MaskState::dump() const {
320320 llvm::dbgs () << " dims: " ;
321321 for (auto dim : dims)
322322 llvm::dbgs () << " \t " << dim << " \n " ;
323+ if (!masks.empty ()) {
324+ llvm::dbgs () << " masks: " ;
325+ for (auto mask : masks)
326+ llvm::dbgs () << " \t " << mask << " \n " ;
327+ }
323328 llvm::dbgs () << " \n " ;
324329}
325330
@@ -341,14 +346,96 @@ LogicalResult MaskState::parseAdd(arith::AddIOp addOp, const Location loc,
341346LogicalResult MaskState::parseAnd (arith::AndIOp andOp, const Location loc,
342347 OpBuilder &builder) {
343348 assert (this ->isEmpty ());
344-
349+ bool isBoolOp = false ;
350+ unsigned rank = 1 ;
351+ if (auto shapedType = dyn_cast<ShapedType>(andOp.getType ())) {
352+ isBoolOp = shapedType.getElementType ().isInteger (1 );
353+ rank = shapedType.getRank ();
354+ }
345355 MaskState lhsState;
346- if (failed (lhsState.parse (andOp.getLhs (), loc, builder)))
356+ LogicalResult lResult = lhsState.parse (andOp.getLhs (), loc, builder);
357+ if (failed (lResult) && !isBoolOp) {
347358 return failure ();
359+ }
348360
349361 MaskState rhsState;
350- if (failed (rhsState.parse (andOp.getRhs (), loc, builder)))
362+ LogicalResult rResult = rhsState.parse (andOp.getRhs (), loc, builder);
363+ if (failed (rResult) && !isBoolOp) {
351364 return failure ();
365+ }
366+
367+ if (isBoolOp) {
368+ if (lhsState.masks .size () != rank) {
369+ return failure ();
370+ }
371+
372+ if (lhsState.masks .size () != rhsState.masks .size ()) {
373+ return failure ();
374+ }
375+
376+ // merge the masks.
377+ if (lhsState.masks .size () == rhsState.masks .size ()) {
378+ auto shapedType = cast<ShapedType>(andOp.getType ());
379+ assert (shapedType.hasStaticShape ());
380+ for (size_t i = 0 ; i < lhsState.masks .size (); i++) {
381+ Value lhsV = lhsState.masks [i];
382+ Value rhsV = rhsState.masks [i];
383+ if (!lhsV && !rhsV) {
384+ masks.push_back (nullptr );
385+ } else {
386+ uint32_t size = shapedType.getShape ()[i];
387+ auto structuredMaskToUnstructuredMask = [](MaskState state,
388+ unsigned dim,
389+ uint32_t size,
390+ OpBuilder &builder,
391+ Location loc) {
392+ OpFoldResult ofr = state.isMask () ? state.dims [dim] : state.scalar ;
393+ if (auto intV = getIntAttr (ofr)) {
394+ if (intV == size) {
395+ // Full mask.
396+ return Value ();
397+ }
398+ }
399+ auto targetTensorType =
400+ RankedTensorType::get ({size}, builder.getI32Type ());
401+ Value range =
402+ builder
403+ .create <triton::MakeRangeOp>(loc, targetTensorType, 0 , size)
404+ .getResult ();
405+ Value v = ofrToIndexValue (ofr, loc, builder);
406+ v = builder
407+ .create <arith::IndexCastUIOp>(loc, builder.getI32Type (), v)
408+ .getResult ();
409+ v = builder.create <triton::SplatOp>(loc, targetTensorType, v)
410+ .getResult ();
411+ return builder
412+ .create <arith::CmpIOp>(loc, arith::CmpIPredicate::ult, range, v)
413+ .getResult ();
414+ };
415+ if (!lhsV) {
416+ lhsV = structuredMaskToUnstructuredMask (lhsState, i, size, builder,
417+ loc);
418+ } else if (!rhsV) {
419+ rhsV = structuredMaskToUnstructuredMask (rhsState, i, size, builder,
420+ loc);
421+ }
422+ if (!lhsV) {
423+ masks.push_back (rhsV);
424+ continue ;
425+ } else if (!rhsV) {
426+ masks.push_back (lhsV);
427+ continue ;
428+ }
429+ // And the mask.
430+ masks.push_back (builder.create <arith::AndIOp>(loc, lhsV, rhsV));
431+ }
432+ }
433+ // Only support one unstructured mask.
434+ if (getUnstructuredMasks ().size () > 1 ) {
435+ return failure ();
436+ }
437+ }
438+ }
352439
353440 if (!lhsState.isMask () || !rhsState.isMask ()) {
354441 return this ->minStateScalar (lhsState, rhsState, loc, builder);
@@ -365,7 +452,48 @@ LogicalResult MaskState::parseExtSI(arith::ExtSIOp op, const Location loc,
365452LogicalResult MaskState::parseCmp (arith::CmpIOp cmpOp, const Location loc,
366453 OpBuilder &builder) {
367454 assert (this ->isEmpty ());
368-
455+ int cmpOpDim = -1 ;
456+ if (auto shapedType = dyn_cast<ShapedType>(cmpOp.getType ())) {
457+ for (unsigned r = 0 ; r < shapedType.getRank (); r++) {
458+ if (shapedType.getShape ()[r] != 1 ) {
459+ if (cmpOpDim != -1 ) {
460+ // This will happen when the cmp has more than one dimension with size
461+ // larger than 1.
462+ // Like a < b while both a and b are tensors with shape 2x2.
463+ cmpOpDim = -1 ;
464+ break ;
465+ }
466+ cmpOpDim = r;
467+ }
468+ }
469+ masks.clear ();
470+ for (unsigned r = 0 ; r < shapedType.getRank (); r++) {
471+ masks.push_back (nullptr );
472+ }
473+ // If cmpOpDim == -1, parseCmp must fail later.
474+ // Here just setup unstructured masks when cmpOpDim != -1.
475+ if (cmpOpDim != -1 ) {
476+ // Save cmpOp as unstructured mask for failure case, will recover it to
477+ // nullptr later if success.
478+ Value unstructuredMask = cmpOp;
479+ if (shapedType.getRank () > 1 ) {
480+ // If cmpOp is not 1D, collapse it to 1D.
481+ auto flatType = RankedTensorType::get ({shapedType.getShape ()[cmpOpDim]},
482+ shapedType.getElementType ());
483+ auto maybeReassociationMap =
484+ getReassociationIndicesForReshape (shapedType, flatType);
485+ SmallVector<ReassociationIndices> reassociation =
486+ *maybeReassociationMap;
487+ // Set masks.
488+ unstructuredMask = builder.create <tensor::CollapseShapeOp>(
489+ loc, flatType, cmpOp, reassociation);
490+ }
491+ masks[cmpOpDim] = unstructuredMask;
492+ }
493+ } else {
494+ cmpOpDim = 0 ;
495+ masks.push_back (cmpOp);
496+ }
369497 if (cmpOp.getPredicate () != arith::CmpIPredicate::slt &&
370498 cmpOp.getPredicate () != arith::CmpIPredicate::ult &&
371499 cmpOp.getPredicate () != arith::CmpIPredicate::sge) {
@@ -453,7 +581,10 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc,
453581 else
454582 this ->dims .push_back (lhsState.dims [i]);
455583 }
456-
584+ if (cmpOpDim != -1 ) {
585+ // Clear masks when success.
586+ masks[cmpOpDim] = nullptr ;
587+ }
457588 return success ();
458589}
459590
@@ -623,7 +754,15 @@ LogicalResult MaskState::parseSplat(triton::SplatOp splatOp, const Location loc,
623754
624755 for (auto s : dstShape)
625756 this ->dims .push_back (builder.getIndexAttr (s));
626-
757+ bool isBool = src.getType ().isInteger (1 );
758+ if (isBool) {
759+ // If src is a 1D boolean tensor and parse success.
760+ // Create masks.
761+ masks.clear ();
762+ for (unsigned i = 0 ; i < dstShape.size (); i++) {
763+ masks.push_back (nullptr );
764+ }
765+ }
627766 return success ();
628767}
629768
@@ -632,18 +771,76 @@ LogicalResult MaskState::parseExpandDims(triton::ExpandDimsOp expandDimsOp,
632771 OpBuilder &builder) {
633772 assert (this ->isEmpty ());
634773
635- if (failed (this ->parse (expandDimsOp.getSrc (), loc, builder)))
636- return failure ();
637-
638774 auto dstShape =
639775 cast<ShapedType>(expandDimsOp.getResult ().getType ()).getShape ();
640776 auto axis = expandDimsOp.getAxis ();
777+ Value src = expandDimsOp.getSrc ();
778+ auto srcType = cast<ShapedType>(src.getType ());
779+ bool isBoolOp = srcType.getElementType ().isInteger (1 );
780+ LogicalResult result = parse (src, loc, builder);
781+ if (failed (result)) {
782+ if (isBoolOp) {
783+ if (srcType.getRank () > 1 && masks.size () != srcType.getRank ()) {
784+ return failure ();
785+ }
786+ } else {
787+ return failure ();
788+ }
789+ }
790+
791+ if (isBoolOp) {
792+ // Save mask for 1D boolean tensor
793+ if (srcType.getRank () == 1 ) {
794+ assert (dstShape.size () == 2 );
795+ masks.resize (dstShape.size ());
796+ masks[axis] = nullptr ;
797+ if (failed (result)) {
798+ // Recover dims to allow other dim to be processed.
799+ dims.clear ();
800+ dims.push_back (builder.getIndexAttr (srcType.getShape ()[0 ]));
801+ // Save src as unstructured mask.
802+ masks[1 - axis] = src;
803+ } else {
804+ // save nullptr when parse success.
805+ masks[1 - axis] = nullptr ;
806+ }
807+ } else {
808+ if (failed (result)) {
809+ auto unstructuredMasks = getUnstructuredMasks ();
810+ if (unstructuredMasks.empty ()) {
811+ return failure ();
812+ }
813+ if (unstructuredMasks.size () > 1 ) {
814+ return failure ();
815+ }
816+ auto [dim, mask] = unstructuredMasks.front ();
817+ // Recover dims for unstructured mask dim to allow other dim to be
818+ // processed.
819+ dims[dim] = builder.getIndexAttr (srcType.getShape ()[dim]);
820+ }
821+ masks.insert (masks.begin () + axis, nullptr );
822+ }
823+ }
824+
641825 assert (dstShape[axis] == 1 &&
642826 " expect changed dimension to be 1 in expand_dims" );
643827 this ->dims .insert (this ->dims .begin () + axis, builder.getIndexAttr (1 ));
644828
645829 return success ();
646830}
647831
832+ // Return all non-nullptr masks along with their dimensions.
833+ SmallVector<std::pair<unsigned , Value>> MaskState::getUnstructuredMasks () {
834+ SmallVector<std::pair<unsigned , Value>> result;
835+
836+ for (auto [i, m] : llvm::enumerate (masks)) {
837+ if (m) {
838+ result.push_back ({i, m});
839+ }
840+ }
841+
842+ return result;
843+ }
844+
648845} // namespace triton
649846} // namespace mlir
0 commit comments