Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 47 additions & 54 deletions mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,23 +183,28 @@ static void adjustStridesForPermutation(AffineMap permMap,
// Computes memory strides for vector transfer operations, handling both
// static and dynamic memrefs while applying permutation transformations
// for XeGPU lowering.
static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
PatternRewriter &rewriter) {
static std::pair<SmallVector<Value>, Value>
computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A function that calls memref.extract_strided_metadata now also returns memref's offset together with the strides

SmallVector<Value> strides;
Value baseMemref = xferOp.getBase();
AffineMap permMap = xferOp.getPermutationMap();
MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());

Location loc = xferOp.getLoc();
Value offsetVal = nullptr;
if (memrefType.hasStaticShape()) {
int64_t offset;
SmallVector<int64_t> intStrides;
if (failed(memrefType.getStridesAndOffset(intStrides, offset)))
return {};
return {{}, offsetVal};
// Wrap static strides as MLIR values
for (int64_t s : intStrides)
strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
} else {
if (!ShapedType::isDynamic(offset))
offsetVal = arith::ConstantIndexOp::create(rewriter, loc, offset);
}

if (strides.empty() || !offsetVal) {
// For dynamic shape memref, use memref.extract_strided_metadata to get
// stride values
unsigned rank = memrefType.getRank();
Expand All @@ -220,11 +225,16 @@ static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,

auto meta = memref::ExtractStridedMetadataOp::create(
rewriter, loc, resultTypes, baseMemref);
strides.append(meta.getStrides().begin(), meta.getStrides().end());

if (strides.empty())
strides.append(meta.getStrides().begin(), meta.getStrides().end());

if (!offsetVal)
offsetVal = meta.getOffset();
}
// Adjust strides according to the permutation map (e.g., for transpose)
adjustStridesForPermutation(permMap, strides);
return strides;
return {strides, offsetVal};
}

// This function compute the vectors of localOffsets for scattered load/stores.
Expand Down Expand Up @@ -256,8 +266,8 @@ static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
// %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
// %offsets = orig_offset + local_offsets
static Value computeOffsets(VectorTransferOpInterface xferOp,
PatternRewriter &rewriter,
ArrayRef<Value> strides) {
PatternRewriter &rewriter, ArrayRef<Value> strides,
Value baseOffset) {
Comment on lines 268 to +270
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

computeOffsets now takes a baseOffset obtained from memref.extract_strided_metadata in computeMemrefMeta (ex computeStrides)

Location loc = xferOp.getLoc();
VectorType vectorType = xferOp.getVectorType();
SmallVector<Value> indices(xferOp.getIndices().begin(),
Expand Down Expand Up @@ -315,51 +325,30 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);

// Compute base offset from transfer read indices
Value baseOffset = nullptr;
if (!indices.empty()) {
Comment on lines -318 to -319
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

made the branch unconditional since we always want to consider the baseOffset

baseOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
for (size_t i = 0; i < indices.size(); ++i) {
Value strideVal = strides[i];
Value offsetContrib =
arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
baseOffset =
arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
}
// Broadcast base offset to match vector shape
Value bcastBase = vector::BroadcastOp::create(
rewriter, loc, fullIndexVectorType, baseOffset);
localOffsets =
arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
for (size_t i = 0; i < indices.size(); ++i) {
Value strideVal = strides[i];
Value offsetContrib =
arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
baseOffset =
arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
}
// Broadcast base offset to match vector shape
Value bcastBase = vector::BroadcastOp::create(
rewriter, loc, fullIndexVectorType, baseOffset);
localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
return localOffsets;
}

// Collapse memref shape to 1D
static Value collapseMemrefTo1D(VectorTransferOpInterface xferOp,
PatternRewriter &rewriter) {
static Value memrefToIndexPtr(VectorTransferOpInterface xferOp,
PatternRewriter &rewriter) {
Comment on lines +343 to +344
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

memref.collapse_shape -> memref.extract_aligned_pointer_as_index + arith.index_cast index -> i64

Location loc = xferOp.getLoc();

Value baseMemref = xferOp.getBase();
MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
Type elementType = memrefType.getElementType();

// Compute the total number of elements in the memref
MemRefType flatMemrefType;
if (memrefType.hasStaticShape()) {
auto totalElements = memrefType.getNumElements();
flatMemrefType = MemRefType::get({totalElements}, elementType);
} else {
flatMemrefType = MemRefType::get({ShapedType::kDynamic}, elementType);
}

SmallVector<ReassociationIndices> reassociation;
ReassociationIndices allDims =
llvm::to_vector(llvm::seq<int64_t>(0, memrefType.getRank()));
reassociation.push_back(allDims);

auto collapseOp = memref::CollapseShapeOp::create(
rewriter, loc, flatMemrefType, baseMemref, reassociation);
return collapseOp;
auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
rewriter, loc, xferOp.getBase())
.getResult();
return arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(),
indexPtr)
.getResult();
}

static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
Expand All @@ -372,13 +361,14 @@ static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
if (!memrefType)
return rewriter.notifyMatchFailure(readOp, "Expected memref source");

SmallVector<Value> strides = computeStrides(readOp, rewriter);
if (strides.empty())
auto meta = computeMemrefMeta(readOp, rewriter);
if (meta.first.empty())
return rewriter.notifyMatchFailure(readOp, "Failed to compute strides");

