-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[MLIR][XeGPU] Switch to 1D representation for SIMT code #135116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
2a1d373
2159119
775d039
5520ce1
7072bc1
605c99e
67edbab
fb2506c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |
| #include "mlir/IR/Builders.h" | ||
| #include "mlir/IR/DialectImplementation.h" | ||
| #include "llvm/ADT/TypeSwitch.h" | ||
| #include <numeric> | ||
|
|
||
| namespace mlir { | ||
| namespace xegpu { | ||
|
|
@@ -336,32 +337,30 @@ LogicalResult TensorDescType::verify( | |
| // [n_distribution_units, lane_data_size] | ||
| FailureOr<VectorType> TensorDescType::getDistributedVectorType() { | ||
| auto layout = llvm::dyn_cast_if_present<LayoutAttr>(getLayout()); | ||
| // If no layout is provided, tensor desc is not used in SIMT mode. | ||
| if (!layout) | ||
| // It only works for subgroup level layout, which only has lane_layout | ||
| // and lane_data, and is to distribute a SIMD code into SIMT code. | ||
| if (!layout || !layout.isSgLayout()) | ||
| return failure(); | ||
|
|
||
| SmallVector<int64_t> laneData(layout.getLaneData().asArrayRef()); | ||
| SmallVector<int64_t> laneLayout(layout.getLaneLayout().asArrayRef()); | ||
| auto tdescShape = getShape(); | ||
|
|
||
| auto laneDataSize = 1, sgSize = 1; | ||
| for (auto [laneDim, laneDataDim] : llvm::zip_equal(laneLayout, laneData)) { | ||
| laneDataSize *= laneDataDim; | ||
| sgSize *= laneDim; | ||
| } | ||
| // compute sgSize by multiply elements of laneLayout | ||
| // e.g. for 2D layout, sgSize = laneLayout[0] * laneLayout[1] | ||
| // e.g. for 1D layout, sgSize = laneLayout[0] | ||
| auto sgSize = std::accumulate(laneLayout.begin(), laneLayout.end(), 1, | ||
| std::multiplies<int64_t>()); | ||
|
|
||
| // Case 1: regular loads/stores | ||
| auto scatterAttr = getEncodingAsScatterTensorDescAttr(); | ||
| if (scatterAttr) { | ||
| auto chunkSize = scatterAttr.getChunkSize().getInt(); | ||
| // Verify if the first dimension of the tensor descriptor shape is | ||
| // distributable. | ||
| assert(tdescShape[0] % (laneLayout[0]) == 0 && | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not very clear why this change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is a small issue after confirming with Charitha. tdescShape[0] has to be equal to laneLayout[0], such that a SIMD instruction is dispatched into a SIMT instruction. if tdescShape[0] % lane_layout[0] == 0, it will imply a SIMD instruction could be dispatched into multiple SIMT instructions, which is actually part of logic of blocking. |
||
| assert(tdescShape[0] == laneLayout[0] && | ||
| "tensor descriptor shape is not distributable"); | ||
| if (chunkSize > 1) | ||
| return VectorType::get({chunkSize / laneDataSize, laneDataSize}, | ||
| getElementType()); | ||
| return VectorType::get({laneDataSize}, getElementType()); | ||
| return VectorType::get({chunkSize}, getElementType()); | ||
| } | ||
|
|
||
| // Case 2: block loads/stores | ||
|
|
@@ -376,8 +375,7 @@ FailureOr<VectorType> TensorDescType::getDistributedVectorType() { | |
| // tensorSize must be adjusted for array_length. | ||
| tensorSize *= getArrayLength(); | ||
|
|
||
| return VectorType::get({tensorSize / (sgSize * laneDataSize), laneDataSize}, | ||
| getElementType()); | ||
| return VectorType::get({tensorSize / sgSize}, getElementType()); | ||
| } | ||
|
|
||
| } // namespace xegpu | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.