Skip to content

Commit cd493ff

Browse files
authored
Set one-matrix-per-load attribute on tt.dot B operand (#4974)
When the B operand of a chained dot operation is defined by a tt.load followed by a tt.trans the current implementation fails to set the "one-matric-per-load" attribute. This PR fixes this problem. Signed-off-by: Ettore Tiotto <[email protected]>
1 parent 52e38d1 commit cd493ff

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,18 @@ namespace mlir::triton::gpu::intel {
3333

3434
namespace {
3535

36+
// FIXME: Remove once IGC can split large 2D block loads.
37+
static void setAttrOnBOperand(tt::DotOp dotOp, StringRef attrName,
38+
Attribute attr) {
39+
Operation *defOp = dotOp.getB().getDefiningOp();
40+
while (auto convOp = dyn_cast_or_null<ttg::ConvertLayoutOp>(defOp))
41+
defOp = convOp.getSrc().getDefiningOp();
42+
if (auto transOp = dyn_cast_or_null<tt::TransOp>(defOp))
43+
defOp = transOp.getOperand().getDefiningOp();
44+
if (auto loadOp = dyn_cast_or_null<tt::LoadOp>(defOp))
45+
loadOp->setAttr(attrName, attr);
46+
}
47+
3648
SmallVector<unsigned>
3749
getWarpsPerTile(tt::DotOp dotOp, ttgi::DpasEncodingAttr::DPASCapability dpasCap,
3850
const ArrayRef<int64_t> shape, unsigned numWarps) {
@@ -46,14 +58,6 @@ getWarpsPerTile(tt::DotOp dotOp, ttgi::DpasEncodingAttr::DPASCapability dpasCap,
4658
if (auto forOp = op->getParentOfType<scf::ForOp>()) {
4759
// FIXME: Remove once IGC can split large 2D block loads.
4860
MLIRContext *ctx = forOp->getContext();
49-
auto setAttrOnBOperand = [&](tt::DotOp dotOp, StringRef attrName,
50-
Attribute attr) {
51-
Operation *defOp = dotOp.getB().getDefiningOp();
52-
while (auto convOp = dyn_cast_or_null<ttg::ConvertLayoutOp>(defOp))
53-
defOp = convOp.getSrc().getDefiningOp();
54-
if (auto loadOp = dyn_cast_or_null<tt::LoadOp>(defOp))
55-
loadOp->setAttr(attrName, attr);
56-
};
5761
StringRef attrName =
5862
ttgi::TritonIntelGPUDialect::getOneMatrixPerLoadAttrName();
5963
setAttrOnBOperand(dotOp, attrName, UnitAttr::get(ctx));

0 commit comments

Comments
 (0)