Skip to content

Commit e0cf57b

Browse files
committed
Propagate layouts/cache hints for transfer_read/write/load/store
Signed-off-by: dchigarev <[email protected]>
1 parent 2a43ee6 commit e0cf57b

File tree

5 files changed

+208
-32
lines changed

5 files changed

+208
-32
lines changed

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 50 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -445,12 +445,16 @@ static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
445445
Value mask = vector::ConstantMaskOp::create(
446446
rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
447447
vectorShape);
448-
auto gatherOp = xegpu::LoadGatherOp::create(
449-
rewriter, loc, vectorType, flatMemref, localOffsets, mask,
450-
/*chunk_size=*/IntegerAttr{},
451-
/*l1_hint=*/xegpu::CachePolicyAttr{},
452-
/*l2_hint=*/xegpu::CachePolicyAttr{},
453-
/*l3_hint=*/xegpu::CachePolicyAttr{});
448+
SmallVector<xegpu::CachePolicyAttr, 3> cacheHints = getOpCacheHints(readOp);
449+
auto gatherOp = xegpu::LoadGatherOp::create(rewriter, loc, vectorType,
450+
flatMemref, localOffsets, mask,
451+
/*chunk_size=*/IntegerAttr{},
452+
/*l1_hint=*/cacheHints[0],
453+
/*l2_hint=*/cacheHints[1],
454+
/*l3_hint=*/cacheHints[2]);
455+
auto resLayout = xegpu::getDistributeLayoutAttr(readOp.getResult());
456+
xegpu::setDistributeLayoutAttrs(gatherOp,
457+
[&](Value val) { return resLayout; });
454458

455459
rewriter.replaceOp(readOp, gatherOp.getResult());
456460
return success();
@@ -479,12 +483,16 @@ static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
479483
Value mask = vector::ConstantMaskOp::create(
480484
rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
481485
vectorShape);
482-
xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref,
483-
localOffsets, mask,
484-
/*chunk_size=*/IntegerAttr{},
485-
/*l1_hint=*/xegpu::CachePolicyAttr{},
486-
/*l2_hint=*/xegpu::CachePolicyAttr{},
487-
/*l3_hint=*/xegpu::CachePolicyAttr{});
486+
auto cacheHints = getOpCacheHints(writeOp);
487+
auto storeOp = xegpu::StoreScatterOp::create(
488+
rewriter, loc, writeOp.getVector(), flatMemref, localOffsets, mask,
489+
/*chunk_size=*/IntegerAttr{},
490+
/*l1_hint=*/cacheHints[0],
491+
/*l2_hint=*/cacheHints[1],
492+
/*l3_hint=*/cacheHints[2]);
493+
auto valueLayout = xegpu::getDistributeLayoutAttr(writeOp->getOpOperand(0));
494+
xegpu::setDistributeLayoutAttrs(storeOp,
495+
[&](Value val) { return valueLayout; });
488496
rewriter.eraseOp(writeOp);
489497
return success();
490498
}
@@ -534,9 +542,11 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
534542
SmallVector<int64_t> descShape(vecTy.getShape());
535543
if (isTransposeLoad)
536544
std::reverse(descShape.begin(), descShape.end());
537-
auto descType = xegpu::TensorDescType::get(
538-
descShape, elementType, /*array_length=*/1,
539-
/*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
545+
auto resLayout = xegpu::getDistributeLayoutAttr(readOp.getResult());
546+
auto descType =
547+
xegpu::TensorDescType::get(descShape, elementType, /*array_length=*/1,
548+
/*boundary_check=*/isOutOfBounds,
549+
xegpu::MemorySpace::Global, resLayout);
540550

541551
xegpu::CreateNdDescOp ndDesc =
542552
createNdDescriptor(rewriter, loc, descType,
@@ -547,12 +557,12 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
547557
!isTransposeLoad ? nullptr
548558
: DenseI64ArrayAttr::get(rewriter.getContext(),
549559
ArrayRef<int64_t>{1, 0});
550-
// By default, no specific caching policy is assigned.
551-
xegpu::CachePolicyAttr hint = nullptr;
560+
auto cacheHints = getOpCacheHints(readOp);
552561
auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
553562
/*packed=*/nullptr, transposeAttr,
554-
/*l1_hint=*/hint,
555-
/*l2_hint=*/hint, /*l3_hint=*/hint);
563+
/*l1_hint=*/cacheHints[0],
564+
/*l2_hint=*/cacheHints[1],
565+
/*l3_hint=*/cacheHints[2]);
556566
rewriter.replaceOp(readOp, loadOp);
557567

