Skip to content

Commit 2a38c2c

Browse files
committed
collapse memref shape to 2d
Signed-off-by: dchigarev <[email protected]>
1 parent babf57e commit 2a38c2c

File tree

6 files changed

+179
-98
lines changed

6 files changed

+179
-98
lines changed

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 98 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,63 @@ static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
358358
.getResult();
359359
}
360360

361+
// Collapses shapes of a nD memref to the target rank while applying offsets for
362+
// the collapsed dimensions. Returns the new memref value and the remaining
363+
// offsets for the last targetRank dimensions. For example:
364+
// input: %memref = memref<2x4x8x32xf32>, offsets=[%i0, %i1, %i2, %i3],
365+
// targetRank=2 output: %memref[%i0, %i1, 0, 0] -> memref<8x32xf32>, returned
366+
// offsets: [%i2, %i3]
367+
static std::pair<Value, SmallVector<OpFoldResult>>
368+
convertMemrefAndOffsetsToTargetRank(PatternRewriter &rewriter, Location loc,
369+
Value memref,
370+
SmallVector<OpFoldResult> offsets,
371+
int64_t targetRank) {
372+
auto memrefType = cast<MemRefType>(memref.getType());
373+
unsigned rank = memrefType.getRank();
374+
375+
if (rank <= targetRank)
376+
return {memref, offsets};
377+
378+
int64_t numCombinedDims = rank - targetRank;
379+
SmallVector<OpFoldResult> subviewOffsets;
380+
SmallVector<OpFoldResult> subviewSizes;
381+
SmallVector<OpFoldResult> subviewStrides;
382+
383+
// For the combined dimensions: use the provided offsets, size=1, stride=1
384+
for (unsigned i = 0; i < numCombinedDims; ++i) {
385+
subviewOffsets.push_back(offsets[i]);
386+
subviewSizes.push_back(rewriter.getI64IntegerAttr(1));
387+
subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
388+
}
389+
390+
// For the last targetRank dimensions: offset=0, use full size, stride=1
391+
SmallVector<int64_t> resultShape;
392+
auto originalShape = memrefType.getShape();
393+
auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, memref);
394+
for (unsigned i = numCombinedDims; i < rank; ++i) {
395+
subviewOffsets.push_back(rewriter.getI64IntegerAttr(0));
396+
if (ShapedType::isDynamic(originalShape[i])) {
397+
subviewSizes.push_back(meta.getSizes()[i]);
398+
resultShape.push_back(ShapedType::kDynamic);
399+
} else {
400+
subviewSizes.push_back(rewriter.getI64IntegerAttr(originalShape[i]));
401+
resultShape.push_back(originalShape[i]);
402+
}
403+
subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
404+
}
405+
406+
auto resultType = memref::SubViewOp::inferRankReducedResultType(
407+
resultShape, memrefType, subviewOffsets, subviewSizes, subviewStrides);
408+
auto subviewOp =
409+
memref::SubViewOp::create(rewriter, loc, resultType, memref,
410+
subviewOffsets, subviewSizes, subviewStrides);
411+
412+
// Return the remaining offsets for the last targetRank dimensions
413+
SmallVector<OpFoldResult> newOffsets(offsets.begin() + numCombinedDims,
414+
offsets.end());
415+
return {subviewOp.getResult(), newOffsets};
416+
}
417+
361418
template <
362419
typename OpType,
363420
typename = std::enable_if_t<llvm::is_one_of<
@@ -493,17 +550,18 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
493550
!isTransposeLoad ? nullptr
494551
: DenseI64ArrayAttr::get(rewriter.getContext(),
495552
ArrayRef<int64_t>{1, 0});
553+
auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
554+
rewriter, loc, readOp.getBase(), getAsOpFoldResult(readOp.getIndices()),
555+
vecTy.getRank());
496556
// By default, no specific caching policy is assigned.
497557
xegpu::CachePolicyAttr hint = nullptr;
498-
xegpu::CreateNdDescOp ndDesc =
499-
createNdDescriptor(rewriter, loc, descType,
500-
dyn_cast<TypedValue<MemRefType>>(readOp.getBase()));
501-
502-
auto loadOp = xegpu::LoadNdOp::create(
503-
rewriter, loc, vecTy, ndDesc, getAsOpFoldResult(readOp.getIndices()),
504-
/*packed=*/nullptr, transposeAttr,
505-
/*l1_hint=*/hint,
506-
/*l2_hint=*/hint, /*l3_hint=*/hint);
558+
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
559+
rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
560+
561+
auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
562+
/*packed=*/nullptr, transposeAttr,
563+
/*l1_hint=*/hint,
564+
/*l2_hint=*/hint, /*l3_hint=*/hint);
507565
rewriter.replaceOp(readOp, loadOp);
508566

