Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ def XeGPU_BlockTensorDescAttr: XeGPU_TensorDescAttr<"BlockTensorDesc", "block_td
)>
];

let extraClassDeclaration = [{
// return true if all fields of the BlockTensorDescAttr are set with
// default values.
bool hasDefaultsOnly();
}];

}

def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scatter_tdesc_attr"> {
Expand Down
53 changes: 21 additions & 32 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -131,62 +131,51 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
return llvm::cast<TensorDescType>(cloneWith(getShape(), elementType));
}

BlockTensorDescAttr getEncodingAsBlockTensorDescAttr() const {
return llvm::dyn_cast_if_present<BlockTensorDescAttr>(getEncoding());
}

ScatterTensorDescAttr getEncodingAsScatterTensorDescAttr() const {
return llvm::dyn_cast_if_present<ScatterTensorDescAttr>(getEncoding());
template <typename T,
typename = std::enable_if_t<
std::is_same_v<T, BlockTensorDescAttr> ||
std::is_same_v<T, ScatterTensorDescAttr>>>
T getEncodingOfType() const {
return llvm::dyn_cast_if_present<T>(getEncoding());
}

LayoutAttr getLayoutAttr() const {
return llvm::dyn_cast_if_present<LayoutAttr>(getLayout());
}

xegpu::MemorySpace getMemorySpace() const {
auto block_attr = getEncodingAsBlockTensorDescAttr();
if (block_attr && block_attr.getMemorySpace())
return block_attr.getMemorySpace().getValue();
if (auto attr = getEncodingOfType<BlockTensorDescAttr>())
return attr.getMemorySpace().getValue();

auto scatter_attr = getEncodingAsScatterTensorDescAttr();
if (scatter_attr && scatter_attr.getMemorySpace())
return scatter_attr.getMemorySpace().getValue();
if (auto attr = getEncodingOfType<ScatterTensorDescAttr>())
return attr.getMemorySpace().getValue();

// return default value
llvm_unreachable("invalid encoding");
return MemorySpace::Global;
}

// get the ArrayLength for blocked TensorDesc
int getArrayLength() {
auto attr = getEncoding();
auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
if (block_attr && block_attr.getArrayLength())
return block_attr.getArrayLength().getInt();
// return default value
return 1;
auto attr = getEncodingOfType<BlockTensorDescAttr>();
assert(attr && "invalid on non BlockTensorDescAttr.");
return attr.getArrayLength().getInt();
}

bool getBoundaryCheck() {
auto attr = getEncoding();
auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
if (block_attr && block_attr.getBoundaryCheck())
return block_attr.getBoundaryCheck().getValue();
// return default value
return true;
auto attr = getEncodingOfType<BlockTensorDescAttr>();
assert(attr && "invalid on non BlockTensorDescAttr.");
return attr.getBoundaryCheck().getValue();
}

bool isScattered() {
return bool(getEncodingAsScatterTensorDescAttr());
return bool(getEncodingOfType<ScatterTensorDescAttr>());
}

// get the ChunkSize for scattered TensorDesc
int getChunkSizeAsInt() {
auto attr = getEncoding();
auto scatter_attr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(attr);
assert(scatter_attr && "invalid on non ScatterTensorDescAttr.");
return scatter_attr.getChunkSizeAsInt();
auto attr = getEncodingOfType<ScatterTensorDescAttr>();
assert(attr && "invalid on non ScatterTensorDescAttr.");
return attr.getChunkSizeAsInt();
}

/// Helper to drop all layout information from the TensorDesc type.
Expand Down
16 changes: 12 additions & 4 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context,
return Base::get(context, scopeAttr, lengthAttr, boundaryAttr);
}

bool BlockTensorDescAttr::hasDefaultsOnly() {
return getMemorySpace().getValue() == xegpu::MemorySpace::Global &&
getArrayLength().getInt() == 1 && getBoundaryCheck().getValue();
}

//===----------------------------------------------------------------------===//
// XeGPU_ScatterTensorDescAttr
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -253,10 +258,11 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
if (parser.parseGreater())
return {};

MLIRContext *ctxt = parser.getContext();
return TensorDescType::getChecked(
[&]() { return parser.emitError(parser.getNameLoc()); },
parser.getContext(), shape, elementType,
encoding.value_or(mlir::Attribute()), layout.value_or(mlir::Attribute()));
[&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
elementType, encoding.value_or(BlockTensorDescAttr::get(ctxt)),
layout.value_or(mlir::Attribute()));
}

void TensorDescType::print(::mlir::AsmPrinter &printer) const {
Expand All @@ -273,7 +279,9 @@ void TensorDescType::print(::mlir::AsmPrinter &printer) const {

printer << getElementType();

if (auto encoding = getEncoding())
auto encoding = getEncoding();
auto blockAttr = llvm::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
if (encoding && (!blockAttr || !blockAttr.hasDefaultsOnly()))
printer << ", " << encoding;

if (auto layout = getLayout())
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) {
std::multiplies<int64_t>());

// Case 1: regular loads/stores
auto scatterAttr = tdescTy.getEncodingAsScatterTensorDescAttr();
auto scatterAttr = tdescTy.getEncodingOfType<ScatterTensorDescAttr>();
if (scatterAttr) {
auto chunkSize = scatterAttr.getChunkSize().getInt();
// Verify if the first dimension of the tensor descriptor shape is
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]

Expand All @@ -55,7 +55,7 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32,
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]

Expand All @@ -73,7 +73,7 @@ func.func @load_out_of_bounds(%source: memref<7x15xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
// CHECK-SAME: memref<7x15xf32> -> !xegpu.tensor_desc<8x16xf32,
// CHECK-SAME: memref<7x15xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]

Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func.func @store_2D_vector(%vec: vector<8x16xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>

// -----
Expand All @@ -57,7 +57,7 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32,
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>

// -----
Expand All @@ -75,7 +75,7 @@ func.func @store_out_of_bounds(%vec: vector<8x16xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
// CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32,
// CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>

// -----
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<32x64xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
// CHECK-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32,
// CHECK-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func.func @store_out_of_bounds(%vec: vector<8x16xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
// CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32,
// CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>

// -----
Expand Down
Loading