Skip to content

Commit 747050b

Browse files
authored
[MLIR][XeGPU][VectorToXeGPU] Lower vector.load/store/transfer_read/transfer_write to new offsets syntax (#162095)
Changes the `VectorToXeGPU` pass to generate `xegpu.load_nd/store_nd` ops using new syntax with where offsets are specified at the load/store ops level. ```mlir // from this %desc = xegpu.create_nd_tdesc %src[%off1, %off2]: memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> %res = xegpu.load_nd %desc : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> // to this %desc = xegpu.create_nd_tdesc %src: memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> %res = xegpu.load_nd %desc[%off1, %off2] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> ``` In order to support cases with dimension reduction at the `create_nd_tdesc` level (e.g. `memref<8x8x16xf16> -> tensor_desc<8x16xf16>` it was decided to insert a memref.subview that collapses the source shape to 2d, for example: ```mlir // input: %0 = vector.load %source[%off0, %off1, %off2] : memref<8x16x32xf32>, vector<8x16xf32> // --vector-to-xegpu (old) %tdesc = xegpu.create_nd_tdesc %source[%off0, %off1, %off2] : memref<8x16x32xf32> -> tdesc<8x32xf32> %vec = xegpu.load_nd %tdesc // --vector-to-xegpu (new) %collapsed = memref.subview %source[%off0, 0, 0] [1, 16, 32] [1, 1, 1] : memref<8x16x32xf32> -> memref<16x32xf32, strided<[32, 1], offset: ?>> %tdesc = xegpu.create_nd_tdesc %collapsed : memref<16x32xf32, ...> -> tdesc<8x32xf32> %vec = xegpu.load_nd %tdesc[%off1, %off2] ``` <details><summary>Why we need to change that?</summary> ```mlir // reduce dim and apply all 3 offsets at load_nd %desc = xegpu.create_nd_tdesc %source : memref<8x16x32xf32> -> !xegpu.tensor_desc<16x32xf32> // error: xegpu.load_nd len(offsets) != desc.rank %res = xegpu.load_nd %desc[%off, %off, %off] : !xegpu.tensor_desc<16x32xf32> -> vector<8x16xf32> ``` </details> --------- Signed-off-by: dchigarev <[email protected]>
1 parent 0307147 commit 747050b

File tree

6 files changed

+196
-172
lines changed

6 files changed

+196
-172
lines changed

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 105 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -97,57 +97,23 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
9797
return success();
9898
}
9999

100-
static xegpu::CreateNdDescOp
101-
createNdDescriptor(PatternRewriter &rewriter, Location loc,
102-
xegpu::TensorDescType descType, TypedValue<MemRefType> src,
103-
Operation::operand_range offsets) {
100+
static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
101+
Location loc,
102+
xegpu::TensorDescType descType,
103+
TypedValue<MemRefType> src) {
104104
MemRefType srcTy = src.getType();
105105
auto [strides, offset] = srcTy.getStridesAndOffset();
106106

107107
xegpu::CreateNdDescOp ndDesc;
108108
if (srcTy.hasStaticShape()) {
109-
ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
110-
getAsOpFoldResult(offsets));
109+
ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
111110
} else {
112111
// In case of any dynamic shapes, source's shape and strides have to be
113112
// explicitly provided.
114-
SmallVector<Value> sourceDims;
115-
unsigned srcRank = srcTy.getRank();
116-
for (unsigned i = 0; i < srcRank; ++i)
117-
sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));
118-
119-
SmallVector<int64_t> constOffsets;
120-
SmallVector<Value> dynOffsets;
121-
for (Value offset : offsets) {
122-
std::optional<int64_t> staticVal = getConstantIntValue(offset);
123-
if (!staticVal)
124-
dynOffsets.push_back(offset);
125-
constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic));
126-
}
127-
128-
SmallVector<Value> dynShapes;
129-
for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
130-
if (shape == ShapedType::kDynamic)
131-
dynShapes.push_back(sourceDims[idx]);
132-
}
133-
134-
// Compute strides in reverse order.
135-
SmallVector<Value> dynStrides;
136-
Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
137-
// Last stride is guaranteed to be static and unit.
138-
for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
139-
accStride =
140-
arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
141-
if (strides[i] == ShapedType::kDynamic)
142-
dynStrides.push_back(accStride);
143-
}
144-
std::reverse(dynStrides.begin(), dynStrides.end());
145-
146-
ndDesc = xegpu::CreateNdDescOp::create(
147-
rewriter, loc, descType, src, dynOffsets, dynShapes, dynStrides,
148-
DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets),
149-
DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()),
150-
DenseI64ArrayAttr::get(rewriter.getContext(), strides));
113+
auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
114+
ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
115+
meta.getConstifiedMixedSizes(),
116+
meta.getConstifiedMixedStrides());
151117
}
152118

