Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,10 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>]
let arguments = (ins
XeGPU_DpasOpType : $lhs,
XeGPU_DpasOpType : $rhs,
Optional<XeGPU_Vector2DType>: $acc);
Optional<XeGPU_Vector2DType>: $acc,
OptionalAttr<XeGPU_SGMapAttr>:$sg_map_a,
OptionalAttr<XeGPU_SGMapAttr>:$sg_map_b,
OptionalAttr<XeGPU_SGMapAttr>:$sg_map_c);
let results = (outs XeGPU_Vector2DType: $result);

let extraClassDeclaration = [{
Expand Down
7 changes: 6 additions & 1 deletion mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space,
CArg<"mlir::Attribute", "mlir::Attribute()">:$sg_map)>
];

let extraClassDeclaration = [{
using TensorType::clone;
using mlir::ShapedType::Trait<TensorDescType>::getElementTypeBitWidth;
Expand Down Expand Up @@ -176,6 +176,11 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
return scatter_attr.getChunkSize().getInt();
return 1;
}

// This returns a vector type that represents the fragment of data owned by
// a work item in SIMT mode if this tensor descriptor is used in a XeGPU
// load/store operation.
FailureOr<VectorType> getDistributedVectorType();
}];

let hasCustomAssemblyFormat = true;
Expand Down
92 changes: 87 additions & 5 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,13 @@

#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/LogicalResult.h"
#include <cassert>

namespace mlir {
namespace xegpu {
Expand Down Expand Up @@ -276,14 +281,13 @@ LogicalResult TensorDescType::verify(
if (scatterAttr) {
// Validate subgroup mapping rules for scattered tensors.
// A work-item's slice of the tensor with shape [sg_size] or
// [sg_size, chunk_size] will be [1] or [1, chunks_size] respectively,
// the mapping should reflect that.
// [sg_size, chunk_size] will be [1] or [1, 32/element_ty_bit_width]
// respectively, the mapping should reflect that. This is because each
// work item access data in 32 bit granularity.
if (wiData[0] != 1)
return emitError()
<< "cannot map over non-contiguous scattered row elements";

unsigned chunkSize = scatterAttr.getChunkSize().getInt();
if (wiData[1] != chunkSize)
if (wiData[1] != (32 / elementType.getIntOrFloatBitWidth()))
return emitError() << "work item data mapping must match the number of "
"contiguous elements";
}
Expand All @@ -307,6 +311,84 @@ LogicalResult TensorDescType::verify(
return success();
}

// If tensor descriptor has a sg_map attribute it is used in SIMT mode.
// In this mode, the distributed vector shape is determined as follows:
// Definitions:
// wi_data_size = wi_data[0] × wi_data[1]
// subgroup_size = wi_layout[0] × wi_layout[1]
// distribution_unit_size = subgroup_size × wi_data_size
// ---------------------------------------------------------------------
// Case 1: Regular loads/stores.
// ---------------------------------------------------------------------
// Distributed vector shape must be:
// [chunk_size / wi_data_size, wi_data_size]
// If the tensor descriptor shape is 1D, first dimension is ignored (set to 1).
// [wi_data_size]
// ---------------------------------------------------------------------
// Case 2: Block loads/stores
// ---------------------------------------------------------------------
// Additional definitions:
// tensor_size = tensor_desc[0] * .. * tensor_desc[r-1] * array_length
// n_distribution_units = tensor_size / distribution_unit_size
// Given above definitions, the following conditions must be met:
// * tensor_desc[0] % (wi_layout[0] × wi_data[0]) == 0
// * tensor_desc[1] % (wi_layout[1] × wi_data[1]) == 0
// Distributed vector shape must be:
// [n_distribution_units, wi_data_size]
FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
auto sgMap = llvm::dyn_cast_if_present<SGMapAttr>(getSgMap());
// If no sg_map is provided, tensor desc is not used in SIMT mode.
if (!sgMap)
return failure();

SmallVector<int64_t> wiData(sgMap.getWiData());
SmallVector<int64_t> wiLayout(sgMap.getWiLayout());
auto tdescShape = getShape();

auto wiDataSize = 1, sgSize = 1;
for (auto [wiDim, wiDataDim] : llvm::zip_equal(wiLayout, wiData)) {
wiDataSize *= wiDataDim;
sgSize *= wiDim;
}

// Case 1: regular loads/stores
auto scatterAttr = getEncodingAsScatterTensorDescAttr();
if (scatterAttr) {
auto chunkSize = scatterAttr.getChunkSize().getInt();
// Verify if the first dimension of the tensor descriptor shape is
// distributable.
assert(tdescShape[0] % (wiLayout[0]) == 0 &&
"tensor descriptor shape is not distributable");
if (chunkSize > 1)
return VectorType::get({chunkSize / wiDataSize, wiDataSize},
getElementType());
return VectorType::get({wiDataSize}, getElementType());
}

// Case 2: block loads/stores
// Tensor descriptor shape can be 1D. For the 1D case, outer dims of wiData
// and wiLayout must be 1.
if (tdescShape.size() == 1) {
assert((wiData[0] == 1 && wiLayout[0] == 1) &&
"wi_data[0] and wi_layout[0] must be 1 for 1D tensor descriptor");
wiData = {wiData[1]};
wiLayout = {wiLayout[1]};
}
// Check if the tensor descriptor shape is distributable.
int64_t tensorSize = 1;
for (auto [tdescDim, wiDim, wiDataDim] :
llvm::zip_equal(tdescShape, wiLayout, wiData)) {
assert((tdescDim % (wiDim * wiDataDim) == 0) &&
"tensor descriptor shape is not distributable");
tensorSize *= tdescDim;
}
// tensorSize must be adjusted for array_length.
tensorSize *= getArrayLength();

return VectorType::get({tensorSize / (sgSize * wiDataSize), wiDataSize},
getElementType());
}

} // namespace xegpu
} // namespace mlir

Expand Down
Loading
Loading