diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 91d6b2a5ead9b..75b16a87e03c6 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -628,35 +628,71 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> { As compared to prefetch_nd, which works on non-scattered TensorDesc, it works on scattered TensorDesc instead. - Example: + Example 1: ```mlir xegpu.prefetch %tdesc {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16xf16> ``` + + Example 2: + A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. + It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc". + The source operand could be a raw pointer (uint64_t). + Please refer to create_tdesc for the restriction of memref. + ```mlir + %a = memref.alloc() : memref<1024xf32> + %0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex> + xegpu.prefetch %a[%0] {l1_hint = #xegpu.cache_hint, + l2_hint = #xegpu.cache_hint, + l3_hint = #xegpu.cache_hint} + : memref<1024xf32>, vector<4xindex> + ``` }]; - let arguments = (ins XeGPU_TensorDesc: $TensorDesc, + let arguments = (ins XeGPU_GatherScatterSourceType: $source, + Optional: $offsets, OptionalAttr: $l1_hint, OptionalAttr: $l2_hint, OptionalAttr: $l3_hint); let extraClassDeclaration = extraBaseClassDeclaration # [{ + Type getSourceType() { + return getSource().getType(); + } + + TypedValue getTensorDesc() { + if (auto tdescType = getTensorDescType()) { + return llvm::cast>(getSource()); + } + return TypedValue(); + } + xegpu::TensorDescType getTensorDescType() { - return getTensorDesc().getType(); + return dyn_cast(getSourceType()); } }]; - let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))"; + let assemblyFormat = [{ + $source + (`[` $offsets^ `]`)? + prop-dict + attr-dict `:` type(operands) + }]; + + let builders = [ + OpBuilder<(ins "Value": $source, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint)> + ]; let hasVerifier = 1; } -def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ - AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemRead]> - ]> { +def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> { let summary = "load a set of scattered data points from memory."; let description = [{ It (aka. load) load data per each work-item. The output @@ -687,6 +723,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16x8xf32> ``` + Example 3 (SIMT mode): ```mlir %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint, @@ -695,19 +732,48 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr> vector<16xi1> -> vector<8xf32> ``` + + Example 4: + A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. + It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc". + The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc + for the restriction of memref. + ```mlir + %a = memref.alloc() : memref<1024xf32> + %offsets = vector.step : vector<16xindex> + %mask = vector.constant_mask [16]: vector<16xi1> + %val = xegpu.load %a[%offsets], %mask {l1_hint = #xegpu.cache_hint, + l2_hint = #xegpu.cache_hint, + l3_hint = #xegpu.cache_hint} + : memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32> + ``` }]; - let arguments = (ins XeGPU_TensorDesc: $TensorDesc, + let arguments = (ins XeGPU_GatherScatterSourceType: $source, + Optional: $offsets, XeGPU_MaskType: $mask, + OptionalAttr: $chunk_size, OptionalAttr: $l1_hint, OptionalAttr: $l2_hint, OptionalAttr: $l3_hint); let results = (outs XeGPU_ValueType: $value); let extraClassDeclaration = extraBaseClassDeclaration # [{ + + Type getSourceType() { + return getSource().getType(); + } + + TypedValue getTensorDesc() { + if (auto tdescType = getTensorDescType()) { + return llvm::cast>(getSource()); + } + return TypedValue(); + } + xegpu::TensorDescType getTensorDescType() { - return getTensorDesc().getType(); + return dyn_cast(getSourceType()); } mlir::Type getElementType() { @@ -725,15 +791,24 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ }]; - let assemblyFormat = [{$TensorDesc `,` $mask prop-dict attr-dict - `:` qualified(type($TensorDesc)) `,` type($mask) `->` type($value)}]; + let assemblyFormat = [{ + $source + (`[` $offsets^ `]`)? `,` + $mask prop-dict + attr-dict `:` type(operands) `->` type($value) + }]; + + let builders = [ + OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint)> + ]; let hasVerifier = 1; } -def XeGPU_StoreScatterOp : XeGPU_Op<"store", [ - AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemWrite]> - ]> { +def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> { let summary = "store data to scattered memory locations."; let description = [{ It (aka. store) stores data to scattered memory locations. The value is typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be @@ -768,19 +843,49 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [ l3_hint = #xegpu.cache_hint}> : vector<8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr> vector<16xi1> ``` + + Example 4: + A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. + It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc". + The dest operand could be a raw pointer (uint64_t). + Please refer to create_tdesc for the restriction of memref. + ```mlir + %a = memref.alloc() : memref<1024xf32> + %val = arith.constant dense<0.0> : vector<16xf32> + %offsets = vector.step : vector<16xindex> + %mask = vector.constant_mask [16]: vector<16xi1> + xegpu.store %val, %a[%offsets], %mask {l1_hint = #xegpu.cache_hint, + l2_hint = #xegpu.cache_hint, + l3_hint = #xegpu.cache_hint} + : memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32> + ``` + }]; let arguments = (ins XeGPU_ValueType: $value, - XeGPU_TensorDesc: $TensorDesc, + XeGPU_GatherScatterSourceType: $dest, + Optional: $offsets, XeGPU_MaskType: $mask, + OptionalAttr: $chunk_size, OptionalAttr: $l1_hint, OptionalAttr: $l2_hint, OptionalAttr: $l3_hint); let extraClassDeclaration = extraBaseClassDeclaration # [{ + Type getDestType() { + return getDest().getType(); + } + + TypedValue getTensorDesc() { + if (auto tdescType = getTensorDescType()) { + return llvm::cast>(getDest()); + } + return TypedValue(); + } + xegpu::TensorDescType getTensorDescType() { - return getTensorDesc().getType(); + return dyn_cast(getDestType()); } VectorType getValueType() { @@ -792,8 +897,21 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [ } }]; - let assemblyFormat = [{$value `,` $TensorDesc `,` $mask prop-dict attr-dict - `:` type($value) `,` qualified(type($TensorDesc)) `,` type($mask)}]; + let assemblyFormat = [{ + $value `,` + $dest + (`[` $offsets^ `]`)? `,` + $mask + prop-dict + attr-dict `:` type(operands) + }]; + + let builders = [ + OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint)> + ]; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index 20916ae9ef830..b268cabb5d266 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -189,6 +189,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", let genVerifyDecl = 1; } +def XeGPU_GatherScatterSourceType : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>; def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> { let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier."; diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 704deeaa1f26b..33450f3fa229e 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -110,6 +110,34 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy, return success(); } +static LogicalResult +isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy, + int64_t chunkSize, + function_ref emitError) { + + if (!valueTy) + return emitError() << "Expecting a vector type result."; + + auto maskShape = getShapeOf(maskTy); + auto valueShape = getShapeOf(valueTy); + + // a valid shape for SIMT case + if (valueTy.getRank() == 1) { + if (valueTy.getNumElements() != chunkSize) + return emitError() << "value elements must match chunk size " << chunkSize + << " for SIMT code."; + return success(); + } + + llvm::SmallVector expectedMaskShape(valueShape); + if (chunkSize > 1) + expectedMaskShape.pop_back(); + if (expectedMaskShape != maskShape) + return emitError() << "Mask should match value except the chunk size dim."; + + return success(); +} + //===----------------------------------------------------------------------===// // XeGPU_CreateNdDescOp //===----------------------------------------------------------------------===// @@ -644,9 +672,14 @@ LogicalResult CreateDescOp::verify() { //===----------------------------------------------------------------------===// LogicalResult PrefetchOp::verify() { auto tdescTy = getTensorDescType(); - if (!tdescTy.isScattered()) + + if (tdescTy && !tdescTy.isScattered()) return emitOpError("Expects a scattered TensorDesc.\n"); + if (!tdescTy && getRankOf(getSource()) > 1) + return emitOpError( + "Expecting the source is a 1D memref or pointer (uint64_t)."); + if (!isReadHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); @@ -659,6 +692,13 @@ LogicalResult PrefetchOp::verify() { return success(); } +void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint); +} + //===----------------------------------------------------------------------===// // XeGPU_LoadGatherOp //===----------------------------------------------------------------------===// @@ -667,6 +707,13 @@ LogicalResult LoadGatherOp::verify() { auto maskTy = getMaskType(); auto valueTy = getValueType(); + if (tdescTy && !tdescTy.isScattered()) + return emitOpError("Expects a scattered TensorDesc."); + + if (!tdescTy && getRankOf(getSource()) > 1) + return emitOpError( + "Expecting the source is a 1D memref or pointer (uint64_t)."); + if (!isReadHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); @@ -676,8 +723,27 @@ LogicalResult LoadGatherOp::verify() { if (!isReadHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); - return isValidGatherScatterParams(maskTy, valueTy, tdescTy, - [&]() { return emitOpError(); }); + if (tdescTy) + return isValidGatherScatterParams(maskTy, valueTy, tdescTy, + [&]() { return emitOpError(); }); + auto srcTy = getSourceType(); + uint64_t chunkSize = static_cast(getChunkSize().value_or(1)); + auto memTy = dyn_cast(srcTy); + + if (memTy && (valueTy.getElementType() != memTy.getElementType())) + return emitError() << "Value should have the same element type as MemRef."; + + return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize, + [&]() { return emitOpError(); }); +} + +void LoadGatherOp::build(OpBuilder &builder, OperationState &state, + Type valueType, Value source, Value mask, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + build(builder, state, valueType, source, Value(), mask, IntegerAttr(), + l1_hint, l2_hint, l3_hint); } //===----------------------------------------------------------------------===// @@ -688,6 +754,13 @@ LogicalResult StoreScatterOp::verify() { auto maskTy = getMaskType(); auto valueTy = getValueType(); + if (tdescTy && !tdescTy.isScattered()) + return emitOpError("Expects a scattered TensorDesc.\n"); + + if (!tdescTy && getRankOf(getDest()) > 1) + return emitOpError( + "Expecting the dest is a 1D memref or pointer (uint64_t)."); + if (!isWriteHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); @@ -697,8 +770,28 @@ LogicalResult StoreScatterOp::verify() { if (!isWriteHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); - return isValidGatherScatterParams(maskTy, valueTy, tdescTy, - [&]() { return emitOpError(); }); + if (tdescTy) + return isValidGatherScatterParams(maskTy, valueTy, tdescTy, + [&]() { return emitOpError(); }); + + auto destTy = getDestType(); + uint64_t chunkSize = static_cast(getChunkSize().value_or(1)); + auto memTy = dyn_cast(destTy); + + if (memTy && (valueTy.getElementType() != memTy.getElementType())) + return emitError() << "Value should have the same element type as MemRef."; + + return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize, + [&]() { return emitOpError(); }); +} + +void StoreScatterOp::build(OpBuilder &builder, OperationState &state, + Value value, Value dest, Value mask, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint, + l2_hint, l3_hint); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index ec8fad484ed3e..c793b71639e86 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -481,7 +481,8 @@ struct UnrollLoadGatherOp : public UnrollPattern { VectorType valueTy = llvm::dyn_cast(op.getValue().getType()); xegpu::TensorDescType tdescTy = op.getTensorDescType(); - if (!tdescTy.isScattered()) + // TODO: handle the unstructure source case (!tdesTy) + if (!tdescTy || op.getOffsets()) return failure(); std::optional> targetShape = getTargetShape(op); @@ -543,7 +544,8 @@ struct UnrollPrefetchOp : public UnrollPattern { Location loc = op.getLoc(); xegpu::TensorDescType tdescTy = op.getTensorDescType(); - if (!tdescTy.isScattered()) + // TODO: handle the unstructure source case (!tdesTy) + if (!tdescTy || op.getOffsets()) return failure(); std::optional> targetShape = getTargetShape(op); @@ -572,7 +574,8 @@ struct UnrollStoreScatterOp : public UnrollPattern { VectorType valueTy = llvm::dyn_cast(op.getValue().getType()); xegpu::TensorDescType tdescTy = op.getTensorDescType(); - if (!tdescTy.isScattered()) + // TODO: handle the unstructure source case (!tdesTy) + if (!tdescTy || op.getOffsets()) return failure(); std::optional> targetShape = getTargetShape(op); diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 0160bfee07bf2..dff3ffab39ecf 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -384,6 +384,74 @@ func.func @load_gather_vc_3(%src: ui64) { return } +// ----- +func.func @prefetch_offset_wi_1(%src: memref<4x4xf32>) { + %offsets = arith.constant dense<[0]> : vector<1xindex> + // expected-error@+1 {{Expecting the source is a 1D memref or pointer}} + xegpu.prefetch %src[%offsets]: memref<4x4xf32>, vector<1xindex> + return +} + +// ----- +func.func @load_gather_offset_sg(%src: memref) { + %offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %mask = arith.constant dense<1>: vector<8xi1> + // expected-error@+1 {{Mask should match value except the chunk size dim}} + %2 = xegpu.load %src[%offsets], %mask + : memref, vector<4xindex>, vector<8xi1> + -> vector<4x2xf16> + return +} + +// ----- +func.func @load_gather_offset_wi(%src: ui64) { + %mask = arith.constant dense<1>: vector<1xi1> + %offsets = arith.constant dense<[0]> : vector<1xindex> + // expected-error@+1 {{value elements must match chunk size}} + %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf32> + return +} + +// ----- +func.func @store_scatter_offset_wi_1(%src: memref) { + %val = arith.constant dense<2.9>: vector<4xf16> + %offsets = arith.constant dense<[0]> : vector<1xindex> + %mask = arith.constant dense<1>: vector<1xi1> + // expected-error@+1 {{value elements must match chunk size}} + xegpu.store %val, %src[%offsets], %mask + : vector<4xf16>, memref, vector<1xindex>, vector<1xi1> + return +} + +// ----- +func.func @store_scatter_offset_wi_2(%src: memref<4x4xf16>) { + %val = arith.constant dense<2.9>: vector<4xf16> + %offsets = arith.constant dense<[0]> : vector<1xindex> + %mask = arith.constant dense<1>: vector<1xi1> + // expected-error@+1 {{Expecting the dest is a 1D memref or pointer}} + xegpu.store %val, %src[%offsets], %mask + : vector<4xf16>, memref<4x4xf16>, vector<1xindex>, vector<1xi1> + return +} + +// ----- +func.func @load_gather_offset_wi_2(%src: ui64) { + %mask = arith.constant dense<1>: vector<1xi1> + %offsets = arith.constant dense<[0]> : vector<1xindex> + // expected-error@+1 {{value elements must match chunk size}} + %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf16> + return +} + +// ----- +func.func @load_gather_offset_wi_1(%src: memref<4x4xf32>) { + %mask = arith.constant dense<1>: vector<1xi1> + %offsets = arith.constant dense<[0]> : vector<1xindex> + // expected-error@+1 {{Expecting the source is a 1D memref or pointer}} + %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : memref<4x4xf32>, vector<1xindex>, vector<1xi1> -> vector<2xf32> + return +} + // ----- func.func @store_scatter_vc_1(%src: memref<24x32xf32>) { %0 = arith.constant dense<1>: vector<4xi1> diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index 3ebb1b969ac74..6be2371d4d7b2 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -521,6 +521,16 @@ gpu.func @subgroup_load_4(%src: ui64) { gpu.return } +// CHECK: gpu.func @subgroup_load_offset_1(%arg0: memref) { +gpu.func @subgroup_load_offset_1(%src: memref) { + %offset = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %mask = arith.constant dense<1>: vector<4xi1> + //CHECK: %[[R1:.*]] = xegpu.load %arg0[%cst], %cst_0 <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint}> : memref, vector<4xindex>, vector<4xi1> -> vector<4x2xf16> + %val = xegpu.load %src[%offset], %mask <{chunk_size=2, l1_hint = #xegpu.cache_hint}> + : memref, vector<4xindex>, vector<4xi1> -> vector<4x2xf16> + gpu.return +} + // CHECK: gpu.func @subgroup_store(%[[arg0:.*]]: ui64) { gpu.func @subgroup_store(%src: ui64) { //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> @@ -626,6 +636,17 @@ gpu.func @subgroup_store_4(%src: ui64) { gpu.return } +// CHECK: gpu.func @subgroup_store_offset_1(%arg0: memref) { +gpu.func @subgroup_store_offset_1(%dest: memref) { + %val = arith.constant dense<2.9>: vector<4x2xf16> + %offset = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %mask = arith.constant dense<1>: vector<4xi1> + //CHECK: xegpu.store %[[R0:.*]], %arg0[%cst_0], %cst_1 <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint}> : vector<4x2xf16>, memref, vector<4xindex>, vector<4xi1> + xegpu.store %val, %dest[%offset], %mask <{chunk_size=2, l1_hint = #xegpu.cache_hint}> + : vector<4x2xf16>, memref, vector<4xindex>, vector<4xi1> + gpu.return +} + // CHECK: gpu.func @prefetch(%[[arg0:.*]]: ui64) { gpu.func @prefetch(%src: ui64) { //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> @@ -637,6 +658,14 @@ gpu.func @prefetch(%src: ui64) { gpu.return } +// CHECK: gpu.func @prefetch_offset(%[[arg0:.*]]: ui64) { +gpu.func @prefetch_offset(%src: ui64) { + //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + // CHECK: xegpu.prefetch %[[arg0]][%cst] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : ui64, vector<4xindex> + xegpu.prefetch %src[%0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: ui64, vector<4xindex> + gpu.return +} // CHECK: gpu.func @create_update_tdesc(%[[arg0:.*]]: ui64) { gpu.func @create_update_tdesc(%src: ui64) {