153119
return ndDesc;
@@ -392,6 +358,62 @@ static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
392358
.getResult();
393359
}
394360

361+
// Collapses shapes of a nD memref to the target rank while applying offsets for
362+
// the collapsed dimensions. Returns the new memref value and the remaining
363+
// offsets for the last targetRank dimensions. For example:
364+
// input: %memref = memref<2x4x8x32xf32>, offsets=[%i0, %i1, %i2, %i3],
365+
// output: %memref[%i0, %i1, 0, 0] -> memref<8x32xf32>, offsets: [%i2, %i3]
366+
static std::pair<Value, SmallVector<OpFoldResult>>
367+
convertMemrefAndOffsetsToTargetRank(PatternRewriter &rewriter, Location loc,
368+
Value memref,
369+
SmallVector<OpFoldResult> offsets,
370+
int64_t targetRank) {
371+
auto memrefType = cast<MemRefType>(memref.getType());
372+
unsigned rank = memrefType.getRank();
373+
374+
if (rank <= targetRank)
375+
return {memref, offsets};
376+
377+
int64_t numCombinedDims = rank - targetRank;
378+
SmallVector<OpFoldResult> subviewOffsets;
379+
SmallVector<OpFoldResult> subviewSizes;
380+
SmallVector<OpFoldResult> subviewStrides;
381+
382+
// For the combined dimensions: use the provided offsets, size=1, stride=1
383+
for (unsigned i = 0; i < numCombinedDims; ++i) {
384+
subviewOffsets.push_back(offsets[i]);
385+
subviewSizes.push_back(rewriter.getI64IntegerAttr(1));
386+
subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
387+
}
388+
389+
// For the last targetRank dimensions: offset=0, use full size, stride=1
390+
SmallVector<int64_t> resultShape;
391+
auto originalShape = memrefType.getShape();
392+
auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, memref);
393+
for (unsigned i = numCombinedDims; i < rank; ++i) {
394+
subviewOffsets.push_back(rewriter.getI64IntegerAttr(0));
395+
if (ShapedType::isDynamic(originalShape[i])) {
396+
subviewSizes.push_back(meta.getSizes()[i]);
397+
resultShape.push_back(ShapedType::kDynamic);
398+
} else {
399+
subviewSizes.push_back(rewriter.getI64IntegerAttr(originalShape[i]));
400+
resultShape.push_back(originalShape[i]);
401+
}
402+
subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
403+
}
404+
405+
auto resultType = memref::SubViewOp::inferRankReducedResultType(
406+
resultShape, memrefType, subviewOffsets, subviewSizes, subviewStrides);
407+
auto subviewOp =
408+
memref::SubViewOp::create(rewriter, loc, resultType, memref,
409+
subviewOffsets, subviewSizes, subviewStrides);
410+
411+
// Return the remaining offsets for the last targetRank dimensions
412+
SmallVector<OpFoldResult> newOffsets(offsets.begin() + numCombinedDims,
413+
offsets.end());
414+
return {subviewOp.getResult(), newOffsets};
415+
}
416+
395417
template <
396418
typename OpType,
397419
typename = std::enable_if_t<llvm::is_one_of<
@@ -523,18 +545,19 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
523545
descShape, elementType, /*array_length=*/1,
524546
/*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
525547

526-
xegpu::CreateNdDescOp ndDesc =
527-
createNdDescriptor(rewriter, loc, descType,
528-
dyn_cast<TypedValue<MemRefType>>(readOp.getBase()),
529-
readOp.getIndices());
530-
531548
DenseI64ArrayAttr transposeAttr =
532549
!isTransposeLoad ? nullptr
533550
: DenseI64ArrayAttr::get(rewriter.getContext(),
534551
ArrayRef<int64_t>{1, 0});
552+
auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
553+
rewriter, loc, readOp.getBase(), getAsOpFoldResult(readOp.getIndices()),
554+
vecTy.getRank());
535555
// By default, no specific caching policy is assigned.
536556
xegpu::CachePolicyAttr hint = nullptr;
537-
auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
557+
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
558+
rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
559+
560+
auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
538561
/*packed=*/nullptr, transposeAttr,
539562
/*l1_hint=*/hint,
540563
/*l2_hint=*/hint, /*l3_hint=*/hint);
@@ -575,21 +598,23 @@ struct TransferWriteLowering
575598
if (!map.isMinorIdentity())
576599
return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
577600

