Skip to content

Commit 2b728ad

Browse files
python3kgaeXiang Li
andauthored
Add support for unstructured masks in MaskAnalysis (#333)
- Enhanced `MaskAnalysis` to handle unstructured masks. MaskState will save unstructured mask for a dimension if that dimension failed MaskAnalysis. - Updated `TritonStructuredDialect` to include `gather_scatter_mask` in `MakeGatherScatterTensorPtrOp`. - Modified `PtrAnalysis` to apply unstructured masks during pointer analysis. - Extended `StructuredToMemref` conversion to support generic masks in load/store operations. - Added verification logic for `MakeGatherScatterTensorPtrOp` to ensure compatibility with masks. --------- Co-authored-by: Xiang Li <[email protected]>
1 parent 9a32209 commit 2b728ad

File tree

13 files changed

+1523
-17
lines changed

13 files changed

+1523
-17
lines changed

include/triton-shared/Analysis/MaskAnalysis.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,36 @@ namespace triton {
4444
//
4545
// Example of creating 2D mask:
4646
// mask = (rows[:, None] < M) & (cols[None, :] < N)
47+
//
48+
// Bool tensor mask could be saved into masks in case that dimension failed
49+
// MaskAnalysis. These is to allow case where only one dimension failed while
50+
// others passed. A MakeGatherScatterTensorPtrOp operation could be generated
51+
// for the failed dimension. Only 3 patterns are supported for this.
52+
// 1. offsets[:, None] < n where the offsets is 1d tensor.
53+
// It will in pattern of expandDims -> broadcast -> cmp
54+
// 2. mask[:, None] where mask is 1d bool tensor.
55+
// It will in pattern of cmp -> expandDims -> broadcast
56+
// 3. scalar_mask[:, None] where scalar mask is scalar bool.
57+
// It will in pattern of splat -> expandDims -> broadcast
58+
// These 3 patterns are only about how a bool tensor was created from 1D or
59+
// scalar bool. How the 1D and scalar bool were created is not important for the
60+
// unstructured mask.
61+
// Only one tensor mask is allowed. If multiple dimensions have failed
62+
// MaskAnalysis, then MaskAnalysis will still fail on the current operation.
4763
struct MaskState {
4864
OpFoldResult start;
4965
OpFoldResult end;
5066
SmallVector<OpFoldResult> dims;
67+
SmallVector<Value> masks;
5168
OpFoldResult scalar;
5269
const bool useUnsafeMask;
5370

5471
void dump() const;
5572

5673
MaskState(bool useUnsafeMask = false) : useUnsafeMask(useUnsafeMask) {}
5774

75+
SmallVector<std::pair<unsigned, Value>> getUnstructuredMasks();
76+
5877
int64_t getRank() const { return dims.size(); }
5978

6079
bool isEmpty() const { return getRank() == 0 && !scalar && !start && !end; }

include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def TTS_MakeGatherScatterTensorPtrOp
159159
// strides: The strides of the parent tensor, which means how much to increase the pointer
160160
// by when moving by 1 element in a specific axis.
161161
// offsets: Offset of the block along each dimension from base.
162+
// gather_scatter_mask: Optional bool mask for mask which failed MaskAnalysis.
162163
// result: A tensor of pointers.
163164

164165
let arguments = (ins TT_Ptr:$base,
@@ -168,19 +169,21 @@ def TTS_MakeGatherScatterTensorPtrOp
168169
Variadic<Index>:$strides,
169170
Variadic<Index>:$offsets,
170171
DenseI64ArrayAttr:$static_strides,
171-
DenseI64ArrayAttr:$static_offsets);
172+
DenseI64ArrayAttr:$static_offsets,
173+
Optional<TT_BoolLike>:$gather_scatter_mask);
172174

173175
let results = (outs TT_PtrLike:$result);
174176

175177
let assemblyFormat = [{
176178
$base `to` `sizes` `` `:` $sizes
177179
`gather_scatter_dim` `` `:` $gather_scatter_dim
178180
`gather_scatter_offset` `` `:` $gather_scatter_offset
181+
(`gather_scatter_mask` `` `:` $gather_scatter_mask^)?
179182
`` `,` `strides` `` `:`
180183
custom<DynamicIndexList>($strides, $static_strides)
181184
`` `,` `offsets` `` `:`
182185
custom<DynamicIndexList>($offsets, $static_offsets)
183-
attr-dict `:` type($gather_scatter_offset) type($base) `to` type($result)
186+
attr-dict `:` type($gather_scatter_offset) type($gather_scatter_mask) type($base) `to` type($result)
184187
}];
185188

186189

@@ -193,6 +196,15 @@ def TTS_MakeGatherScatterTensorPtrOp
193196
"ArrayRef<int64_t>":$sizes,
194197
"ArrayRef<OpFoldResult>":$strides,
195198
"ArrayRef<OpFoldResult>":$offsets)>,
199+
200+
OpBuilder<(ins
201+
"Value":$base,
202+
"Value":$gather_scatter_offset,
203+
"Value":$gather_scatter_mask,
204+
"int":$gather_scatter_dim,
205+
"ArrayRef<int64_t>":$sizes,
206+
"ArrayRef<OpFoldResult>":$strides,
207+
"ArrayRef<OpFoldResult>":$offsets)>,
196208
];
197209

