Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space,
CArg<"mlir::Attribute", "mlir::Attribute()">:$sg_map)>
];

let extraClassDeclaration = [{
using TensorType::clone;
using mlir::ShapedType::Trait<TensorDescType>::getElementTypeBitWidth;
Expand Down Expand Up @@ -176,6 +176,11 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
return scatter_attr.getChunkSize().getInt();
return 1;
}

// This returns a vector type that represents the fragment of data owned by
// a work item in SIMT mode if this tensor descriptor is used in a XeGPU
// load/store operation.
FailureOr<VectorType> getDistributedVectorType();
}];

let hasCustomAssemblyFormat = true;
Expand Down
90 changes: 86 additions & 4 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@

#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/LogicalResult.h"

namespace mlir {
namespace xegpu {
Expand Down Expand Up @@ -276,14 +280,13 @@ LogicalResult TensorDescType::verify(
if (scatterAttr) {
// Validate subgroup mapping rules for scattered tensors.
// A work-item's slice of the tensor with shape [sg_size] or
// [sg_size, chunk_size] will be [1] or [1, chunks_size] respectively,
// the mapping should reflect that.
// [sg_size, chunk_size] will be [1] or [1, 32/element_ty_bit_width]
// respectively, the mapping should reflect that.
if (wiData[0] != 1)
return emitError()
<< "cannot map over non-contiguous scattered row elements";

unsigned chunkSize = scatterAttr.getChunkSize().getInt();
if (wiData[1] != chunkSize)
if (wiData[1] != (32 / elementType.getIntOrFloatBitWidth()))
return emitError() << "work item data mapping must match the number of "
"contiguous elements";
}
Expand All @@ -307,6 +310,85 @@ LogicalResult TensorDescType::verify(
return success();
}

// If tensor descriptor has a sg_map attribute it is used in SIMT mode.
// In this mode, the distributed vector shape is determined as follows:
// Definitions:
// wi_data_size = wi_data[0] × wi_data[1]
// subgroup_size = wi_layout[0] × wi_layout[1]
// distribution_unit_size = subgroup_size × wi_data_size
// ---------------------------------------------------------------------
// Case 1: Regular loads/stores.
// ---------------------------------------------------------------------
// Distributed vector shape must be:
// [chunk_size / wi_data_size, wi_data_size]
// If the tensor descriptor shape is 1D, first dimension is ignored (set to 1).
// [wi_data_size]
// ---------------------------------------------------------------------
// Case 2: Block loads/stores
// ---------------------------------------------------------------------
// Additionalm definitions:
// tensor_size = tensor_desc[0] * .. * tensor_desc[r-1] * array_length
// n_distribution_units = tensor_size / distribution_unit_size
// Given above definitions, the following conditions must be met:
// * tensor_desc[0] % (wi_layout[0] × wi_data[0]) == 0
// * tensor_desc[1] % (wi_layout[1] × wi_data[1]) == 0
// Distributed vector shape must be:
// [n_distribution_units, wi_data_size]
FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
auto sgMap = llvm::dyn_cast_if_present<SGMapAttr>(getSgMap());
// If no sg_map is provided, tensor desc is not used in SIMT mode.
if (!sgMap)
return failure();

SmallVector<int64_t> wiData(sgMap.getWiData());
SmallVector<int64_t> wiLayout(sgMap.getWiLayout());
auto tdescShape = getShape();

auto wiDataSize = 1, sgSize = 1;
for (auto [wiDim, wiDataDim] : llvm::zip_equal(wiLayout, wiData)) {
wiDataSize *= wiDataDim;
sgSize *= wiDim;
}

// Case 1: regular loads/stores
auto scatterAttr =
llvm::dyn_cast_if_present<ScatterTensorDescAttr>(getEncoding());
if (scatterAttr) {
auto chunkSize = scatterAttr.getChunkSize().getInt();
// Check if the first dimension of the tensor descriptor shape is
// distributable.
if (tdescShape[0] % (wiLayout[0]) != 0)
return failure();
if (chunkSize > 1)
return VectorType::get({chunkSize / wiDataSize, wiDataSize},
getElementType());
return VectorType::get({wiDataSize}, getElementType());
}

// Case 2: block loads/stores
// Tensor descriptor shape can be 1D. For the 1D case, outer dims of wiData
// and wiLayout must be 1.
if (tdescShape.size() == 1) {
if (wiData[0] != 1 || wiLayout[0] != 1)
return failure();
wiData = {wiData[1]};
wiLayout = {wiLayout[1]};
}
// Check if the tensor descriptor shape is distributable.
int64_t tensorSize = 1;
for (auto [tdescDim, wiDim, wiDataDim] :
llvm::zip_equal(tdescShape, wiLayout, wiData)) {
if (tdescDim % (wiDim * wiDataDim) != 0)
return failure();
tensorSize *= tdescDim;
}
// tensorSize must be adjusted for array_length.
tensorSize *= getArrayLength();

return VectorType::get({tensorSize / (sgSize * wiDataSize), wiDataSize},
getElementType());
}

} // namespace xegpu
} // namespace mlir

Expand Down
Loading