601+
auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
602+
rewriter, loc, writeOp.getBase(),
603+
getAsOpFoldResult(writeOp.getIndices()), vecTy.getRank());
604+
578605
auto descType = xegpu::TensorDescType::get(
579606
vecTy.getShape(), vecTy.getElementType(),
580607
/*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
581608
xegpu::MemorySpace::Global);
582-
xegpu::CreateNdDescOp ndDesc =
583-
createNdDescriptor(rewriter, loc, descType,
584-
dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()),
585-
writeOp.getIndices());
586-
587609
// By default, no specific caching policy is assigned.
588610
xegpu::CachePolicyAttr hint = nullptr;
589-
auto storeOp =
590-
xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
591-
/*l1_hint=*/hint,
592-
/*l2_hint=*/hint, /*l3_hint=*/hint);
611+
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
612+
rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
613+
614+
auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(),
615+
ndDesc, indices,
616+
/*l1_hint=*/hint,
617+
/*l2_hint=*/hint, /*l3_hint=*/hint);
593618
rewriter.replaceOp(writeOp, storeOp);
594619

595620
return success();
@@ -674,19 +699,24 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
674699

675700
// Boundary check is available only for block instructions.
676701
bool boundaryCheck = vecTy.getRank() > 1;
702+
// By default, no specific caching policy is assigned.
703+
xegpu::CachePolicyAttr hint = nullptr;
704+
705+
auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
706+
rewriter, loc, loadOp.getBase(), getAsOpFoldResult(loadOp.getIndices()),
707+
vecTy.getRank());
677708

678709
auto descType = xegpu::TensorDescType::get(
679710
vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
680711
boundaryCheck, xegpu::MemorySpace::Global);
681-
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
682-
rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
683712

684-
// By default, no specific caching policy is assigned.
685-
xegpu::CachePolicyAttr hint = nullptr;
686-
auto loadNdOp = xegpu::LoadNdOp::create(
687-
rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
688-
/*l1_hint=*/hint,
689-
/*l2_hint=*/hint, /*l3_hint=*/hint);
713+
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
714+
rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
715+
auto loadNdOp =
716+
xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
717+
/*packed=*/nullptr, /*transpose=*/nullptr,
718+
/*l1_hint=*/hint,
719+
/*l2_hint=*/hint, /*l3_hint=*/hint);
690720
rewriter.replaceOp(loadOp, loadNdOp);
691721

692722
return success();
@@ -708,18 +738,24 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
708738
// Boundary check is available only for block instructions.
709739
bool boundaryCheck = vecTy.getRank() > 1;
710740

741+
auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
742+
rewriter, loc, storeOp.getBase(),
743+
getAsOpFoldResult(storeOp.getIndices()), vecTy.getRank());
744+
711745
auto descType = xegpu::TensorDescType::get(
712746
vecTy.getShape(), vecTy.getElementType(),
713747
/*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
714-
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
715-
rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
716748

717749
// By default, no specific caching policy is assigned.
718750
xegpu::CachePolicyAttr hint = nullptr;
751+
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
752+
rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
753+
719754
auto storeNdOp =
720-
xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
755+
xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, indices,
721756
/*l1_hint=*/hint,
722757
/*l2_hint=*/hint, /*l3_hint=*/hint);
758+
723759
rewriter.replaceOp(storeOp, storeNdOp);
724760

