Skip to content

Commit 7f0ecf9

Browse files
committed
Support 2D block IO to load column major for DotOp matrix
Signed-off-by: Lu,Chengjun <[email protected]>
1 parent d1dd736 commit 7f0ecf9

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1648,6 +1648,15 @@ struct LoadOpConversion
16481648
usePackedType = true;
16491649
}
16501650

1651+
if (isTransposeRequired) {
1652+
if (!usePackedType) {
1653+
// use the d32 transpose 2d load.
1654+
loadResultElemType = i32_ty;
1655+
packedElemsPerLanePerDPASInst = 32 / elemSizeInBits;
1656+
usePackedType = true;
1657+
}
1658+
}
1659+
16511660
Type packedDPASOperandType =
16521661
LLVM::getVectorType(loadResultElemType, packedElemsPerLanePerDPASInst);
16531662

@@ -2105,6 +2114,10 @@ struct LoadOpConversion
21052114
rewriter.eraseOp(load2dOp);
21062115
return failure();
21072116
}
2117+
#if 0
2118+
targetInfo.printf(rewriter, "base: %p, baseWidth: %d, baseHeight:%d, pitch:%d, offset_x:%d, offset_y:%d, loadVal: %d",
2119+
{base, base_width, baseHeight, base_pitch, offsetX, offsetY, load2dOp.getResult()});
2120+
#endif
21082121
LLVM_DEBUG(llvm::dbgs() << "Generated load op: " << load2dOp << "\n");
21092122

21102123
unsigned packedRowNum = opIdx == DpasEncodingAttr::OpIdx::OperandA
@@ -2166,11 +2179,14 @@ struct LoadOpConversion
21662179
vblk * packedColNumPerVBlock + col)
21672180
<< ", " << std::to_string(k + row) << "\n";
21682181
});
2182+
auto ret = b.bitcast(loadVal, unpackedDPASOperandType);
2183+
#if 0
2184+
targetInfo.printf(rewriter, "loadVal: %d", {ret});
2185+
#endif
21692186
loadVals[{outer * packedColNum * numLoadPerOutRepCluster +
21702187
rep * packedColNum +
21712188
vblk * packedColNumPerVBlock + col,
2172-
k + row}] =
2173-
b.bitcast(loadVal, unpackedDPASOperandType);
2189+
k + row}] = ret;
21742190
} break;
21752191
case DpasEncodingAttr::OpIdx::OperandC: {
21762192
llvm_unreachable("unexpected OpIdx::OperandC");

0 commit comments

Comments
 (0)