198210
let extraClassDeclaration = [{
@@ -213,9 +225,8 @@ def TTS_MakeGatherScatterTensorPtrOp
213225
}
214226
}];
215227

216-
// TODO
217-
//let hasVerifier = 1;
218-
//let hasCanonicalizer = 1;
228+
let hasVerifier = 1;
229+
let hasCanonicalizer = 0;
219230
}
220231

221232
def TTS_GetStructuredStateOp : TTS_Op<"get_structured_state", [AttrSizedResultSegments, Pure]> {

lib/Analysis/MaskAnalysis.cpp

Lines changed: 206 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,11 @@ void MaskState::dump() const {
320320
llvm::dbgs() << "dims: ";
321321
for (auto dim : dims)
322322
llvm::dbgs() << "\t" << dim << "\n";
323+
if (!masks.empty()) {
324+
llvm::dbgs() << "masks: ";
325+
for (auto mask : masks)
326+
llvm::dbgs() << "\t" << mask << "\n";
327+
}
323328
llvm::dbgs() << "\n";
324329
}
325330

@@ -341,14 +346,96 @@ LogicalResult MaskState::parseAdd(arith::AddIOp addOp, const Location loc,
341346
LogicalResult MaskState::parseAnd(arith::AndIOp andOp, const Location loc,
342347
OpBuilder &builder) {
343348
assert(this->isEmpty());
344-
349+
bool isBoolOp = false;
350+
unsigned rank = 1;
351+
if (auto shapedType = dyn_cast<ShapedType>(andOp.getType())) {
352+
isBoolOp = shapedType.getElementType().isInteger(1);
353+
rank = shapedType.getRank();
354+
}
345355
MaskState lhsState;
346-
if (failed(lhsState.parse(andOp.getLhs(), loc, builder)))
356+
LogicalResult lResult = lhsState.parse(andOp.getLhs(), loc, builder);
357+
if (failed(lResult) && !isBoolOp) {
347358
return failure();
359+
}
348360

349361
MaskState rhsState;
350-
if (failed(rhsState.parse(andOp.getRhs(), loc, builder)))
362+
LogicalResult rResult = rhsState.parse(andOp.getRhs(), loc, builder);
363+
if (failed(rResult) && !isBoolOp) {
351364
return failure();
365+
}
366+
367+
if (isBoolOp) {
368+
if (lhsState.masks.size() != rank) {
369+
return failure();
370+
}
371+
372+
if (lhsState.masks.size() != rhsState.masks.size()) {
373+
return failure();
374+
}
375+
376+
// merge the masks.
377+
if (lhsState.masks.size() == rhsState.masks.size()) {
378+
auto shapedType = cast<ShapedType>(andOp.getType());
379+
assert(shapedType.hasStaticShape());
380+
for (size_t i = 0; i < lhsState.masks.size(); i++) {
381+
Value lhsV = lhsState.masks[i];
382+
Value rhsV = rhsState.masks[i];
383+
if (!lhsV && !rhsV) {
384+
masks.push_back(nullptr);
385+
} else {
386+
uint32_t size = shapedType.getShape()[i];
387+
auto structuredMaskToUnstructuredMask = [](MaskState state,
388+
unsigned dim,
389+
uint32_t size,
390+
OpBuilder &builder,
391+
Location loc) {
392+
OpFoldResult ofr = state.isMask() ? state.dims[dim] : state.scalar;
393+
if (auto intV = getIntAttr(ofr)) {
394+
if (intV == size) {
395+
// Full mask.
396+
return Value();
397+
}
398+
}
399+
auto targetTensorType =
400+
RankedTensorType::get({size}, builder.getI32Type());
401+
Value range =
402+
builder
403+
.create<triton::MakeRangeOp>(loc, targetTensorType, 0, size)
404+
.getResult();
405+
Value v = ofrToIndexValue(ofr, loc, builder);
406+
v = builder
407+
.create<arith::IndexCastUIOp>(loc, builder.getI32Type(), v)
408+
.getResult();
409+
v = builder.create<triton::SplatOp>(loc, targetTensorType, v)
410+
.getResult();
411+
return builder
412+
.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, range, v)
413+
.getResult();
414+
};
415+
if (!lhsV) {
416+
lhsV = structuredMaskToUnstructuredMask(lhsState, i, size, builder,
417+
loc);
418+
} else if (!rhsV) {
419+
rhsV = structuredMaskToUnstructuredMask(rhsState, i, size, builder,
420+
loc);
421+
}
422+
if (!lhsV) {
423+
masks.push_back(rhsV);
424+
continue;
425+
} else if (!rhsV) {
426+
masks.push_back(lhsV);
427+
continue;
428+
}
429+
// And the mask.
430+
masks.push_back(builder.create<arith::AndIOp>(loc, lhsV, rhsV));
431+
}
432+
}
433+
// Only support one unstructured mask.
434+
if (getUnstructuredMasks().size() > 1) {
435+
return failure();
436+
}
437+
}
438+
}
352439

353440
if (!lhsState.isMask() || !rhsState.isMask()) {
354441
return this->minStateScalar(lhsState, rhsState, loc, builder);
@@ -365,7 +452,48 @@ LogicalResult MaskState::parseExtSI(arith::ExtSIOp op, const Location loc,
365452
LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc,
366453
OpBuilder &builder) {
367454
assert(this->isEmpty());
368-
455+
int cmpOpDim = -1;
456+
if (auto shapedType = dyn_cast<ShapedType>(cmpOp.getType())) {
457+
for (unsigned r = 0; r < shapedType.getRank(); r++) {
458+
if (shapedType.getShape()[r] != 1) {
459+
if (cmpOpDim != -1) {
460+
// This will happen when the cmp has more than one dimension with size
461+
// larger than 1.
462+
// Like a < b while both a and b are tensors with shape 2x2.
463+
cmpOpDim = -1;
464+
break;
465+
}
466+
cmpOpDim = r;
467+
}
468+
}
469+
masks.clear();
470+
for (unsigned r = 0; r < shapedType.getRank(); r++) {
471+
masks.push_back(nullptr);
472+
}
473+
// If cmpOpDim == -1, parseCmp must fail later.
474+
// Here just setup unstructured masks when cmpOpDim != -1.
475+
if (cmpOpDim != -1) {
476+
// Save cmpOp as unstructured mask for failure case, will recover it to
477+
// nullptr later if success.
478+
Value unstructuredMask = cmpOp;
479+
if (shapedType.getRank() > 1) {
480+
// If cmpOp is not 1D, collapse it to 1D.
481+
auto flatType = RankedTensorType::get({shapedType.getShape()[cmpOpDim]},
482+
shapedType.getElementType());
483+
auto maybeReassociationMap =
484+
getReassociationIndicesForReshape(shapedType, flatType);
485+
SmallVector<ReassociationIndices> reassociation =
486+
*maybeReassociationMap;
487+
// Set masks.
488+
unstructuredMask = builder.create<tensor::CollapseShapeOp>(
489+
loc, flatType, cmpOp, reassociation);
490+
}
491+
masks[cmpOpDim] = unstructuredMask;
492+
}
493+
} else {
494+
cmpOpDim = 0;
495+
masks.push_back(cmpOp);
496+
}
369497
if (cmpOp.getPredicate() != arith::CmpIPredicate::slt &&
370498
cmpOp.getPredicate() != arith::CmpIPredicate::ult &&
371499
cmpOp.getPredicate() != arith::CmpIPredicate::sge) {
@@ -453,7 +581,10 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc,
453581
else
454582
this->dims.push_back(lhsState.dims[i]);
455583
}
456-
584+
if (cmpOpDim != -1) {
585+
// Clear masks when success.
586+
masks[cmpOpDim] = nullptr;
587+
}
457588
return success();
458589
}
459590

@@ -623,7 +754,15 @@ LogicalResult MaskState::parseSplat(triton::SplatOp splatOp, const Location loc,
623754

624755
for (auto s : dstShape)
625756
this->dims.push_back(builder.getIndexAttr(s));
626-
757+
bool isBool = src.getType().isInteger(1);
758+
if (isBool) {
759+
// If src is a 1D boolean tensor and parse success.
760+
// Create masks.
761+
masks.clear();
762+
for (unsigned i = 0; i < dstShape.size(); i++) {
763+
masks.push_back(nullptr);
764+
}
765+
}
627766
return success();
628767
}
629768

@@ -632,18 +771,76 @@ LogicalResult MaskState::parseExpandDims(triton::ExpandDimsOp expandDimsOp,
632771
OpBuilder &builder) {
633772
assert(this->isEmpty());
634773

635-
if (failed(this->parse(expandDimsOp.getSrc(), loc, builder)))
636-
return failure();
637-
638774
auto dstShape =
639775
cast<ShapedType>(expandDimsOp.getResult().getType()).getShape();
640776
auto axis = expandDimsOp.getAxis();
777+
Value src = expandDimsOp.getSrc();
778+
auto srcType = cast<ShapedType>(src.getType());
779+
bool isBoolOp = srcType.getElementType().isInteger(1);
780+
LogicalResult result = parse(src, loc, builder);
781+
if (failed(result)) {
782+
if (isBoolOp) {
783+
if (srcType.getRank() > 1 && masks.size() != srcType.getRank()) {
784+
return failure();
785+
}
786+
} else {
787+
return failure();
788+
}
789+
}
790+
791+
if (isBoolOp) {
792+
// Save mask for 1D boolean tensor
793+
if (srcType.getRank() == 1) {
794+
assert(dstShape.size() == 2);
795+
masks.resize(dstShape.size());
796+
masks[axis] = nullptr;
797+
if (failed(result)) {
798+
// Recover dims to allow other dim to be processed.
799+
dims.clear();
800+
dims.push_back(builder.getIndexAttr(srcType.getShape()[0]));
801+
// Save src as unstructured mask.
802+
masks[1 - axis] = src;
803+
} else {
804+
// save nullptr when parse success.
805+
masks[1 - axis] = nullptr;
806+
}
807+
} else {
808+
if (failed(result)) {
809+
auto unstructuredMasks = getUnstructuredMasks();
810+
if (unstructuredMasks.empty()) {
811+
return failure();
812+
}
813+
if (unstructuredMasks.size() > 1) {
814+
return failure();
815+
}
816+
auto [dim, mask] = unstructuredMasks.front();
817+
// Recover dims for unstructured mask dim to allow other dim to be
818+
// processed.
819+
dims[dim] = builder.getIndexAttr(srcType.getShape()[dim]);
820+
}
821+
masks.insert(masks.begin() + axis, nullptr);
822+
}
823+
}
824+
641825
assert(dstShape[axis] == 1 &&
642826
"expect changed dimension to be 1 in expand_dims");
643827
this->dims.insert(this->dims.begin() + axis, builder.getIndexAttr(1));
644828

645829
return success();
646830
}
647831

832+
// Return all non-nullptr masks along with their dimensions.
833+
SmallVector<std::pair<unsigned, Value>> MaskState::getUnstructuredMasks() {
834+
SmallVector<std::pair<unsigned, Value>> result;
835+
836+
for (auto [i, m] : llvm::enumerate(masks)) {
837+
if (m) {
838+
result.push_back({i, m});
839+
}
840+
}
841+
842+
return result;
843+
}
844+
648845
} // namespace triton
649846
} // namespace mlir

0 commit comments

Comments
 (0)