From 1373ffa1d836cb8401f5d24fca5c9283c2484d0e Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Thu, 17 Jul 2025 23:21:14 +0000 Subject: [PATCH 01/16] add optional offsets to nd load/store/prefetch --- .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 70 +++++++++++++++++-- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 32 +++++++-- mlir/test/Dialect/XeGPU/ops.mlir | 50 +++++++++++++ 3 files changed, 140 insertions(+), 12 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 81e25f7537cb0..e9f8437d7c102 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -29,9 +29,22 @@ class XeGPU_Op traits = []>: void printProperties(::mlir::MLIRContext *ctx, ::mlir::OpAsmPrinter &p, const Properties &prop, ::mlir::ArrayRef<::llvm::StringRef> elidedProps) { - Attribute propAttr = getPropertiesAsAttr(ctx, prop); - if (propAttr) - p << "<" << propAttr << ">"; + + DictionaryAttr propAttr = dyn_cast_if_present(getPropertiesAsAttr(ctx, prop)); + + // filter out the elidedProps from propAttr, and get the resultAttr + mlir::SmallVector filteredAttrs; + if (propAttr) { + for (auto namedAttr : propAttr.getValue()) { + if (llvm::is_contained(elidedProps, namedAttr.getName().strref())) + continue; + filteredAttrs.push_back(namedAttr); + } + } + + if (!filteredAttrs.empty()) { + p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">"; + } } static ::mlir::ParseResult parseProperties(::mlir::OpAsmParser &parser, @@ -288,6 +301,8 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> { }]; let arguments = (ins XeGPU_TensorDesc: $TensorDesc, + Variadic: $offsets, + OptionalAttr: $const_offsets, OptionalAttr: $l1_hint, OptionalAttr: $l2_hint, OptionalAttr: $l3_hint); @@ -298,7 +313,18 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> { } }]; - let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))"; + let assemblyFormat = [{ + $TensorDesc `` + custom($offsets, $const_offsets) + prop-dict attr-dict `:` qualified(type($TensorDesc)) + }]; + + let builders = [ + OpBuilder<(ins "Value": $TensorDesc, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint)> + ]; let hasVerifier = 1; } @@ -343,6 +369,8 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [ }]; let arguments = (ins XeGPU_TensorDesc: $TensorDesc, + Variadic: $offsets, + OptionalAttr: $const_offsets, OptionalAttr: $packed, OptionalAttr: $transpose, OptionalAttr: $l1_hint, @@ -361,7 +389,20 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [ } }]; - let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value)"; + let assemblyFormat = [{ + $TensorDesc `` + custom($offsets, $const_offsets) + prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value) + }]; + + let builders = [ + OpBuilder<(ins "Type": $value, "Value": $TensorDesc, + "UnitAttr": $packed, "DenseI64ArrayAttr": $transpose, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint)> + ]; + let hasVerifier = 1; } @@ -400,6 +441,8 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [ let arguments = (ins XeGPU_ValueType: $value, XeGPU_TensorDesc: $TensorDesc, + Variadic: $offsets, + OptionalAttr: $const_offsets, OptionalAttr: $l1_hint, OptionalAttr: $l2_hint, OptionalAttr: $l3_hint); @@ -414,8 +457,21 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [ } }]; - let assemblyFormat = [{$value `,` $TensorDesc prop-dict attr-dict - `:` type($value) `,` qualified(type($TensorDesc))}]; + let assemblyFormat = [{ + $value `,` + $TensorDesc `` + custom($offsets, $const_offsets) + prop-dict attr-dict `:` type($value) `,` qualified(type($TensorDesc)) + }]; + + let builders = [ + OpBuilder<(ins "Value": $value, "Value": $TensorDesc, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint)> + ]; + + let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 78cbf884a1911..ca3c92cf4b52c 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -331,16 +331,24 @@ ParseResult parseOptionalDynamicIndexList( void printOptionalDynamicIndexList( OpAsmPrinter &printer, Operation *op, OperandRange values, - ArrayRef integers, TypeRange valueTypes = TypeRange(), - AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { + DenseI64ArrayAttr integers) { - return printDynamicIndexList(printer, op, values, integers, - /*scalableFlags=*/{}, valueTypes, delimiter); -} + if (!integers) + return; + return printDynamicIndexList(printer, op, values, integers, + /*scalableFlags=*/{}, {}, AsmParser::Delimiter::Square); + } //===----------------------------------------------------------------------===// // XeGPU_PrefetchNdOp //===----------------------------------------------------------------------===// + +void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, Value tensorDesc, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { + + return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint); + +} + LogicalResult PrefetchNdOp::verify() { auto tdescTy = getTensorDescType(); if (tdescTy.isScattered()) @@ -361,6 +369,13 @@ LogicalResult PrefetchNdOp::verify() { //===----------------------------------------------------------------------===// // XeGPU_LoadNdOp //===----------------------------------------------------------------------===// + +void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, Value tensorDesc, UnitAttr packed, DenseI64ArrayAttr transpose, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { + + return build(builder, state, retType, tensorDesc, ValueRange(), DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint, l3_hint); + +} + LogicalResult LoadNdOp::verify() { auto tdescTy = getTensorDescType(); auto valueTy = getType(); @@ -448,6 +463,13 @@ LogicalResult LoadNdOp::verify() { //===----------------------------------------------------------------------===// // XeGPU_StoreNdOp //===----------------------------------------------------------------------===// + +void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, Value tensorDesc, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { + + return build(builder, state, value, tensorDesc, ValueRange(), DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint); + +} + LogicalResult StoreNdOp::verify() { auto dstTy = getTensorDescType(); // Tile auto valTy = getValueType(); // Vector diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index 695437354cd7c..a1028a8e8a2f3 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -121,6 +121,15 @@ gpu.func @prefetch_nd_2(%src: memref<8x24x32x48x64xf16>) { gpu.return } +// CHECK: gpu.func @prefetch_nd_offset_1(%[[arg0:.*]]: memref<8x24x32x48x64xf16>) { +gpu.func @prefetch_nd_offset_1(%src: memref<8x24x32x48x64xf16>) { + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16> + %1 = xegpu.create_nd_tdesc %src[0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16> + // CHECK: xegpu.prefetch_nd %[[R0]][0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<1x2x4x8x16xf16> + xegpu.prefetch_nd %1[0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: !xegpu.tensor_desc<1x2x4x8x16xf16> + gpu.return +} + // CHECK: func @subgroup_load_nd(%[[arg0:.*]]: memref<8x16xf16>) { gpu.func @subgroup_load_nd(%src: memref<8x16xf16>) { // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> @@ -260,6 +269,15 @@ gpu.func @subgroup_load_nd_8(%src: memref<24x32xf32>) { gpu.return } +// CHECK: func @subgroup_load_nd_offset_1(%[[arg0:.*]]: memref<24x32xf32>) { +gpu.func @subgroup_load_nd_offset_1(%src: memref<24x32xf32>) { + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32> + %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32> + // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]][0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose = array}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32> + %2 = xegpu.load_nd %1[0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose = array}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32> + gpu.return +} + // CHECK: func @simt_load_nd_8(%[[arg0:.*]]: memref<24x32xf32>) { gpu.func @simt_load_nd_8(%src: memref<24x32xf32>) { // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32> @@ -269,6 +287,16 @@ gpu.func @simt_load_nd_8(%src: memref<24x32xf32>) { gpu.return } + +// CHECK: func @simt_load_nd_offset_1(%[[arg0:.*]]: memref<24x32xf32>) { +gpu.func @simt_load_nd_offset_1(%src: memref<24x32xf32>) { + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32> + %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32> + // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]][0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose = array}> : !xegpu.tensor_desc<16x8xf32> -> vector<8xf32> + %2 = xegpu.load_nd %1[0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose = array}> : !xegpu.tensor_desc<16x8xf32> -> vector<8xf32> + gpu.return +} + // CHECK: func @subgroup_store_nd(%[[arg0:.*]]: memref<24x32xf16>) { gpu.func @subgroup_store_nd(%dst: memref<24x32xf16>) { // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16> @@ -293,6 +321,17 @@ gpu.func @simt_store_nd(%src: memref<24x32xf16>) { // CHECK: func @subgroup_store_nd_2(%[[arg0:.*]]: memref<24x32xf16>) { gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>) { + // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16> + %1 = arith.constant dense<1.0>: vector<32xf16> + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16> + %2 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16> + // CHECK: xegpu.store_nd %[[C]], %[[R0]][0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16> + xegpu.store_nd %1, %2[0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: vector<32xf16>, !xegpu.tensor_desc<32xf16> + gpu.return +} + +// CHECK: func @subgroup_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) { +gpu.func @subgroup_store_nd_offset_1(%dst: memref<24x32xf16>) { // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16> %1 = arith.constant dense<1.0>: vector<32xf16> // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16> @@ -313,6 +352,17 @@ gpu.func @simt_store_nd_2(%src: memref<24x32xf16>) { gpu.return } +// CHECK: func @simt_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) { +gpu.func @simt_store_nd_offset_1(%src: memref<24x32xf16>) { + // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16> + %1 = arith.constant dense<1.0>: vector<2xf16> + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16> + %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16> + // CHECK: xegpu.store_nd %[[C]], %[[R0]][0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<2xf16>, !xegpu.tensor_desc<32xf16> + xegpu.store_nd %1, %2[0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: vector<2xf16>, !xegpu.tensor_desc<32xf16> + gpu.return +} + // CHECK: gpu.func @update_nd_tdesc(%[[arg0:.*]]: memref<24x32xf32>) { gpu.func @update_nd_tdesc(%src: memref<24x32xf32>) { // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> From 30ff640a8d2c59d31effbb9828f1775032564c57 Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Thu, 17 Jul 2025 23:22:18 +0000 Subject: [PATCH 02/16] git-clang-format --- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 47 ++++++++++++++++---------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index ca3c92cf4b52c..7cb105bf4292d 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -329,24 +329,28 @@ ParseResult parseOptionalDynamicIndexList( return success(); } -void printOptionalDynamicIndexList( - OpAsmPrinter &printer, Operation *op, OperandRange values, - DenseI64ArrayAttr integers) { +void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op, + OperandRange values, + DenseI64ArrayAttr integers) { - if (!integers) - return; + if (!integers) + return; - return printDynamicIndexList(printer, op, values, integers, - /*scalableFlags=*/{}, {}, AsmParser::Delimiter::Square); - } + return printDynamicIndexList(printer, op, values, integers, + /*scalableFlags=*/{}, {}, + AsmParser::Delimiter::Square); +} //===----------------------------------------------------------------------===// // XeGPU_PrefetchNdOp //===----------------------------------------------------------------------===// -void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, Value tensorDesc, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { - - return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint); +void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, + Value tensorDesc, xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(), + l1_hint, l2_hint, l3_hint); } LogicalResult PrefetchNdOp::verify() { @@ -370,10 +374,16 @@ LogicalResult PrefetchNdOp::verify() { // XeGPU_LoadNdOp //===----------------------------------------------------------------------===// -void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, Value tensorDesc, UnitAttr packed, DenseI64ArrayAttr transpose, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { - - return build(builder, state, retType, tensorDesc, ValueRange(), DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint, l3_hint); +void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, + Value tensorDesc, UnitAttr packed, + DenseI64ArrayAttr transpose, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + return build(builder, state, retType, tensorDesc, ValueRange(), + DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint, + l3_hint); } LogicalResult LoadNdOp::verify() { @@ -464,10 +474,13 @@ LogicalResult LoadNdOp::verify() { // XeGPU_StoreNdOp //===----------------------------------------------------------------------===// -void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, Value tensorDesc, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { - - return build(builder, state, value, tensorDesc, ValueRange(), DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint); +void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, + Value tensorDesc, xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + return build(builder, state, value, tensorDesc, ValueRange(), + DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint); } LogicalResult StoreNdOp::verify() { From efd1661b4ea93a08776213504219b96871449507 Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Fri, 18 Jul 2025 01:48:18 +0000 Subject: [PATCH 03/16] add optional offsets to load_gather --- mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 15 ++++++++++----- mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 1 + mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 2 +- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index e9f8437d7c102..31c2fb357371a 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -655,7 +655,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> { } def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ - AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemRead]> + AllElementTypesMatch<["value", "source"]>, MemoryEffects<[MemRead]> ]> { let summary = "load a set of scattered data points from memory."; @@ -698,7 +698,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ }]; - let arguments = (ins XeGPU_TensorDesc: $TensorDesc, + let arguments = (ins XeGPU_TensorDesc_or_MemRef: $source, XeGPU_MaskType: $mask, OptionalAttr: $l1_hint, OptionalAttr: $l2_hint, @@ -706,8 +706,13 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ let results = (outs XeGPU_ValueType: $value); let extraClassDeclaration = extraBaseClassDeclaration # [{ + + Type getSourceType() { + return getSource().getType(); + } + xegpu::TensorDescType getTensorDescType() { - return getTensorDesc().getType(); + return dyn_cast(getSourceType()); } mlir::Type getElementType() { @@ -725,8 +730,8 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ }]; - let assemblyFormat = [{$TensorDesc `,` $mask prop-dict attr-dict - `:` qualified(type($TensorDesc)) `,` type($mask) `->` type($value)}]; + let assemblyFormat = [{$source `,` $mask prop-dict attr-dict + `:` qualified(type($source)) `,` type($mask) `->` type($value)}]; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index 277158ac85409..ac41907655122 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -203,6 +203,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", let genVerifyDecl = 1; } +def XeGPU_TensorDesc_or_MemRef : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>]>; def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> { let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier."; diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index dc76441b27c02..44f2364d0caec 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -486,7 +486,7 @@ struct UnrollLoadGatherOp : public UnrollPattern { SmallVector convertedTdescTypes = getUnrolledTypes(tdescTy, *targetShape); SmallVector convertedTdescs = pack( - op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); + op.getSource(), convertedTdescTypes, *targetShape, loc, rewriter); SmallVector convertedMaskTypes; SmallVector convertedMasks; From 3578c1b96fd22a1e013bc70f987ac4bfb6849c10 Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Thu, 17 Jul 2025 23:21:14 +0000 Subject: [PATCH 04/16] add optional offsets to nd load/store/prefetch --- mlir/test/Dialect/XeGPU/ops.mlir | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index 3ebb1b969ac74..3523e3083c168 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -130,6 +130,15 @@ gpu.func @prefetch_nd_offset_1(%src: memref<48x64xf16>, %x : index, %y : index) gpu.return } +// CHECK: gpu.func @prefetch_nd_offset_1(%[[arg0:.*]]: memref<8x24x32x48x64xf16>) { +gpu.func @prefetch_nd_offset_1(%src: memref<8x24x32x48x64xf16>) { + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16> + %1 = xegpu.create_nd_tdesc %src[0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16> + // CHECK: xegpu.prefetch_nd %[[R0]][0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<1x2x4x8x16xf16> + xegpu.prefetch_nd %1[0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: !xegpu.tensor_desc<1x2x4x8x16xf16> + gpu.return +} + // CHECK: func @subgroup_load_nd(%[[arg0:.*]]: memref<8x16xf16>) { gpu.func @subgroup_load_nd(%src: memref<8x16xf16>) { // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> @@ -330,6 +339,17 @@ gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>, %x : index) { gpu.return } +// CHECK: func @subgroup_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) { +gpu.func @subgroup_store_nd_offset_1(%dst: memref<24x32xf16>) { + // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16> + %1 = arith.constant dense<1.0>: vector<32xf16> + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16> + %2 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16> + // CHECK: xegpu.store_nd %[[C]], %[[R0]][0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16> + xegpu.store_nd %1, %2[0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: vector<32xf16>, !xegpu.tensor_desc<32xf16> + gpu.return +} + // CHECK: func @subgroup_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) { gpu.func @subgroup_store_nd_offset_1(%dst: memref<24x32xf16>) { // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16> From 59f7ea9bf601205496f5867ddf1445e1f51641fd Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Fri, 18 Jul 2025 01:48:18 +0000 Subject: [PATCH 05/16] add optional offsets to load_gather --- mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 15 ++++++++++----- mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 1 + mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 2 +- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 91d6b2a5ead9b..a5a7dab1bf55a 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -655,7 +655,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> { } def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ - AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemRead]> + AllElementTypesMatch<["value", "source"]>, MemoryEffects<[MemRead]> ]> { let summary = "load a set of scattered data points from memory."; @@ -698,7 +698,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ }]; - let arguments = (ins XeGPU_TensorDesc: $TensorDesc, + let arguments = (ins XeGPU_TensorDesc_or_MemRef: $source, XeGPU_MaskType: $mask, OptionalAttr: $l1_hint, OptionalAttr: $l2_hint, @@ -706,8 +706,13 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ let results = (outs XeGPU_ValueType: $value); let extraClassDeclaration = extraBaseClassDeclaration # [{ + + Type getSourceType() { + return getSource().getType(); + } + xegpu::TensorDescType getTensorDescType() { - return getTensorDesc().getType(); + return dyn_cast(getSourceType()); } mlir::Type getElementType() { @@ -725,8 +730,8 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ }]; - let assemblyFormat = [{$TensorDesc `,` $mask prop-dict attr-dict - `:` qualified(type($TensorDesc)) `,` type($mask) `->` type($value)}]; + let assemblyFormat = [{$source `,` $mask prop-dict attr-dict + `:` qualified(type($source)) `,` type($mask) `->` type($value)}]; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index 20916ae9ef830..334f749ace745 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -189,6 +189,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", let genVerifyDecl = 1; } +def XeGPU_TensorDesc_or_MemRef : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>]>; def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> { let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier."; diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index a6208b455aa35..c8f332184bd1b 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -502,7 +502,7 @@ struct UnrollLoadGatherOp : public UnrollPattern { SmallVector convertedTdescTypes = getUnrolledTypes(tdescTy, *targetShape); SmallVector convertedTdescs = pack( - op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); + op.getSource(), convertedTdescTypes, *targetShape, loc, rewriter); SmallVector convertedMaskTypes; SmallVector convertedMasks; From abc84c759f0993d8aa699ba1b87fdad7c5760a69 Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Mon, 21 Jul 2025 04:18:01 +0000 Subject: [PATCH 06/16] add offsets to load --- mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 18 ++++++++++++++++-- .../mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 2 +- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index a5a7dab1bf55a..51356e963e778 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -699,6 +699,8 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ }]; let arguments = (ins XeGPU_TensorDesc_or_MemRef: $source, + Variadic: $offsets, + OptionalAttr: $const_offsets, XeGPU_MaskType: $mask, OptionalAttr: $l1_hint, OptionalAttr: $l2_hint, @@ -730,8 +732,20 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ }]; - let assemblyFormat = [{$source `,` $mask prop-dict attr-dict - `:` qualified(type($source)) `,` type($mask) `->` type($value)}]; + let assemblyFormat = [{ + $source `,` + custom($offsets, $const_offsets) + $mask prop-dict + attr-dict `:` qualified(type($source)) `,` type($mask) `->` type($value) + }]; + +// let builders = [ +// OpBuilder<(ins "Type": $value, "Value": $TensorDesc, +// "UnitAttr": $packed, "DenseI64ArrayAttr": $transpose, +// "xegpu::CachePolicyAttr": $l1_hint, +// "xegpu::CachePolicyAttr": $l2_hint, +// "xegpu::CachePolicyAttr": $l3_hint)> +// ]; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index 334f749ace745..8e575e31255a7 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -189,7 +189,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", let genVerifyDecl = 1; } -def XeGPU_TensorDesc_or_MemRef : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>]>; +def XeGPU_TensorDesc_or_MemRef : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>; def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> { let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier."; From 80b4462f48c164309f70b1a6dbeb5805d869c998 Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Thu, 24 Jul 2025 02:58:15 +0000 Subject: [PATCH 07/16] add chunk_size and use XeGPU_offsetType --- .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 84 +++++++++++++++---- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 27 +++++- .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 2 +- mlir/test/Dialect/XeGPU/ops.mlir | 33 +++----- 4 files changed, 105 insertions(+), 41 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 51356e963e778..bf036d86d14bb 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -16,6 +16,7 @@ include "mlir/Dialect/XeGPU/IR/XeGPUTypes.td" include "mlir/Interfaces/ShapedOpInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" +include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td" // Base class for dialect operations. This operation inherits from the base // `Op` class in OpBase.td, and provides: @@ -638,18 +639,39 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> { }]; - let arguments = (ins XeGPU_TensorDesc: $TensorDesc, + let arguments = (ins XeGPU_TensorDesc_or_MemRef: $source, + Optional: $offsets, OptionalAttr: $l1_hint, OptionalAttr: $l2_hint, OptionalAttr: $l3_hint); let extraClassDeclaration = extraBaseClassDeclaration # [{ + Type getSourceType() { + return getSource().getType(); + } + + Value getTensorDesc() { + return getSource(); + } + xegpu::TensorDescType getTensorDescType() { - return getTensorDesc().getType(); + return dyn_cast(getSourceType()); } }]; - let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))"; + let assemblyFormat = [{ + $source + (`,` $offsets^)? + prop-dict + attr-dict `:` type($source) (`,` type($offsets)^)? + }]; + + let builders = [ + OpBuilder<(ins "Value": $source, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint)> + ]; let hasVerifier = 1; } @@ -702,6 +724,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ Variadic: $offsets, OptionalAttr: $const_offsets, XeGPU_MaskType: $mask, + OptionalAttr: $chunk_size, OptionalAttr: $l1_hint, OptionalAttr: $l2_hint, OptionalAttr: $l3_hint); @@ -713,6 +736,10 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ return getSource().getType(); } + Value getTensorDesc() { + return getSource(); + } + xegpu::TensorDescType getTensorDescType() { return dyn_cast(getSourceType()); } @@ -733,25 +760,24 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ }]; let assemblyFormat = [{ - $source `,` - custom($offsets, $const_offsets) + $source `` + custom($offsets, $const_offsets) `,` $mask prop-dict attr-dict `:` qualified(type($source)) `,` type($mask) `->` type($value) }]; -// let builders = [ -// OpBuilder<(ins "Type": $value, "Value": $TensorDesc, -// "UnitAttr": $packed, "DenseI64ArrayAttr": $transpose, -// "xegpu::CachePolicyAttr": $l1_hint, -// "xegpu::CachePolicyAttr": $l2_hint, -// "xegpu::CachePolicyAttr": $l3_hint)> -// ]; + let builders = [ + OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint)> + ]; let hasVerifier = 1; } def XeGPU_StoreScatterOp : XeGPU_Op<"store", [ - AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemWrite]> + AllElementTypesMatch<["value", "dest"]>, MemoryEffects<[MemWrite]> ]> { let summary = "store data to scattered memory locations."; let description = [{ It (aka. store) stores data to scattered memory locations. The value is @@ -791,15 +817,26 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [ let arguments = (ins XeGPU_ValueType: $value, - XeGPU_TensorDesc: $TensorDesc, + XeGPU_TensorDesc_or_MemRef: $dest, + Variadic: $offsets, + OptionalAttr: $const_offsets, XeGPU_MaskType: $mask, + OptionalAttr: $chunk_size, OptionalAttr: $l1_hint, OptionalAttr: $l2_hint, OptionalAttr: $l3_hint); let extraClassDeclaration = extraBaseClassDeclaration # [{ + Type getDestType() { + return getDest().getType(); + } + + Value getTensorDesc() { + return getDest(); + } + xegpu::TensorDescType getTensorDescType() { - return getTensorDesc().getType(); + return dyn_cast(getDestType()); } VectorType getValueType() { @@ -811,8 +848,21 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [ } }]; - let assemblyFormat = [{$value `,` $TensorDesc `,` $mask prop-dict attr-dict - `:` type($value) `,` qualified(type($TensorDesc)) `,` type($mask)}]; + let assemblyFormat = [{ + $value `,` + $dest `` + custom($offsets, $const_offsets) `,` + $mask + prop-dict + attr-dict `:` type($value) `,` qualified(type($dest)) `,` type($mask) + }]; + + let builders = [ + OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint)> + ]; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 704deeaa1f26b..4f3b3ed475afc 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -644,7 +644,7 @@ LogicalResult CreateDescOp::verify() { //===----------------------------------------------------------------------===// LogicalResult PrefetchOp::verify() { auto tdescTy = getTensorDescType(); - if (!tdescTy.isScattered()) + if (tdescTy && !tdescTy.isScattered()) return emitOpError("Expects a scattered TensorDesc.\n"); if (!isReadHintOrNone(getL1HintAttr())) @@ -659,6 +659,13 @@ LogicalResult PrefetchOp::verify() { return success(); } +void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint); +} + //===----------------------------------------------------------------------===// // XeGPU_LoadGatherOp //===----------------------------------------------------------------------===// @@ -680,6 +687,15 @@ LogicalResult LoadGatherOp::verify() { [&]() { return emitOpError(); }); } +void LoadGatherOp::build(OpBuilder &builder, OperationState &state, + Type valueType, Value source, Value mask, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + build(builder, state, valueType, source, ValueRange(), DenseI64ArrayAttr(), + mask, IntegerAttr(), l1_hint, l2_hint, l3_hint); +} + //===----------------------------------------------------------------------===// // XeGPU_StoreScatterOp //===----------------------------------------------------------------------===// @@ -701,6 +717,15 @@ LogicalResult StoreScatterOp::verify() { [&]() { return emitOpError(); }); } +void StoreScatterOp::build(OpBuilder &builder, OperationState &state, + Value value, Value dest, Value mask, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + build(builder, state, value, dest, ValueRange(), DenseI64ArrayAttr(), mask, + IntegerAttr(), l1_hint, l2_hint, l3_hint); +} + //===----------------------------------------------------------------------===// // XeGPU_UpdateOffsetOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index c8f332184bd1b..a6208b455aa35 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -502,7 +502,7 @@ struct UnrollLoadGatherOp : public UnrollPattern { SmallVector convertedTdescTypes = getUnrolledTypes(tdescTy, *targetShape); SmallVector convertedTdescs = pack( - op.getSource(), convertedTdescTypes, *targetShape, loc, rewriter); + op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); SmallVector convertedMaskTypes; SmallVector convertedMasks; diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index 3523e3083c168..98836adaa57a3 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -130,15 +130,6 @@ gpu.func @prefetch_nd_offset_1(%src: memref<48x64xf16>, %x : index, %y : index) gpu.return } -// CHECK: gpu.func @prefetch_nd_offset_1(%[[arg0:.*]]: memref<8x24x32x48x64xf16>) { -gpu.func @prefetch_nd_offset_1(%src: memref<8x24x32x48x64xf16>) { - // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16> - %1 = xegpu.create_nd_tdesc %src[0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16> - // CHECK: xegpu.prefetch_nd %[[R0]][0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<1x2x4x8x16xf16> - xegpu.prefetch_nd %1[0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: !xegpu.tensor_desc<1x2x4x8x16xf16> - gpu.return -} - // CHECK: func @subgroup_load_nd(%[[arg0:.*]]: memref<8x16xf16>) { gpu.func @subgroup_load_nd(%src: memref<8x16xf16>) { // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> @@ -339,19 +330,8 @@ gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>, %x : index) { gpu.return } -// CHECK: func @subgroup_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) { -gpu.func @subgroup_store_nd_offset_1(%dst: memref<24x32xf16>) { - // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16> - %1 = arith.constant dense<1.0>: vector<32xf16> - // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16> - %2 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16> - // CHECK: xegpu.store_nd %[[C]], %[[R0]][0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16> - xegpu.store_nd %1, %2[0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: vector<32xf16>, !xegpu.tensor_desc<32xf16> - gpu.return -} - -// CHECK: func @subgroup_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) { -gpu.func @subgroup_store_nd_offset_1(%dst: memref<24x32xf16>) { +// CHECK: func @subgroup_store_nd_3(%[[arg0:.*]]: memref<24x32xf16>) { +gpu.func @subgroup_store_nd_3(%dst: memref<24x32xf16>) { // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16> %1 = arith.constant dense<1.0>: vector<32xf16> // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16> @@ -658,6 +638,15 @@ gpu.func @prefetch(%src: ui64) { } +// CHECK: gpu.func @prefetch_offset(%[[arg0:.*]]: ui64) { +gpu.func @prefetch_offset(%src: ui64) { + //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + // CHECK: xegpu.prefetch %[[arg0]], %cst <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : ui64, vector<4xindex> + xegpu.prefetch %src, %0 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: ui64, vector<4xindex> + gpu.return +} + // CHECK: gpu.func @create_update_tdesc(%[[arg0:.*]]: ui64) { gpu.func @create_update_tdesc(%src: ui64) { //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> From 769bf19a3775cc8ceacbe9077371c4c712c9f493 Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Fri, 25 Jul 2025 02:33:06 +0000 Subject: [PATCH 08/16] add tests --- .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 31 +++--- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 97 +++++++++++++++++-- mlir/test/Dialect/XeGPU/invalid.mlir | 22 +++++ mlir/test/Dialect/XeGPU/ops.mlir | 40 ++++---- 4 files changed, 147 insertions(+), 43 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index bf036d86d14bb..c6b192a9dda31 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -676,9 +676,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> { let hasVerifier = 1; } -def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ - AllElementTypesMatch<["value", "source"]>, MemoryEffects<[MemRead]> - ]> { +def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> { let summary = "load a set of scattered data points from memory."; let description = [{ It (aka. load) load data per each work-item. The output @@ -721,8 +719,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ }]; let arguments = (ins XeGPU_TensorDesc_or_MemRef: $source, - Variadic: $offsets, - OptionalAttr: $const_offsets, + Optional: $offsets, XeGPU_MaskType: $mask, OptionalAttr: $chunk_size, OptionalAttr: $l1_hint, @@ -760,11 +757,15 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ }]; let assemblyFormat = [{ - $source `` - custom($offsets, $const_offsets) `,` + $source + (`[` $offsets^ `]`)? `,` $mask prop-dict - attr-dict `:` qualified(type($source)) `,` type($mask) `->` type($value) + attr-dict `:` type(operands) `->` type($value) }]; + + // functional-type(operands, results) + // type($source) (type($offsets)^ )? `,` type($mask) `->` type($value) + let builders = [ OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask, @@ -776,9 +777,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ let hasVerifier = 1; } -def XeGPU_StoreScatterOp : XeGPU_Op<"store", [ - AllElementTypesMatch<["value", "dest"]>, MemoryEffects<[MemWrite]> - ]> { +def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> { let summary = "store data to scattered memory locations."; let description = [{ It (aka. store) stores data to scattered memory locations. The value is typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be @@ -818,8 +817,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [ let arguments = (ins XeGPU_ValueType: $value, XeGPU_TensorDesc_or_MemRef: $dest, - Variadic: $offsets, - OptionalAttr: $const_offsets, + Optional: $offsets, XeGPU_MaskType: $mask, OptionalAttr: $chunk_size, OptionalAttr: $l1_hint, @@ -850,12 +848,13 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [ let assemblyFormat = [{ $value `,` - $dest `` - custom($offsets, $const_offsets) `,` + $dest + (`[` $offsets^ `]`)? `,` $mask prop-dict - attr-dict `:` type($value) `,` qualified(type($dest)) `,` type($mask) + attr-dict `:` type(operands) }]; +// type($value) `,` qualified(type($dest)) (type($offsets)^)? `,` type($mask) let builders = [ OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask, diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 4f3b3ed475afc..7a32f1a45c762 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -110,6 +110,66 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy, return success(); } +static LogicalResult +isValidGatherScatterMemRefParams(Type maskTy, VectorType valueTy, + MemRefType memTy, int64_t chunkSize, + function_ref emitError) { + + if (!valueTy) + return emitError() << "Expecting a vector type result."; + + auto maskShape = getShapeOf(maskTy); + auto valueShape = getShapeOf(valueTy); + auto memShape = getShapeOf(memTy); + + if (valueTy.getElementType() != memTy.getElementType()) + return emitError() << "Value should have the same element type as MemRef."; + + // a valid shape for SIMT case + if (valueTy.getRank() == 1) { + if (valueTy.getNumElements() != chunkSize) + return emitError() << "value elements must match chunk size " << chunkSize + << " for SIMT code."; + return success(); + } + + llvm::SmallVector expectedMaskShape(valueShape); + if (chunkSize > 1) + expectedMaskShape.pop_back(); + if (expectedMaskShape != maskShape) + return emitError() << "Mask should match value except the chunk size dim."; + + return success(); +} + +static LogicalResult +isValidGatherScatterRawptrParams(Type maskTy, VectorType valueTy, + int64_t chunkSize, + function_ref emitError) { + + if (!valueTy) + return emitError() << "Expecting a vector type result."; + + auto maskShape = getShapeOf(maskTy); + auto valueShape = getShapeOf(valueTy); + + // a valid shape for SIMT case + if (valueTy.getRank() == 1) { + if (valueTy.getNumElements() != chunkSize) + return emitError() << "value elements must match chunk size " << chunkSize + << " for SIMT code."; + return success(); + } + + llvm::SmallVector expectedMaskShape(valueShape); + if (chunkSize > 1) + expectedMaskShape.pop_back(); + if (expectedMaskShape != maskShape) + return emitError() << "Mask should match value except the chunk size dim."; + + return success(); +} + //===----------------------------------------------------------------------===// // XeGPU_CreateNdDescOp //===----------------------------------------------------------------------===// @@ -683,8 +743,18 @@ LogicalResult LoadGatherOp::verify() { if (!isReadHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); - return isValidGatherScatterParams(maskTy, valueTy, tdescTy, - [&]() { return emitOpError(); }); + if (tdescTy) + return isValidGatherScatterParams(maskTy, valueTy, tdescTy, + [&]() { return emitOpError(); }); + auto srcTy = getSourceType(); + uint64_t chunkSize = static_cast(getChunkSize().value_or(1)); + auto memTy = dyn_cast(srcTy); + + if (memTy) + return isValidGatherScatterMemRefParams(maskTy, valueTy, memTy, chunkSize, + [&]() { return emitOpError(); }); + return isValidGatherScatterRawptrParams(maskTy, valueTy, chunkSize, + [&]() { return emitOpError(); }); } void LoadGatherOp::build(OpBuilder &builder, OperationState &state, @@ -692,8 +762,8 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { - build(builder, state, valueType, source, ValueRange(), DenseI64ArrayAttr(), - mask, IntegerAttr(), l1_hint, l2_hint, l3_hint); + build(builder, state, valueType, source, Value(), mask, IntegerAttr(), + l1_hint, l2_hint, l3_hint); } //===----------------------------------------------------------------------===// @@ -713,8 +783,19 @@ LogicalResult StoreScatterOp::verify() { if (!isWriteHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); - return isValidGatherScatterParams(maskTy, valueTy, tdescTy, - [&]() { return emitOpError(); }); + if (tdescTy) + return isValidGatherScatterParams(maskTy, valueTy, tdescTy, + [&]() { return emitOpError(); }); + + auto destTy = getDestType(); + uint64_t chunkSize = static_cast(getChunkSize().value_or(1)); + auto memTy = dyn_cast(destTy); + + if (memTy) + return isValidGatherScatterMemRefParams(maskTy, valueTy, memTy, chunkSize, + [&]() { return emitOpError(); }); + return isValidGatherScatterRawptrParams(maskTy, valueTy, chunkSize, + [&]() { return emitOpError(); }); } void StoreScatterOp::build(OpBuilder &builder, OperationState &state, @@ -722,8 +803,8 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { - build(builder, state, value, dest, ValueRange(), DenseI64ArrayAttr(), mask, - IntegerAttr(), l1_hint, l2_hint, l3_hint); + build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint, + l2_hint, l3_hint); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 0160bfee07bf2..af34add37f7ad 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -384,6 +384,28 @@ func.func @load_gather_vc_3(%src: ui64) { return } +// ----- +func.func @load_offset(%src: ui64) { + %offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %mask = arith.constant dense<1>: vector<8xi1> + // expected-error@+1 {{Mask should match value except the chunk size dim}} + %2 = xegpu.load %src[%offsets], %mask + : ui64, vector<4xindex>, vector<8xi1> + -> vector<4x2xf32> + return +} + +// ----- +func.func @store_offset(%src: ui64) { + %val = arith.constant dense<2.9>: vector<4x2xf16> + %offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %mask = arith.constant dense<1>: vector<8xi1> + // expected-error@+1 {{Mask should match value except the chunk size dim}} + xegpu.store %val, %src[%offsets], %mask + : vector<4x2xf16>, ui64, vector<4xindex>, vector<8xi1> + return +} + // ----- func.func @store_scatter_vc_1(%src: memref<24x32xf32>) { %0 = arith.constant dense<1>: vector<4xi1> diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index f8c558d614ee6..16f5356a69f24 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -130,15 +130,6 @@ gpu.func @prefetch_nd_offset_1(%src: memref<48x64xf16>, %x : index, %y : index) gpu.return } -// CHECK: gpu.func @prefetch_nd_offset_1(%[[arg0:.*]]: memref<8x24x32x48x64xf16>) { -gpu.func @prefetch_nd_offset_1(%src: memref<8x24x32x48x64xf16>) { - // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16> - %1 = xegpu.create_nd_tdesc %src[0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16> - // CHECK: xegpu.prefetch_nd %[[R0]][0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<1x2x4x8x16xf16> - xegpu.prefetch_nd %1[0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: !xegpu.tensor_desc<1x2x4x8x16xf16> - gpu.return -} - // CHECK: func @subgroup_load_nd(%[[arg0:.*]]: memref<8x16xf16>) { gpu.func @subgroup_load_nd(%src: memref<8x16xf16>) { // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> @@ -339,16 +330,6 @@ gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>, %x : index) { gpu.return } -// CHECK: func @subgroup_store_nd_3(%[[arg0:.*]]: memref<24x32xf16>) { -gpu.func @subgroup_store_nd_3(%dst: memref<24x32xf16>) { - // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16> - %1 = arith.constant dense<1.0>: vector<32xf16> - // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16> - %2 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16> - // CHECK: xegpu.store_nd %[[C]], %[[R0]][0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16> - xegpu.store_nd %1, %2[0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: vector<32xf16>, !xegpu.tensor_desc<32xf16> - gpu.return -} // CHECK: func @subgroup_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) { gpu.func @subgroup_store_nd_offset_1(%dst: memref<24x32xf16>) { @@ -541,6 +522,16 @@ gpu.func @subgroup_load_4(%src: ui64) { gpu.return } +// CHECK: gpu.func @subgroup_load_offset_1(%arg0: memref) { +gpu.func @subgroup_load_offset_1(%src: memref) { + %offset = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %mask = arith.constant dense<1>: vector<4xi1> + //CHECK: %[[R1:.*]] = xegpu.load %arg0[%cst], %cst_0 <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint}> : memref, vector<4xindex>, vector<4xi1> -> vector<4x2xf16> + %val = xegpu.load %src[%offset], %mask <{chunk_size=2, l1_hint = #xegpu.cache_hint}> + : memref, vector<4xindex>, vector<4xi1> -> vector<4x2xf16> + gpu.return +} + // CHECK: gpu.func @subgroup_store(%[[arg0:.*]]: ui64) { gpu.func @subgroup_store(%src: ui64) { //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> @@ -646,6 +637,17 @@ gpu.func @subgroup_store_4(%src: ui64) { gpu.return } +// CHECK: gpu.func @subgroup_store_offset_1(%arg0: memref) { +gpu.func @subgroup_store_offset_1(%dest: memref) { + %val = arith.constant dense<2.9>: vector<4x2xf16> + %offset = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %mask = arith.constant dense<1>: vector<4xi1> + //CHECK: xegpu.store %[[R0:.*]], %arg0[%cst_0], %cst_1 <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint}> : vector<4x2xf16>, memref, vector<4xindex>, vector<4xi1> + xegpu.store %val, %dest[%offset], %mask <{chunk_size=2, l1_hint = #xegpu.cache_hint}> + : vector<4x2xf16>, memref, vector<4xindex>, vector<4xi1> + gpu.return +} + // CHECK: gpu.func @prefetch(%[[arg0:.*]]: ui64) { gpu.func @prefetch(%src: ui64) { //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> From 45537469eddccf41eabf689b19f550d8616442de Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Fri, 25 Jul 2025 05:15:34 +0000 Subject: [PATCH 09/16] add invalid tests --- .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 7 ++-- .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 6 +-- mlir/test/Dialect/XeGPU/invalid.mlir | 39 ++++++++++++++----- mlir/test/Dialect/XeGPU/ops.mlir | 4 +- 4 files changed, 39 insertions(+), 17 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index c6b192a9dda31..312db1402f58f 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -661,11 +661,12 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> { let assemblyFormat = [{ $source - (`,` $offsets^)? + (`[` $offsets^ `]`)? prop-dict - attr-dict `:` type($source) (`,` type($offsets)^)? + attr-dict `:` type(operands) }]; - + // type($source) (type($offsets)^)? + let builders = [ OpBuilder<(ins "Value": $source, "xegpu::CachePolicyAttr": $l1_hint, diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index c8f332184bd1b..cafbf8d5ffc5e 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -484,7 +484,7 @@ struct UnrollLoadGatherOp : public UnrollPattern { VectorType valueTy = llvm::dyn_cast(op.getValue().getType()); xegpu::TensorDescType tdescTy = op.getTensorDescType(); - if (!tdescTy.isScattered()) + if (!tdescTy || !tdescTy.isScattered()) return failure(); std::optional> targetShape = getTargetShape(op); @@ -546,7 +546,7 @@ struct UnrollPrefetchOp : public UnrollPattern { Location loc = op.getLoc(); xegpu::TensorDescType tdescTy = op.getTensorDescType(); - if (!tdescTy.isScattered()) + if (!tdescTy || !tdescTy.isScattered()) return failure(); std::optional> targetShape = getTargetShape(op); @@ -575,7 +575,7 @@ struct UnrollStoreScatterOp : public UnrollPattern { VectorType valueTy = llvm::dyn_cast(op.getValue().getType()); xegpu::TensorDescType tdescTy = op.getTensorDescType(); - if (!tdescTy.isScattered()) + if (!tdescTy || !tdescTy.isScattered()) return failure(); std::optional> targetShape = getTargetShape(op); diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index af34add37f7ad..b56de88391803 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -385,27 +385,48 @@ func.func @load_gather_vc_3(%src: ui64) { } // ----- -func.func @load_offset(%src: ui64) { +func.func @load_gather_offset_sg(%src: memref) { %offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> %mask = arith.constant dense<1>: vector<8xi1> // expected-error@+1 {{Mask should match value except the chunk size dim}} %2 = xegpu.load %src[%offsets], %mask - : ui64, vector<4xindex>, vector<8xi1> - -> vector<4x2xf32> + : memref, vector<4xindex>, vector<8xi1> + -> vector<4x2xf16> + return +} + +// ----- +func.func @load_gather_offset_wi(%src: ui64) { + %mask = arith.constant dense<1>: vector<1xi1> + %offsets = arith.constant dense<[0]> : vector<1xindex> + // expected-error@+1 {{value elements must match chunk size}} + %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf32> + return +} + +// ----- +func.func @store_scatter_offset(%src: memref) { + %val = arith.constant dense<2.9>: vector<4xf16> + %offsets = arith.constant dense<[0]> : vector<1xindex> + %mask = arith.constant dense<1>: vector<1xi1> + // expected-error@+1 {{value elements must match chunk size}} + xegpu.store %val, %src[%offsets], %mask + : vector<4xf16>, memref, vector<1xindex>, vector<1xi1> return } + // ----- -func.func @store_offset(%src: ui64) { +func.func @load_gather_offset_wi(%src: ui64) { %val = arith.constant dense<2.9>: vector<4x2xf16> - %offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> - %mask = arith.constant dense<1>: vector<8xi1> - // expected-error@+1 {{Mask should match value except the chunk size dim}} - xegpu.store %val, %src[%offsets], %mask - : vector<4x2xf16>, ui64, vector<4xindex>, vector<8xi1> + %mask = arith.constant dense<1>: vector<1xi1> + %offsets = arith.constant dense<[0]> : vector<1xindex> + // expected-error@+1 {{value elements must match chunk size}} + %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf32> return } + // ----- func.func @store_scatter_vc_1(%src: memref<24x32xf32>) { %0 = arith.constant dense<1>: vector<4xi1> diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index 16f5356a69f24..ea80601ef5574 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -664,8 +664,8 @@ gpu.func @prefetch(%src: ui64) { gpu.func @prefetch_offset(%src: ui64) { //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> - // CHECK: xegpu.prefetch %[[arg0]], %cst <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : ui64, vector<4xindex> - xegpu.prefetch %src, %0 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: ui64, vector<4xindex> + // CHECK: xegpu.prefetch %[[arg0]][%cst] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : ui64, vector<4xindex> + xegpu.prefetch %src[%0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: ui64, vector<4xindex> gpu.return } From 1249794952a1e67a4f0ff54b7e4fa39d0704b42e Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Fri, 25 Jul 2025 05:28:50 +0000 Subject: [PATCH 10/16] small fixes --- mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 6 ------ mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 2 +- mlir/test/Dialect/XeGPU/invalid.mlir | 2 +- 3 files changed, 2 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 312db1402f58f..82edd69f63694 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -665,7 +665,6 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> { prop-dict attr-dict `:` type(operands) }]; - // type($source) (type($offsets)^)? let builders = [ OpBuilder<(ins "Value": $source, @@ -763,10 +762,6 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> { $mask prop-dict attr-dict `:` type(operands) `->` type($value) }]; - - // functional-type(operands, results) - // type($source) (type($offsets)^ )? `,` type($mask) `->` type($value) - let builders = [ OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask, @@ -855,7 +850,6 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> { prop-dict attr-dict `:` type(operands) }]; -// type($value) `,` qualified(type($dest)) (type($offsets)^)? `,` type($mask) let builders = [ OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask, diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index cafbf8d5ffc5e..29ee864b8f34f 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -502,7 +502,7 @@ struct UnrollLoadGatherOp : public UnrollPattern { SmallVector convertedTdescTypes = getUnrolledTypes(tdescTy, *targetShape); SmallVector convertedTdescs = pack( - op.getSource(), convertedTdescTypes, *targetShape, loc, rewriter); + op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); SmallVector convertedMaskTypes; SmallVector convertedMasks; diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index b56de88391803..b8e6a31d8d2f7 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -405,7 +405,7 @@ func.func @load_gather_offset_wi(%src: ui64) { } // ----- -func.func @store_scatter_offset(%src: memref) { +func.func @store_scatter_offset_sg(%src: memref) { %val = arith.constant dense<2.9>: vector<4xf16> %offsets = arith.constant dense<[0]> : vector<1xindex> %mask = arith.constant dense<1>: vector<1xi1> From 5cfb24b395ef3e83f4e316c9e35b9d2f8cb72dd1 Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Fri, 25 Jul 2025 21:39:07 +0000 Subject: [PATCH 11/16] address comments --- mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 7 +++---- mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 2 +- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 1 - 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 82edd69f63694..9da015b65a6af 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -16,7 +16,6 @@ include "mlir/Dialect/XeGPU/IR/XeGPUTypes.td" include "mlir/Interfaces/ShapedOpInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" -include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td" // Base class for dialect operations. This operation inherits from the base // `Op` class in OpBase.td, and provides: @@ -639,7 +638,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> { }]; - let arguments = (ins XeGPU_TensorDesc_or_MemRef: $source, + let arguments = (ins XeGPU_TensorDescOrMemRef: $source, Optional: $offsets, OptionalAttr: $l1_hint, OptionalAttr: $l2_hint, @@ -718,7 +717,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> { }]; - let arguments = (ins XeGPU_TensorDesc_or_MemRef: $source, + let arguments = (ins XeGPU_TensorDescOrMemRef: $source, Optional: $offsets, XeGPU_MaskType: $mask, OptionalAttr: $chunk_size, @@ -812,7 +811,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> { let arguments = (ins XeGPU_ValueType: $value, - XeGPU_TensorDesc_or_MemRef: $dest, + XeGPU_TensorDescOrMemRef: $dest, Optional: $offsets, XeGPU_MaskType: $mask, OptionalAttr: $chunk_size, diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index 8e575e31255a7..fa59bf2e40c4d 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -189,7 +189,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", let genVerifyDecl = 1; } -def XeGPU_TensorDesc_or_MemRef : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>; +def XeGPU_TensorDescOrMemRef : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>; def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> { let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier."; diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 7a32f1a45c762..3a41b298e2aae 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -120,7 +120,6 @@ isValidGatherScatterMemRefParams(Type maskTy, VectorType valueTy, auto maskShape = getShapeOf(maskTy); auto valueShape = getShapeOf(valueTy); - auto memShape = getShapeOf(memTy); if (valueTy.getElementType() != memTy.getElementType()) return emitError() << "Value should have the same element type as MemRef."; From 5940d191e865d06063a8ed97b804b90858bc78ba Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Tue, 29 Jul 2025 17:33:18 +0000 Subject: [PATCH 12/16] address feedback --- .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 46 ++++++++++- .../mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 2 +- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 76 ++++++++----------- .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 3 + mlir/test/Dialect/XeGPU/invalid.mlir | 2 +- mlir/test/Dialect/XeGPU/ops.mlir | 2 - 6 files changed, 80 insertions(+), 51 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 9da015b65a6af..c864ce0c3d9cd 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -628,17 +628,28 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> { As compared to prefetch_nd, which works on non-scattered TensorDesc, it works on scattered TensorDesc instead. - Example: + Example 1: ```mlir xegpu.prefetch %tdesc {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16xf16> ``` + + Example 2: + A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc". The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc for the restriction of memref. + ```mlir + %a = memref.alloc() : memref<1024xf32> + %0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex> + xegpu.prefetch %a[%0] {l1_hint = #xegpu.cache_hint, + l2_hint = #xegpu.cache_hint, + l3_hint = #xegpu.cache_hint} + : memref<1024xf32>, vector<4xindex> + ``` }]; - let arguments = (ins XeGPU_TensorDescOrMemRef: $source, + let arguments = (ins XeGPU_GatherScatterSourceType: $source, Optional: $offsets, OptionalAttr: $l1_hint, OptionalAttr: $l2_hint, @@ -706,6 +717,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> { : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16x8xf32> ``` + Example 3 (SIMT mode): ```mlir %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint, @@ -714,10 +726,22 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> { : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr> vector<16xi1> -> vector<8xf32> ``` + + Example 4: + A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc". The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc for the restriction of memref. + ```mlir + %a = memref.alloc() : memref<1024xf32> + %offsets = vector.step : vector<16xindex> + %mask = vector.constant_mask [16]: vector<16xi1> + %val = xegpu.load %a[%offsets], %mask {l1_hint = #xegpu.cache_hint, + l2_hint = #xegpu.cache_hint, + l3_hint = #xegpu.cache_hint} + : memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32> + ``` }]; - let arguments = (ins XeGPU_TensorDescOrMemRef: $source, + let arguments = (ins XeGPU_GatherScatterSourceType: $source, Optional: $offsets, XeGPU_MaskType: $mask, OptionalAttr: $chunk_size, @@ -807,11 +831,25 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> { l3_hint = #xegpu.cache_hint}> : vector<8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr> vector<16xi1> ``` + + Example 4: + A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc". The dest operand could be a raw pointer (uint64_t). Please refer to create_tdesc for the restriction of memref. + ```mlir + %a = memref.alloc() : memref<1024xf32> + %val = arith.constant dense<0.0> : vector<16xf32> + %offsets = vector.step : vector<16xindex> + %mask = vector.constant_mask [16]: vector<16xi1> + xegpu.store %val, %a[%offsets], %mask {l1_hint = #xegpu.cache_hint, + l2_hint = #xegpu.cache_hint, + l3_hint = #xegpu.cache_hint} + : memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32> + ``` + }]; let arguments = (ins XeGPU_ValueType: $value, - XeGPU_TensorDescOrMemRef: $dest, + XeGPU_GatherScatterSourceType: $dest, Optional: $offsets, XeGPU_MaskType: $mask, OptionalAttr: $chunk_size, diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index fa59bf2e40c4d..b268cabb5d266 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -189,7 +189,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", let genVerifyDecl = 1; } -def XeGPU_TensorDescOrMemRef : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>; +def XeGPU_GatherScatterSourceType : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>; def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> { let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier."; diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 3a41b298e2aae..7c8ee7408dfd1 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -111,39 +111,7 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy, } static LogicalResult -isValidGatherScatterMemRefParams(Type maskTy, VectorType valueTy, - MemRefType memTy, int64_t chunkSize, - function_ref emitError) { - - if (!valueTy) - return emitError() << "Expecting a vector type result."; - - auto maskShape = getShapeOf(maskTy); - auto valueShape = getShapeOf(valueTy); - - if (valueTy.getElementType() != memTy.getElementType()) - return emitError() << "Value should have the same element type as MemRef."; - - // a valid shape for SIMT case - if (valueTy.getRank() == 1) { - if (valueTy.getNumElements() != chunkSize) - return emitError() << "value elements must match chunk size " << chunkSize - << " for SIMT code."; - return success(); - } - - llvm::SmallVector expectedMaskShape(valueShape); - if (chunkSize > 1) - expectedMaskShape.pop_back(); - if (expectedMaskShape != maskShape) - return emitError() << "Mask should match value except the chunk size dim."; - - return success(); -} - -static LogicalResult -isValidGatherScatterRawptrParams(Type maskTy, VectorType valueTy, - int64_t chunkSize, +isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy, int64_t chunkSize, function_ref emitError) { if (!valueTy) @@ -703,8 +671,14 @@ LogicalResult CreateDescOp::verify() { //===----------------------------------------------------------------------===// LogicalResult PrefetchOp::verify() { auto tdescTy = getTensorDescType(); - if (tdescTy && !tdescTy.isScattered()) - return emitOpError("Expects a scattered TensorDesc.\n"); + + if (tdescTy) { + if (!tdescTy.isScattered()) + return emitOpError("Expects a scattered TensorDesc.\n"); + } else { + if (getRankOf(getSource()) > 1) + return emitOpError("Expecting the source is a 1D memref or pointer (uint64_t)."); + } if (!isReadHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); @@ -733,6 +707,14 @@ LogicalResult LoadGatherOp::verify() { auto maskTy = getMaskType(); auto valueTy = getValueType(); + if (tdescTy) { + if (!tdescTy.isScattered()) + return emitOpError("Expects a scattered TensorDesc.\n"); + } else { + if (getRankOf(getSource()) > 1) + return emitOpError("Expecting the source is a 1D memref or pointer (uint64_t)."); + } + if (!isReadHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); @@ -749,10 +731,10 @@ LogicalResult LoadGatherOp::verify() { uint64_t chunkSize = static_cast(getChunkSize().value_or(1)); auto memTy = dyn_cast(srcTy); - if (memTy) - return isValidGatherScatterMemRefParams(maskTy, valueTy, memTy, chunkSize, - [&]() { return emitOpError(); }); - return isValidGatherScatterRawptrParams(maskTy, valueTy, chunkSize, + if (memTy && (valueTy.getElementType() != memTy.getElementType()) ) + return emitError() << "Value should have the same element type as MemRef."; + + return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize, [&]() { return emitOpError(); }); } @@ -773,6 +755,14 @@ LogicalResult StoreScatterOp::verify() { auto maskTy = getMaskType(); auto valueTy = getValueType(); + if (tdescTy) { + if (!tdescTy.isScattered()) + return emitOpError("Expects a scattered TensorDesc.\n"); + } else { + if (getRankOf(getDest()) > 1) + return emitOpError("Expecting the dest is a 1D memref or pointer (uint64_t)."); + } + if (!isWriteHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); @@ -790,10 +780,10 @@ LogicalResult StoreScatterOp::verify() { uint64_t chunkSize = static_cast(getChunkSize().value_or(1)); auto memTy = dyn_cast(destTy); - if (memTy) - return isValidGatherScatterMemRefParams(maskTy, valueTy, memTy, chunkSize, - [&]() { return emitOpError(); }); - return isValidGatherScatterRawptrParams(maskTy, valueTy, chunkSize, + if (memTy && (valueTy.getElementType() != memTy.getElementType()) ) + return emitError() << "Value should have the same element type as MemRef."; + + return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize, [&]() { return emitOpError(); }); } diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index 29ee864b8f34f..d52f7f2ac274a 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -484,6 +484,7 @@ struct UnrollLoadGatherOp : public UnrollPattern { VectorType valueTy = llvm::dyn_cast(op.getValue().getType()); xegpu::TensorDescType tdescTy = op.getTensorDescType(); + // TODO: handle the unstructure source case (!tdesTy) if (!tdescTy || !tdescTy.isScattered()) return failure(); @@ -546,6 +547,7 @@ struct UnrollPrefetchOp : public UnrollPattern { Location loc = op.getLoc(); xegpu::TensorDescType tdescTy = op.getTensorDescType(); + // TODO: handle the unstructure source case (!tdesTy) if (!tdescTy || !tdescTy.isScattered()) return failure(); @@ -575,6 +577,7 @@ struct UnrollStoreScatterOp : public UnrollPattern { VectorType valueTy = llvm::dyn_cast(op.getValue().getType()); xegpu::TensorDescType tdescTy = op.getTensorDescType(); + // TODO: handle the unstructure source case (!tdesTy) if (!tdescTy || !tdescTy.isScattered()) return failure(); diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index b8e6a31d8d2f7..4cece4640634e 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -405,7 +405,7 @@ func.func @load_gather_offset_wi(%src: ui64) { } // ----- -func.func @store_scatter_offset_sg(%src: memref) { +func.func @store_scatter_offset_wi(%src: memref) { %val = arith.constant dense<2.9>: vector<4xf16> %offsets = arith.constant dense<[0]> : vector<1xindex> %mask = arith.constant dense<1>: vector<1xi1> diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index ea80601ef5574..6be2371d4d7b2 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -330,7 +330,6 @@ gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>, %x : index) { gpu.return } - // CHECK: func @subgroup_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) { gpu.func @subgroup_store_nd_offset_1(%dst: memref<24x32xf16>) { // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16> @@ -659,7 +658,6 @@ gpu.func @prefetch(%src: ui64) { gpu.return } - // CHECK: gpu.func @prefetch_offset(%[[arg0:.*]]: ui64) { gpu.func @prefetch_offset(%src: ui64) { //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> From da7142aa4eb83340885e459edd71a4c865ba2aca Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Tue, 29 Jul 2025 17:36:14 +0000 Subject: [PATCH 13/16] git-clang-format --- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 7c8ee7408dfd1..45a4363bd11ba 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -111,7 +111,8 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy, } static LogicalResult -isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy, int64_t chunkSize, +isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy, + int64_t chunkSize, function_ref emitError) { if (!valueTy) @@ -677,7 +678,8 @@ LogicalResult PrefetchOp::verify() { return emitOpError("Expects a scattered TensorDesc.\n"); } else { if (getRankOf(getSource()) > 1) - return emitOpError("Expecting the source is a 1D memref or pointer (uint64_t)."); + return emitOpError( + "Expecting the source is a 1D memref or pointer (uint64_t)."); } if (!isReadHintOrNone(getL1HintAttr())) @@ -712,7 +714,8 @@ LogicalResult LoadGatherOp::verify() { return emitOpError("Expects a scattered TensorDesc.\n"); } else { if (getRankOf(getSource()) > 1) - return emitOpError("Expecting the source is a 1D memref or pointer (uint64_t)."); + return emitOpError( + "Expecting the source is a 1D memref or pointer (uint64_t)."); } if (!isReadHintOrNone(getL1HintAttr())) @@ -731,7 +734,7 @@ LogicalResult LoadGatherOp::verify() { uint64_t chunkSize = static_cast(getChunkSize().value_or(1)); auto memTy = dyn_cast(srcTy); - if (memTy && (valueTy.getElementType() != memTy.getElementType()) ) + if (memTy && (valueTy.getElementType() != memTy.getElementType())) return emitError() << "Value should have the same element type as MemRef."; return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize, @@ -760,9 +763,10 @@ LogicalResult StoreScatterOp::verify() { return emitOpError("Expects a scattered TensorDesc.\n"); } else { if (getRankOf(getDest()) > 1) - return emitOpError("Expecting the dest is a 1D memref or pointer (uint64_t)."); + return emitOpError( + "Expecting the dest is a 1D memref or pointer (uint64_t)."); } - + if (!isWriteHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); @@ -780,7 +784,7 @@ LogicalResult StoreScatterOp::verify() { uint64_t chunkSize = static_cast(getChunkSize().value_or(1)); auto memTy = dyn_cast(destTy); - if (memTy && (valueTy.getElementType() != memTy.getElementType()) ) + if (memTy && (valueTy.getElementType() != memTy.getElementType())) return emitError() << "Value should have the same element type as MemRef."; return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize, From 8b99ecc629be3c1dfb96c93ead49594dc30a47ef Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Tue, 29 Jul 2025 22:43:25 +0000 Subject: [PATCH 14/16] minor polish --- .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 3 +- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 42 ++++++++----------- 2 files changed, 20 insertions(+), 25 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index c864ce0c3d9cd..3e075de1651ab 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -863,7 +863,8 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> { } Value getTensorDesc() { - return getDest(); + assert(getTensorDescType() && "Expected dest to be a TensorDescType"); + return getDest(); } xegpu::TensorDescType getTensorDescType() { diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 45a4363bd11ba..1b114d41b6ca5 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -673,14 +673,12 @@ LogicalResult CreateDescOp::verify() { LogicalResult PrefetchOp::verify() { auto tdescTy = getTensorDescType(); - if (tdescTy) { - if (!tdescTy.isScattered()) - return emitOpError("Expects a scattered TensorDesc.\n"); - } else { - if (getRankOf(getSource()) > 1) - return emitOpError( - "Expecting the source is a 1D memref or pointer (uint64_t)."); - } + if (tdescTy && !tdescTy.isScattered()) + return emitOpError("Expects a scattered TensorDesc.\n"); + + if (!tdescTy && getRankOf(getSource()) > 1) + return emitOpError( + "Expecting the source is a 1D memref or pointer (uint64_t)."); if (!isReadHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); @@ -709,14 +707,12 @@ LogicalResult LoadGatherOp::verify() { auto maskTy = getMaskType(); auto valueTy = getValueType(); - if (tdescTy) { - if (!tdescTy.isScattered()) - return emitOpError("Expects a scattered TensorDesc.\n"); - } else { - if (getRankOf(getSource()) > 1) - return emitOpError( - "Expecting the source is a 1D memref or pointer (uint64_t)."); - } + if (tdescTy && !tdescTy.isScattered()) + return emitOpError("Expects a scattered TensorDesc.\n"); + + if (!tdescTy && getRankOf(getSource()) > 1) + return emitOpError( + "Expecting the source is a 1D memref or pointer (uint64_t)."); if (!isReadHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); @@ -758,14 +754,12 @@ LogicalResult StoreScatterOp::verify() { auto maskTy = getMaskType(); auto valueTy = getValueType(); - if (tdescTy) { - if (!tdescTy.isScattered()) - return emitOpError("Expects a scattered TensorDesc.\n"); - } else { - if (getRankOf(getDest()) > 1) - return emitOpError( - "Expecting the dest is a 1D memref or pointer (uint64_t)."); - } + if (tdescTy && !tdescTy.isScattered()) + return emitOpError("Expects a scattered TensorDesc.\n"); + + if (!tdescTy && getRankOf(getDest()) > 1) + return emitOpError( + "Expecting the dest is a 1D memref or pointer (uint64_t)."); if (!isWriteHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); From 04306ca04fa6531f2cefe607555d64ba55fe83ef Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Wed, 30 Jul 2025 17:23:15 +0000 Subject: [PATCH 15/16] address comments --- .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 37 ++++++++++++++----- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 2 +- .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 6 +-- 3 files changed, 31 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 3e075de1651ab..75b16a87e03c6 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -637,7 +637,10 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> { ``` Example 2: - A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc". The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc for the restriction of memref. + A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. + It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc". + The source operand could be a raw pointer (uint64_t). + Please refer to create_tdesc for the restriction of memref. ```mlir %a = memref.alloc() : memref<1024xf32> %0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex> @@ -660,8 +663,11 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> { return getSource().getType(); } - Value getTensorDesc() { - return getSource(); + TypedValue getTensorDesc() { + if (auto tdescType = getTensorDescType()) { + return llvm::cast>(getSource()); + } + return TypedValue(); } xegpu::TensorDescType getTensorDescType() { @@ -728,7 +734,10 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> { ``` Example 4: - A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc". The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc for the restriction of memref. + A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. + It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc". + The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc + for the restriction of memref. ```mlir %a = memref.alloc() : memref<1024xf32> %offsets = vector.step : vector<16xindex> @@ -756,8 +765,11 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> { return getSource().getType(); } - Value getTensorDesc() { - return getSource(); + TypedValue getTensorDesc() { + if (auto tdescType = getTensorDescType()) { + return llvm::cast>(getSource()); + } + return TypedValue(); } xegpu::TensorDescType getTensorDescType() { @@ -833,7 +845,10 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> { ``` Example 4: - A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc". The dest operand could be a raw pointer (uint64_t). Please refer to create_tdesc for the restriction of memref. + A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. + It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc". + The dest operand could be a raw pointer (uint64_t). + Please refer to create_tdesc for the restriction of memref. ```mlir %a = memref.alloc() : memref<1024xf32> %val = arith.constant dense<0.0> : vector<16xf32> @@ -862,9 +877,11 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> { return getDest().getType(); } - Value getTensorDesc() { - assert(getTensorDescType() && "Expected dest to be a TensorDescType"); - return getDest(); + TypedValue getTensorDesc() { + if (auto tdescType = getTensorDescType()) { + return llvm::cast>(getDest()); + } + return TypedValue(); } xegpu::TensorDescType getTensorDescType() { diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 1b114d41b6ca5..33450f3fa229e 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -708,7 +708,7 @@ LogicalResult LoadGatherOp::verify() { auto valueTy = getValueType(); if (tdescTy && !tdescTy.isScattered()) - return emitOpError("Expects a scattered TensorDesc.\n"); + return emitOpError("Expects a scattered TensorDesc."); if (!tdescTy && getRankOf(getSource()) > 1) return emitOpError( diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index d52f7f2ac274a..9f0c074a1489d 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -485,7 +485,7 @@ struct UnrollLoadGatherOp : public UnrollPattern { xegpu::TensorDescType tdescTy = op.getTensorDescType(); // TODO: handle the unstructure source case (!tdesTy) - if (!tdescTy || !tdescTy.isScattered()) + if (!tdescTy || op.getOffsets()) return failure(); std::optional> targetShape = getTargetShape(op); @@ -548,7 +548,7 @@ struct UnrollPrefetchOp : public UnrollPattern { xegpu::TensorDescType tdescTy = op.getTensorDescType(); // TODO: handle the unstructure source case (!tdesTy) - if (!tdescTy || !tdescTy.isScattered()) + if (!tdescTy || op.getOffsets()) return failure(); std::optional> targetShape = getTargetShape(op); @@ -578,7 +578,7 @@ struct UnrollStoreScatterOp : public UnrollPattern { xegpu::TensorDescType tdescTy = op.getTensorDescType(); // TODO: handle the unstructure source case (!tdesTy) - if (!tdescTy || !tdescTy.isScattered()) + if (!tdescTy || op.getOffsets()) return failure(); std::optional> targetShape = getTargetShape(op); From bbd6530ee6c0789af571578530dabe3e0cce7915 Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Wed, 30 Jul 2025 17:53:36 +0000 Subject: [PATCH 16/16] add more invalid tests --- mlir/test/Dialect/XeGPU/invalid.mlir | 33 ++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 4cece4640634e..dff3ffab39ecf 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -384,6 +384,14 @@ func.func @load_gather_vc_3(%src: ui64) { return } +// ----- +func.func @prefetch_offset_wi_1(%src: memref<4x4xf32>) { + %offsets = arith.constant dense<[0]> : vector<1xindex> + // expected-error@+1 {{Expecting the source is a 1D memref or pointer}} + xegpu.prefetch %src[%offsets]: memref<4x4xf32>, vector<1xindex> + return +} + // ----- func.func @load_gather_offset_sg(%src: memref) { %offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> @@ -405,7 +413,7 @@ func.func @load_gather_offset_wi(%src: ui64) { } // ----- -func.func @store_scatter_offset_wi(%src: memref) { +func.func @store_scatter_offset_wi_1(%src: memref) { %val = arith.constant dense<2.9>: vector<4xf16> %offsets = arith.constant dense<[0]> : vector<1xindex> %mask = arith.constant dense<1>: vector<1xi1> @@ -415,17 +423,34 @@ func.func @store_scatter_offset_wi(%src: memref) { return } +// ----- +func.func @store_scatter_offset_wi_2(%src: memref<4x4xf16>) { + %val = arith.constant dense<2.9>: vector<4xf16> + %offsets = arith.constant dense<[0]> : vector<1xindex> + %mask = arith.constant dense<1>: vector<1xi1> + // expected-error@+1 {{Expecting the dest is a 1D memref or pointer}} + xegpu.store %val, %src[%offsets], %mask + : vector<4xf16>, memref<4x4xf16>, vector<1xindex>, vector<1xi1> + return +} // ----- -func.func @load_gather_offset_wi(%src: ui64) { - %val = arith.constant dense<2.9>: vector<4x2xf16> +func.func @load_gather_offset_wi_2(%src: ui64) { %mask = arith.constant dense<1>: vector<1xi1> %offsets = arith.constant dense<[0]> : vector<1xindex> // expected-error@+1 {{value elements must match chunk size}} - %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf32> + %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf16> return } +// ----- +func.func @load_gather_offset_wi_1(%src: memref<4x4xf32>) { + %mask = arith.constant dense<1>: vector<1xi1> + %offsets = arith.constant dense<[0]> : vector<1xindex> + // expected-error@+1 {{Expecting the source is a 1D memref or pointer}} + %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : memref<4x4xf32>, vector<1xindex>, vector<1xi1> -> vector<2xf32> + return +} // ----- func.func @store_scatter_vc_1(%src: memref<24x32xf32>) {