-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][xegpu] Improve XeGPU op verification logic for SIMT flavor and update tests. #127920
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
f92c151
061ace2
bd2c8be
981c7d3
ddc8cba
2d79647
13edb33
0d42148
8796537
4b5cffb
ab59c46
6210d1f
be1e728
ff24db0
e1c8963
3772d92
37470a8
b877c2c
7834dc9
0426ea2
74dd97a
2137b2f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,8 +8,13 @@ | |
|
|
||
| #include "mlir/Dialect/XeGPU/IR/XeGPU.h" | ||
| #include "mlir/IR/Builders.h" | ||
| #include "mlir/IR/BuiltinTypes.h" | ||
| #include "mlir/IR/DialectImplementation.h" | ||
| #include "llvm/ADT/STLExtras.h" | ||
| #include "llvm/ADT/TypeSwitch.h" | ||
| #include "llvm/Support/Casting.h" | ||
| #include "llvm/Support/LogicalResult.h" | ||
| #include <cassert> | ||
|
|
||
adam-smnk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| namespace mlir { | ||
| namespace xegpu { | ||
|
|
@@ -239,6 +244,7 @@ LogicalResult TensorDescType::verify( | |
| llvm::ArrayRef<int64_t> shape, mlir::Type elementType, | ||
| mlir::Attribute encoding, mlir::Attribute sg_map) { | ||
| size_t rank = shape.size(); | ||
| unsigned packingFactor = 32 / elementType.getIntOrFloatBitWidth(); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It may be worth to make
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good point. I have added a comment on this. |
||
| if (rank != 1 && rank != 2) | ||
| return emitError() << "expected 1D or 2D tensor"; | ||
|
|
||
|
|
@@ -252,6 +258,16 @@ LogicalResult TensorDescType::verify( | |
| return emitError() << "expected non-contiguous elements for 1D tensor"; | ||
| if (rank == 2 && chunkSize < 2) | ||
| return emitError() << "expected chunk blocks for 2D tensor"; | ||
| // If chunk size > 1, the the second dimension of the tensor shape must be | ||
charithaintc marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // equal to chunk size and it must be a multiple of the packing factor. | ||
| if (chunkSize > 1) { | ||
| if (shape.back() != chunkSize) | ||
| return emitError() << "expected tensor shape[1] to match chunk size"; | ||
| if (shape.back() % packingFactor != 0) | ||
| return emitError() | ||
| << "expected tensor shape[1] to be a multiple of packing factor " | ||
| << packingFactor; | ||
| } | ||
| } | ||
|
|
||
| if (auto blockAttr = | ||
|
|
@@ -276,14 +292,13 @@ LogicalResult TensorDescType::verify( | |
| if (scatterAttr) { | ||
| // Validate subgroup mapping rules for scattered tensors. | ||
| // A work-item's slice of the tensor with shape [sg_size] or | ||
| // [sg_size, chunk_size] will be [1] or [1, chunks_size] respectively, | ||
| // the mapping should reflect that. | ||
| // [sg_size, chunk_size] will be [1] or [1, 32/element_ty_bit_width] | ||
| // respectively, the mapping should reflect that. This is because each | ||
| // work item access data in 32 bit granularity. | ||
| if (wiData[0] != 1) | ||
| return emitError() | ||
| << "cannot map over non-contiguous scattered row elements"; | ||
|
|
||
| unsigned chunkSize = scatterAttr.getChunkSize().getInt(); | ||
| if (wiData[1] != chunkSize) | ||
| if (wiData[1] != packingFactor) | ||
adam-smnk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return emitError() << "work item data mapping must match the number of " | ||
| "contiguous elements"; | ||
| } | ||
|
|
@@ -307,6 +322,84 @@ LogicalResult TensorDescType::verify( | |
| return success(); | ||
| } | ||
|
|
||
| // If tensor descriptor has a sg_map attribute it is used in SIMT mode. | ||
| // In this mode, the distributed vector shape is determined as follows: | ||
| // Definitions: | ||
| // wi_data_size = wi_data[0] × wi_data[1] | ||
| // subgroup_size = wi_layout[0] × wi_layout[1] | ||
| // distribution_unit_size = subgroup_size × wi_data_size | ||
| // --------------------------------------------------------------------- | ||
| // Case 1: Regular loads/stores. | ||
| // --------------------------------------------------------------------- | ||
| // Distributed vector shape must be: | ||
| // [chunk_size / wi_data_size, wi_data_size] | ||
| // If the tensor descriptor shape is 1D, first dimension is ignored (set to 1). | ||
| // [wi_data_size] | ||
| // --------------------------------------------------------------------- | ||
| // Case 2: Block loads/stores | ||
| // --------------------------------------------------------------------- | ||
| // Additional definitions: | ||
| // tensor_size = tensor_desc[0] * .. * tensor_desc[r-1] * array_length | ||
| // n_distribution_units = tensor_size / distribution_unit_size | ||
| // Given above definitions, the following conditions must be met: | ||
| // * tensor_desc[0] % (wi_layout[0] × wi_data[0]) == 0 | ||
| // * tensor_desc[1] % (wi_layout[1] × wi_data[1]) == 0 | ||
| // Distributed vector shape must be: | ||
| // [n_distribution_units, wi_data_size] | ||
| FailureOr<VectorType> TensorDescType::getDistributedVectorType() { | ||
| auto sgMap = llvm::dyn_cast_if_present<SGMapAttr>(getSgMap()); | ||
| // If no sg_map is provided, tensor desc is not used in SIMT mode. | ||
| if (!sgMap) | ||
| return failure(); | ||
|
|
||
| SmallVector<int64_t> wiData(sgMap.getWiData()); | ||
| SmallVector<int64_t> wiLayout(sgMap.getWiLayout()); | ||
| auto tdescShape = getShape(); | ||
|
|
||
| auto wiDataSize = 1, sgSize = 1; | ||
| for (auto [wiDim, wiDataDim] : llvm::zip_equal(wiLayout, wiData)) { | ||
| wiDataSize *= wiDataDim; | ||
| sgSize *= wiDim; | ||
| } | ||
|
|
||
| // Case 1: regular loads/stores | ||
| auto scatterAttr = getEncodingAsScatterTensorDescAttr(); | ||
| if (scatterAttr) { | ||
| auto chunkSize = scatterAttr.getChunkSize().getInt(); | ||
| // Verify if the first dimension of the tensor descriptor shape is | ||
| // distributable. | ||
| assert(tdescShape[0] % (wiLayout[0]) == 0 && | ||
| "tensor descriptor shape is not distributable"); | ||
| if (chunkSize > 1) | ||
| return VectorType::get({chunkSize / wiDataSize, wiDataSize}, | ||
| getElementType()); | ||
| return VectorType::get({wiDataSize}, getElementType()); | ||
| } | ||
|
|
||
| // Case 2: block loads/stores | ||
| // Tensor descriptor shape can be 1D. For the 1D case, outer dims of wiData | ||
| // and wiLayout must be 1. | ||
| if (tdescShape.size() == 1) { | ||
| assert((wiData[0] == 1 && wiLayout[0] == 1) && | ||
| "wi_data[0] and wi_layout[0] must be 1 for 1D tensor descriptor"); | ||
| wiData = {wiData[1]}; | ||
| wiLayout = {wiLayout[1]}; | ||
| } | ||
| // Check if the tensor descriptor shape is distributable. | ||
| int64_t tensorSize = 1; | ||
| for (auto [tdescDim, wiDim, wiDataDim] : | ||
| llvm::zip_equal(tdescShape, wiLayout, wiData)) { | ||
| assert((tdescDim % (wiDim * wiDataDim) == 0) && | ||
| "tensor descriptor shape is not distributable"); | ||
| tensorSize *= tdescDim; | ||
| } | ||
| // tensorSize must be adjusted for array_length. | ||
| tensorSize *= getArrayLength(); | ||
|
|
||
| return VectorType::get({tensorSize / (sgSize * wiDataSize), wiDataSize}, | ||
| getElementType()); | ||
| } | ||
|
|
||
| } // namespace xegpu | ||
| } // namespace mlir | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.