-
Notifications
You must be signed in to change notification settings - Fork 76
[TritonIntelGPUToLLVM] Detect more sub-group shuffle convert_layout
#2573
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
e580308
fccfb5a
a7277b1
b41ea82
c749f19
9f6975c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -446,6 +446,62 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion | |
| : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { | ||
| } | ||
|
|
||
| // Return a vector such as: | ||
| // [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [laneSize, 0], ..., | ||
| // [registerSize / 2, 0]], | ||
| // i.e., mapping registers to lanes till laneSize and performing an ID | ||
| // conversion afterwards. | ||
| static std::vector<std::vector<int32_t>> | ||
| buildSubGroupTransposeRegisterBases(int32_t registerSize, int32_t laneSize) { | ||
| std::vector<std::vector<int32_t>> bases; | ||
| std::vector<int32_t> curr(2); | ||
| for (int32_t i = 1; i < laneSize; i *= 2) { | ||
| curr[1] = i; | ||
| bases.push_back(curr); | ||
| } | ||
| curr[1] = 0; | ||
| for (int32_t i = laneSize; i < registerSize; i *= 2) { | ||
| curr[0] = i; | ||
| bases.push_back(curr); | ||
| } | ||
| return bases; | ||
| } | ||
|
|
||
| // Return a vector such as: | ||
| // [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [1, 0], ..., | ||
| // [registerSize / (2 * laneSize), 0]] | ||
| // i.e., mapping registers to lanes till laneSize and repeating the pattern | ||
| // afterwards. | ||
| static std::vector<std::vector<int32_t>> | ||
| buildSubGroupShuffleRegisterBases(int32_t registerSize, int32_t laneSize) { | ||
| std::vector<std::vector<int32_t>> bases; | ||
| std::vector<int32_t> curr(2); | ||
| for (int32_t i = 1; i < laneSize; i *= 2) { | ||
| curr[1] = i; | ||
| bases.push_back(curr); | ||
| } | ||
| curr[1] = 0; | ||
| for (int32_t i = laneSize, val = 1; i < registerSize; i *= 2, val *= 2) { | ||
| curr[0] = val; | ||
| bases.push_back(curr); | ||
| } | ||
| return bases; | ||
| } | ||
|
|
||
| // Return a vector such as: | ||
| // [[1, 0], [2, 0], [4, 0], ..., [laneSize / 2, 0]], | ||
| // i.e., mapping lanes to registers. | ||
| static std::vector<std::vector<int32_t>> | ||
| buildSubGroupTransposeLaneBases(int32_t laneSize) { | ||
| std::vector<std::vector<int32_t>> bases; | ||
| std::vector<int32_t> curr(2); | ||
| for (int32_t i = 1; i < laneSize; i *= 2) { | ||
| curr[0] = i; | ||
| bases.push_back(curr); | ||
| } | ||
| return bases; | ||
| } | ||
|
|
||
| bool isSubGroupTranspose(const LinearLayout &srcLayout, | ||
| const LinearLayout &dstLayout) const { | ||
| MLIRContext *ctx = srcLayout.getInDimNames().begin()->getContext(); | ||
|
|
@@ -476,35 +532,13 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion | |
| // where out dims are: [register (size 2**(N + 1)), lane (size 2**(M + 1))] | ||
| // | ||
| // With N >= M. | ||
| const auto buildBasis = [&](int32_t size, std::size_t index) { | ||
| std::vector<std::vector<int32_t>> basis; | ||
| std::vector<int32_t> curr(2); | ||
| for (int32_t i = 1; i < size; i *= 2) { | ||
| curr[index] = i; | ||
| basis.push_back(curr); | ||
| } | ||
| return basis; | ||
| }; | ||
| constexpr std::size_t laneIndex = 0; | ||
| constexpr std::size_t registerIndex = 1; | ||
| int32_t laneSize = conversion->getInDimSize(kLane); | ||
| std::vector<std::vector<int32_t>> registerBases = | ||
| buildBasis(laneSize, registerIndex); | ||
| { | ||
| // Populate register bases for N > M. | ||
| std::vector<int32_t> base(2); | ||
| for (int32_t i = laneSize, | ||
| registerSize = conversion->getInDimSize(kRegister); | ||
| i < registerSize; i *= 2) { | ||
| base[laneIndex] = i; | ||
| registerBases.push_back(base); | ||
| } | ||
| } | ||
| std::array<std::pair<StringAttr, std::vector<std::vector<int32_t>>>, 2> | ||
| bases{{{kRegister, std::move(registerBases)}, | ||
| {kLane, buildBasis(laneSize, laneIndex)}}}; | ||
| std::array<StringAttr, 2> outDimNames{kRegister, kLane}; | ||
| return conversion == LinearLayout(bases, outDimNames); | ||
| int32_t registerInDimSize = conversion->getInDimSize(kRegister); | ||
| int32_t laneInDimSize = conversion->getInDimSize(kLane); | ||
| return conversion->getBases().lookup(kRegister) == | ||
| buildSubGroupTransposeRegisterBases(registerInDimSize, | ||
| laneInDimSize) && | ||
| conversion->getBases().lookup(kLane) == | ||
| buildSubGroupTransposeLaneBases(laneInDimSize); | ||
|
Comment on lines
+535
to
+541
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Refactor and protected against some illegal layout creation. |
||
| } | ||
|
|
||
| LogicalResult | ||
|
|
@@ -619,32 +653,27 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion | |
| // Expected conversion is: | ||
| // - register=1 -> (0, 1) | ||
| // ... | ||
| // register=i -> (0, i) | ||
| // - register=2**i -> (0, 2**i) | ||
| // ... | ||
| // - register=M -> (0, 2**M) | ||
| // ... | ||
| // - register=2**k -> (2**(k-M), 0) | ||
| // ... | ||
| // register=N -> (0, N) | ||
| // - register=2**N -> (2**(N-M), 0) | ||
| // - lane=1 -> (0, 0) | ||
| // ... | ||
| // lane=i -> (0, 0) | ||
| // - lane=2**j -> (0, 0) | ||
| // ... | ||
| // lane=N -> (0, 0) | ||
| // where out dims are: [register (size 1), lane (size N)] | ||
| std::vector<std::vector<int32_t>> registerBases; | ||
| { | ||
| constexpr std::size_t registerIndex = 1; | ||
| std::vector<int32_t> base(2); | ||
| for (int32_t i = 1, n = conversion->getInDimSize(kLane); i < n; i *= 2) { | ||
| base[registerIndex] = i; | ||
| registerBases.push_back(base); | ||
| } | ||
| } | ||
|
|
||
| std::vector<std::vector<int32_t>> laneBases( | ||
| conversion->getInDimSizeLog2(kLane), std::vector<int32_t>{0, 0}); | ||
| std::array<std::pair<StringAttr, std::vector<std::vector<int32_t>>>, 2> | ||
| bases{{{kRegister, std::move(registerBases)}, | ||
| {kLane, std::move(laneBases)}}}; | ||
| std::array<StringAttr, 2> outDimNames{kRegister, kLane}; | ||
| return conversion == LinearLayout(bases, outDimNames); | ||
| // lane=2**M -> (0, 0) | ||
| // where out dims are: [register (size 2**(N - M)), lane (size 2**(M + 1))] | ||
| // | ||
| // With N >= M. | ||
| int32_t registerInDimSize = conversion->getInDimSize(kRegister); | ||
| int32_t laneOutDimSize = conversion->getOutDimSize(kLane); | ||
| return conversion->sublayoutIsZero({kLane}, {kRegister, kLane}) && | ||
| conversion->getBases().lookup(kRegister) == | ||
| buildSubGroupShuffleRegisterBases(registerInDimSize, | ||
| laneOutDimSize); | ||
|
Comment on lines
+674
to
+676
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will investigate whether there is a better way to express this in the future when generalizing. Same for transpose case. |
||
| } | ||
|
|
||
| bool isSupportedSubGroupShuffle(ConvertLayoutOp, OpAdaptor) const { | ||
|
|
@@ -674,7 +703,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion | |
|
|
||
| SmallVector<Value> inVals = | ||
| unpackLLElements(loc, adaptor.getSrc(), rewriter); | ||
| assert(inVals.size() == 1 && "Expecting single element"); | ||
|
|
||
| // TODO: Drop 'BFloat16Type' and 'IntegerType' cases when supported at MLIR | ||
| // upstream level. We are not enabling support for all types here as that | ||
|
|
@@ -703,7 +731,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion | |
| }); | ||
|
|
||
| SmallVector<Value> outVals = | ||
| performSubGroupShuffle(loc, inVals.front(), subGroupSize, rewriter); | ||
| performSubGroupShuffle(loc, inVals, subGroupSize, rewriter); | ||
|
|
||
| // TODO: Drop 'BFloat16Type' and 'IntegerType' cases when supported at MLIR | ||
| // upstream level. We are not enabling support for all types here as that | ||
|
|
@@ -734,16 +762,19 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion | |
| } | ||
|
|
||
| SmallVector<Value> | ||
| performSubGroupShuffle(Location loc, Value val, int32_t subGroupSize, | ||
| performSubGroupShuffle(Location loc, ArrayRef<Value> inVals, | ||
| int32_t subGroupSize, | ||
| ConversionPatternRewriter &rewriter) const { | ||
| SmallVector<Value> res; | ||
| Value width = i32_val(subGroupSize); | ||
| for (int32_t i = 0; i < subGroupSize; ++i) | ||
| res.push_back( | ||
| rewriter | ||
| .create<mlir::gpu::ShuffleOp>(loc, val, i32_val(i), width, | ||
| mlir::gpu::ShuffleMode::IDX) | ||
| .getShuffleResult()); | ||
| for (Value val : inVals) { | ||
| for (int32_t i = 0; i < subGroupSize; ++i) | ||
| res.push_back( | ||
| rewriter | ||
| .create<mlir::gpu::ShuffleOp>(loc, val, i32_val(i), width, | ||
| mlir::gpu::ShuffleMode::IDX) | ||
| .getShuffleResult()); | ||
| } | ||
| return res; | ||
| } | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.