@@ -1240,35 +1240,42 @@ struct LoadOpToBlockIOConversion
12401240 loc, load2DGenXType, rewriter.getZeroAttr (load2DGenXType));
12411241 }
12421242
1243+ auto createLoadInstruction = [&]() -> SmallVector<Value, 1 > {
1244+ // Use the top-left address of the block to load the data.
1245+ Value addrElem = b.bitcast (ptrs[{offsetM, offsetN}],
1246+ ptr_ty (ctx, 1 /* global*/ ));
1247+ addrElem = targetInfo.shuffleIdx (rewriter, loc, addrElem, 0 );
1248+
1249+ auto load2dOp = rewriter.create <TritonGEN::Matrix2DBlockLoadOp>(
1250+ loc, load2DGenXType,
1251+ /* ptr*/ addrElem,
1252+ /* base_width*/ baseWidth,
1253+ /* base_height*/ baseHeight,
1254+ /* base_pitch*/ pitch,
1255+ /* x*/ b.i32_val (0 ),
1256+ /* y*/ b.i32_val (0 ),
1257+ /* elem_size_in_bits*/ elemSizeInBits,
1258+ /* tile_width*/ tileWidth,
1259+ /* tile_height*/ tileHeight,
1260+ /* v_blocks*/ vBlocks,
1261+ /* transpose*/ false ,
1262+ /* vnni_transform*/
1263+ (usePackedType &&
1264+ opIdx == DpasEncodingAttr::OpIdx::OperandB &&
1265+ !isTransposeRequired && originalElemBits != 32 ));
1266+ return {load2dOp};
1267+ };
1268+
1269+ Value ret;
12431270 // Create a predicated load operation.
1244- Block &endBlock = LLVM::intel::createPredicatedBlock (
1245- rewriter, loc, pred, SmallVector<Value, 1 >{other_}, [&]() {
1246- // Use the top-left address of the block to load the data.
1247- Value addrElem = b.bitcast (ptrs[{offsetM, offsetN}],
1248- ptr_ty (ctx, 1 /* global*/ ));
1249- addrElem = targetInfo.shuffleIdx (rewriter, loc, addrElem, 0 );
1250-
1251- auto load2dOp =
1252- rewriter.create <TritonGEN::Matrix2DBlockLoadOp>(
1253- loc, load2DGenXType,
1254- /* ptr*/ addrElem,
1255- /* base_width*/ baseWidth,
1256- /* base_height*/ baseHeight,
1257- /* base_pitch*/ pitch,
1258- /* x*/ b.i32_val (0 ),
1259- /* y*/ b.i32_val (0 ),
1260- /* elem_size_in_bits*/ elemSizeInBits,
1261- /* tile_width*/ tileWidth,
1262- /* tile_height*/ tileHeight,
1263- /* v_blocks*/ vBlocks,
1264- /* transpose*/ false ,
1265- /* vnni_transform*/
1266- (usePackedType &&
1267- opIdx == DpasEncodingAttr::OpIdx::OperandB &&
1268- !isTransposeRequired && originalElemBits != 32 ));
1269- return SmallVector<Value, 1 >{load2dOp};
1270- });
1271- Value ret = *endBlock.args_begin ();
1271+ if (llMask) {
1272+ Block &endBlock = LLVM::intel::createPredicatedBlock (
1273+ rewriter, loc, pred, SmallVector<Value, 1 >{other_},
1274+ createLoadInstruction);
1275+ ret = *endBlock.args_begin ();
1276+ } else {
1277+ ret = createLoadInstruction ()[0 ];
1278+ }
12721279
12731280 unsigned numOperandsM = opIdx != DpasEncodingAttr::OpIdx::OperandB
12741281 ? numOperandsOuterDimPerLoad
0 commit comments