Skip to content

Commit 79410a5

Browse files
committed
Fix rebase issues
1 parent 53a73a7 commit 79410a5

File tree

2 files changed

+40
-35
lines changed

2 files changed

+40
-35
lines changed

test/Conversion/intel/sub-group-transpose.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16
305305
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
306306
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
307307

308-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} {
308+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
309309
// CHECK-LABEL: llvm.func spir_kernelcc @test(
310310
// CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>)
311311
tt.func @test(%arg0: tensor<32x16xf32, #blocked>) -> tensor<32x16xf32, #blocked1> {
@@ -348,7 +348,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
348348
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
349349
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
350350

351-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} {
351+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
352352
// CHECK-LABEL: llvm.func spir_kernelcc @test(
353353
// CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>)
354354
tt.func @test(%arg0: tensor<16x32xf32, #blocked>) -> tensor<16x32xf32, #blocked1> {
@@ -391,7 +391,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
391391
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [0, 1]}>
392392
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 2], order = [0, 1]}>
393393

394-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} {
394+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
395395
// CHECK-LABEL: llvm.func spir_kernelcc @test(
396396
// CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>)
397397
tt.func @test(%arg0: tensor<32x64xf32, #blocked>) -> tensor<32x64xf32, #blocked1> {

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -461,30 +461,49 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
461461
assert(conversion && "Expecting valid conversion");
462462
// Expected conversion is:
463463
// - register=1 -> (0, 1)
464-
// register=2 -> (0, 2)
465-
// register=4 -> (0, 4)
466-
// register=8 -> (0, 8)
467-
// register=N -> (N, 0)
468-
// ...
469-
// - lane=1 -> (1, 0)
470-
// lane=2 -> (2, 0)
471-
// lane=4 -> (4, 0)
472-
// lane=8 -> (8, 0)
473-
// where out dims are: [register (size 2*N), lane (size 16)]
474-
std::vector<std::vector<int32_t>> registerBases{
475-
{0, 1}, {0, 2}, {0, 4}, {0, 8}};
464+
// ...
465+
// - register=2**i -> (0, 2**i)
466+
// ...
467+
// - register=M -> (0, 2**M)
468+
// ...
469+
// - register=2**k -> (2**k, 0)
470+
// ...
471+
// - register=N -> (2**N, 0)
472+
// - lane=1 -> (0, 1)
473+
// ...
474+
// - lane=2**j -> (2**j, 0)
475+
// ...
476+
// lane=2**M -> (2**M, 0)
477+
// where out dims are: [register (size 2**(N + 1)), lane (size 2**(M + 1))]
478+
//
479+
// With N >= M.
480+
const auto buildBasis = [&](int32_t size, std::size_t index) {
481+
std::vector<std::vector<int32_t>> basis;
482+
std::vector<int32_t> curr(2);
483+
for (int32_t i = 1; i < size; i *= 2) {
484+
curr[index] = i;
485+
basis.push_back(curr);
486+
}
487+
return basis;
488+
};
489+
constexpr std::size_t laneIndex = 0;
490+
constexpr std::size_t registerIndex = 1;
491+
int32_t laneSize = conversion->getInDimSize(kLane);
492+
std::vector<std::vector<int32_t>> registerBases =
493+
buildBasis(laneSize, registerIndex);
476494
{
477-
// Populate register bases for N > 8.
495+
// Populate register bases for N > M.
478496
std::vector<int32_t> base(2);
479-
for (int32_t i = 16, n = conversion->getInDimSize(kRegister); i < n;
480-
i *= 2) {
481-
base.front() = i;
497+
for (int32_t i = laneSize,
498+
registerSize = conversion->getInDimSize(kRegister);
499+
i < registerSize; i *= 2) {
500+
base[laneIndex] = i;
482501
registerBases.push_back(base);
483502
}
484503
}
485504
std::array<std::pair<StringAttr, std::vector<std::vector<int32_t>>>, 2>
486505
bases{{{kRegister, std::move(registerBases)},
487-
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}}}}};
506+
{kLane, buildBasis(laneSize, laneIndex)}}};
488507
std::array<StringAttr, 2> outDimNames{kRegister, kLane};
489508
return conversion == LinearLayout(bases, outDimNames);
490509
}
@@ -853,18 +872,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
853872
llvm::transform(
854873
outVals, std::begin(outVals),
855874
[&](Value val) -> Value { return inttoptr(ptrTy, val); });
856-
})As a follow up to #2266, extend work in #2531 to detect more complex broadcast shuffles.
857-
858-
Cases with more than 1 warp in the "sliced" dimension are problematic here, e.g.:
859-
860-
```mlir
861-
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 16, 1, 1, 1, 1], order = [3, 4, 5, 6, 0, 1, 2]}>
862-
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16, 1], threadsPerWarp = [16, 1, 1, 1, 1], warpsPerCTA = [1, 1, 16, 1, 1], order = [3, 4, 0, 1, 2]}>
863-
// ...
864-
triton_gpu.convert_layout %arg : tensor<16x1x16x16x1xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #blocked}>}>> -> tensor<16x1x16x16x1xf32, #blocked1>
865-
```
866-
867-
Is lowered to a shufle via
875+
})
868876
.Default([](auto) { llvm_unreachable("Unsupported type"); });
869877

870878
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
@@ -967,9 +975,6 @@ Is lowered to a shufle via
967975
// TODO(jlebar): Implement me.
968976
return failure();
969977
}
970-
971-
private:
972-
const triton::intel::TargetInfo &targetInfo;
973978
};
974979

975980
} // namespace

0 commit comments

Comments
 (0)