@@ -1240,35 +1240,42 @@ struct LoadOpToBlockIOConversion
1240
1240
loc, load2DGenXType, rewriter.getZeroAttr (load2DGenXType));
1241
1241
}
1242
1242
1243
+ auto createLoadInstruction = [&]() -> SmallVector<Value, 1 > {
1244
+ // Use the top-left address of the block to load the data.
1245
+ Value addrElem = b.bitcast (ptrs[{offsetM, offsetN}],
1246
+ ptr_ty (ctx, 1 /* global*/ ));
1247
+ addrElem = targetInfo.shuffleIdx (rewriter, loc, addrElem, 0 );
1248
+
1249
+ auto load2dOp = rewriter.create <TritonGEN::Matrix2DBlockLoadOp>(
1250
+ loc, load2DGenXType,
1251
+ /* ptr*/ addrElem,
1252
+ /* base_width*/ baseWidth,
1253
+ /* base_height*/ baseHeight,
1254
+ /* base_pitch*/ pitch,
1255
+ /* x*/ b.i32_val (0 ),
1256
+ /* y*/ b.i32_val (0 ),
1257
+ /* elem_size_in_bits*/ elemSizeInBits,
1258
+ /* tile_width*/ tileWidth,
1259
+ /* tile_height*/ tileHeight,
1260
+ /* v_blocks*/ vBlocks,
1261
+ /* transpose*/ false ,
1262
+ /* vnni_transform*/
1263
+ (usePackedType &&
1264
+ opIdx == DpasEncodingAttr::OpIdx::OperandB &&
1265
+ !isTransposeRequired && originalElemBits != 32 ));
1266
+ return {load2dOp};
1267
+ };
1268
+
1269
+ Value ret;
1243
1270
// Create a predicated load operation.
1244
- Block &endBlock = LLVM::intel::createPredicatedBlock (
1245
- rewriter, loc, pred, SmallVector<Value, 1 >{other_}, [&]() {
1246
- // Use the top-left address of the block to load the data.
1247
- Value addrElem = b.bitcast (ptrs[{offsetM, offsetN}],
1248
- ptr_ty (ctx, 1 /* global*/ ));
1249
- addrElem = targetInfo.shuffleIdx (rewriter, loc, addrElem, 0 );
1250
-
1251
- auto load2dOp =
1252
- rewriter.create <TritonGEN::Matrix2DBlockLoadOp>(
1253
- loc, load2DGenXType,
1254
- /* ptr*/ addrElem,
1255
- /* base_width*/ baseWidth,
1256
- /* base_height*/ baseHeight,
1257
- /* base_pitch*/ pitch,
1258
- /* x*/ b.i32_val (0 ),
1259
- /* y*/ b.i32_val (0 ),
1260
- /* elem_size_in_bits*/ elemSizeInBits,
1261
- /* tile_width*/ tileWidth,
1262
- /* tile_height*/ tileHeight,
1263
- /* v_blocks*/ vBlocks,
1264
- /* transpose*/ false ,
1265
- /* vnni_transform*/
1266
- (usePackedType &&
1267
- opIdx == DpasEncodingAttr::OpIdx::OperandB &&
1268
- !isTransposeRequired && originalElemBits != 32 ));
1269
- return SmallVector<Value, 1 >{load2dOp};
1270
- });
1271
- Value ret = *endBlock.args_begin ();
1271
+ if (llMask) {
1272
+ Block &endBlock = LLVM::intel::createPredicatedBlock (
1273
+ rewriter, loc, pred, SmallVector<Value, 1 >{other_},
1274
+ createLoadInstruction);
1275
+ ret = *endBlock.args_begin ();
1276
+ } else {
1277
+ ret = createLoadInstruction ()[0 ];
1278
+ }
1272
1279
1273
1280
unsigned numOperandsM = opIdx != DpasEncodingAttr::OpIdx::OperandB
1274
1281
? numOperandsOuterDimPerLoad
0 commit comments