@@ -1648,6 +1648,15 @@ struct LoadOpConversion
1648
1648
usePackedType = true ;
1649
1649
}
1650
1650
1651
+ if (isTransposeRequired) {
1652
+ if (!usePackedType) {
1653
+ // use the d32 transpose 2d load.
1654
+ loadResultElemType = i32_ty;
1655
+ packedElemsPerLanePerDPASInst = 32 / elemSizeInBits;
1656
+ usePackedType = true ;
1657
+ }
1658
+ }
1659
+
1651
1660
Type packedDPASOperandType =
1652
1661
LLVM::getVectorType (loadResultElemType, packedElemsPerLanePerDPASInst);
1653
1662
@@ -2082,12 +2091,14 @@ struct LoadOpConversion
2082
2091
offsetX = b.udiv (offsetX, b.i32_val (32 / originalElemBits));
2083
2092
}
2084
2093
2094
+ Value base_width = b.mul (baseWidth, elemSizeInBytes);
2095
+ Value base_pitch = b.mul (pitch, elemSizeInBytes);
2085
2096
auto load2dOp = rewriter.create <TritonGEN::Matrix2DBlockLoadOp>(
2086
2097
loc, load2DGenXType,
2087
2098
/* ptr*/ base,
2088
- /* base_width*/ b. mul (baseWidth, elemSizeInBytes) ,
2099
+ /* base_width*/ base_width ,
2089
2100
/* base_height*/ baseHeight,
2090
- /* base_pitch*/ b. mul (pitch, elemSizeInBytes) ,
2101
+ /* base_pitch*/ base_pitch ,
2091
2102
/* x*/ b.trunc (i32_ty, offsetX),
2092
2103
/* y*/ b.trunc (i32_ty, offsetY),
2093
2104
/* elem_size_in_bits*/ elemSizeInBits,
@@ -2105,6 +2116,10 @@ struct LoadOpConversion
2105
2116
rewriter.eraseOp (load2dOp);
2106
2117
return failure ();
2107
2118
}
2119
+ #if 0
2120
+ targetInfo.printf(rewriter, "base: %p, baseWidth: %d, baseHeight:%d, pitch:%d, offset_x:%d, offset_y:%d, loadVal: %d",
2121
+ {base, base_width, baseHeight, base_pitch, offsetX, offsetY, load2dOp.getResult()});
2122
+ #endif
2108
2123
LLVM_DEBUG (llvm::dbgs () << " Generated load op: " << load2dOp << " \n " );
2109
2124
2110
2125
unsigned packedRowNum = opIdx == DpasEncodingAttr::OpIdx::OperandA
@@ -2166,11 +2181,14 @@ struct LoadOpConversion
2166
2181
vblk * packedColNumPerVBlock + col)
2167
2182
<< " , " << std::to_string (k + row) << " \n " ;
2168
2183
});
2184
+ auto ret = b.bitcast (loadVal, unpackedDPASOperandType);
2185
+ #if 0
2186
+ targetInfo.printf(rewriter, "loadVal: %d", {ret});
2187
+ #endif
2169
2188
loadVals[{outer * packedColNum * numLoadPerOutRepCluster +
2170
2189
rep * packedColNum +
2171
2190
vblk * packedColNumPerVBlock + col,
2172
- k + row}] =
2173
- b.bitcast (loadVal, unpackedDPASOperandType);
2191
+ k + row}] = ret;
2174
2192
} break ;
2175
2193
case DpasEncodingAttr::OpIdx::OperandC: {
2176
2194
llvm_unreachable (" unexpected OpIdx::OperandC" );
0 commit comments