Skip to content

Commit 2404d99

Browse files
PR review round 1
Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent 3873eda commit 2404d99

File tree

1 file changed

+19
-14
lines changed

1 file changed

+19
-14
lines changed

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -670,16 +670,16 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
670670
case 4:
671671
op.setScalesIdxB(val);
672672
break;
673-
default:
673+
default:
674674
break;
675675
}
676676
};
677677

678678
// Obtain flat index from offsets and shape.
679679
auto getIdxFromExtract = [](vector::ExtractOp op) {
680680
ShapedType ty = dyn_cast<ShapedType>(op.getOperand(0).getType());
681-
int cumul = 1;
682-
int idx = 0;
681+
int64_t cumul = 1;
682+
int64_t idx = 0;
683683
for (auto [offset, size] :
684684
reverse(llvm::zip_equal(op.getStaticPosition(), ty.getShape()))) {
685685
idx += offset * cumul;
@@ -720,33 +720,37 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
720720
for (auto opIdx : SmallVector<int64_t>({3, 4})) {
721721
auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
722722
if (!insertOp) {
723-
return failure();
723+
return rewriter.notifyMatchFailure(op,
724+
"defining op not a vector.insert");
724725
}
725726
if (llvm::any_of(insertOp.getResult().getUses(), checkIfUnpackable)) {
726-
return failure();
727+
return rewriter.notifyMatchFailure(op,
728+
"some scaled mfma's already packed");
727729
}
728730

729731
auto extractOp =
730732
insertOp.getOperand(0).getDefiningOp<vector::ExtractOp>();
731733
if (!extractOp) {
732-
return failure();
734+
return rewriter.notifyMatchFailure(op,
735+
"defining op not a vector.extract");
733736
}
734737

735738
Value scaleSrc = extractOp.getOperand(0);
736-
auto stype = dyn_cast<ShapedType>(scaleSrc.getType());
739+
auto stype = dyn_cast<VectorType>(scaleSrc.getType());
737740
if (!stype) {
738-
return failure();
741+
return rewriter.notifyMatchFailure(op, "not a shaped type");
739742
}
740743
// We do not handle dynamic dims yet, assume that the input is padded to
741744
// a static shape now.
742-
if (llvm::any_of(llvm::seq<int64_t>(0, stype.getRank()),
743-
[&](int64_t i) { return stype.isDynamicDim(i); })) {
744-
return failure();
745+
if (!stype.hasStaticShape()) {
746+
return rewriter.notifyMatchFailure(op,
747+
"dynamic dims not yet supported");
745748
}
746749

747750
int64_t numElements = stype.getNumElements();
748-
if (numElements <= 4) {
749-
return failure();
751+
if (numElements <= 4 || !(numElements % 4)) {
752+
return rewriter.notifyMatchFailure(
753+
op, "no packing if # of scales less than or indivisible by four");
750754
}
751755

752756
Type newSrcType = VectorType::get(
@@ -760,7 +764,8 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
760764
loc, newScaleSrc, SmallVector<int64_t>{offsets[0], 0},
761765
SmallVector<int64_t>{1, 4}, SmallVector<int64_t>{1, 1});
762766
Value scale = rewriter.create<vector::ShapeCastOp>(loc, scaleTy, extract);
763-
op.setOperand(opIdx, scale);
767+
rewriter.modifyOpInPlace(
768+
op, [&op, opIdx, scale] { op->setOperand(opIdx, scale); });
764769
setOpsel(opIdx, offsets[1]);
765770
}
766771
return success();

0 commit comments

Comments
 (0)