Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -608,8 +608,8 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
return getElementTypeOrSelf(type);
}

Type getValueType() {
return getValue().getType();
VectorType getValueType() {
return llvm::dyn_cast<VectorType>(getValue().getType());
}

Type getMaskType() {
Expand Down Expand Up @@ -668,8 +668,8 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
return getTensorDesc().getType();
}

Type getValueType() {
return getValue().getType();
VectorType getValueType() {
return llvm::dyn_cast<VectorType>(getValue().getType());
}

Type getMaskType() {
Expand Down Expand Up @@ -757,7 +757,10 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>]
let arguments = (ins
XeGPU_DpasOpType : $lhs,
XeGPU_DpasOpType : $rhs,
Optional<XeGPU_Vector2DType>: $acc);
Optional<XeGPU_Vector2DType>: $acc,
OptionalAttr<XeGPU_SGMapAttr>:$sg_map_a,
OptionalAttr<XeGPU_SGMapAttr>:$sg_map_b,
OptionalAttr<XeGPU_SGMapAttr>:$sg_map_c);
let results = (outs XeGPU_Vector2DType: $result);

let extraClassDeclaration = [{
Expand Down
7 changes: 6 additions & 1 deletion mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space,
CArg<"mlir::Attribute", "mlir::Attribute()">:$sg_map)>
];

let extraClassDeclaration = [{
using TensorType::clone;
using mlir::ShapedType::Trait<TensorDescType>::getElementTypeBitWidth;
Expand Down Expand Up @@ -176,6 +176,11 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
return scatter_attr.getChunkSize().getInt();
return 1;
}

// This returns a vector type that represents the fragment of data owned by
// a work item in SIMT mode if this tensor descriptor is used in a XeGPU
// load/store operation.
FailureOr<VectorType> getDistributedVectorType();
}];

let hasCustomAssemblyFormat = true;
Expand Down
104 changes: 99 additions & 5 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,13 @@

#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/LogicalResult.h"
#include <cassert>

