diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp index e2c7d803e5a5e..75822e7d6dec4 100644 --- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp +++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp @@ -97,6 +97,21 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter, return success(); } +// Extract cache hints from the op attributes if available. +static SmallVector getOpCacheHints(Operation *op) { + SmallVector cacheHints{xegpu::CachePolicyAttr{}, + xegpu::CachePolicyAttr{}, + xegpu::CachePolicyAttr{}}; + // get l1, l2, l3 hints from attributes if available. + if (auto l1Attr = op->getAttrOfType("l1_hint")) + cacheHints[0] = l1Attr; + if (auto l2Attr = op->getAttrOfType("l2_hint")) + cacheHints[1] = l2Attr; + if (auto l3Attr = op->getAttrOfType("l3_hint")) + cacheHints[2] = l3Attr; + return cacheHints; +} + static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter, Location loc, xegpu::TensorDescType descType, TypedValue src, @@ -430,12 +445,16 @@ static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp, Value mask = vector::ConstantMaskOp::create( rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()), vectorShape); - auto gatherOp = xegpu::LoadGatherOp::create( - rewriter, loc, vectorType, flatMemref, localOffsets, mask, - /*chunk_size=*/IntegerAttr{}, - /*l1_hint=*/xegpu::CachePolicyAttr{}, - /*l2_hint=*/xegpu::CachePolicyAttr{}, - /*l3_hint=*/xegpu::CachePolicyAttr{}); + SmallVector cacheHints = getOpCacheHints(readOp); + auto gatherOp = xegpu::LoadGatherOp::create(rewriter, loc, vectorType, + flatMemref, localOffsets, mask, + /*chunk_size=*/IntegerAttr{}, + /*l1_hint=*/cacheHints[0], + /*l2_hint=*/cacheHints[1], + /*l3_hint=*/cacheHints[2]); + auto resLayout = xegpu::getDistributeLayoutAttr(readOp.getResult()); + xegpu::setDistributeLayoutAttrs(gatherOp, + [&](Value val) { return resLayout; }); rewriter.replaceOp(readOp, gatherOp.getResult()); return success(); @@ -464,12 +483,16 @@ static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp, Value mask = vector::ConstantMaskOp::create( rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()), vectorShape); - xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref, - localOffsets, mask, - /*chunk_size=*/IntegerAttr{}, - /*l1_hint=*/xegpu::CachePolicyAttr{}, - /*l2_hint=*/xegpu::CachePolicyAttr{}, - /*l3_hint=*/xegpu::CachePolicyAttr{}); + auto cacheHints = getOpCacheHints(writeOp); + auto storeOp = xegpu::StoreScatterOp::create( + rewriter, loc, writeOp.getVector(), flatMemref, localOffsets, mask, + /*chunk_size=*/IntegerAttr{}, + /*l1_hint=*/cacheHints[0], + /*l2_hint=*/cacheHints[1], + /*l3_hint=*/cacheHints[2]); + auto valueLayout = xegpu::getDistributeLayoutAttr(writeOp->getOpOperand(0)); + xegpu::setDistributeLayoutAttrs(storeOp, + [&](Value val) { return valueLayout; }); rewriter.eraseOp(writeOp); return success(); } @@ -519,9 +542,11 @@ struct TransferReadLowering : public OpRewritePattern { SmallVector descShape(vecTy.getShape()); if (isTransposeLoad) std::reverse(descShape.begin(), descShape.end()); - auto descType = xegpu::TensorDescType::get( - descShape, elementType, /*array_length=*/1, - /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global); + auto resLayout = xegpu::getDistributeLayoutAttr(readOp.getResult()); + auto descType = + xegpu::TensorDescType::get(descShape, elementType, /*array_length=*/1, + /*boundary_check=*/isOutOfBounds, + xegpu::MemorySpace::Global, resLayout); xegpu::CreateNdDescOp ndDesc = createNdDescriptor(rewriter, loc, descType, @@ -532,12 +557,12 @@ struct TransferReadLowering : public OpRewritePattern { !isTransposeLoad ? nullptr : DenseI64ArrayAttr::get(rewriter.getContext(), ArrayRef{1, 0}); - // By default, no specific caching policy is assigned. - xegpu::CachePolicyAttr hint = nullptr; + auto cacheHints = getOpCacheHints(readOp); auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, transposeAttr, - /*l1_hint=*/hint, - /*l2_hint=*/hint, /*l3_hint=*/hint); + /*l1_hint=*/cacheHints[0], + /*l2_hint=*/cacheHints[1], + /*l3_hint=*/cacheHints[2]); rewriter.replaceOp(readOp, loadOp); return success(); @@ -575,10 +600,11 @@ struct TransferWriteLowering if (!map.isMinorIdentity()) return rewriter.notifyMatchFailure(writeOp, "Expects identity map"); + auto valLayout = xegpu::getDistributeLayoutAttr(writeOp->getOpOperand(0)); auto descType = xegpu::TensorDescType::get( vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(), - xegpu::MemorySpace::Global); + xegpu::MemorySpace::Global, valLayout); xegpu::CreateNdDescOp ndDesc = createNdDescriptor(rewriter, loc, descType, dyn_cast>(writeOp.getBase()), @@ -586,10 +612,12 @@ struct TransferWriteLowering // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; + auto cacheHints = getOpCacheHints(writeOp); auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc, - /*l1_hint=*/hint, - /*l2_hint=*/hint, /*l3_hint=*/hint); + /*l1_hint=*/cacheHints[0], + /*l2_hint=*/cacheHints[1], + /*l3_hint=*/cacheHints[2]); rewriter.replaceOp(writeOp, storeOp); return success(); @@ -616,16 +644,33 @@ struct GatherLowering : public OpRewritePattern { computeOffsets(rewriter, gatherOp, meta.first, meta.second); Value flatMemref = memrefToIndexPtr(gatherOp, rewriter); + auto layoutRes = xegpu::getDistributeLayoutAttr(gatherOp.getResult()); + auto layoutIndices = + xegpu::getDistributeLayoutAttr(gatherOp.getIndicesMutable()); + auto layoutMask = xegpu::getDistributeLayoutAttr(gatherOp.getMaskMutable()); + auto layoutPassThru = + xegpu::getDistributeLayoutAttr(gatherOp.getPassThruMutable()); + SmallVector cacheHints = + getOpCacheHints(gatherOp); auto xeGatherOp = xegpu::LoadGatherOp::create( rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(), /*chunk_size=*/IntegerAttr{}, - /*l1_hint=*/xegpu::CachePolicyAttr{}, - /*l2_hint=*/xegpu::CachePolicyAttr{}, - /*l3_hint=*/xegpu::CachePolicyAttr{}); + /*l1_hint=*/cacheHints[0], + /*l2_hint=*/cacheHints[1], + /*l3_hint=*/cacheHints[2]); + xegpu::setDistributeLayoutAttr(xeGatherOp->getOpResult(0), layoutRes); + xegpu::setDistributeLayoutAttr(xeGatherOp.getOffsetsMutable()[0], + layoutIndices); + xegpu::setDistributeLayoutAttr(xeGatherOp.getMaskMutable(), layoutMask); auto selectOp = arith::SelectOp::create(rewriter, loc, gatherOp.getMask(), xeGatherOp.getResult(), gatherOp.getPassThru()); + xegpu::setDistributeLayoutAttr(selectOp.getConditionMutable(), layoutMask); + xegpu::setDistributeLayoutAttr(selectOp.getFalseValueMutable(), + layoutPassThru); + xegpu::setDistributeLayoutAttr(selectOp->getOpResult(0), layoutRes); + rewriter.replaceOp(gatherOp, selectOp.getResult()); return success(); } @@ -650,12 +695,25 @@ struct ScatterLowering : public OpRewritePattern { computeOffsets(rewriter, scatterOp, meta.first, meta.second); Value flatMemref = memrefToIndexPtr(scatterOp, rewriter); - xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(), - flatMemref, localOffsets, scatterOp.getMask(), - /*chunk_size=*/IntegerAttr{}, - /*l1_hint=*/xegpu::CachePolicyAttr{}, - /*l2_hint=*/xegpu::CachePolicyAttr{}, - /*l3_hint=*/xegpu::CachePolicyAttr{}); + auto layoutIndices = + xegpu::getDistributeLayoutAttr(scatterOp.getIndicesMutable()); + auto layoutMask = + xegpu::getDistributeLayoutAttr(scatterOp.getMaskMutable()); + auto layoutVal = + xegpu::getDistributeLayoutAttr(scatterOp.getValueToStoreMutable()); + SmallVector cacheHints = + getOpCacheHints(scatterOp); + auto storeOp = xegpu::StoreScatterOp::create( + rewriter, loc, scatterOp.getValueToStore(), flatMemref, localOffsets, + scatterOp.getMask(), + /*chunk_size=*/IntegerAttr{}, + /*l1_hint=*/cacheHints[0], + /*l2_hint=*/cacheHints[1], + /*l3_hint=*/cacheHints[2]); + xegpu::setDistributeLayoutAttr(storeOp.getValueMutable(), layoutVal); + xegpu::setDistributeLayoutAttr(storeOp.getOffsetsMutable()[0], + layoutIndices); + xegpu::setDistributeLayoutAttr(storeOp.getMaskMutable(), layoutMask); rewriter.eraseOp(scatterOp); return success(); } @@ -675,18 +733,20 @@ struct LoadLowering : public OpRewritePattern { // Boundary check is available only for block instructions. bool boundaryCheck = vecTy.getRank() > 1; + auto resLayout = xegpu::getDistributeLayoutAttr(loadOp.getResult()); auto descType = xegpu::TensorDescType::get( vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1, - boundaryCheck, xegpu::MemorySpace::Global); + boundaryCheck, xegpu::MemorySpace::Global, resLayout); xegpu::CreateNdDescOp ndDesc = createNdDescriptor( rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices()); // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; + auto cacheHints = getOpCacheHints(loadOp); auto loadNdOp = xegpu::LoadNdOp::create( rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr, - /*l1_hint=*/hint, - /*l2_hint=*/hint, /*l3_hint=*/hint); + /*l1_hint=*/cacheHints[0], + /*l2_hint=*/cacheHints[1], /*l3_hint=*/cacheHints[2]); rewriter.replaceOp(loadOp, loadNdOp); return success(); @@ -708,18 +768,21 @@ struct StoreLowering : public OpRewritePattern { // Boundary check is available only for block instructions. bool boundaryCheck = vecTy.getRank() > 1; - auto descType = xegpu::TensorDescType::get( - vecTy.getShape(), vecTy.getElementType(), - /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global); + auto valLayout = xegpu::getDistributeLayoutAttr(storeOp->getOpOperand(0)); + auto descType = + xegpu::TensorDescType::get(vecTy.getShape(), vecTy.getElementType(), + /*array_length=*/1, boundaryCheck, + xegpu::MemorySpace::Global, valLayout); xegpu::CreateNdDescOp ndDesc = createNdDescriptor( rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices()); // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; - auto storeNdOp = - xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, - /*l1_hint=*/hint, - /*l2_hint=*/hint, /*l3_hint=*/hint); + auto cacheHints = getOpCacheHints(storeOp); + auto storeNdOp = xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, + /*l1_hint=*/cacheHints[0], + /*l2_hint=*/cacheHints[1], + /*l3_hint=*/cacheHints[2]); rewriter.replaceOp(storeOp, storeNdOp); return success(); diff --git a/mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir index 2a319869a7b06..373955ac26c12 100644 --- a/mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir @@ -249,3 +249,160 @@ gpu.func @non_unit_inner_stride_3D( // CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[V]], %[[PASS]] : vector<8xi1>, vector<8xf32> // CHECK: gpu.return %[[RES]] : vector<8xf32> } + +// ----- + +gpu.module @xevm_module { +// Layouts are only specified for the gather op itself. +gpu.func @load_dynamic_layout_operands(%source: memref, + %off0: index, %off1: index, + %indices: vector<8x16xindex>, %mask: vector<8x16xi1>, + %pass_thru: vector<8x16xf32>) -> vector<8x16xf32> { + %res = vector.gather %source[%off0, %off1][%indices], %mask, + %pass_thru { + layout_result_0 = #xegpu.layout, + layout_operand_3 = #xegpu.layout, + layout_operand_4 = #xegpu.layout, + layout_operand_5 = #xegpu.layout + } : memref, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32> + gpu.return %res : vector<8x16xf32> +} +// CHECK-LABEL: @load_dynamic_layout_operands( +// CHECK: %[[VEC:.+]] = xegpu.load {{[^{]*}} +// CHECK-SAME: {layout_operand_1 = #xegpu.layout, layout_operand_2 = #xegpu.layout, +// CHECK-SAME: layout_result_0 = #xegpu.layout} +// CHECK: %[[RES:.+]] = arith.select {{[^{]*}} +// CHECK-SAME: {{{[^}]*}}layout_operand_0 = #xegpu.layout, +// CHECK-SAME: {{[^}]*}}layout_operand_2 = #xegpu.layout, +// CHECK-SAME: {{[^}]*}}layout_result_0 = #xegpu.layout} : vector<8x16xi1>, vector<8x16xf32> +} + +// ----- + +gpu.module @xevm_module { +gpu.func @load_dynamic_layout_mixed(%source: memref, + %off0: index, %off1: index, %off2: index, + %mask: vector<8x16xi1>) -> vector<8x16xf32> { + %pass_thru = arith.constant {layout_result_0 = #xegpu.layout} dense<0.000000e+00> : vector<8x16xf32> + %cst_1 = arith.constant {layout_result_0 = #xegpu.layout} dense<[[0], [32], [64], [96], [128], [160], [192], [224]]> : vector<8x1xindex> + %cst_2 = arith.constant {layout_result_0 = #xegpu.layout} dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]> : vector<1x16xindex> + %0 = vector.broadcast %cst_1 {layout_result_0 = #xegpu.layout} : vector<8x1xindex> to vector<8x16xindex> + %1 = vector.broadcast %cst_2 {layout_result_0 = #xegpu.layout} : vector<1x16xindex> to vector<8x16xindex> + %2 = arith.addi %0, %1 {layout_result_0 = #xegpu.layout} : vector<8x16xindex> + + %res = vector.gather %source[%off0, %off1, %off2][%2], %mask, + %pass_thru { + layout_result_0 = #xegpu.layout, + layout_operand_5 = #xegpu.layout + } : memref, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32> + %res2 = arith.addf %res, %pass_thru : vector<8x16xf32> + gpu.return %res2 : vector<8x16xf32> +} +// CHECK-LABEL: @load_dynamic_layout_mixed( +// CHECK: %[[VEC:.+]] = xegpu.load {{[^{]*}} +// CHECK-SAME: {{{[^}]*}}layout_operand_2 = #xegpu.layout +// CHECK-SAME: {{[^}]*}}layout_result_0 = #xegpu.layout} +// CHECK: %[[RES:.+]] = arith.select {{[^{]*}} +// CHECK-SAME: {{{[^}]*}}layout_operand_0 = #xegpu.layout, +// CHECK-SAME: {{[^}]*}}layout_result_0 = #xegpu.layout} : vector<8x16xi1>, vector<8x16xf32> +} + + +// ----- + +gpu.module @xevm_module { +gpu.func @load_static_layout_mixed(%source: memref<8x16x32xf32>, + %off0: index, %off1: index, %off2: index, + %mask: vector<8x16xi1>) -> vector<8x16xf32> { + %pass_thru = arith.constant {layout_result_0 = #xegpu.layout} dense<0.000000e+00> : vector<8x16xf32> + %cst_1 = arith.constant {layout_result_0 = #xegpu.layout} dense<[[0], [32], [64], [96], [128], [160], [192], [224]]> : vector<8x1xindex> + %cst_2 = arith.constant {layout_result_0 = #xegpu.layout} dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]> : vector<1x16xindex> + %0 = vector.broadcast %cst_1 {layout_result_0 = #xegpu.layout} : vector<8x1xindex> to vector<8x16xindex> + %1 = vector.broadcast %cst_2 {layout_result_0 = #xegpu.layout} : vector<1x16xindex> to vector<8x16xindex> + %2 = arith.addi %0, %1 {layout_result_0 = #xegpu.layout} : vector<8x16xindex> + + %res = vector.gather %source[%off0, %off1, %off2][%2], %mask, + %pass_thru { + layout_result_0 = #xegpu.layout, + layout_operand_5 = #xegpu.layout + } : memref<8x16x32xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32> + %res2 = arith.addf %res, %pass_thru : vector<8x16xf32> + gpu.return %res2 : vector<8x16xf32> +} +// CHECK-LABEL: @load_static_layout_mixed( +// CHECK: %[[VEC:.+]] = xegpu.load {{[^{]*}} +// CHECK-SAME: {{{[^}]*}}layout_operand_2 = #xegpu.layout +// CHECK-SAME: {{[^}]*}}layout_result_0 = #xegpu.layout} +// CHECK: %[[RES:.+]] = arith.select {{[^{]*}} +// CHECK-SAME: {{{[^}]*}}layout_operand_0 = #xegpu.layout, +// CHECK-SAME: {{[^}]*}}layout_result_0 = #xegpu.layout} : vector<8x16xi1>, vector<8x16xf32> +} + +// ----- + +gpu.module @xevm_module { +gpu.func @load_dynamic_layout_mixed_override(%source: memref, + %off0: index, %off1: index, %off2: index, + %mask: vector<8x16xi1>) -> vector<8x16xf32> { + %pass_thru = arith.constant {layout_result_0 = #xegpu.layout} dense<0.000000e+00> : vector<8x16xf32> + %cst_1 = arith.constant {layout_result_0 = #xegpu.layout} dense<[[0], [32], [64], [96], [128], [160], [192], [224]]> : vector<8x1xindex> + %cst_2 = arith.constant {layout_result_0 = #xegpu.layout} dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]> : vector<1x16xindex> + %0 = vector.broadcast %cst_1 {layout_result_0 = #xegpu.layout} : vector<8x1xindex> to vector<8x16xindex> + %1 = vector.broadcast %cst_2 {layout_result_0 = #xegpu.layout} : vector<1x16xindex> to vector<8x16xindex> + %2 = arith.addi %0, %1 {layout_result_0 = #xegpu.layout} : vector<8x16xindex> + + %res = vector.gather %source[%off0, %off1, %off2][%2], %mask, + %pass_thru { + layout_result_0 = #xegpu.layout, + layout_operand_4 = #xegpu.layout, // overriding %2's layout + layout_operand_5 = #xegpu.layout + } : memref, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32> + %res2 = arith.addf %res, %pass_thru : vector<8x16xf32> + gpu.return %res2 : vector<8x16xf32> +} +// CHECK-LABEL: @load_dynamic_layout_mixed_override( +// CHECK: %[[VEC:.+]] = xegpu.load {{[^{]*}} +// CHECK-SAME: {layout_operand_1 = #xegpu.layout, layout_operand_2 = #xegpu.layout +// CHECK-SAME: {{[^}]*}}layout_result_0 = #xegpu.layout} +// CHECK: %[[RES:.+]] = arith.select {{[^{]*}} +// CHECK-SAME: {{{[^}]*}}layout_operand_0 = #xegpu.layout, +// CHECK-SAME: {{[^}]*}}layout_result_0 = #xegpu.layout} : vector<8x16xi1>, vector<8x16xf32> +} + +// ----- + +gpu.module @xevm_module { +gpu.func @load_with_cache_hints(%source: memref<8x16x32xf32>, + %off1: index, %off2: index, %off3: index, + %indices: vector<8xindex>, %mask: vector<8xi1>, + %pass_thru: vector<8xf32>) -> vector<8xf32> { + %0 = vector.gather %source[%off1, %off2, %off3][%indices], %mask, + %pass_thru { + l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, + l3_hint = #xegpu.cache_hint + } : memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32> into vector<8xf32> + gpu.return %0 : vector<8xf32> +} +// CHECK-LABEL: @load_with_cache_hints( +// CHECK: xegpu.load {{[^<]*}} +// CHECK-SAME: <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> +} + +// ----- + +gpu.module @xevm_module { +gpu.func @load_with_partial_cache_hints(%source: memref<8x16x32xf32>, + %off1: index, %off2: index, %off3: index, + %indices: vector<8xindex>, %mask: vector<8xi1>, + %pass_thru: vector<8xf32>) -> vector<8xf32> { + %0 = vector.gather %source[%off1, %off2, %off3][%indices], %mask, + %pass_thru { + l1_hint = #xegpu.cache_hint, + l3_hint = #xegpu.cache_hint + } : memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32> into vector<8xf32> + gpu.return %0 : vector<8xf32> +} +// CHECK-LABEL: @load_with_partial_cache_hints( +// CHECK: xegpu.load {{[^<]*}} +// CHECK-SAME: <{l1_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> +} diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir index 9908205f07c92..7053f53874d47 100644 --- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir @@ -79,6 +79,35 @@ func.func @load_out_of_bounds(%source: memref<7x15xf32>, // ----- +func.func @load_2D_layout(%source: memref<8x16x32xf32>, + %offset: index) -> vector<8x16xf32> { + %0 = vector.load %source[%offset, %offset, %offset] {layout_result_0 = #xegpu.layout} + : memref<8x16x32xf32>, vector<8x16xf32> + return %0 : vector<8x16xf32> +} + +// CHECK-LABEL: @load_2D_layout( +// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc {{[^:]*}} : +// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout> + +// ----- + +func.func @load_2D_cache_hints(%source: memref<8x16x32xf32>, + %offset: index) -> vector<8x16xf32> { + %0 = vector.load %source[%offset, %offset, %offset] { + l1_hint = #xegpu.cache_hint, + l2_hint = #xegpu.cache_hint, + l3_hint = #xegpu.cache_hint + }: memref<8x16x32xf32>, vector<8x16xf32> + return %0 : vector<8x16xf32> +} + +// CHECK-LABEL: @load_2D_cache_hints( +// CHECK: xegpu.load_nd {{[^<]*}} +// CHECK-SAME: <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> + +// ----- + func.func @no_load_high_dim_vector(%source: memref<16x32x64xf32>, %offset: index) -> vector<8x16x32xf32> { %0 = vector.load %source[%offset, %offset, %offset] diff --git a/mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir index ffd3f170c0fad..a5356bdc839cd 100644 --- a/mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir @@ -204,3 +204,124 @@ gpu.func @scatter_into_subview(%vals: vector<8xf16>, // CHECK: xegpu.store %[[VALS]], %[[BASE_I64]]{{\[}}%[[LIN]]{{\]}}, %[[MASK]] : vector<8xf16>, i64, vector<8xindex>, vector<8xi1> // CHECK: gpu.return } + +// ----- + +gpu.module @xevm_module { +gpu.func @store_dynamic_layout_operands(%vec: vector<8x16xf32>, %source: memref, + %off0: index, %off1: index, + %indices: vector<8x16xindex>, %mask: vector<8x16xi1>) { + vector.scatter %source[%off0, %off1][%indices], %mask, %vec { + layout_operand_3 = #xegpu.layout, + layout_operand_4 = #xegpu.layout, + layout_operand_5 = #xegpu.layout + } : memref, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> + gpu.return +} +// CHECK-LABEL: @store_dynamic_layout_operands( +// CHECK: xegpu.store {{[^{]*}} +// CHECK-SAME: {layout_operand_0 = #xegpu.layout, layout_operand_2 = #xegpu.layout, layout_operand_3 = #xegpu.layout} +} + +// ----- + +gpu.module @xevm_module { +gpu.func @store_dynamic_layout_mixed(%source: memref, + %off0: index, %off1: index, %off2: index, + %mask: vector<8x16xi1>) { + %vec = arith.constant {layout_operand_0 = #xegpu.layout} dense<1.000000e+00> : vector<8x16xf32> + %cst_1 = arith.constant {layout_result_0 = #xegpu.layout} dense<[[0], [32], [64], [96], [128], [160], [192], [224]]> : vector<8x1xindex> + %cst_2 = arith.constant {layout_result_0 = #xegpu.layout} dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]> : vector<1x16xindex> + %0 = vector.broadcast %cst_1 {layout_result_0 = #xegpu.layout} : vector<8x1xindex> to vector<8x16xindex> + %1 = vector.broadcast %cst_2 {layout_result_0 = #xegpu.layout} : vector<1x16xindex> to vector<8x16xindex> + %2 = arith.addi %0, %1 {layout_result_0 = #xegpu.layout} : vector<8x16xindex> + + vector.scatter %source[%off0, %off1, %off2][%2], %mask, %vec { + layout_operand_5 = #xegpu.layout + } : memref, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> + gpu.return +} +// CHECK-LABEL: @store_dynamic_layout_mixed( +// CHECK: xegpu.store {{[^{]*}} +// CHECK-SAME: {{[^}]*}}layout_operand_3 = #xegpu.layout} +} + +// ----- + +gpu.module @xevm_module { +gpu.func @store_static_layout_mixed(%source: memref<8x16x32xf32>, + %off0: index, %off1: index, %off2: index, + %mask: vector<8x16xi1>) { + %vec = arith.constant {layout_operand_0 = #xegpu.layout} dense<1.000000e+00> : vector<8x16xf32> + %cst_1 = arith.constant {layout_result_0 = #xegpu.layout} dense<[[0], [32], [64], [96], [128], [160], [192], [224]]> : vector<8x1xindex> + %cst_2 = arith.constant {layout_result_0 = #xegpu.layout} dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]> : vector<1x16xindex> + %0 = vector.broadcast %cst_1 {layout_result_0 = #xegpu.layout} : vector<8x1xindex> to vector<8x16xindex> + %1 = vector.broadcast %cst_2 {layout_result_0 = #xegpu.layout} : vector<1x16xindex> to vector<8x16xindex> + %2 = arith.addi %0, %1 {layout_result_0 = #xegpu.layout} : vector<8x16xindex> + + vector.scatter %source[%off0, %off1, %off2][%2], %mask, %vec { + layout_operand_5 = #xegpu.layout + } : memref<8x16x32xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> + gpu.return +} +// CHECK-LABEL: @store_static_layout_mixed( +// CHECK: xegpu.store {{[^{]*}} +// CHECK-SAME: {{[^}]*}}layout_operand_3 = #xegpu.layout} +} + +// ----- + +gpu.module @xevm_module { +gpu.func @store_dynamic_layout_mixed_override(%source: memref, + %off0: index, %off1: index, %off2: index, + %mask: vector<8x16xi1>) { + %vec = arith.constant {layout_operand_0 = #xegpu.layout} dense<1.000000e+00> : vector<8x16xf32> + %cst_1 = arith.constant {layout_result_0 = #xegpu.layout} dense<[[0], [32], [64], [96], [128], [160], [192], [224]]> : vector<8x1xindex> + %cst_2 = arith.constant {layout_result_0 = #xegpu.layout} dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]> : vector<1x16xindex> + %0 = vector.broadcast %cst_1 {layout_result_0 = #xegpu.layout} : vector<8x1xindex> to vector<8x16xindex> + %1 = vector.broadcast %cst_2 {layout_result_0 = #xegpu.layout} : vector<1x16xindex> to vector<8x16xindex> + %2 = arith.addi %0, %1 {layout_result_0 = #xegpu.layout} : vector<8x16xindex> + + vector.scatter %source[%off0, %off1, %off2][%2], %mask, %vec { + layout_operand_4 = #xegpu.layout, + layout_operand_5 = #xegpu.layout + } : memref, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> + gpu.return +} +// CHECK-LABEL: @store_dynamic_layout_mixed_override( +// CHECK: xegpu.store {{[^{]*}} +// CHECK-SAME: {{[^}]*}}layout_operand_2 = #xegpu.layout, +// CHECK-SAME: {{[^}]*}}layout_operand_3 = #xegpu.layout} +} + +// ----- + +gpu.module @xevm_module { +gpu.func @store_with_cache_hints(%vec: vector<8xf32>, %source: memref<8x16x32xf32>, + %off1: index, %off2: index, %off3: index, + %indices: vector<8xindex>, %mask: vector<8xi1>) { + vector.scatter %source[%off1, %off2, %off3][%indices], %mask, %vec { + l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint + } : memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32> + gpu.return +} +// CHECK-LABEL: @store_with_cache_hints( +// CHECK: xegpu.store {{[^<]*}} +// CHECK-SAME: <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> +} + +// ----- + +gpu.module @xevm_module { +gpu.func @store_with_partial_cache_hints(%vec: vector<8xf32>, %source: memref<8x16x32xf32>, + %off1: index, %off2: index, %off3: index, + %indices: vector<8xindex>, %mask: vector<8xi1>) { + vector.scatter %source[%off1, %off2, %off3][%indices], %mask, %vec { + l1_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint + } : memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32> + gpu.return +} +// CHECK-LABEL: @store_with_partial_cache_hints( +// CHECK: xegpu.store {{[^<]*}} +// CHECK-SAME: <{l1_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> +} diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir index 2c498dcc2a071..52a90799ea762 100644 --- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir @@ -80,6 +80,36 @@ func.func @store_out_of_bounds(%vec: vector<8x16xf32>, // ----- +func.func @store_2D_layouts(%vec: vector<8x16xf32>, + %source: memref<8x16x32xf32>, %offset: index) { + vector.store %vec, %source[%offset, %offset, %offset] {layout_operand_0 = #xegpu.layout} + : memref<8x16x32xf32>, vector<8x16xf32> + return +} + +// CHECK-LABEL: @store_2D_layouts( +// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc {{[^:]*}} : +// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout> + +// ----- + +func.func @store_2D_cache_hints(%vec: vector<8x16xf32>, + %source: memref<8x16x32xf32>, %offset: index) { + vector.store %vec, %source[%offset, %offset, %offset] { + l1_hint = #xegpu.cache_hint, + l2_hint = #xegpu.cache_hint, + l3_hint = #xegpu.cache_hint + } + : memref<8x16x32xf32>, vector<8x16xf32> + return +} + +// CHECK-LABEL: @store_2D_cache_hints( +// CHECK: xegpu.store_nd {{[^<]*}} +// CHECK-SAME: <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> + +// ----- + func.func @no_store_high_dim_vector(%vec: vector<8x16x32xf32>, %source: memref<16x32x64xf32>, %offset: index) { vector.store %vec, %source[%offset, %offset, %offset] diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir index c4ca79af1bd9a..b84efabf58f38 100644 --- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir @@ -441,3 +441,52 @@ gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2: // LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 // LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf16> } + +// ----- +gpu.module @xevm_module { +gpu.func @load_2D_layouts(%source: memref<8x16x32xf32>, + %offset: index) -> vector<8x16xf32> { + %c0 = arith.constant 0.0 : f32 + %0 = vector.transfer_read %source[%offset, %offset, %offset], %c0 + { + in_bounds = [true, true], + layout_result_0 = #xegpu.layout + } : memref<8x16x32xf32>, vector<8x16xf32> + gpu.return %0 : vector<8x16xf32> +} + +// LOAD-ND-LABEL: @load_2D_layouts( +// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc {{[^:]*}} : +// LOAD-ND-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32, +// LOAD-ND-SAME: #xegpu.block_tdesc_attr, #xegpu.layout> + +// LOAD-GATHER-LABEL: @load_2D_layouts( +// LOAD-GATHER: %[[VEC:.+]] = xegpu.load {{[^{]*}} +// LOAD-GATHER-SAME {layout_operand_0 = #xegpu.layout_attr, +// LOAD-GATHER-SAME layout_operand_1 = #xegpu.layout_attr, +// LOAD-GATHER-SAME layout_result_0 = #xegpu.layout_attr} : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> +} + +// ----- +gpu.module @xevm_module { +gpu.func @load_2D_cache_hints(%source: memref<8x16x32xf32>, + %offset: index) -> vector<8x16xf32> { + %c0 = arith.constant 0.0 : f32 + %0 = vector.transfer_read %source[%offset, %offset, %offset], %c0 + { + in_bounds = [true, true], + l1_hint = #xegpu.cache_hint, + l2_hint = #xegpu.cache_hint, + l3_hint = #xegpu.cache_hint + } : memref<8x16x32xf32>, vector<8x16xf32> + gpu.return %0 : vector<8x16xf32> +} + +// LOAD-ND-LABEL: @load_2D_cache_hints( +// LOAD-ND: xegpu.load_nd {{[^<]*}} +// LOAD-ND-SAME: <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> + +// LOAD-GATHER-LABEL: @load_2D_cache_hints( +// LOAD-GATHER: xegpu.load {{[^<]*}} +// LOAD-GATHER-SAME: <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> +} diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir index fcfc9414da4f6..23048626740a3 100644 --- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir @@ -326,3 +326,53 @@ gpu.func @store_to_subview(%vec: vector<8xf16>, // STORE-SCATTER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 // STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8xf16>, i64, vector<8xindex>, vector<8xi1> } + +// ----- +gpu.module @xevm_module { +gpu.func @store_2D_layout(%vec: vector<8x16xf32>, + %source: memref<8x16x32xf32>, %offset: index) { + vector.transfer_write %vec, %source[%offset, %offset, %offset] + { + in_bounds = [true, true], + layout_operand_0 = #xegpu.layout + } + : vector<8x16xf32>, memref<8x16x32xf32> + gpu.return +} + +// STORE-ND-LABEL: @store_2D_layout( +// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc {{[^:]*}} : +// STORE-ND-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32, +// STORE-ND-SAME: #xegpu.block_tdesc_attr, #xegpu.layout> + +// STORE-SCATTER-LABEL: @store_2D_layout( +// STORE-SCATTER: xegpu.store {{[^{]*}} +// STORE-SCATTER-SAME {layout_operand_0 = #xegpu.layout_attr, +// STORE-SCATTER-SAME layout_operand_1 = #xegpu.layout_attr, +// STORE-SCATTER-SAME layout_result_0 = #xegpu.layout_attr} : vector<8x16xf32>, i64, vector<8x16xindex>, vector<8x16xi1> +} + +// ----- +gpu.module @xevm_module { +gpu.func @store_2D_cache_hints(%vec: vector<8x16xf32>, + %source: memref<8x16x32xf32>, %offset: index) { + vector.transfer_write %vec, %source[%offset, %offset, %offset] + { + in_bounds = [true, true], + l1_hint = #xegpu.cache_hint, + l2_hint = #xegpu.cache_hint, + l3_hint = #xegpu.cache_hint + } + : vector<8x16xf32>, memref<8x16x32xf32> + gpu.return +} + +// STORE-ND-LABEL: @store_2D_cache_hints( +// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc {{[^:]*}} : +// STORE-ND: xegpu.store_nd {{[^<]*}} +// STORE-ND-SAME: <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> + +// STORE-SCATTER-LABEL: @store_2D_cache_hints( +// STORE-SCATTER: xegpu.store {{[^<]*}} +// STORE-SCATTER-SAME: <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> +}