509567
return success();
@@ -541,21 +599,23 @@ struct TransferWriteLowering
541599
if (!map.isMinorIdentity())
542600
return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
543601

602+
auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
603+
rewriter, loc, writeOp.getBase(),
604+
getAsOpFoldResult(writeOp.getIndices()), vecTy.getRank());
605+
544606
auto descType = xegpu::TensorDescType::get(
545607
vecTy.getShape(), vecTy.getElementType(),
546608
/*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
547609
xegpu::MemorySpace::Global);
548610
// By default, no specific caching policy is assigned.
549611
xegpu::CachePolicyAttr hint = nullptr;
550-
xegpu::CreateNdDescOp ndDesc =
551-
createNdDescriptor(rewriter, loc, descType,
552-
dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()));
612+
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
613+
rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
553614

554-
auto storeOp =
555-
xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
556-
getAsOpFoldResult(writeOp.getIndices()),
557-
/*l1_hint=*/hint,
558-
/*l2_hint=*/hint, /*l3_hint=*/hint);
615+
auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(),
616+
ndDesc, indices,
617+
/*l1_hint=*/hint,
618+
/*l2_hint=*/hint, /*l3_hint=*/hint);
559619
rewriter.replaceOp(writeOp, storeOp);
560620

561621
return success();
@@ -643,17 +703,21 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
643703
// By default, no specific caching policy is assigned.
644704
xegpu::CachePolicyAttr hint = nullptr;
645705

706+
auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
707+
rewriter, loc, loadOp.getBase(), getAsOpFoldResult(loadOp.getIndices()),
708+
vecTy.getRank());
709+
646710
auto descType = xegpu::TensorDescType::get(
647711
vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
648712
boundaryCheck, xegpu::MemorySpace::Global);
649713

650-
xegpu::CreateNdDescOp ndDesc =
651-
createNdDescriptor(rewriter, loc, descType, loadOp.getBase());
652-
auto loadNdOp = xegpu::LoadNdOp::create(
653-
rewriter, loc, vecTy, ndDesc, getAsOpFoldResult(loadOp.getIndices()),
654-
/*packed=*/nullptr, /*transpose=*/nullptr,
655-
/*l1_hint=*/hint,
656-
/*l2_hint=*/hint, /*l3_hint=*/hint);
714+
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
715+
rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
716+
auto loadNdOp =
717+
xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
718+
/*packed=*/nullptr, /*transpose=*/nullptr,
719+
/*l1_hint=*/hint,
720+
/*l2_hint=*/hint, /*l3_hint=*/hint);
657721
rewriter.replaceOp(loadOp, loadNdOp);
658722

659723
return success();
@@ -675,19 +739,23 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
675739
// Boundary check is available only for block instructions.
676740
bool boundaryCheck = vecTy.getRank() > 1;
677741

742+
auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
743+
rewriter, loc, storeOp.getBase(),
744+
getAsOpFoldResult(storeOp.getIndices()), vecTy.getRank());
745+
678746
auto descType = xegpu::TensorDescType::get(
679747
vecTy.getShape(), vecTy.getElementType(),
680748
/*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
681749

682750
// By default, no specific caching policy is assigned.
683751
xegpu::CachePolicyAttr hint = nullptr;
684-
xegpu::CreateNdDescOp ndDesc =
685-
createNdDescriptor(rewriter, loc, descType, storeOp.getBase());
752+
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
753+
rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
686754

687-
auto storeNdOp = xegpu::StoreNdOp::create(
688-
rewriter, loc, vector, ndDesc, getAsOpFoldResult(storeOp.getIndices()),
689-
/*l1_hint=*/hint,
690-
/*l2_hint=*/hint, /*l3_hint=*/hint);
755+
auto storeNdOp =
756+
xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, indices,
757+
/*l1_hint=*/hint,
758+
/*l2_hint=*/hint, /*l3_hint=*/hint);
691759

692760
rewriter.replaceOp(storeOp, storeNdOp);
693761

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
258258
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
259259