namespace mlir {
namespace xegpu {
Expand Down Expand Up @@ -239,6 +244,8 @@ LogicalResult TensorDescType::verify(
llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
mlir::Attribute encoding, mlir::Attribute sg_map) {
size_t rank = shape.size();
// Low-pressure types are packed in 32-bit units.
unsigned packingFactor = 32 / elementType.getIntOrFloatBitWidth();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be worth to make 32 a named variable just to clarify its meaning right away, this could also spare a comment below.

This is because each work item access data in 32 bit granularity

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. I have added a comment on this.

if (rank != 1 && rank != 2)
return emitError() << "expected 1D or 2D tensor";

Expand All @@ -252,6 +259,16 @@ LogicalResult TensorDescType::verify(
return emitError() << "expected non-contiguous elements for 1D tensor";
if (rank == 2 && chunkSize < 2)
return emitError() << "expected chunk blocks for 2D tensor";
// If chunk size > 1, the second dimension of the tensor shape must be
// equal to chunk size and it must be a multiple of the packing factor.
if (chunkSize > 1) {
if (shape.back() != chunkSize)
return emitError() << "expected tensor shape[1] to match chunk size";
if (shape.back() % packingFactor != 0)
return emitError()
<< "expected tensor shape[1] to be a multiple of packing factor "
<< packingFactor;
}
}

if (auto blockAttr =
Expand All @@ -276,14 +293,13 @@ LogicalResult TensorDescType::verify(
if (scatterAttr) {
// Validate subgroup mapping rules for scattered tensors.
// A work-item's slice of the tensor with shape [sg_size] or
// [sg_size, chunk_size] will be [1] or [1, chunks_size] respectively,
// the mapping should reflect that.
// [sg_size, chunk_size] will be [1] or [1, 32/element_ty_bit_width]
// respectively, the mapping should reflect that. This is because each
// work item access data in 32 bit granularity.
if (wiData[0] != 1)
return emitError()
<< "cannot map over non-contiguous scattered row elements";

unsigned chunkSize = scatterAttr.getChunkSize().getInt();
if (wiData[1] != chunkSize)
if (wiData[1] != packingFactor)
return emitError() << "work item data mapping must match the number of "
"contiguous elements";
}
Expand All @@ -307,6 +323,84 @@ LogicalResult TensorDescType::verify(
return success();
}

// If tensor descriptor has a sg_map attribute it is used in SIMT mode.
// In this mode, the distributed vector shape is determined as follows:
// Definitions:
// wi_data_size = wi_data[0] × wi_data[1]
// subgroup_size = wi_layout[0] × wi_layout[1]
// distribution_unit_size = subgroup_size × wi_data_size
// ---------------------------------------------------------------------
// Case 1: Regular loads/stores.
// ---------------------------------------------------------------------
// Distributed vector shape must be:
// [chunk_size / wi_data_size, wi_data_size]
// If the tensor descriptor shape is 1D, first dimension is ignored (set to 1).
// [wi_data_size]
// ---------------------------------------------------------------------
// Case 2: Block loads/stores
// ---------------------------------------------------------------------
// Additional definitions:
// tensor_size = tensor_desc[0] * .. * tensor_desc[r-1] * array_length
// n_distribution_units = tensor_size / distribution_unit_size
// Given above definitions, the following conditions must be met:
// * tensor_desc[0] % (wi_layout[0] × wi_data[0]) == 0
// * tensor_desc[1] % (wi_layout[1] × wi_data[1]) == 0
// Distributed vector shape must be:
// [n_distribution_units, wi_data_size]
FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
auto sgMap = llvm::dyn_cast_if_present<SGMapAttr>(getSgMap());
// If no sg_map is provided, tensor desc is not used in SIMT mode.
if (!sgMap)
return failure();

SmallVector<int64_t> wiData(sgMap.getWiData());
SmallVector<int64_t> wiLayout(sgMap.getWiLayout());
auto tdescShape = getShape();

auto wiDataSize = 1, sgSize = 1;
for (auto [wiDim, wiDataDim] : llvm::zip_equal(wiLayout, wiData)) {
wiDataSize *= wiDataDim;
sgSize *= wiDim;
}

// Case 1: regular loads/stores
auto scatterAttr = getEncodingAsScatterTensorDescAttr();
if (scatterAttr) {
auto chunkSize = scatterAttr.getChunkSize().getInt();
// Verify if the first dimension of the tensor descriptor shape is
// distributable.
assert(tdescShape[0] % (wiLayout[0]) == 0 &&
"tensor descriptor shape is not distributable");
if (chunkSize > 1)
return VectorType::get({chunkSize / wiDataSize, wiDataSize},
getElementType());
return VectorType::get({wiDataSize}, getElementType());
}

// Case 2: block loads/stores
// Tensor descriptor shape can be 1D. For the 1D case, outer dims of wiData
// and wiLayout must be 1.
if (tdescShape.size() == 1) {
assert((wiData[0] == 1 && wiLayout[0] == 1) &&
"wi_data[0] and wi_layout[0] must be 1 for 1D tensor descriptor");
wiData = {wiData[1]};
wiLayout = {wiLayout[1]};
}
// Check if the tensor descriptor shape is distributable.
int64_t tensorSize = 1;
for (auto [tdescDim, wiDim, wiDataDim] :
llvm::zip_equal(tdescShape, wiLayout, wiData)) {
assert((tdescDim % (wiDim * wiDataDim) == 0) &&
"tensor descriptor shape is not distributable");
tensorSize *= tdescDim;
}
// tensorSize must be adjusted for array_length.
tensorSize *= getArrayLength();

return VectorType::get({tensorSize / (sgSize * wiDataSize), wiDataSize},
getElementType());
}

} // namespace xegpu
} // namespace mlir

Expand Down
Loading