diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 42b5b7a0d4e3f..d022361d1e376 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -64,6 +64,12 @@ def XeGPU_BlockTensorDescAttr: XeGPU_TensorDescAttr<"BlockTensorDesc", "block_td )> ]; + let extraClassDeclaration = [{ + // return true if all fields of the BlockTensorDescAttr are set with + // default values. + bool hasDefaultsOnly(); + }]; + } def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scatter_tdesc_attr"> { diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index 277158ac85409..1f4e817dc549c 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -131,12 +131,12 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", return llvm::cast(cloneWith(getShape(), elementType)); } - BlockTensorDescAttr getEncodingAsBlockTensorDescAttr() const { - return llvm::dyn_cast_if_present(getEncoding()); - } - - ScatterTensorDescAttr getEncodingAsScatterTensorDescAttr() const { - return llvm::dyn_cast_if_present(getEncoding()); + template || + std::is_same_v>> + T getEncodingOfType() const { + return llvm::dyn_cast_if_present(getEncoding()); } LayoutAttr getLayoutAttr() const { @@ -144,49 +144,35 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", } xegpu::MemorySpace getMemorySpace() const { - auto block_attr = getEncodingAsBlockTensorDescAttr(); - if (block_attr && block_attr.getMemorySpace()) - return block_attr.getMemorySpace().getValue(); - - auto scatter_attr = getEncodingAsScatterTensorDescAttr(); - if (scatter_attr && scatter_attr.getMemorySpace()) - return scatter_attr.getMemorySpace().getValue(); + if (auto attr = getEncodingOfType()) + return attr.getMemorySpace().getValue(); - // return default value - return MemorySpace::Global; + auto attr = getEncodingOfType(); + return attr.getMemorySpace().getValue(); } // get the ArrayLength for blocked TensorDesc int getArrayLength() { - auto attr = getEncoding(); - auto block_attr = mlir::dyn_cast_if_present(attr); - assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr."); - if (block_attr && block_attr.getArrayLength()) - return block_attr.getArrayLength().getInt(); - // return default value - return 1; + auto attr = getEncodingOfType(); + assert(attr && "invalid on non BlockTensorDescAttr."); + return attr.getArrayLength().getInt(); } bool getBoundaryCheck() { - auto attr = getEncoding(); - auto block_attr = mlir::dyn_cast_if_present(attr); - assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr."); - if (block_attr && block_attr.getBoundaryCheck()) - return block_attr.getBoundaryCheck().getValue(); - // return default value - return true; + auto attr = getEncodingOfType(); + assert(attr && "invalid on non BlockTensorDescAttr."); + return attr.getBoundaryCheck().getValue(); } bool isScattered() { - return bool(getEncodingAsScatterTensorDescAttr()); + return bool(getEncodingOfType()); } // get the ChunkSize for scattered TensorDesc int getChunkSizeAsInt() { - auto attr = getEncoding(); - auto scatter_attr = mlir::dyn_cast_if_present(attr); - assert(scatter_attr && "invalid on non ScatterTensorDescAttr."); - return scatter_attr.getChunkSizeAsInt(); + auto attr = getEncodingOfType(); + assert(attr && "invalid on non ScatterTensorDescAttr."); + return attr.getChunkSizeAsInt(); } /// Helper to drop all layout information from the TensorDesc type. diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 642c393cbc2c8..8ab404d52eab4 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -112,6 +112,11 @@ BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context, return Base::get(context, scopeAttr, lengthAttr, boundaryAttr); } +bool BlockTensorDescAttr::hasDefaultsOnly() { + return getMemorySpace().getValue() == xegpu::MemorySpace::Global && + getArrayLength().getInt() == 1 && getBoundaryCheck().getValue(); +} + //===----------------------------------------------------------------------===// // XeGPU_ScatterTensorDescAttr //===----------------------------------------------------------------------===// @@ -253,10 +258,11 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) { if (parser.parseGreater()) return {}; + MLIRContext *ctxt = parser.getContext(); return TensorDescType::getChecked( - [&]() { return parser.emitError(parser.getNameLoc()); }, - parser.getContext(), shape, elementType, - encoding.value_or(mlir::Attribute()), layout.value_or(mlir::Attribute())); + [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape, + elementType, encoding.value_or(BlockTensorDescAttr::get(ctxt)), + layout.value_or(mlir::Attribute())); } void TensorDescType::print(::mlir::AsmPrinter &printer) const { @@ -273,7 +279,9 @@ void TensorDescType::print(::mlir::AsmPrinter &printer) const { printer << getElementType(); - if (auto encoding = getEncoding()) + auto encoding = getEncoding(); + auto blockAttr = llvm::dyn_cast_if_present(encoding); + if (encoding && (!blockAttr || !blockAttr.hasDefaultsOnly())) printer << ", " << encoding; if (auto layout = getLayout()) diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index 370d149ee55af..1f0a663fe676c 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -54,7 +54,7 @@ mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) { std::multiplies()); // Case 1: regular loads/stores - auto scatterAttr = tdescTy.getEncodingAsScatterTensorDescAttr(); + auto scatterAttr = tdescTy.getEncodingOfType(); if (scatterAttr) { auto chunkSize = scatterAttr.getChunkSize().getInt(); // Verify if the first dimension of the tensor descriptor shape is diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir index 58719e75b1bde..9908205f07c92 100644 --- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir @@ -30,7 +30,7 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>, // CHECK-SAME: %[[OFFSET:.+]]: index // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc // CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] -// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32, +// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32> // CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32> // CHECK: return %[[VEC]] @@ -55,7 +55,7 @@ func.func @load_dynamic_source(%source: memref, // CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]] // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] // CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1] -// CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf32, +// CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf32> // CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32> // CHECK: return %[[VEC]] @@ -73,7 +73,7 @@ func.func @load_out_of_bounds(%source: memref<7x15xf32>, // CHECK-SAME: %[[OFFSET:.+]]: index // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc // CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]] -// CHECK-SAME: memref<7x15xf32> -> !xegpu.tensor_desc<8x16xf32, +// CHECK-SAME: memref<7x15xf32> -> !xegpu.tensor_desc<8x16xf32> // CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32> // CHECK: return %[[VEC]] diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir index 0d3da815529e3..2c498dcc2a071 100644 --- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir @@ -32,7 +32,7 @@ func.func @store_2D_vector(%vec: vector<8x16xf32>, // CHECK-SAME: %[[OFFSET:.+]]: index // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc // CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] -// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32, +// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32> // CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32> // ----- @@ -57,7 +57,7 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>, // CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]] // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] // CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1] -// CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf32, +// CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf32> // CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32> // ----- @@ -75,7 +75,7 @@ func.func @store_out_of_bounds(%vec: vector<8x16xf32>, // CHECK-SAME: %[[OFFSET:.+]]: index // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc // CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]] -// CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32, +// CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32> // CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32> // ----- diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir index 05b41a8233e8c..d1e5a62ad3e9b 100644 --- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir @@ -51,7 +51,7 @@ func.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>, // CHECK-SAME: %[[SRC:.+]]: memref<32x64xf32>, // CHECK-SAME: %[[OFFSET:.+]]: index // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]] -// CHECK-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32, +// CHECK-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32> // CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32> // CHECK: return %[[VEC]] diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir index 2bfee03892d10..d5f1221aebed5 100644 --- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir @@ -80,7 +80,7 @@ func.func @store_out_of_bounds(%vec: vector<8x16xf32>, // CHECK-SAME: %[[OFFSET:.+]]: index // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc // CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]] -// CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32, +// CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32> // CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32> // -----