Value localOffsets = computeOffsets(readOp, rewriter, strides);
Value localOffsets =
computeOffsets(readOp, rewriter, meta.first, meta.second);

Value flatMemref = collapseMemrefTo1D(readOp, rewriter);
Value flatMemref = memrefToIndexPtr(readOp, rewriter);

Value mask = vector::ConstantMaskOp::create(
rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
Expand All @@ -405,11 +395,14 @@ static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
if (!memrefType)
return rewriter.notifyMatchFailure(writeOp, "Expected memref source");

SmallVector<Value> strides = computeStrides(writeOp, rewriter);
auto meta = computeMemrefMeta(writeOp, rewriter);
if (meta.first.empty())
return rewriter.notifyMatchFailure(writeOp, "Failed to compute strides");

Value localOffsets = computeOffsets(writeOp, rewriter, strides);
Value localOffsets =
computeOffsets(writeOp, rewriter, meta.first, meta.second);

Value flatMemref = collapseMemrefTo1D(writeOp, rewriter);
Value flatMemref = memrefToIndexPtr(writeOp, rewriter);

Value mask = vector::ConstantMaskOp::create(
rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
Expand Down
77 changes: 61 additions & 16 deletions mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32>

}

Expand Down Expand Up @@ -62,8 +63,9 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}}: vector<8x16xindex>
// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>

}

Expand Down Expand Up @@ -124,8 +126,9 @@ gpu.func @load_transposed(%source: memref<32x64xf32>,
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf32> into memref<2048xf32>
// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<32x64xf32> -> index
// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>

}

Expand Down Expand Up @@ -164,8 +167,9 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER: %[[BROADIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
// LOAD-GATHER: %[[FINALIDX:.+]] = arith.addi %[[BROADIDX]], {{.*}} : vector<8x16xindex>
// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32>
// LOAD-GATHER: %[[RES:.+]] = xegpu.load %[[COLLAPSE]][%[[FINALIDX]]], %[[CST]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<?x?x?xf32> -> index
// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
// LOAD-GATHER: %[[RES:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[FINALIDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
// LOAD-GATHER: gpu.return %[[RES]] : vector<8x16xf32>
}

Expand Down Expand Up @@ -195,8 +199,9 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER-DAG: %[[BCASTIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
// LOAD-GATHER-DAG: %[[OFFSETS:.+]] = arith.addi %[[BCASTIDX]], {{.*}} : vector<8x16xindex>
// LOAD-GATHER-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<?x8x16xf32> into memref<?xf32>
// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
// LOAD-GATHER-DAG: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %arg0 : memref<?x8x16xf32> -> index
// LOAD-GATHER-DAG: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>

}

Expand Down Expand Up @@ -224,8 +229,9 @@ gpu.func @load_dynamic_source3(%source: memref<?x?x?x?x?xf32>,
// LOAD-GATHER-COUNT3: arith.addi {{.*}} : vector<2x4x8x16xindex>
// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<2x4x8x16xindex>
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<2x4x8x16xindex>
// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2, 3, 4]{{\]}} : memref<?x?x?x?x?xf32> into memref<?xf32>
// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<?xf32>, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32>
// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x?x?x?x?xf32> -> index
// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32>
// LOAD-GATHER: return %[[VEC]]
}

Expand Down Expand Up @@ -254,8 +260,9 @@ gpu.func @load_high_dim_vector(%source: memref<16x32x64xf32>,
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex>
// LOAD-GATHER: %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex>
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex>
// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<16x32x64xf32> into memref<32768xf32>
// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<32768xf32>, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32>
// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %arg0 : memref<16x32x64xf32> -> index
// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32>

}

Expand Down Expand Up @@ -283,8 +290,9 @@ gpu.func @load_transpose_f16(%source: memref<32x64xf16>,
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf16> into memref<2048xf16>
// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf16>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<32x64xf16> -> index
// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
}

// -----
Expand Down Expand Up @@ -396,3 +404,40 @@ gpu.func @no_load_unsupported_map(%source: memref<16x32x64xf32>,
// LOAD-GATHER: vector.transfer_read
}

// -----
gpu.module @xevm_module {
gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2: index) -> vector<8xf16> {
%c0 = arith.constant 0.0 : f16
%subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
%0 = vector.transfer_read %subview[%off2, %off2], %c0
{in_bounds = [true]} : memref<256x256xf16, strided<[4096, 1], offset: ?>>, vector<8xf16>
gpu.return %0 : vector<8xf16>
}

// LOAD-ND-LABEL: @load_from_subview(
// LOAD-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
// LOAD-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// LOAD-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
// LOAD-ND-SAME: %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]]
// LOAD-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
// LOAD-ND-SAME: boundary_check = false
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf16>
// LOAD-ND: return %[[VEC]]

// LOAD-GATHER-LABEL: @load_from_subview(
// LOAD-GATHER-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
// LOAD-GATHER-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// LOAD-GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
// LOAD-GATHER: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
// LOAD-GATHER: %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
// LOAD-GATHER: %[[STEP:.+]] = vector.step : vector<8xindex>
// LOAD-GATHER: arith.muli {{.*}} : index
// LOAD-GATHER: arith.addi %[[OFFSET]]{{.*}} : index
// LOAD-GATHER: arith.addi {{.*}} : index
// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index
// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf16>
}
Loading