@@ -2456,11 +2456,6 @@ struct StoreOpToBlockIOConversion
2456
2456
auto b = TritonLLVMOpBuilder (loc, rewriter);
2457
2457
Type resultType = op.getValue ().getType ();
2458
2458
auto tensorType = cast<RankedTensorType>(resultType);
2459
-
2460
- // Only lower StoreOp with dpas layout encoding.
2461
- if (!hasDpasEncoding (tensorType))
2462
- return failure ();
2463
-
2464
2459
auto dpasLayout = cast<DpasEncodingAttr>(tensorType.getEncoding ());
2465
2460
LLVMTypeConverter *typeConverter = getTypeConverter ();
2466
2461
MLIRContext *ctx = rewriter.getContext ();
@@ -2471,14 +2466,21 @@ struct StoreOpToBlockIOConversion
2471
2466
const ArrayRef<int64_t > tensorShape = tensorType.getShape ();
2472
2467
size_t rank = tensorShape.size ();
2473
2468
unsigned numElems = getTotalElemsPerThread (tensorType);
2469
+
2474
2470
SmallVector<unsigned > elemsPerInstr = dpasLayout.getDPASInstShapeC ();
2471
+ // 2D block store supports 8 rows at most.
2472
+ unsigned tileHeight = std::min (8u , elemsPerInstr[0 ]);
2473
+ // 2D block store supports 64 bytes per row at most.
2474
+ unsigned tileWidth = elemsPerInstr[1 ];
2475
+ unsigned totalBytesPerRowPerMatrix = tileWidth * elemSizeInBits / 8 ;
2476
+ if (totalBytesPerRowPerMatrix > 64 )
2477
+ return failure ();
2478
+
2475
2479
auto warpsPerCTA = dpasLayout.getWarpsPerCTA ();
2476
2480
SmallVector<int64_t > numReps =
2477
2481
dpasLayout.getDPASRepetitions (tensorShape, 2 );
2478
2482
SmallVector<unsigned > dpasWarpsOrder =
2479
2483
getMatrixOrder (warpsPerCTA.size (), /* rowMajor*/ true );
2480
- unsigned threadsPerWarp =
2481
- product<unsigned >(getThreadsPerWarp (dpasLayout, tensorShape));
2482
2484
2483
2485
Value warpId = rewriter.create <arith::IndexCastOp>(
2484
2486
loc, i32_ty,
@@ -2487,25 +2489,34 @@ struct StoreOpToBlockIOConversion
2487
2489
SmallVector<Value> multiDimWarpId = mlir::LLVM::delinearize (
2488
2490
rewriter, loc, warpId, warpsPerCTA, dpasWarpsOrder);
2489
2491
2490
- int64_t elemsPerLane = product<unsigned >(elemsPerInstr) / threadsPerWarp;
2491
- Type store2DGenXType =
2492
- LLVM::getVectorType (IntegerType::get (ctx, elemSizeInBits),
2493
- elemsPerLane); // make it opaque type.
2494
-
2495
2492
Value blockPtr = adaptor.getPtr ();
2496
2493
auto [base, width, height, rowStride, colStride, offsetBaseX, offsetBaseY] =
2497
2494
getValuesFromBlockPointerStruct (blockPtr, rewriter);
2498
2495
2499
- auto vals = unpackLLElements (loc, adaptor.getValue (), rewriter);
2500
- assert (vals.size () == numElems);
2501
-
2502
2496
width = b.trunc (i32_ty, width);
2503
- height = b.trunc (i32_ty, height);
2504
2497
rowStride = b.trunc (i32_ty, rowStride);
2505
2498
// encoded as bytes.
2506
2499
Value baseWidth = b.mul (width, elemSizeInBytes);
2500
+ Value baseHeight = b.trunc (i32_ty, height);
2507
2501
// encoded as bytes.
2508
- Value basePitch = b.mul (rowStride, elemSizeInBytes);
2502
+ Value pitch = b.mul (rowStride, elemSizeInBytes);
2503
+ // 2D block store only supports vBlocks = 1.
2504
+ unsigned vBlocks = 1 ;
2505
+
2506
+ // Get the LLVM values for store values
2507
+ SmallVector<Value> valElems =
2508
+ unpackLLElements (loc, adaptor.getValue (), rewriter);
2509
+ assert (valElems.size () == numElems &&
2510
+ " the number of store values does not match the number of elements" );
2511
+
2512
+ unsigned threadsPerWarp =
2513
+ TritonGPUDialect::getThreadsPerWarp (op->getParentOfType <ModuleOp>());
2514
+
2515
+ int64_t elemsPerLane = tileHeight * tileWidth / threadsPerWarp;
2516
+ Type opaqueType = IntegerType::get (ctx, elemSizeInBits);
2517
+ Type store2DGenXType =
2518
+ LLVM::getVectorType (opaqueType,
2519
+ elemsPerLane); // make it opaque type.
2509
2520
2510
2521
// A warp stride for the replicates.
2511
2522
SmallVector<unsigned > repClusterShape = dpasLayout.getShapeC ();
@@ -2538,34 +2549,34 @@ struct StoreOpToBlockIOConversion
2538
2549
for (int m = 0 ; m < numRepOuter; ++m) {
2539
2550
for (int n = 0 ; n < numRepInner; ++n) {
2540
2551
for (int repM = 0 ; repM < repCluster[0 ]; ++repM) {
2541
- Value offsetY =
2542
- b.add (warpId0Offset,
2543
- b.i32_val (m * replicaStride[0 ] + repM * elemsPerInstr[0 ]));
2552
+ Value offsetY = b.add (warpId0Offset, b.i32_val (m * replicaStride[0 ] +
2553
+ repM * tileHeight));
2544
2554
for (int repN = 0 ; repN < repCluster[1 ]; ++repN) {
2545
2555
Value offsetX =
2546
- b.add (warpId1Offset, b.i32_val (n * replicaStride[1 ] +
2547
- repN * elemsPerInstr[1 ]));
2556
+ b.add (warpId1Offset,
2557
+ b.i32_val (n * replicaStride[1 ] + repN * tileWidth));
2558
+
2548
2559
Value storeVal = rewriter.create <LLVM::UndefOp>(
2549
2560
loc, LLVM::getVectorType (typeConverter->convertType (eltTy),
2550
2561
elemsPerLane));
2551
2562
for (size_t i = 0 ; i < elemsPerLane; ++i) {
2552
2563
storeVal =
2553
- b.insert_element (storeVal, vals [valOffset], b.i32_val (i));
2564
+ b.insert_element (storeVal, valElems [valOffset], b.i32_val (i));
2554
2565
++valOffset;
2555
2566
}
2556
2567
2557
2568
auto newOp = rewriter.create <TritonGEN::Matrix2DBlockStoreOp>(
2558
2569
loc,
2559
2570
/* ptr*/ base,
2560
2571
/* base_width*/ baseWidth,
2561
- /* base_height*/ height ,
2562
- /* base_pitch*/ basePitch ,
2572
+ /* base_height*/ baseHeight ,
2573
+ /* base_pitch*/ pitch ,
2563
2574
/* x*/ offsetX,
2564
2575
/* y*/ offsetY,
2565
2576
/* elem_size_in_bits*/ elemSizeInBits,
2566
- /* tile_width*/ elemsPerInstr[ 1 ] ,
2567
- /* tile_height*/ elemsPerInstr[ 0 ] ,
2568
- /* v_blocks*/ 1 ,
2577
+ /* tile_width*/ tileWidth ,
2578
+ /* tile_height*/ tileHeight ,
2579
+ /* v_blocks*/ vBlocks ,
2569
2580
/* stored_val*/ b.bitcast (storeVal, store2DGenXType));
2570
2581
2571
2582
if (failed (newOp.verify ())) {
0 commit comments