558568
return success();
@@ -590,21 +600,24 @@ struct TransferWriteLowering
590600
if (!map.isMinorIdentity())
591601
return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
592602

603+
auto valLayout = xegpu::getDistributeLayoutAttr(writeOp->getOpOperand(0));
593604
auto descType = xegpu::TensorDescType::get(
594605
vecTy.getShape(), vecTy.getElementType(),
595606
/*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
596-
xegpu::MemorySpace::Global);
607+
xegpu::MemorySpace::Global, valLayout);
597608
xegpu::CreateNdDescOp ndDesc =
598609
createNdDescriptor(rewriter, loc, descType,
599610
dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()),
600611
writeOp.getIndices());
601612

602613
// By default, no specific caching policy is assigned.
603614
xegpu::CachePolicyAttr hint = nullptr;
615+
auto cacheHints = getOpCacheHints(writeOp);
604616
auto storeOp =
605617
xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
606-
/*l1_hint=*/hint,
607-
/*l2_hint=*/hint, /*l3_hint=*/hint);
618+
/*l1_hint=*/cacheHints[0],
619+
/*l2_hint=*/cacheHints[1],
620+
/*l3_hint=*/cacheHints[2]);
608621
rewriter.replaceOp(writeOp, storeOp);
609622

610623
return success();
@@ -720,18 +733,20 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
720733
// Boundary check is available only for block instructions.
721734
bool boundaryCheck = vecTy.getRank() > 1;
722735

736+
auto resLayout = xegpu::getDistributeLayoutAttr(loadOp.getResult());
723737
auto descType = xegpu::TensorDescType::get(
724738
vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
725-
boundaryCheck, xegpu::MemorySpace::Global);
739+
boundaryCheck, xegpu::MemorySpace::Global, resLayout);
726740
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
727741
rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
728742

729743
// By default, no specific caching policy is assigned.
730744
xegpu::CachePolicyAttr hint = nullptr;
745+
auto cacheHints = getOpCacheHints(loadOp);
731746
auto loadNdOp = xegpu::LoadNdOp::create(
732747
rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
733-
/*l1_hint=*/hint,
734-
/*l2_hint=*/hint, /*l3_hint=*/hint);
748+
/*l1_hint=*/cacheHints[0],
749+
/*l2_hint=*/cacheHints[1], /*l3_hint=*/cacheHints[2]);
735750
rewriter.replaceOp(loadOp, loadNdOp);
736751

737752
return success();
@@ -753,18 +768,21 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
753768
// Boundary check is available only for block instructions.
754769
bool boundaryCheck = vecTy.getRank() > 1;
755770

756-
auto descType = xegpu::TensorDescType::get(
757-
vecTy.getShape(), vecTy.getElementType(),
758-
/*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
771+
auto valLayout = xegpu::getDistributeLayoutAttr(storeOp->getOpOperand(0));
772+
auto descType =
773+
xegpu::TensorDescType::get(vecTy.getShape(), vecTy.getElementType(),
774+
/*array_length=*/1, boundaryCheck,
775+
xegpu::MemorySpace::Global, valLayout);
759776
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
760777
rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
761778

762779
// By default, no specific caching policy is assigned.
763780
xegpu::CachePolicyAttr hint = nullptr;
764-
auto storeNdOp =
765-
xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
766-
/*l1_hint=*/hint,
767-
/*l2_hint=*/hint, /*l3_hint=*/hint);
781+
auto cacheHints = getOpCacheHints(storeOp);
782+
auto storeNdOp = xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
783+
/*l1_hint=*/cacheHints[0],
784+
/*l2_hint=*/cacheHints[1],
785+
/*l3_hint=*/cacheHints[2]);
768786
rewriter.replaceOp(storeOp, storeNdOp);
769787

770788
return success();

mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,35 @@ func.func @load_out_of_bounds(%source: memref<7x15xf32>,
7979

8080
// -----
8181

82+
func.func @load_2D_layout(%source: memref<8x16x32xf32>,
83+
%offset: index) -> vector<8x16xf32> {
84+
%0 = vector.load %source[%offset, %offset, %offset] {layout_result_0 = #xegpu.layout<sg_layout = [8, 16]>}
85+
: memref<8x16x32xf32>, vector<8x16xf32>
86+
return %0 : vector<8x16xf32>
87+
}
88+
89+
// CHECK-LABEL: @load_2D_layout(
90+
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc {{[^:]*}} :
91+
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<sg_layout = [8, 16]>>
92+
93+
// -----
94+
95+
func.func @load_2D_cache_hints(%source: memref<8x16x32xf32>,
96+
%offset: index) -> vector<8x16xf32> {
97+
%0 = vector.load %source[%offset, %offset, %offset] {
98+
l1_hint = #xegpu.cache_hint<cached>,
99+
l2_hint = #xegpu.cache_hint<uncached>,
100+
l3_hint = #xegpu.cache_hint<streaming>
101+
}: memref<8x16x32xf32>, vector<8x16xf32>
102+
return %0 : vector<8x16xf32>
103+
}
104+
105+
// CHECK-LABEL: @load_2D_cache_hints(
106+
// CHECK: xegpu.load_nd {{[^<]*}}
107+
// CHECK-SAME: <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, l3_hint = #xegpu.cache_hint<streaming>}>
108+
109+
// -----
110+
82111
func.func @no_load_high_dim_vector(%source: memref<16x32x64xf32>,
83112
%offset: index) -> vector<8x16x32xf32> {
84113
%0 = vector.load %source[%offset, %offset, %offset]

mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,36 @@ func.func @store_out_of_bounds(%vec: vector<8x16xf32>,
8080

8181
// -----
8282

83+
func.func @store_2D_layouts(%vec: vector<8x16xf32>,
84+
%source: memref<8x16x32xf32>, %offset: index) {
85+
vector.store %vec, %source[%offset, %offset, %offset] {layout_operand_0 = #xegpu.layout<sg_layout = [8, 16]>}
86+
: memref<8x16x32xf32>, vector<8x16xf32>
87+
return
88+
}
89+
90+
// CHECK-LABEL: @store_2D_layouts(
91+
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc {{[^:]*}} :
92+
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<sg_layout = [8, 16]>>
93+
94+
// -----
95+
96+
func.func @store_2D_cache_hints(%vec: vector<8x16xf32>,
97+
%source: memref<8x16x32xf32>, %offset: index) {
98+
vector.store %vec, %source[%offset, %offset, %offset] {
99+
l1_hint = #xegpu.cache_hint<cached>,
100+
l2_hint = #xegpu.cache_hint<uncached>,
101+
l3_hint = #xegpu.cache_hint<write_back>
102+
}
103+
: memref<8x16x32xf32>, vector<8x16xf32>
104+
return
105+
}
106+
107+
// CHECK-LABEL: @store_2D_cache_hints(
108+
// CHECK: xegpu.store_nd {{[^<]*}}
109+
// CHECK-SAME: <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, l3_hint = #xegpu.cache_hint<write_back>}>
110+
111+
// -----
112+
83113
func.func @no_store_high_dim_vector(%vec: vector<8x16x32xf32>,
84114
%source: memref<16x32x64xf32>, %offset: index) {
85115
vector.store %vec, %source[%offset, %offset, %offset]

mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,3 +441,52 @@ gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2:
441441
// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
442442
// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf16>
443443
}
444+
445+
// -----
446+
gpu.module @xevm_module {
447+
gpu.func @load_2D_layouts(%source: memref<8x16x32xf32>,
448+
%offset: index) -> vector<8x16xf32> {
449+
%c0 = arith.constant 0.0 : f32
450+
%0 = vector.transfer_read %source[%offset, %offset, %offset], %c0
451+
{
452+
in_bounds = [true, true],
453+
layout_result_0 = #xegpu.layout<sg_layout = [8, 16]>
454+
} : memref<8x16x32xf32>, vector<8x16xf32>
455+
gpu.return %0 : vector<8x16xf32>
456+
}
457+
458+
// LOAD-ND-LABEL: @load_2D_layouts(
459+
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc {{[^:]*}} :
460+
// LOAD-ND-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
461+
// LOAD-ND-SAME: #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [8, 16]>>
462+
463+
// LOAD-GATHER-LABEL: @load_2D_layouts(
464+
// LOAD-GATHER: %[[VEC:.+]] = xegpu.load {{[^{]*}}
465+
// LOAD-GATHER-SAME {layout_operand_0 = #xegpu.layout_attr<sg_layout = [8, 16]>,
466+
// LOAD-GATHER-SAME layout_operand_1 = #xegpu.layout_attr<sg_layout = [8, 16]>,
467+
// LOAD-GATHER-SAME layout_result_0 = #xegpu.layout_attr<sg_layout = [8, 16]>} : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
468+
}
469+
470+
// -----
471+
gpu.module @xevm_module {
472+
gpu.func @load_2D_cache_hints(%source: memref<8x16x32xf32>,
473+
%offset: index) -> vector<8x16xf32> {
474+
%c0 = arith.constant 0.0 : f32
475+
%0 = vector.transfer_read %source[%offset, %offset, %offset], %c0
476+
{
477+
in_bounds = [true, true],
478+
l1_hint = #xegpu.cache_hint<cached>,
479+
l2_hint = #xegpu.cache_hint<uncached>,
480+
l3_hint = #xegpu.cache_hint<streaming>
481+
} : memref<8x16x32xf32>, vector<8x16xf32>
482+
gpu.return %0 : vector<8x16xf32>
483+
}
484+
485+
// LOAD-ND-LABEL: @load_2D_cache_hints(
486+
// LOAD-ND: xegpu.load_nd {{[^<]*}}
487+
// LOAD-ND-SAME: <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, l3_hint = #xegpu.cache_hint<streaming>}>
488+
489+
// LOAD-GATHER-LABEL: @load_2D_cache_hints(
490+
// LOAD-GATHER: xegpu.load {{[^<]*}}
491+
// LOAD-GATHER-SAME: <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, l3_hint = #xegpu.cache_hint<streaming>}>
492+
}

mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,53 @@ gpu.func @store_to_subview(%vec: vector<8xf16>,
326326
// STORE-SCATTER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
327327
// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8xf16>, i64, vector<8xindex>, vector<8xi1>
328328
}
329+
330+
// -----
331+
gpu.module @xevm_module {
332+
gpu.func @store_2D_layout(%vec: vector<8x16xf32>,
333+
%source: memref<8x16x32xf32>, %offset: index) {
334+
vector.transfer_write %vec, %source[%offset, %offset, %offset]
335+
{
336+
in_bounds = [true, true],
337+
layout_operand_0 = #xegpu.layout<sg_layout = [8, 16]>
338+
}
339+
: vector<8x16xf32>, memref<8x16x32xf32>
340+
gpu.return
341+
}
342+
343+
// STORE-ND-LABEL: @store_2D_layout(
344+
// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc {{[^:]*}} :
345+
// STORE-ND-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
346+
// STORE-ND-SAME: #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [8, 16]>>
347+
348+
// STORE-SCATTER-LABEL: @store_2D_layout(
349+
// STORE-SCATTER: xegpu.store {{[^{]*}}
350+
// STORE-SCATTER-SAME {layout_operand_0 = #xegpu.layout_attr<sg_layout = [8, 16]>,
351+
// STORE-SCATTER-SAME layout_operand_1 = #xegpu.layout_attr<sg_layout = [8, 16]>,
352+
// STORE-SCATTER-SAME layout_result_0 = #xegpu.layout_attr<sg_layout = [8, 16]>} : vector<8x16xf32>, i64, vector<8x16xindex>, vector<8x16xi1>
353+
}
354+
355+
// -----
356+
gpu.module @xevm_module {
357+
gpu.func @store_2D_cache_hints(%vec: vector<8x16xf32>,
358+
%source: memref<8x16x32xf32>, %offset: index) {
359+
vector.transfer_write %vec, %source[%offset, %offset, %offset]
360+
{
361+
in_bounds = [true, true],
362+
l1_hint = #xegpu.cache_hint<cached>,
363+
l2_hint = #xegpu.cache_hint<uncached>,
364+
l3_hint = #xegpu.cache_hint<write_back>
365+
}
366+
: vector<8x16xf32>, memref<8x16x32xf32>
367+
gpu.return
368+
}
369+
370+
// STORE-ND-LABEL: @store_2D_cache_hints(
371+
// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc {{[^:]*}} :
372+
// STORE-ND: xegpu.store_nd {{[^<]*}}
373+
// STORE-ND-SAME: <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, l3_hint = #xegpu.cache_hint<write_back>}>
374+
375+
// STORE-SCATTER-LABEL: @store_2D_cache_hints(
376+
// STORE-SCATTER: xegpu.store {{[^<]*}}
377+
// STORE-SCATTER-SAME: <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, l3_hint = #xegpu.cache_hint<write_back>}>
378+
}

0 commit comments

Comments
 (0)