Skip to content

Commit dc975e3

Browse files
committed
Support 2D block IO to load column major for DotOp matrix
Signed-off-by: Lu,Chengjun <[email protected]>
1 parent d1dd736 commit dc975e3

File tree

1 file changed

+22
-4
lines changed

1 file changed

+22
-4
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1648,6 +1648,15 @@ struct LoadOpConversion
16481648
usePackedType = true;
16491649
}
16501650

1651+
if (isTransposeRequired) {
1652+
if (!usePackedType) {
1653+
// use the d32 transpose 2d load.
1654+
loadResultElemType = i32_ty;
1655+
packedElemsPerLanePerDPASInst = 32 / elemSizeInBits;
1656+
usePackedType = true;
1657+
}
1658+
}
1659+
16511660
Type packedDPASOperandType =
16521661
LLVM::getVectorType(loadResultElemType, packedElemsPerLanePerDPASInst);
16531662

@@ -2082,12 +2091,14 @@ struct LoadOpConversion
20822091
offsetX = b.udiv(offsetX, b.i32_val(32 / originalElemBits));
20832092
}
20842093

2094+
Value base_width = b.mul(baseWidth, elemSizeInBytes);
2095+
Value base_pitch = b.mul(pitch, elemSizeInBytes);
20852096
auto load2dOp = rewriter.create<TritonGEN::Matrix2DBlockLoadOp>(
20862097
loc, load2DGenXType,
20872098
/*ptr*/ base,
2088-
/*base_width*/ b.mul(baseWidth, elemSizeInBytes),
2099+
/*base_width*/ base_width,
20892100
/*base_height*/ baseHeight,
2090-
/*base_pitch*/ b.mul(pitch, elemSizeInBytes),
2101+
/*base_pitch*/ base_pitch,
20912102
/*x*/ b.trunc(i32_ty, offsetX),
20922103
/*y*/ b.trunc(i32_ty, offsetY),
20932104
/*elem_size_in_bits*/ elemSizeInBits,
@@ -2105,6 +2116,10 @@ struct LoadOpConversion
21052116
rewriter.eraseOp(load2dOp);
21062117
return failure();
21072118
}
2119+
#if 0
2120+
targetInfo.printf(rewriter, "base: %p, baseWidth: %d, baseHeight:%d, pitch:%d, offset_x:%d, offset_y:%d, loadVal: %d",
2121+
{base, base_width, baseHeight, base_pitch, offsetX, offsetY, load2dOp.getResult()});
2122+
#endif
21082123
LLVM_DEBUG(llvm::dbgs() << "Generated load op: " << load2dOp << "\n");
21092124

21102125
unsigned packedRowNum = opIdx == DpasEncodingAttr::OpIdx::OperandA
@@ -2166,11 +2181,14 @@ struct LoadOpConversion
21662181
vblk * packedColNumPerVBlock + col)
21672182
<< ", " << std::to_string(k + row) << "\n";
21682183
});
2184+
auto ret = b.bitcast(loadVal, unpackedDPASOperandType);
2185+
#if 0
2186+
targetInfo.printf(rewriter, "loadVal: %d", {ret});
2187+
#endif
21692188
loadVals[{outer * packedColNum * numLoadPerOutRepCluster +
21702189
rep * packedColNum +
21712190
vblk * packedColNumPerVBlock + col,
2172-
k + row}] =
2173-
b.bitcast(loadVal, unpackedDPASOperandType);
2191+
k + row}] = ret;
21742192
} break;
21752193
case DpasEncodingAttr::OpIdx::OperandC: {
21762194
llvm_unreachable("unexpected OpIdx::OperandC");

0 commit comments

Comments
 (0)