725761
return success();

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -280,8 +280,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
280280
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
281281

282282
// if shape and strides are from Memref, we don't need attributes for them
283-
// to keep the IR print clean.
284-
if (staticShape == memrefShape && staticStrides == memrefStrides) {
283+
// to keep the IR print clean (only do so for full-static case, otherwise
284+
// printer would fail trying to print empty array-attr).
285+
if (staticShape == memrefShape && staticStrides == memrefStrides &&
286+
dynamicShape.empty() && dynamicStrides.empty()) {
285287
staticShapeAttr = DenseI64ArrayAttr();
286288
staticStridesAttr = DenseI64ArrayAttr();
287289
}
@@ -342,8 +344,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
342344
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
343345

344346
// if shape and strides are from Memref, we don't need attributes for them
345-
// to keep the IR print clean.
346-
if (staticShape == memrefShape && staticStrides == memrefStrides) {
347+
// to keep the IR print clean (only do so for full-static case, otherwise
348+
// printer would fail trying to print empty array-attr).
349+
if (staticShape == memrefShape && staticStrides == memrefStrides &&
350+
dynamicShape.empty() && dynamicStrides.empty()) {
347351
staticShapeAttr = DenseI64ArrayAttr();
348352
staticStridesAttr = DenseI64ArrayAttr();
349353
}

mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@ func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vecto
99
// CHECK-LABEL: @load_1D_vector(
1010
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
1111
// CHECK-SAME: %[[OFFSET:.+]]: index
12+
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
1213
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
13-
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
14-
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
14+
// CHECK-SAME: %[[COLLAPSED]]
15+
// CHECK-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
1516
// CHECK-SAME: boundary_check = false
16-
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
17+
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]]]{{.*}}-> vector<8xf32>
1718
// CHECK: return %[[VEC]]
1819

1920
// -----
@@ -28,35 +29,29 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>,
2829
// CHECK-LABEL: @load_2D_vector(
2930
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
3031
// CHECK-SAME: %[[OFFSET:.+]]: index
32+
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
3133
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
32-
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
33-
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
34-
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
34+
// CHECK-SAME: %[[COLLAPSED]]
35+
// CHECK-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
36+
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
3537
// CHECK: return %[[VEC]]
3638

3739
// -----
3840

3941
func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
40-
%offset: index) -> vector<8x16xf32> {
41-
%0 = vector.load %source[%offset, %offset, %offset]
42+
%i: index, %j: index, %k: index) -> vector<8x16xf32> {
43+
%0 = vector.load %source[%i, %j, %k]
4244
: memref<?x?x?xf32>, vector<8x16xf32>
4345
return %0 : vector<8x16xf32>
4446
}
4547

4648
// CHECK-LABEL: @load_dynamic_source(
4749
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
48-
// CHECK-SAME: %[[OFFSET:.+]]: index
49-
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
50-
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
51-
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
52-
// CHECK-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
53-
// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
54-
// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
55-
// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
56-
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
57-
// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
58-
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
59-
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
50+
// CHECK-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
51+
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
52+
// CHECK: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
53+
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
54+
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32>
6055
// CHECK: return %[[VEC]]
6156

6257
// -----
@@ -72,9 +67,9 @@ func.func @load_out_of_bounds(%source: memref<7x15xf32>,
7267
// CHECK-SAME: %[[SRC:.+]]: memref<7x15xf32>,
7368
// CHECK-SAME: %[[OFFSET:.+]]: index
7469
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
75-
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
70+
// CHECK-SAME: %[[SRC]]
7671
// CHECK-SAME: memref<7x15xf32> -> !xegpu.tensor_desc<8x16xf32>
77-
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
72+
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
7873
// CHECK: return %[[VEC]]
7974

8075
// -----

0 commit comments

Comments
 (0)