Skip to content

Commit 2ed6e7f

Browse files
authored
[NFC] Separate back memdesc reshape encoding inference (#7544)
This fixes extra error message emitted for non error kind of failure. This avoids having op errors being emitted when the memdesc pattern fails.
1 parent 815b2a4 commit 2ed6e7f

File tree

2 files changed

+42
-36
lines changed

2 files changed

+42
-36
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2469,39 +2469,6 @@ struct TritonGPUInferLayoutInterface
24692469
Attribute srcEnc,
24702470
ArrayRef<int64_t> dstShape,
24712471
Attribute &dstEnc) const {
2472-
if (auto mmaEncoding = dyn_cast<NVMMASharedEncodingAttr>(srcEnc)) {
2473-
// TODO: supporting reshape of CTA layouts is non-trivial.
2474-
if (getNumCTAs(mmaEncoding) > 1)
2475-
return failure();
2476-
int innerDimDst =
2477-
mmaEncoding.getTransposed() ? dstShape.front() : dstShape.back();
2478-
int innerDimSrc =
2479-
mmaEncoding.getTransposed() ? srcShape.front() : srcShape.back();
2480-
// For now disallow reshape of the inner dimension.
2481-
if (innerDimDst != innerDimSrc)
2482-
return failure();
2483-
auto *ctx = srcEnc.getContext();
2484-
2485-
// CTALayout can be all 1's because we bailed on multi-CTA layouts above.
2486-
auto CTALayout = CTALayoutAttr::get(
2487-
ctx,
2488-
/*CTAsPerCGA=*/SmallVector<unsigned>(dstShape.size(), 1),
2489-
/*CTASplitNum=*/SmallVector<unsigned>(dstShape.size(), 1),
2490-
/*CTAOrder=*/llvm::to_vector(llvm::seq<unsigned>(dstShape.size())));
2491-
dstEnc = NVMMASharedEncodingAttr::get(
2492-
ctx, mmaEncoding.getSwizzlingByteWidth(), mmaEncoding.getTransposed(),
2493-
mmaEncoding.getElementBitWidth(), mmaEncoding.getFp4Padded(),
2494-
CTALayout);
2495-
// Big guns, check linear layouts are equivalent
2496-
// We disallow reshaping memdesc_subviews in the verifier
2497-
// We disallow reshaping memdesc_subviews in the verifier
2498-
auto srcLL = toLinearLayout(srcShape, srcEnc, srcShape);
2499-
auto dstLL = toLinearLayout(dstShape, dstEnc, dstShape);
2500-
if (reshapeLayout(ctx, srcLL, dstShape) != dstLL) {
2501-
return failure();
2502-
}
2503-
return success();
2504-
}
25052472
auto src = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc);
25062473
if (!src) {
25072474
return failure();

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,46 @@ LogicalResult MemDescReshapeOp::verify() {
494494
return success();
495495
}
496496

497+
static LogicalResult inferMemDescReshapeOpEncoding(ArrayRef<int64_t> srcShape,
498+
Attribute srcEnc,
499+
ArrayRef<int64_t> dstShape,
500+
Attribute &dstEnc) {
501+
if (auto mmaEncoding = dyn_cast<NVMMASharedEncodingAttr>(srcEnc)) {
502+
// TODO: supporting reshape of CTA layouts is non-trivial.
503+
if (getNumCTAs(mmaEncoding) > 1)
504+
return failure();
505+
int innerDimDst =
506+
mmaEncoding.getTransposed() ? dstShape.front() : dstShape.back();
507+
int innerDimSrc =
508+
mmaEncoding.getTransposed() ? srcShape.front() : srcShape.back();
509+
// For now disallow reshape of the inner dimension.
510+
if (innerDimDst != innerDimSrc)
511+
return failure();
512+
auto *ctx = srcEnc.getContext();
513+
514+
// CTALayout can be all 1's because we bailed on multi-CTA layouts above.
515+
auto CTALayout = CTALayoutAttr::get(
516+
ctx,
517+
/*CTAsPerCGA=*/SmallVector<unsigned>(dstShape.size(), 1),
518+
/*CTASplitNum=*/SmallVector<unsigned>(dstShape.size(), 1),
519+
/*CTAOrder=*/llvm::to_vector(llvm::seq<unsigned>(dstShape.size())));
520+
dstEnc = NVMMASharedEncodingAttr::get(
521+
ctx, mmaEncoding.getSwizzlingByteWidth(), mmaEncoding.getTransposed(),
522+
mmaEncoding.getElementBitWidth(), mmaEncoding.getFp4Padded(),
523+
CTALayout);
524+
// Big guns, check linear layouts are equivalent
525+
// We disallow reshaping memdesc_subviews in the verifier
526+
// We disallow reshaping memdesc_subviews in the verifier
527+
auto srcLL = toLinearLayout(srcShape, srcEnc, srcShape);
528+
auto dstLL = toLinearLayout(dstShape, dstEnc, dstShape);
529+
if (reshapeLayout(ctx, srcLL, dstShape) != dstLL) {
530+
return failure();
531+
}
532+
return success();
533+
}
534+
return failure();
535+
}
536+
497537
LogicalResult MemDescReshapeOp::inferReturnTypes(
498538
MLIRContext *context, std::optional<Location> loc, MemDescType srcTy,
499539
ArrayRef<int64_t> dstShape, MemDescType &inferredReturnType) {
@@ -503,9 +543,8 @@ LogicalResult MemDescReshapeOp::inferReturnTypes(
503543

504544
Attribute dstEncoding;
505545
if (Attribute srcEnc = srcTy.getEncoding()) {
506-
auto *inferLayout = cast<DialectInferLayoutInterface>(&srcEnc.getDialect());
507-
if (failed(inferLayout->inferReshapeOpEncoding(srcTy.getShape(), srcEnc,
508-
dstShape, dstEncoding, loc)))
546+
if (failed(inferMemDescReshapeOpEncoding(srcTy.getShape(), srcEnc, dstShape,
547+
dstEncoding)))
509548
return failure();
510549
}
511550

0 commit comments

Comments
 (0)