260260
// if shape and strides are from Memref, we don't need attributes for them
261-
// to keep the IR print clean.
262-
if (staticShape == memrefShape && staticStrides == memrefStrides) {
261+
// to keep the IR print clean (only do so for full-static case, otherwise
262+
// printer would fail trying to print empty array-attr).
263+
if (staticShape == memrefShape && staticStrides == memrefStrides &&
264+
dynamicShape.empty() && dynamicStrides.empty()) {
263265
staticShapeAttr = DenseI64ArrayAttr();
264266
staticStridesAttr = DenseI64ArrayAttr();
265267
}
@@ -320,8 +322,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
320322
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
321323

322324
// if shape and strides are from Memref, we don't need attributes for them
323-
// to keep the IR print clean.
324-
if (staticShape == memrefShape && staticStrides == memrefStrides) {
325+
// to keep the IR print clean (only do so for full-static case, otherwise
326+
// printer would fail trying to print empty array-attr).
327+
if (staticShape == memrefShape && staticStrides == memrefStrides &&
328+
dynamicShape.empty() && dynamicStrides.empty()) {
325329
staticShapeAttr = DenseI64ArrayAttr();
326330
staticStridesAttr = DenseI64ArrayAttr();
327331
}

mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@ func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vecto
99
// CHECK-LABEL: @load_1D_vector(
1010
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
1111
// CHECK-SAME: %[[OFFSET:.+]]: index
12+
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
1213
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
13-
// CHECK-SAME: %[[SRC]]
14-
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
14+
// CHECK-SAME: %[[COLLAPSED]]
15+
// CHECK-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
1516
// CHECK-SAME: boundary_check = false
16-
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8xf32>
17+
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]]]{{.*}}-> vector<8xf32>
1718
// CHECK: return %[[VEC]]
1819

1920
// -----
@@ -28,29 +29,29 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>,
2829
// CHECK-LABEL: @load_2D_vector(
2930
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
3031
// CHECK-SAME: %[[OFFSET:.+]]: index
32+
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
3133
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
32-
// CHECK-SAME: %[[SRC]]
33-
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
34-
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
34+
// CHECK-SAME: %[[COLLAPSED]]
35+
// CHECK-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
36+
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
3537
// CHECK: return %[[VEC]]
3638

3739
// -----
3840

3941
func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
40-
%offset: index) -> vector<8x16xf32> {
41-
%0 = vector.load %source[%offset, %offset, %offset]
42+
%i: index, %j: index, %k: index) -> vector<8x16xf32> {
43+
%0 = vector.load %source[%i, %j, %k]
4244
: memref<?x?x?xf32>, vector<8x16xf32>
4345
return %0 : vector<8x16xf32>
4446
}
4547

4648
// CHECK-LABEL: @load_dynamic_source(
4749
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
48-
// CHECK-SAME: %[[OFFSET:.+]]: index
49-
// CHECK: {{.*}} %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
50-
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
51-
// CHECK-SAME: , shape : [%[[SIZES]]#0, %[[SIZES]]#1, %[[SIZES]]#2], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, 1]
52-
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
53-
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
50+
// CHECK-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
51+
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
52+
// CHECK: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
53+
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
54+
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32>
5455
// CHECK: return %[[VEC]]
5556

5657
// -----

mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@ func.func @store_1D_vector(%vec: vector<8xf32>,
1111
// CHECK-SAME: %[[VEC:.+]]: vector<8xf32>,
1212
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
1313
// CHECK-SAME: %[[OFFSET:.+]]: index
14+
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
1415
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
15-
// CHECK-SAME: %[[SRC]]
16-
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
16+
// CHECK-SAME: %[[COLLAPSED]]
17+
// CHECK-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
1718
// CHECK-SAME: boundary_check = false
18-
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8xf32>
19+
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32>
1920

2021
// -----
2122

@@ -30,29 +31,29 @@ func.func @store_2D_vector(%vec: vector<8x16xf32>,
3031
// CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>,
3132
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
3233
// CHECK-SAME: %[[OFFSET:.+]]: index
34+
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
3335
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
34-
// CHECK-SAME: %[[SRC]]
35-
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
36-
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
36+
// CHECK-SAME: %[[COLLAPSED]]
37+
// CHECK-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
38+
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
3739

3840
// -----
3941

4042
func.func @store_dynamic_source(%vec: vector<8x16xf32>,
41-
%source: memref<?x?x?xf32>, %offset: index) {
42-
vector.store %vec, %source[%offset, %offset, %offset]
43+
%source: memref<?x?x?xf32>, %i: index, %j: index, %k: index) {
44+
vector.store %vec, %source[%i, %j, %k]
4345
: memref<?x?x?xf32>, vector<8x16xf32>
4446
return
4547
}
4648

4749
// CHECK-LABEL: @store_dynamic_source(
4850
// CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>,
4951
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
50-
// CHECK-SAME: %[[OFFSET:.+]]: index
51-
// CHECK: {{.*}} %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
52-
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
53-
// CHECK-SAME: , shape : [%[[SIZES]]#0, %[[SIZES]]#1, %[[SIZES]]#2], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, 1]
54-
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
55-
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
52+
// CHECK-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
53+
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
54+
// CHECK: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
55+
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
56+
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32>
5657

5758
// -----
5859

0 commit comments

Comments
 (0)