diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index 9c234c1e866b9..0457f8128b908 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -402,30 +402,58 @@ struct UnrollCreateDescOp : public UnrollPattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); xegpu::TensorDescType tdescTy = op.getType(); + TypedValue<::mlir::VectorType> indiceVec = op.getOffsets(); + VectorType indiceVecTy = indiceVec.getType(); - // check if the tensor descriptor type is a 1d vector type - if (tdescTy.getRank() > 1) + if (!tdescTy.isScattered()) return failure(); std::optional> targetShape = getTargetShape(op); if (!targetShape) return failure(); - auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0]; - - TypedValue<::mlir::VectorType> indiceVec = op.getOffsets(); - VectorType indiceVecTy = indiceVec.getType(); + SmallVector targetIndiceShape(*targetShape); + int64_t originalChunkSize = tdescTy.getChunkSize(); + // IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1. + if (originalChunkSize > 1) + targetIndiceShape.pop_back(); + auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0]; SmallVector convertedIndiceTypes = - getUnrolledTypes(indiceVecTy, *targetShape); + getUnrolledTypes(indiceVecTy, targetIndiceShape); SmallVector convertedIndiceVec = - pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter); + pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter); SmallVector newOps; - for (auto indice : convertedIndiceVec) { - auto newOp = rewriter.create(loc, newTdescTy, - op.getSource(), indice); - newOps.push_back(newOp); + + // More indices is need when chunkSize > 1. Since a big load from one + // address could be break into multiple small loads. + if (originalChunkSize > 1) { + int64_t blockedChunkSize = targetShape->back(); + int64_t numNewChunks = originalChunkSize / blockedChunkSize; + + for (auto [indice, indiceType] : + llvm::zip(convertedIndiceVec, convertedIndiceTypes)) { + for (int64_t i = 0; i < numNewChunks; ++i) { + // Compute the offset + Value inc = rewriter.create( + loc, i * blockedChunkSize); + Value incVec = rewriter.create(loc, indiceType, inc); + Value offsetIndice = + rewriter.create(loc, indice, incVec); + + auto newOp = rewriter.create( + loc, newTdescTy, op.getSource(), offsetIndice); + + newOps.push_back(newOp); + } + } + } else { + for (auto indice : convertedIndiceVec) { + auto newOp = rewriter.create( + loc, newTdescTy, op.getSource(), indice); + newOps.push_back(newOp); + } } Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter); @@ -444,16 +472,18 @@ struct UnrollLoadGatherOp : public UnrollPattern { VectorType valueTy = llvm::dyn_cast(op.getValue().getType()); xegpu::TensorDescType tdescTy = op.getTensorDescType(); - // check if the tensor descriptor type is a 1d vector type - if (tdescTy.getRank() > 1) + if (!tdescTy.isScattered()) return failure(); - VectorType maskTy = llvm::dyn_cast(op.getMask().getType()); - std::optional> targetShape = getTargetShape(op); if (!targetShape) return failure(); + SmallVector targetMaskShape(*targetShape); + int64_t originalChunkSize = tdescTy.getChunkSize(); + + VectorType maskTy = llvm::dyn_cast(op.getMask().getType()); + Type elemTy = tdescTy.getElementType(); VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy); @@ -462,10 +492,29 @@ struct UnrollLoadGatherOp : public UnrollPattern { SmallVector convertedTdescs = pack( op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); - SmallVector convertedMaskTypes = - getUnrolledTypes(maskTy, *targetShape); - SmallVector convertedMasks = - pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter); + SmallVector convertedMaskTypes; + SmallVector convertedMasks; + + if (originalChunkSize > 1) { + targetMaskShape.pop_back(); + convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape); + SmallVector convertedMasks1D = pack( + op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter); + int64_t blockedChunkSize = targetShape->back(); + int64_t numNewChunks = originalChunkSize / blockedChunkSize; + + for (auto mask : convertedMasks1D) { + for (int64_t i = 0; i < numNewChunks; ++i) + convertedMasks.push_back(mask); + } + // This is to handle the transpose effect when chunkSize > 1. + std::swap((*targetShape)[0], (*targetShape)[1]); + newValueTy = valueTy.cloneWith(*targetShape, elemTy); + } else { + convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape); + convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape, + loc, rewriter); + } SmallVector newOps; for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) { @@ -476,7 +525,6 @@ struct UnrollLoadGatherOp : public UnrollPattern { } Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter); - rewriter.replaceOp(op, castOp); return success(); } @@ -489,8 +537,7 @@ struct UnrollPrefetchOp : public UnrollPattern { Location loc = op.getLoc(); xegpu::TensorDescType tdescTy = op.getTensorDescType(); - // check if the tensor descriptor type is a 1d vector type - if (tdescTy.getRank() > 1) + if (!tdescTy.isScattered()) return failure(); std::optional> targetShape = getTargetShape(op); @@ -519,30 +566,51 @@ struct UnrollStoreScatterOp : public UnrollPattern { VectorType valueTy = llvm::dyn_cast(op.getValue().getType()); xegpu::TensorDescType tdescTy = op.getTensorDescType(); - // check if the tensor descriptor type is a 1d vector type - if (tdescTy.getRank() > 1) + if (!tdescTy.isScattered()) return failure(); - VectorType maskTy = llvm::dyn_cast(op.getMask().getType()); - std::optional> targetShape = getTargetShape(op); if (!targetShape) return failure(); - SmallVector convertedValTypes = - getUnrolledTypes(valueTy, *targetShape); + SmallVector targetIndiceShape(*targetShape); + int64_t originalChunkSize = tdescTy.getChunkSize(); + + VectorType maskTy = llvm::dyn_cast(op.getMask().getType()); + SmallVector convertedTdescTypes = getUnrolledTypes(tdescTy, *targetShape); - - SmallVector convertedValues = - pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter); SmallVector convertedTdescs = pack( op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); - SmallVector convertedMaskTypes = - getUnrolledTypes(maskTy, *targetShape); - SmallVector convertedMasks = - pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter); + SmallVector convertedMaskTypes; + SmallVector convertedMasks; + + if (originalChunkSize > 1) { + int64_t blockedChunkSize = targetShape->back(); + int64_t numNewChunks = originalChunkSize / blockedChunkSize; + convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]); + SmallVector convertedMasks1D = pack( + op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter); + + for (auto mask : convertedMasks1D) { + for (int64_t i = 0; i < numNewChunks; ++i) { + convertedMasks.push_back(mask); + } + } + // This is to handle the transpose effect when chunkSize > 1. + std::swap((*targetShape)[0], (*targetShape)[1]); + + } else { + convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape); + convertedMasks = + pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter); + } + + SmallVector convertedValTypes = + getUnrolledTypes(valueTy, *targetShape); + SmallVector convertedValues = + pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter); for (size_t i = 0; i < convertedValues.size(); ++i) { Value v = convertedValues[i]; @@ -565,8 +633,10 @@ struct UnrollUpdateOffsetOp : public UnrollPattern { Location loc = op.getLoc(); xegpu::TensorDescType tdescTy = op.getTensorDescType(); - // check if the tensor descriptor type is a 1d vector type - if (tdescTy.getRank() > 1) + if (tdescTy.getRank() > 2) + return failure(); + + if (!tdescTy.isScattered()) return failure(); std::optional> targetShape = getTargetShape(op); @@ -580,12 +650,32 @@ struct UnrollUpdateOffsetOp : public UnrollPattern { TypedValue<::mlir::VectorType> offsetVec = op.getOffsets(); VectorType offsetVecTy = offsetVec.getType(); - SmallVector convertedOffsetTypes = - getUnrolledTypes(offsetVecTy, *targetShape); - SmallVector convertedOffsetVec = - pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter); - + SmallVector convertedOffsetTypes; + SmallVector convertedOffsetVec; SmallVector newOps; + int64_t originalChunkSize = tdescTy.getChunkSize(); + if (originalChunkSize > 1) { + SmallVector shape1D(targetShape->begin(), + targetShape->end() - 1); + convertedOffsetTypes = getUnrolledTypes(offsetVecTy, shape1D); + SmallVector convertedOffsetVec1D = + pack(offsetVec, convertedOffsetTypes, shape1D, loc, rewriter); + + int64_t blockedChunkSize = targetShape->back(); + int64_t numNewChunks = originalChunkSize / blockedChunkSize; + + for (auto offset : convertedOffsetVec1D) { + for (int64_t i = 0; i < numNewChunks; ++i) { + convertedOffsetVec.push_back(offset); + } + } + + } else { + convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape); + convertedOffsetVec = + pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter); + } + for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) { auto newOp = rewriter.create(loc, t.getType(), t, o); diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir index 52ec3b856da49..41414d802f212 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir @@ -2,7 +2,7 @@ gpu.module @test { - // CHECK-LABEL: test_create_nd_tdesc + // CHECK-LABEL: create_nd_tdesc // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32> // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> // CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast @@ -10,31 +10,31 @@ gpu.module @test { // CHECK-SAME: !xegpu.tensor_desc<8x16xf32>, !xegpu.tensor_desc<8x16xf32>, // CHECK-SAME: !xegpu.tensor_desc<8x16xf32>, !xegpu.tensor_desc<8x16xf32> // CHECK-SAME: to !xegpu.tensor_desc<24x32xf32, #xegpu.layout> {__xegpu_blocking_tile_shape__ = array, __xegpu_blocking_unpack__} - gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> { + gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> { %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> gpu.return %tdesc : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> } //----- - // CHECK-LABEL: test_create_nd_tdesc_1d + // CHECK-LABEL: create_nd_tdesc_1d // CHECK-SAME: [[arg0:%.+]]: memref<64xf32> // CHECK-COUNT-2: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32> // CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast // CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32> // CHECK-SAME: to !xegpu.tensor_desc<32xf32, #xegpu.layout> {__xegpu_blocking_tile_shape__ = array, __xegpu_blocking_unpack__} - gpu.func @test_create_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout> { + gpu.func @create_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout> { %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.layout> gpu.return %tdesc : !xegpu.tensor_desc<32xf32, #xegpu.layout> } //----- - // CHECK-LABEL: test_update_nd_tdesc + // CHECK-LABEL: update_nd_tdesc // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32> // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> // CHECK-COUNT-6: [[update:%.+]] = xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8x16xf32> - gpu.func @test_update_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> { + gpu.func @update_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> { %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> %update = xegpu.update_nd_offset %tdesc, [0, 16] : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> gpu.return %update : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> @@ -42,11 +42,11 @@ gpu.module @test { //----- - // CHECK-LABEL: test_update_nd_tdesc_1d + // CHECK-LABEL: update_nd_tdesc_1d // CHECK-SAME: [[arg0:%.+]]: memref<64xf32> // CHECK-COUNT-2: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32> // CHECK-COUNT-2: [[update:%.+]] = xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<16xf32> - gpu.func @test_update_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout> { + gpu.func @update_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout> { %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.layout> %update = xegpu.update_nd_offset %tdesc, [32] : !xegpu.tensor_desc<32xf32, #xegpu.layout> gpu.return %update : !xegpu.tensor_desc<32xf32, #xegpu.layout> @@ -54,11 +54,11 @@ gpu.module @test { //----- - // CHECK-LABEL: test_prefetch_nd_tdesc + // CHECK-LABEL: prefetch_nd_tdesc // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32> // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> // CHECK-COUNT-6: xegpu.prefetch_nd {{.*}} : !xegpu.tensor_desc<8x16xf32> - gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) { + gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) { %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> xegpu.prefetch_nd %tdesc : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> gpu.return @@ -66,23 +66,23 @@ gpu.module @test { //----- - // CHECK-LABEL: test_prefetch_nd_tdesc_1d + // CHECK-LABEL: prefetch_nd_tdesc_1d // CHECK-SAME: [[arg0:%.+]]: memref<64xf32> // CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32> // CHECK-COUNT-4: xegpu.prefetch_nd {{.*}} : !xegpu.tensor_desc<16xf32> - gpu.func @test_prefetch_nd_tdesc_1d(%src: memref<64xf32>) { + gpu.func @prefetch_nd_tdesc_1d(%src: memref<64xf32>) { %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout> xegpu.prefetch_nd %tdesc : !xegpu.tensor_desc<64xf32, #xegpu.layout> gpu.return } //----- - // CHECK-LABEL: test_load_nd + // CHECK-LABEL: load_nd // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32> // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> // CHECK-COUNT-6: [[ld:%.+]] = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> // CHECK-COUNT-6: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<24x32xf32> - gpu.func @test_load_nd(%src: memref<24x32xf32>) -> vector<24x32xf32> { + gpu.func @load_nd(%src: memref<24x32xf32>) -> vector<24x32xf32> { %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> %ld = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout> -> vector<24x32xf32> gpu.return %ld : vector<24x32xf32> @@ -90,12 +90,12 @@ gpu.module @test { //----- - // CHECK-LABEL: test_load_nd_1d + // CHECK-LABEL: load_nd_1d // CHECK-SAME: [[arg0:%.+]]: memref<64xf32> // CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32> // CHECK-COUNT-4: [[ld:%.+]] = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<16xf32> -> vector<16xf32> // CHECK-COUNT-4: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<16xf32> into vector<64xf32> - gpu.func @test_load_nd_1d(%src: memref<64xf32>) -> vector<64xf32> { + gpu.func @load_nd_1d(%src: memref<64xf32>) -> vector<64xf32> { %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout> %data = xegpu.load_nd %tdesc: !xegpu.tensor_desc<64xf32, #xegpu.layout> -> vector<64xf32> gpu.return %data : vector<64xf32> @@ -103,11 +103,11 @@ gpu.module @test { //----- - // CHECK-LABEL: test_store_nd + // CHECK-LABEL: store_nd // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32> // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> // CHECK-COUNT-6: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - gpu.func @test_store_nd(%src: memref<24x32xf32>) { + gpu.func @store_nd(%src: memref<24x32xf32>) { %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> %data = arith.constant dense<9.0> : vector<24x32xf32> xegpu.store_nd %data, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout> @@ -116,11 +116,11 @@ gpu.module @test { //----- - // CHECK-LABEL: test_store_nd_1d + // CHECK-LABEL: store_nd_1d // CHECK-SAME: [[arg0:%.+]]: memref<64xf32> // CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32> // CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32> - gpu.func @test_store_nd_1d(%src: memref<64xf32>) { + gpu.func @store_nd_1d(%src: memref<64xf32>) { %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout> %data = arith.constant dense<9.0> : vector<64xf32> xegpu.store_nd %data, %tdesc: vector<64xf32>, !xegpu.tensor_desc<64xf32, #xegpu.layout> @@ -129,7 +129,7 @@ gpu.module @test { //----- - // CHECK-LABEL: test_createNd_loadNd_storeNd + // CHECK-LABEL: createNd_loadNd_storeNd // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32> //CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> //CHECK-COUNT-6: [[data:%.+]] = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> @@ -137,7 +137,7 @@ gpu.module @test { //CHECK: [[add:%.+]] = arith.addf {{.*}} : vector<24x32xf32> //CHECK-COUNT-6: [[extract:%.+]] = vector.extract_strided_slice {{.*}} : vector<24x32xf32> to vector<8x16xf32> //CHECK-COUNT-6: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - gpu.func @test_createNd_loadNd_storeNd(%src: memref<24x32xf32>) { + gpu.func @createNd_loadNd_storeNd(%src: memref<24x32xf32>) { %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> %data = arith.constant dense<9.0> : vector<24x32xf32> %ld = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout> -> vector<24x32xf32> @@ -148,23 +148,23 @@ gpu.module @test { //----- - // CHECK-LABEL: test_dpas + // CHECK-LABEL: dpas // CHECK-SAME: [[arg0:%.+]]: vector<32x32xf16>, [[arg1:%.+]]: vector<32x32xf16> //CHECK-COUNT-8: [[extract1:%.+]] = vector.extract_strided_slice [[arg0]] {{.*}} : vector<32x32xf16> to vector<8x16xf16> //CHECK-COUNT-4: [[extract2:%.+]] = vector.extract_strided_slice [[arg1]] {{.*}} : vector<32x32xf16> to vector<16x16xf16> //CHECK-COUNT-16: [[dpas:%.+]] = xegpu.dpas {{.*}} -> vector<8x16xf32> //CHECK-COUNT-8: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<32x32xf32> - gpu.func @test_dpas(%a: vector<32x32xf16>, %b: vector<32x32xf16>) -> vector<32x32xf32> { + gpu.func @dpas(%a: vector<32x32xf16>, %b: vector<32x32xf16>) -> vector<32x32xf32> { %c = xegpu.dpas %a, %b : vector<32x32xf16>, vector<32x32xf16> -> vector<32x32xf32> gpu.return %c : vector<32x32xf32> } //----- - // CHECK-LABEL: test_create_tdesc_vec + // CHECK-LABEL: create_tdesc_vec // CHECK-SAME: [[arg0:%.+]]: ui64 // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> - gpu.func @test_create_tdesc_vec(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout> { + gpu.func @create_tdesc_vec(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout> { %cst = arith.constant dense<[ 0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, @@ -177,10 +177,10 @@ gpu.module @test { //----- - // CHECK-LABEL: test_create_tdesc_step + // CHECK-LABEL: create_tdesc_step // CHECK-SAME: [[arg0:%.+]]: ui64 // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> - gpu.func @test_create_tdesc_step(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout> { + gpu.func @create_tdesc_step(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout> { %step = arith.constant dense<8> : vector<32xindex> %seq = vector.step : vector<32xindex> %cst = arith.muli %seq, %step : vector<32xindex> @@ -190,11 +190,11 @@ gpu.module @test { //----- - // CHECK-LABEL: test_load + // CHECK-LABEL: load // CHECK-SAME: [[arg0:%.+]]: ui64 // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> // CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32> - gpu.func @test_load(%src: ui64) -> vector<32xf32> { + gpu.func @load(%src: ui64) -> vector<32xf32> { %cst = arith.constant dense<[ 0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, @@ -212,11 +212,11 @@ gpu.module @test { //----- - // CHECK-LABEL: test_prefetch + // CHECK-LABEL: prefetch // CHECK-SAME: [[arg0:%.+]]: ui64 // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> - gpu.func @test_prefetch(%src: ui64) { + gpu.func @prefetch(%src: ui64) { %cst = arith.constant dense<[ 0, 8, 16, 24, 32, 40, 48, 56, @@ -233,11 +233,11 @@ gpu.module @test { //----- - // CHECK-LABEL: test_store + // CHECK-LABEL: store // CHECK-SAME: [[arg0:%.+]]: ui64 // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> // CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> - gpu.func @test_store(%src: ui64) { + gpu.func @store(%src: ui64) { %cst = arith.constant dense<[ 0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, @@ -256,47 +256,129 @@ gpu.module @test { } //----- + // CHECK-LABEL: create_tdesc_step_chunk + // CHECK-SAME: [[arg0:%.+]]: ui64 + // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x4xf32, #xegpu.scatter_tdesc_attr> + gpu.func @create_tdesc_step_chunk(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> { + %step = arith.constant dense<8> : vector<32xindex> + %seq = vector.step : vector<32xindex> + %cst = arith.muli %seq, %step : vector<32xindex> + %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> + gpu.return %tdesc : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> + } - // CHECK-LABEL: test_prefetch_load_store_update +//----- + // CHECK-LABEL: create_tdesc_step_chunk2 // CHECK-SAME: [[arg0:%.+]]: ui64 - // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> - // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> - // CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex> - // CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32> - // CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> + // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr> + gpu.func @create_tdesc_step_chunk2(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> { + %step = arith.constant dense<8> : vector<32xindex> + %seq = vector.step : vector<32xindex> + %cst = arith.muli %seq, %step : vector<32xindex> + %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> + gpu.return %tdesc : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> + } - gpu.func @test_prefetch_load_store_update(%src: ui64) { +// CHECK-LABEL: create_tdesc_step_chunk3 + // CHECK-SAME: [[arg0:%.+]]: ui64 + // CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr> + // CHECK: arith.addi %{{.*}}, %{{.*}} : vector<16xindex> + // CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr> + // CHECK: arith.addi %{{.*}}, %{{.*}} : vector<16xindex> + // CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr> + // CHECK: arith.addi %{{.*}}, %{{.*}} : vector<16xindex> + // CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr> + gpu.func @create_tdesc_step_chunk3(%src: ui64) -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> { + %step = arith.constant dense<8> : vector<16xindex> + %seq = vector.step : vector<16xindex> + %cst = arith.muli %seq, %step : vector<16xindex> + %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> + gpu.return %tdesc : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> + } + +//----- + // CHECK-LABEL: load_chunk + // CHECK-SAME: [[arg0:%.+]]: ui64 + // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr> + // CHECK-COUNT-4: xegpu.load {{.*}} <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}> : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<2x16xf32> + gpu.func @load_chunk(%src: ui64) -> vector<4x32xf32> { %cst = arith.constant dense<[ - 0, 8, 16, 24, 32, 40, 48, 56, - 64, 72, 80, 88, 96, 104, 112, 120, - 128, 136, 144, 152, 160, 168, 176, 184, - 192, 200, 208, 216, 224, 232, 240, 248 + 0, 8, 16, 24, 32, 40, 48, 56, + 64, 72, 80, 88, 96, 104, 112, 120, + 128, 136, 144, 152, 160, 168, 176, 184, + 192, 200, 208, 216, 224, 232, 240, 248 ]> : vector<32xindex> + + %c17 = arith.constant 17: index + %mask = vector.create_mask %c17: vector<32xi1> - %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout> - xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout> - - %delta = arith.constant dense<[ - 32, 32, 32, 32, 32, 32, 32, 32, - 32, 32, 32, 32, 32, 32, 32, 64, - 128, 128, 128, 128, 128, 128, 128, 128, - 128, 128, 128, 128, 128, 128, 128, 256 + %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> + %ld = xegpu.load %tdesc, %mask <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}>: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout>, vector<32xi1> -> vector<4x32xf32> + + gpu.return %ld : vector<4x32xf32> + } + +//----- + // CHECK-LABEL: store_chunk + // CHECK-SAME: [[arg0:%.+]]: ui64 + // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr> + // CHECK-COUNT-4: xegpu.store {{.*}} <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}> : vector<2x16xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> + gpu.func @store_chunk(%src: ui64) { + %cst = arith.constant dense<[ + 0, 8, 16, 24, 32, 40, 48, 56, + 64, 72, 80, 88, 96, 104, 112, 120, + 128, 136, 144, 152, 160, 168, 176, 184, + 192, 200, 208, 216, 224, 232, 240, 248 ]> : vector<32xindex> - %new_tdesc = xegpu.update_offset %tdesc, %delta - : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout>, vector<32xindex> - + %c17 = arith.constant 17: index %mask = vector.create_mask %c17: vector<32xi1> - %ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout>, vector<32xi1> -> vector<32xf32> + %st_vec = arith.constant dense<1023.>: vector<4x32xf32> + %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> + xegpu.store %st_vec, %tdesc, %mask <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}>: vector<4x32xf32>, !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout>, vector<32xi1> + + gpu.return + } - %st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32> - xegpu.store %st_vec, %tdesc, %mask: - vector<32xf32>, - !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout>, - vector<32xi1> - +//----- + // CHECK-LABEL: prefetch_chunk + // CHECK-SAME: [[arg0:%.+]]: ui64 + // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr> + // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr> + gpu.func @prefetch_chunk(%src: ui64) { + %cst = arith.constant dense<[ + 0, 8, 16, 24, 32, 40, 48, 56, + 64, 72, 80, 88, 96, 104, 112, 120, + 128, 136, 144, 152, 160, 168, 176, 184, + 192, 200, 208, 216, 224, 232, 240, 248 + ]> : vector<32xindex> + %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> + xegpu.prefetch %tdesc: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> + gpu.return } + +//----- + // CHECK-LABEL: update_chunk + // CHECK-SAME: [[arg0:%.+]]: ui64 + // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr> + // CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr>, vector<16xindex> + gpu.func @update_chunk(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> { + %cst = arith.constant dense<[ + 0, 8, 16, 24, 32, 40, 48, 56, + 64, 72, 80, 88, 96, 104, 112, 120, + 128, 136, 144, 152, 160, 168, 176, 184, + 192, 200, 208, 216, 224, 232, 240, 248 + ]> : vector<32xindex> + %delta = arith.constant dense<32>: vector<32xindex> + %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> + + %new_tdesc = xegpu.update_offset %tdesc, %delta + : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout>, vector<32xindex> + + gpu.return %new_tdesc : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> + } } + diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp index 57aaecbd7962f..4400d6d9625f7 100644 --- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp +++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp @@ -19,6 +19,10 @@ using namespace mlir::xegpu; namespace { +#define DEBUG_TYPE "test-xegpu-unroll" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + struct TestXeGPUUnrollingPatterns : public PassWrapper> { @@ -48,7 +52,9 @@ struct TestXeGPUUnrollingPatterns options.setNativeShapeFn( [&](Operation *op) -> std::optional> { if (isa(op)) { + xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp, + xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp, + xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) { xegpu::TensorDescType tdescTy; if (auto createNdOp = dyn_cast(op)) { tdescTy = createNdOp.getType(); @@ -61,20 +67,7 @@ struct TestXeGPUUnrollingPatterns tdescTy = loadNdOp.getTensorDescType(); } else if (auto storeNdOp = dyn_cast(op)) { tdescTy = storeNdOp.getTensorDescType(); - } - - if (auto layout = tdescTy.getLayoutAttr()) { - auto inst_data = layout.getInstData(); - if (inst_data && layout.isSgLayout()) - return SmallVector(inst_data.asArrayRef().begin(), - inst_data.asArrayRef().end()); - } - } - - if (isa(op)) { - xegpu::TensorDescType tdescTy; - if (auto createOp = dyn_cast(op)) { + } else if (auto createOp = dyn_cast(op)) { tdescTy = createOp.getType(); } else if (auto updateOp = dyn_cast(op)) { tdescTy = updateOp.getTensorDescType(); @@ -111,14 +104,40 @@ struct TestXeGPUUnrollingPatterns Attribute encoding = tdescTy.getEncoding(); auto layout = llvm::dyn_cast_if_present( tdescTy.getLayout()); + + // If the encoding is a ScatterTensorDescAttr, we need to + // potentially adjust the chunk size based on the inst_data. + if (encoding && mlir::isa(encoding)) { + auto scatterAttr = + mlir::dyn_cast(encoding); + int64_t chunkSize = scatterAttr.getChunkSize().getInt(); + + if (chunkSize > 1) { + int64_t blockedChunkSize = chunkSize; + auto instData = layout.getInstData(); + if (!instData.empty()) + blockedChunkSize = instData.asArrayRef().back(); + + auto chunkSizeAttr = mlir::IntegerAttr::get( + mlir::IntegerType::get(ctx, 64), blockedChunkSize); + + // To create a new attribute with a different chunk_size: + auto newEncoding = xegpu::ScatterTensorDescAttr::get( + ctx, scatterAttr.getMemorySpace(), chunkSizeAttr); + + encoding = newEncoding; + } + } if (layout) { if (layout.getLaneLayout() == nullptr) layout = xegpu::LayoutAttr(); else layout = layout.dropInstData(); } + newTy = xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding, layout); + } else { newTy = type.clone(tileShape, elemTy); }