Skip to content

Commit 76c8129

Browse files
committed
add cache hint propagation
Signed-off-by: dchigarev <[email protected]>
1 parent a22b251 commit 76c8129

File tree

3 files changed

+110
-6
lines changed

3 files changed

+110
-6
lines changed

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,20 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
9797
return success();
9898
}
9999

100+
// Extract cache hints from the op attributes if available.
101+
static void getOpCacheHints(Operation *op,
102+
SmallVector<xegpu::CachePolicyAttr, 3> &hints) {
103+
assert(hints.size() == 3 &&
104+
"Expecting a vector of size 3 for l1, l2, l3 hints.");
105+
// get l1, l2, l3 hints from attributes if available.
106+
if (auto l1Attr = op->getAttrOfType<xegpu::CachePolicyAttr>("l1_hint"))
107+
hints[0] = l1Attr;
108+
if (auto l2Attr = op->getAttrOfType<xegpu::CachePolicyAttr>("l2_hint"))
109+
hints[1] = l2Attr;
110+
if (auto l3Attr = op->getAttrOfType<xegpu::CachePolicyAttr>("l3_hint"))
111+
hints[2] = l3Attr;
112+
}
113+
100114
static xegpu::CreateNdDescOp
101115
createNdDescriptor(PatternRewriter &rewriter, Location loc,
102116
xegpu::TensorDescType descType, TypedValue<MemRefType> src,
@@ -631,12 +645,17 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
631645
gatherOp->getOpOperand(numOffsets + 2));
632646
auto layoutPassThru = mlir::xegpu::getDistributeLayoutAttr(
633647
gatherOp->getOpOperand(numOffsets + 3));
648+
649+
SmallVector<xegpu::CachePolicyAttr, 3> cacheHints{xegpu::CachePolicyAttr{},
650+
xegpu::CachePolicyAttr{},
651+
xegpu::CachePolicyAttr{}};
652+
getOpCacheHints(gatherOp, cacheHints);
634653
auto xeGatherOp = xegpu::LoadGatherOp::create(
635654
rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
636655
/*chunk_size=*/IntegerAttr{},
637-
/*l1_hint=*/xegpu::CachePolicyAttr{},
638-
/*l2_hint=*/xegpu::CachePolicyAttr{},
639-
/*l3_hint=*/xegpu::CachePolicyAttr{});
656+
/*l1_hint=*/cacheHints[0],
657+
/*l2_hint=*/cacheHints[1],
658+
/*l3_hint=*/cacheHints[2]);
640659
mlir::xegpu::setDistributeLayoutAttr(xeGatherOp->getOpResult(0), layoutRes);
641660
mlir::xegpu::setDistributeLayoutAttr(xeGatherOp->getOpOperand(1),
642661
layoutIndices);
@@ -682,13 +701,17 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
682701
scatterOp->getOpOperand(numOffsets + 2));
683702
auto layoutVal = mlir::xegpu::getDistributeLayoutAttr(
684703
scatterOp->getOpOperand(numOffsets + 3));
704+
SmallVector<xegpu::CachePolicyAttr, 3> cacheHints{xegpu::CachePolicyAttr{},
705+
xegpu::CachePolicyAttr{},
706+
xegpu::CachePolicyAttr{}};
707+
getOpCacheHints(scatterOp, cacheHints);
685708
auto storeOp = xegpu::StoreScatterOp::create(
686709
rewriter, loc, scatterOp.getValueToStore(), flatMemref, localOffsets,
687710
scatterOp.getMask(),
688711
/*chunk_size=*/IntegerAttr{},
689-
/*l1_hint=*/xegpu::CachePolicyAttr{},
690-
/*l2_hint=*/xegpu::CachePolicyAttr{},
691-
/*l3_hint=*/xegpu::CachePolicyAttr{});
712+
/*l1_hint=*/cacheHints[0],
713+
/*l2_hint=*/cacheHints[1],
714+
/*l3_hint=*/cacheHints[2]);
692715
mlir::xegpu::setDistributeLayoutAttr(storeOp->getOpOperand(0), layoutVal);
693716
mlir::xegpu::setDistributeLayoutAttr(storeOp->getOpOperand(2),
694717
layoutIndices);

mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ gpu.func @non_unit_inner_stride_3D(
253253
// -----
254254

255255
gpu.module @xevm_module {
256+
// Layouts are only specified for the gather op itself.
256257
gpu.func @load_dynamic_layout_operands(%source: memref<?x?xf32>,
257258
%off0: index, %off1: index,
258259
%indices: vector<8x16xindex>, %mask: vector<8x16xi1>,
@@ -270,6 +271,7 @@ gpu.func @load_dynamic_layout_operands(%source: memref<?x?xf32>,
270271
// CHECK-SAME: %[[SRC:.+]]: memref<?x?xf32>,
271272
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index,
272273
// CHECK-SAME: %[[INDICES:.+]]: vector<8x16xindex>, %[[MASK:.+]]: vector<8x16xi1>, %[[PASS:.+]]: vector<8x16xf32>) -> vector<8x16xf32> {
274+
// %indices producer doesn't have a layout, so as 'broadcast/add' ops computing linear index.
273275
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
274276
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<8x16xindex>
275277
// CHECK: %[[VEC:.+]] = xegpu.load %[[BASE_I64:.+]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]]
@@ -307,6 +309,7 @@ gpu.func @load_dynamic_layout_mixed(%source: memref<?x?x?xf32>,
307309
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
308310
// CHECK-SAME: %[[MASK:.+]]: vector<8x16xi1>) -> vector<8x16xf32> {
309311
// CHECK: %[[PASS_THRU:.+]] = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [0]>} dense<0.000000e+00> : vector<8x16xf32>
312+
// Verify that linear-indices computation uses layout from the 'indices' producer op (%2).
310313
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : index to vector<8x16xindex>
311314
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : vector<8x16xindex>
312315
// CHECK: %[[VEC:.+]] = xegpu.load %[[BASE_I64:.+]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]]
@@ -344,6 +347,7 @@ gpu.func @load_static_layout_mixed(%source: memref<8x16x32xf32>,
344347
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
345348
// CHECK-SAME: %[[MASK:.+]]: vector<8x16xi1>) -> vector<8x16xf32> {
346349
// CHECK: %[[PASS_THRU:.+]] = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [0]>} dense<0.000000e+00> : vector<8x16xf32>
350+
// Verify that linear-indices computation uses layout from the 'indices' producer op (%2).
347351
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : index to vector<8x16xindex>
348352
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : vector<8x16xindex>
349353
// CHECK: %[[VEC:.+]] = xegpu.load %[[BASE_I64:.+]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]]
@@ -381,6 +385,8 @@ gpu.func @load_dynamic_layout_mixed_override(%source: memref<?x?x?xf32>,
381385
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
382386
// CHECK-SAME: %[[MASK:.+]]: vector<8x16xi1>) -> vector<8x16xf32> {
383387
// CHECK: %[[PASS_THRU:.+]] = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [0]>} dense<0.000000e+00> : vector<8x16xf32>
388+
// Verify that linear-indices computation uses layout from the 'indices' producer op (%2)
389+
// and not it's overriden version from the scatter_op (sg_layout = [99])
384390
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : index to vector<8x16xindex>
385391
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : vector<8x16xindex>
386392
// CHECK: %[[VEC:.+]] = xegpu.load %[[BASE_I64:.+]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]]
@@ -390,3 +396,41 @@ gpu.func @load_dynamic_layout_mixed_override(%source: memref<?x?x?xf32>,
390396
// CHECK-SAME: {{{[^}]*}}layout_operand_0 = #xegpu.layout<sg_layout = [7]>,
391397
// CHECK-SAME: {{[^}]*}}layout_result_0 = #xegpu.layout<sg_layout = [6]>} : vector<8x16xi1>, vector<8x16xf32>
392398
}
399+
400+
// -----
401+
402+
gpu.module @xevm_module {
403+
gpu.func @load_with_cache_hints(%source: memref<8x16x32xf32>,
404+
%off1: index, %off2: index, %off3: index,
405+
%indices: vector<8xindex>, %mask: vector<8xi1>,
406+
%pass_thru: vector<8xf32>) -> vector<8xf32> {
407+
%0 = vector.gather %source[%off1, %off2, %off3][%indices], %mask,
408+
%pass_thru {
409+
l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>,
410+
l3_hint = #xegpu.cache_hint<streaming>
411+
} : memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32> into vector<8xf32>
412+
gpu.return %0 : vector<8xf32>
413+
}
414+
// CHECK-LABEL: @load_with_cache_hints(
415+
// CHECK: xegpu.load {{[^<]*}}
416+
// CHECK-SAME: <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, l3_hint = #xegpu.cache_hint<streaming>}>
417+
}
418+
419+
// -----
420+
421+
gpu.module @xevm_module {
422+
gpu.func @load_with_partial_cache_hints(%source: memref<8x16x32xf32>,
423+
%off1: index, %off2: index, %off3: index,
424+
%indices: vector<8xindex>, %mask: vector<8xi1>,
425+
%pass_thru: vector<8xf32>) -> vector<8xf32> {
426+
%0 = vector.gather %source[%off1, %off2, %off3][%indices], %mask,
427+
%pass_thru {
428+
l1_hint = #xegpu.cache_hint<cached>,
429+
l3_hint = #xegpu.cache_hint<streaming>
430+
} : memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32> into vector<8xf32>
431+
gpu.return %0 : vector<8xf32>
432+
}
433+
// CHECK-LABEL: @load_with_partial_cache_hints(
434+
// CHECK: xegpu.load {{[^<]*}}
435+
// CHECK-SAME: <{l1_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<streaming>}>
436+
}

mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ gpu.func @store_dynamic_layout_operands(%vec: vector<8x16xf32>, %source: memref<
222222
// CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>, %[[SRC:.+]]: memref<?x?xf32>,
223223
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index,
224224
// CHECK-SAME: %[[INDICES:.+]]: vector<8x16xindex>, %[[MASK:.+]]: vector<8x16xi1>) {
225+
// %indices producer doesn't have a layout, so as 'broadcast/add' ops computing linear index.
225226
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
226227
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<8x16xindex>
227228
// CHECK: xegpu.store %[[VEC]], %[[BASE_I64:.+]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]]
@@ -251,6 +252,7 @@ gpu.func @store_dynamic_layout_mixed(%source: memref<?x?x?xf32>,
251252
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
252253
// CHECK-SAME: %[[MASK:.+]]: vector<8x16xi1>) {
253254
// CHECK: %[[VEC:.+]] = arith.constant {layout_operand_0 = #xegpu.layout<sg_layout = [0]>} dense<1.000000e+00> : vector<8x16xf32>
255+
// Verify that linear-indices computation uses layout from the 'indices' producer op (%2).
254256
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : index to vector<8x16xindex>
255257
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : vector<8x16xindex>
256258
// CHECK: xegpu.store %[[VEC]], %[[BASE_I64:.+]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]]
@@ -280,6 +282,7 @@ gpu.func @store_static_layout_mixed(%source: memref<8x16x32xf32>,
280282
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
281283
// CHECK-SAME: %[[MASK:.+]]: vector<8x16xi1>) {
282284
// CHECK: %[[VEC:.+]] = arith.constant {layout_operand_0 = #xegpu.layout<sg_layout = [0]>} dense<1.000000e+00> : vector<8x16xf32>
285+
// Verify that linear-indices computation uses layout from the 'indices' producer op (%2).
283286
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : index to vector<8x16xindex>
284287
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : vector<8x16xindex>
285288
// CHECK: xegpu.store %[[VEC]], %[[BASE_I64:.+]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]]
@@ -310,9 +313,43 @@ gpu.func @store_dynamic_layout_mixed_override(%source: memref<?x?x?xf32>,
310313
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
311314
// CHECK-SAME: %[[MASK:.+]]: vector<8x16xi1>) {
312315
// CHECK: %[[VEC:.+]] = arith.constant {layout_operand_0 = #xegpu.layout<sg_layout = [0]>} dense<1.000000e+00> : vector<8x16xf32>
316+
// Verify that linear-indices computation uses layout from the 'indices' producer op (%2)
317+
// and not it's overriden version from the scatter_op (sg_layout = [99])
313318
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : index to vector<8x16xindex>
314319
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [5]>} : vector<8x16xindex>
315320
// CHECK: xegpu.store %[[VEC]], %[[BASE_I64:.+]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]]
316321
// CHECK-SAME: {{[^}]*}}layout_operand_2 = #xegpu.layout<sg_layout = [99]>,
317322
// CHECK-SAME: {{[^}]*}}layout_operand_3 = #xegpu.layout<sg_layout = [6]>}
318323
}
324+
325+
// -----
326+
327+
gpu.module @xevm_module {
328+
gpu.func @store_with_cache_hints(%vec: vector<8xf32>, %source: memref<8x16x32xf32>,
329+
%off1: index, %off2: index, %off3: index,
330+
%indices: vector<8xindex>, %mask: vector<8xi1>) {
331+
vector.scatter %source[%off1, %off2, %off3][%indices], %mask, %vec {
332+
l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, l3_hint = #xegpu.cache_hint<write_back>
333+
} : memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32>
334+
gpu.return
335+
}
336+
// CHECK-LABEL: @store_with_cache_hints(
337+
// CHECK: xegpu.store {{[^<]*}}
338+
// CHECK-SAME: <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, l3_hint = #xegpu.cache_hint<write_back>}>
339+
}
340+
341+
// -----
342+
343+
gpu.module @xevm_module {
344+
gpu.func @store_with_partial_cache_hints(%vec: vector<8xf32>, %source: memref<8x16x32xf32>,
345+
%off1: index, %off2: index, %off3: index,
346+
%indices: vector<8xindex>, %mask: vector<8xi1>) {
347+
vector.scatter %source[%off1, %off2, %off3][%indices], %mask, %vec {
348+
l1_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<write_back>
349+
} : memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32>
350+
gpu.return
351+
}
352+
// CHECK-LABEL: @store_with_partial_cache_hints(
353+
// CHECK: xegpu.store {{[^<]*}}
354+
// CHECK-SAME: <{l1_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<write_back>}>
355+
}

0 commit comments

Comments
 (0)