@@ -2456,11 +2456,6 @@ struct StoreOpToBlockIOConversion
24562456 auto b = TritonLLVMOpBuilder (loc, rewriter);
24572457 Type resultType = op.getValue ().getType ();
24582458 auto tensorType = cast<RankedTensorType>(resultType);
2459-
2460- // Only lower StoreOp with dpas layout encoding.
2461- if (!hasDpasEncoding (tensorType))
2462- return failure ();
2463-
24642459 auto dpasLayout = cast<DpasEncodingAttr>(tensorType.getEncoding ());
24652460 LLVMTypeConverter *typeConverter = getTypeConverter ();
24662461 MLIRContext *ctx = rewriter.getContext ();
@@ -2471,14 +2466,21 @@ struct StoreOpToBlockIOConversion
24712466 const ArrayRef<int64_t > tensorShape = tensorType.getShape ();
24722467 size_t rank = tensorShape.size ();
24732468 unsigned numElems = getTotalElemsPerThread (tensorType);
2469+
24742470 SmallVector<unsigned > elemsPerInstr = dpasLayout.getDPASInstShapeC ();
2471+ // 2D block store supports 8 rows at most.
2472+ unsigned tileHeight = std::min (8u , elemsPerInstr[0 ]);
2473+ // 2D block store supports 64 bytes per row at most.
2474+ unsigned tileWidth = elemsPerInstr[1 ];
2475+ unsigned totalBytesPerRowPerMatrix = tileWidth * elemSizeInBits / 8 ;
2476+ if (totalBytesPerRowPerMatrix > 64 )
2477+ return failure ();
2478+
24752479 auto warpsPerCTA = dpasLayout.getWarpsPerCTA ();
24762480 SmallVector<int64_t > numReps =
24772481 dpasLayout.getDPASRepetitions (tensorShape, 2 );
24782482 SmallVector<unsigned > dpasWarpsOrder =
24792483 getMatrixOrder (warpsPerCTA.size (), /* rowMajor*/ true );
2480- unsigned threadsPerWarp =
2481- product<unsigned >(getThreadsPerWarp (dpasLayout, tensorShape));
24822484
24832485 Value warpId = rewriter.create <arith::IndexCastOp>(
24842486 loc, i32_ty,
@@ -2487,25 +2489,34 @@ struct StoreOpToBlockIOConversion
24872489 SmallVector<Value> multiDimWarpId = mlir::LLVM::delinearize (
24882490 rewriter, loc, warpId, warpsPerCTA, dpasWarpsOrder);
24892491
2490- int64_t elemsPerLane = product<unsigned >(elemsPerInstr) / threadsPerWarp;
2491- Type store2DGenXType =
2492- LLVM::getVectorType (IntegerType::get (ctx, elemSizeInBits),
2493- elemsPerLane); // make it opaque type.
2494-
24952492 Value blockPtr = adaptor.getPtr ();
24962493 auto [base, width, height, rowStride, colStride, offsetBaseX, offsetBaseY] =
24972494 getValuesFromBlockPointerStruct (blockPtr, rewriter);
24982495
2499- auto vals = unpackLLElements (loc, adaptor.getValue (), rewriter);
2500- assert (vals.size () == numElems);
2501-
25022496 width = b.trunc (i32_ty, width);
2503- height = b.trunc (i32_ty, height);
25042497 rowStride = b.trunc (i32_ty, rowStride);
25052498 // encoded as bytes.
25062499 Value baseWidth = b.mul (width, elemSizeInBytes);
2500+ Value baseHeight = b.trunc (i32_ty, height);
25072501 // encoded as bytes.
2508- Value basePitch = b.mul (rowStride, elemSizeInBytes);
2502+ Value pitch = b.mul (rowStride, elemSizeInBytes);
2503+ // 2D block store only supports vBlocks = 1.
2504+ unsigned vBlocks = 1 ;
2505+
2506+ // Get the LLVM values for store values
2507+ SmallVector<Value> valElems =
2508+ unpackLLElements (loc, adaptor.getValue (), rewriter);
2509+ assert (valElems.size () == numElems &&
2510+ " the number of store values does not match the number of elements" );
2511+
2512+ unsigned threadsPerWarp =
2513+ TritonGPUDialect::getThreadsPerWarp (op->getParentOfType <ModuleOp>());
2514+
2515+ int64_t elemsPerLane = tileHeight * tileWidth / threadsPerWarp;
2516+ Type opaqueType = IntegerType::get (ctx, elemSizeInBits);
2517+ Type store2DGenXType =
2518+ LLVM::getVectorType (opaqueType,
2519+ elemsPerLane); // make it opaque type.
25092520
25102521 // A warp stride for the replicates.
25112522 SmallVector<unsigned > repClusterShape = dpasLayout.getShapeC ();
@@ -2538,34 +2549,34 @@ struct StoreOpToBlockIOConversion
25382549 for (int m = 0 ; m < numRepOuter; ++m) {
25392550 for (int n = 0 ; n < numRepInner; ++n) {
25402551 for (int repM = 0 ; repM < repCluster[0 ]; ++repM) {
2541- Value offsetY =
2542- b.add (warpId0Offset,
2543- b.i32_val (m * replicaStride[0 ] + repM * elemsPerInstr[0 ]));
2552+ Value offsetY = b.add (warpId0Offset, b.i32_val (m * replicaStride[0 ] +
2553+ repM * tileHeight));
25442554 for (int repN = 0 ; repN < repCluster[1 ]; ++repN) {
25452555 Value offsetX =
2546- b.add (warpId1Offset, b.i32_val (n * replicaStride[1 ] +
2547- repN * elemsPerInstr[1 ]));
2556+ b.add (warpId1Offset,
2557+ b.i32_val (n * replicaStride[1 ] + repN * tileWidth));
2558+
25482559 Value storeVal = rewriter.create <LLVM::UndefOp>(
25492560 loc, LLVM::getVectorType (typeConverter->convertType (eltTy),
25502561 elemsPerLane));
25512562 for (size_t i = 0 ; i < elemsPerLane; ++i) {
25522563 storeVal =
2553- b.insert_element (storeVal, vals [valOffset], b.i32_val (i));
2564+ b.insert_element (storeVal, valElems [valOffset], b.i32_val (i));
25542565 ++valOffset;
25552566 }
25562567
25572568 auto newOp = rewriter.create <TritonGEN::Matrix2DBlockStoreOp>(
25582569 loc,
25592570 /* ptr*/ base,
25602571 /* base_width*/ baseWidth,
2561- /* base_height*/ height ,
2562- /* base_pitch*/ basePitch ,
2572+ /* base_height*/ baseHeight ,
2573+ /* base_pitch*/ pitch ,
25632574 /* x*/ offsetX,
25642575 /* y*/ offsetY,
25652576 /* elem_size_in_bits*/ elemSizeInBits,
2566- /* tile_width*/ elemsPerInstr[ 1 ] ,
2567- /* tile_height*/ elemsPerInstr[ 0 ] ,
2568- /* v_blocks*/ 1 ,
2577+ /* tile_width*/ tileWidth ,
2578+ /* tile_height*/ tileHeight ,
2579+ /* v_blocks*/ vBlocks ,
25692580 /* stored_val*/ b.bitcast (storeVal, store2DGenXType));
25702581
25712582 if (failed (newOp.verify ())) {
0 commit comments