diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp index 819c2e5973ffd..852c322cc6467 100644 --- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp +++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp @@ -180,26 +180,31 @@ static void adjustStridesForPermutation(AffineMap permMap, strides = applyPermutation(strides, perms64); } -// Computes memory strides for vector transfer operations, handling both -// static and dynamic memrefs while applying permutation transformations -// for XeGPU lowering. -static SmallVector computeStrides(VectorTransferOpInterface xferOp, - PatternRewriter &rewriter) { +// Computes memory strides and a memref offset for vector transfer operations, +// handling both static and dynamic memrefs while applying permutation +// transformations for XeGPU lowering. +static std::pair, Value> +computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) { SmallVector strides; Value baseMemref = xferOp.getBase(); AffineMap permMap = xferOp.getPermutationMap(); MemRefType memrefType = dyn_cast(baseMemref.getType()); Location loc = xferOp.getLoc(); + Value offsetVal = nullptr; if (memrefType.hasStaticShape()) { int64_t offset; SmallVector intStrides; if (failed(memrefType.getStridesAndOffset(intStrides, offset))) - return {}; + return {{}, offsetVal}; // Wrap static strides as MLIR values for (int64_t s : intStrides) strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s)); - } else { + if (!ShapedType::isDynamic(offset)) + offsetVal = arith::ConstantIndexOp::create(rewriter, loc, offset); + } + + if (strides.empty() || !offsetVal) { // For dynamic shape memref, use memref.extract_strided_metadata to get // stride values unsigned rank = memrefType.getRank(); @@ -220,11 +225,16 @@ static SmallVector computeStrides(VectorTransferOpInterface xferOp, auto meta = memref::ExtractStridedMetadataOp::create( rewriter, loc, resultTypes, baseMemref); - strides.append(meta.getStrides().begin(), meta.getStrides().end()); + + if (strides.empty()) + strides.append(meta.getStrides().begin(), meta.getStrides().end()); + + if (!offsetVal) + offsetVal = meta.getOffset(); } // Adjust strides according to the permutation map (e.g., for transpose) adjustStridesForPermutation(permMap, strides); - return strides; + return {strides, offsetVal}; } // This function compute the vectors of localOffsets for scattered load/stores. @@ -254,10 +264,10 @@ static SmallVector computeStrides(VectorTransferOpInterface xferOp, // %23 = arith.add %20, %21 // %local_offsets = arith.add %22, %23 // %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map -// %offsets = orig_offset + local_offsets +// %offsets = memref_offset + orig_offset + local_offsets static Value computeOffsets(VectorTransferOpInterface xferOp, - PatternRewriter &rewriter, - ArrayRef strides) { + PatternRewriter &rewriter, ArrayRef strides, + Value baseOffset) { Location loc = xferOp.getLoc(); VectorType vectorType = xferOp.getVectorType(); SmallVector indices(xferOp.getIndices().begin(), @@ -315,51 +325,30 @@ static Value computeOffsets(VectorTransferOpInterface xferOp, arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]); // Compute base offset from transfer read indices - Value baseOffset = nullptr; - if (!indices.empty()) { - baseOffset = arith::ConstantIndexOp::create(rewriter, loc, 0); - for (size_t i = 0; i < indices.size(); ++i) { - Value strideVal = strides[i]; - Value offsetContrib = - arith::MulIOp::create(rewriter, loc, indices[i], strideVal); - baseOffset = - arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib); - } - // Broadcast base offset to match vector shape - Value bcastBase = vector::BroadcastOp::create( - rewriter, loc, fullIndexVectorType, baseOffset); - localOffsets = - arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets); + for (size_t i = 0; i < indices.size(); ++i) { + Value strideVal = strides[i]; + Value offsetContrib = + arith::MulIOp::create(rewriter, loc, indices[i], strideVal); + baseOffset = + arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib); } + // Broadcast base offset to match vector shape + Value bcastBase = vector::BroadcastOp::create( + rewriter, loc, fullIndexVectorType, baseOffset); + localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets); return localOffsets; } -// Collapse memref shape to 1D -static Value collapseMemrefTo1D(VectorTransferOpInterface xferOp, - PatternRewriter &rewriter) { +// Convert memref to i64 base pointer +static Value memrefToIndexPtr(VectorTransferOpInterface xferOp, + PatternRewriter &rewriter) { Location loc = xferOp.getLoc(); - - Value baseMemref = xferOp.getBase(); - MemRefType memrefType = dyn_cast(baseMemref.getType()); - Type elementType = memrefType.getElementType(); - - // Compute the total number of elements in the memref - MemRefType flatMemrefType; - if (memrefType.hasStaticShape()) { - auto totalElements = memrefType.getNumElements(); - flatMemrefType = MemRefType::get({totalElements}, elementType); - } else { - flatMemrefType = MemRefType::get({ShapedType::kDynamic}, elementType); - } - - SmallVector reassociation; - ReassociationIndices allDims = - llvm::to_vector(llvm::seq(0, memrefType.getRank())); - reassociation.push_back(allDims); - - auto collapseOp = memref::CollapseShapeOp::create( - rewriter, loc, flatMemrefType, baseMemref, reassociation); - return collapseOp; + auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create( + rewriter, loc, xferOp.getBase()) + .getResult(); + return arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(), + indexPtr) + .getResult(); } static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp, @@ -372,13 +361,14 @@ static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp, if (!memrefType) return rewriter.notifyMatchFailure(readOp, "Expected memref source"); - SmallVector strides = computeStrides(readOp, rewriter); - if (strides.empty()) + auto meta = computeMemrefMeta(readOp, rewriter); + if (meta.first.empty()) return rewriter.notifyMatchFailure(readOp, "Failed to compute strides"); - Value localOffsets = computeOffsets(readOp, rewriter, strides); + Value localOffsets = + computeOffsets(readOp, rewriter, meta.first, meta.second); - Value flatMemref = collapseMemrefTo1D(readOp, rewriter); + Value flatMemref = memrefToIndexPtr(readOp, rewriter); Value mask = vector::ConstantMaskOp::create( rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()), @@ -405,11 +395,14 @@ static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp, if (!memrefType) return rewriter.notifyMatchFailure(writeOp, "Expected memref source"); - SmallVector strides = computeStrides(writeOp, rewriter); + auto meta = computeMemrefMeta(writeOp, rewriter); + if (meta.first.empty()) + return rewriter.notifyMatchFailure(writeOp, "Failed to compute strides"); - Value localOffsets = computeOffsets(writeOp, rewriter, strides); + Value localOffsets = + computeOffsets(writeOp, rewriter, meta.first, meta.second); - Value flatMemref = collapseMemrefTo1D(writeOp, rewriter); + Value flatMemref = memrefToIndexPtr(writeOp, rewriter); Value mask = vector::ConstantMaskOp::create( rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()), diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir index b373bdab80567..c4ca79af1bd9a 100644 --- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir @@ -27,8 +27,9 @@ gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index // LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex> // LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex> -// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32> -// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32> +// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index +// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32> } @@ -62,8 +63,9 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>, // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index // LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex> // LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}}: vector<8x16xindex> -// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32> -// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> +// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index +// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> } @@ -124,8 +126,9 @@ gpu.func @load_transposed(%source: memref<32x64xf32>, // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index // LOAD-GATHER: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex> // LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex> -// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf32> into memref<2048xf32> -// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> +// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<32x64xf32> -> index +// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> } @@ -164,8 +167,9 @@ gpu.func @load_dynamic_source(%source: memref, // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index // LOAD-GATHER: %[[BROADIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex> // LOAD-GATHER: %[[FINALIDX:.+]] = arith.addi %[[BROADIDX]], {{.*}} : vector<8x16xindex> -// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]{{\]}} : memref into memref -// LOAD-GATHER: %[[RES:.+]] = xegpu.load %[[COLLAPSE]][%[[FINALIDX]]], %[[CST]] : memref, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> +// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref -> index +// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// LOAD-GATHER: %[[RES:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[FINALIDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> // LOAD-GATHER: gpu.return %[[RES]] : vector<8x16xf32> } @@ -195,8 +199,9 @@ gpu.func @load_dynamic_source2(%source: memref, // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index // LOAD-GATHER-DAG: %[[BCASTIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex> // LOAD-GATHER-DAG: %[[OFFSETS:.+]] = arith.addi %[[BCASTIDX]], {{.*}} : vector<8x16xindex> -// LOAD-GATHER-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref into memref -// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : memref, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> +// LOAD-GATHER-DAG: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %arg0 : memref -> index +// LOAD-GATHER-DAG: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> } @@ -224,8 +229,9 @@ gpu.func @load_dynamic_source3(%source: memref, // LOAD-GATHER-COUNT3: arith.addi {{.*}} : vector<2x4x8x16xindex> // LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<2x4x8x16xindex> // LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<2x4x8x16xindex> -// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2, 3, 4]{{\]}} : memref into memref -// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32> +// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref -> index +// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32> // LOAD-GATHER: return %[[VEC]] } @@ -254,8 +260,9 @@ gpu.func @load_high_dim_vector(%source: memref<16x32x64xf32>, // LOAD-GATHER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex> // LOAD-GATHER: %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex> // LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex> -// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<16x32x64xf32> into memref<32768xf32> -// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<32768xf32>, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32> +// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %arg0 : memref<16x32x64xf32> -> index +// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32> } @@ -283,8 +290,9 @@ gpu.func @load_transpose_f16(%source: memref<32x64xf16>, // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index // LOAD-GATHER: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex> // LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex> -// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf16> into memref<2048xf16> -// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf16>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16> +// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<32x64xf16> -> index +// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16> } // ----- @@ -396,3 +404,40 @@ gpu.func @no_load_unsupported_map(%source: memref<16x32x64xf32>, // LOAD-GATHER: vector.transfer_read } +// ----- +gpu.module @xevm_module { +gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2: index) -> vector<8xf16> { + %c0 = arith.constant 0.0 : f16 + %subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> + %0 = vector.transfer_read %subview[%off2, %off2], %c0 + {in_bounds = [true]} : memref<256x256xf16, strided<[4096, 1], offset: ?>>, vector<8xf16> + gpu.return %0 : vector<8xf16> +} + +// LOAD-ND-LABEL: @load_from_subview( +// LOAD-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>, +// LOAD-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index +// LOAD-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> +// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc +// LOAD-ND-SAME: %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]] +// LOAD-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16, +// LOAD-ND-SAME: boundary_check = false +// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf16> +// LOAD-ND: return %[[VEC]] + +// LOAD-GATHER-LABEL: @load_from_subview( +// LOAD-GATHER-SAME: %[[SRC:.+]]: memref<4096x4096xf16>, +// LOAD-GATHER-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index +// LOAD-GATHER: %[[CST:.+]] = arith.constant dense : vector<8xi1> +// LOAD-GATHER: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> +// LOAD-GATHER: %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref, index, index, index, index, index +// LOAD-GATHER: %[[STEP:.+]] = vector.step : vector<8xindex> +// LOAD-GATHER: arith.muli {{.*}} : index +// LOAD-GATHER: arith.addi %[[OFFSET]]{{.*}} : index +// LOAD-GATHER: arith.addi {{.*}} : index +// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex> +// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex> +// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index +// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf16> +} diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir index b3f761a545ee1..fcfc9414da4f6 100644 --- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir @@ -30,8 +30,9 @@ gpu.func @store_1D_vector(%vec: vector<8xf32>, // STORE-SCATTER-COUNT2: arith.addi {{.*}} : index // STORE-SCATTER-DAG: %[[BCAST:.+]] = vector.broadcast {{.*}} : index to vector<8xindex> // STORE-SCATTER-DAG: %[[IDX:.+]] = arith.addi %[[BCAST]], %{{.*}} : vector<8xindex> -// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32> -// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8xf32>, memref<4096xf32>, vector<8xindex>, vector<8xi1> +// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index +// STORE-SCATTER-DAG: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8xf32>, i64, vector<8xindex>, vector<8xi1> } // ----- @@ -64,8 +65,9 @@ gpu.func @store_2D_vector(%vec: vector<8x16xf32>, // STORE-SCATTER-COUNT2: vector.broadcast {{.*}} : vector<8x16xindex> // STORE-SCATTER-DAG: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex> // STORE-SCATTER-DAG: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}} : vector<8x16xindex> -// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32> -// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1> +// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index +// STORE-SCATTER-DAG: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, i64, vector<8x16xindex>, vector<8x16xi1> } // ----- @@ -104,8 +106,9 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>, // STORE-SCATTER-COUNT2: vector.broadcast {{.*}} : vector<8x16xindex> // STORE-SCATTER-DAG: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex> // STORE-SCATTER-DAG: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}} : vector<8x16xindex> -// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref into memref -// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, memref, vector<8x16xindex>, vector<8x16xi1> +// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref -> index +// STORE-SCATTER-DAG: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, i64, vector<8x16xindex>, vector<8x16xi1> } // ----- @@ -155,8 +158,9 @@ gpu.func @no_store_transposed(%vec: vector<8x16xf32>, // STORE-SCATTER-COUNT2: vector.broadcast {{.*}} : vector<8x16xindex> // STORE-SCATTER-DAG: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex> // STORE-SCATTER-DAG: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}} : vector<8x16xindex> -// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1]{{\]}} : memref<32x64xf32> into memref<2048xf32> -// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, memref<2048xf32>, vector<8x16xindex>, vector<8x16xi1> +// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<32x64xf32> -> index +// STORE-SCATTER-DAG: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, i64, vector<8x16xindex>, vector<8x16xi1> } // ----- @@ -186,8 +190,9 @@ gpu.func @store_high_dim_vector(%vec: vector<8x16x32xf32>, // STORE-SCATTER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex> // STORE-SCATTER: %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex> // STORE-SCATTER: %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex> -// STORE-SCATTER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<16x32x64xf32> into memref<32768xf32> -// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE]][%[[IDX]]], %[[CST]] : vector<8x16x32xf32>, memref<32768xf32>, vector<8x16x32xindex>, vector<8x16x32xi1> +// STORE-SCATTER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<16x32x64xf32> -> index +// STORE-SCATTER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : vector<8x16x32xf32>, i64, vector<8x16x32xindex>, vector<8x16x32xi1> } // ----- @@ -275,4 +280,49 @@ gpu.func @no_store_out_of_bounds_1D_vector(%vec: vector<8xf32>, // STORE-SCATTER-LABEL: @no_store_out_of_bounds_1D_vector( // STORE-SCATTER: vector.transfer_write -} \ No newline at end of file +} + +// ----- +gpu.module @xevm_module { +gpu.func @store_to_subview(%vec: vector<8xf16>, + %source: memref<4096x4096xf16>, %off1: index, %off2: index) { + %subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1] + : memref<4096x4096xf16> + to memref<256x256xf16, strided<[4096, 1], offset: ?>> + vector.transfer_write %vec, %subview[%off2, %off2] + {in_bounds = [true]} + : vector<8xf16>, memref<256x256xf16, strided<[4096, 1], offset: ?>> + gpu.return +} +// STORE-ND-LABEL: @store_to_subview( +// STORE-ND-SAME: %[[VEC:.+]]: vector<8xf16>, +// STORE-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>, +// STORE-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index +// STORE-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] +// STORE-ND-SAME: : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> +// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc +// STORE-ND-SAME: %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]] +// STORE-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16, +// STORE-ND-SAME: boundary_check = false +// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf16> + +// STORE-SCATTER-LABEL: @store_to_subview( +// STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8xf16>, +// STORE-SCATTER-SAME: %[[SRC:.+]]: memref<4096x4096xf16>, +// STORE-SCATTER-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index +// STORE-SCATTER: %[[CST:.+]] = arith.constant dense : vector<8xi1> +// STORE-SCATTER: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] +// STORE-SCATTER-SAME: : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> +// STORE-SCATTER: %[[BB:.+]], %[[OFFSET:.+]], {{.*}}, {{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] +// STORE-SCATTER-SAME: : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref, index, index, index, index, index +// STORE-SCATTER: %[[STEP:.+]] = vector.step : vector<8xindex> +// STORE-SCATTER: arith.muli {{.*}} : index +// STORE-SCATTER: arith.addi %[[OFFSET]]{{.*}} : index +// STORE-SCATTER: arith.addi {{.*}} : index +// STORE-SCATTER: %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<8xindex> +// STORE-SCATTER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex> +// STORE-SCATTER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] +// STORE-SCATTER-SAME: : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index +// STORE-SCATTER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8xf16>, i64, vector<8xindex>, vector<8xi1> +}