Skip to content

Commit a86e5a0

Browse files
Merge commit 'a19f32454271ff9565ab957834bdf1e5d4ddce57'
2 parents 64b232e + a19f324 commit a86e5a0

File tree

5 files changed

+40
-12
lines changed

5 files changed

+40
-12
lines changed

bin/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ target_link_libraries(triton-reduce PRIVATE
5555
mlir_check_all_link_libraries(triton-reduce)
5656

5757
add_llvm_executable(triton-lsp triton-lsp.cpp PARTIAL_SOURCES_INTENDED)
58-
mlir_check_all_link_libraries(triton-lsp)
5958

6059
llvm_update_compile_flags(triton-lsp)
6160
target_link_libraries(triton-lsp PRIVATE

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,7 @@ def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
710710
//
711711
def TT_ReduceOp: TT_Op<"reduce",
712712
[Pure,
713+
SameOperandsShape,
713714
SameOperandsEncoding,
714715
SingleBlock,
715716
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {

lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -538,17 +538,14 @@ scheduleLoads(scf::ForOp forOp, tt::CoarseSchedule &schedule,
538538
if (loadOpToIndLevelAndUse.empty())
539539
return {};
540540

541-
for (auto iter = loadOpToIndLevelAndUse.begin();
542-
iter != loadOpToIndLevelAndUse.end();) {
543-
auto iterNext = iter + 1;
544-
if (std::get<1>(*iter) >= numStages - 1)
545-
// We assume loads with different dist are assigned to different stages.
546-
// If numStages is 2, we will have no stage available for indirect loads
547-
// with dist >= 1. In general, when dist is equal to numStages - 1, we
548-
// should not pipeline it.
549-
loadOpToIndLevelAndUse.erase(iter);
550-
iter = iterNext;
551-
}
541+
// We assume loads with different dist are assigned to different stages.
542+
// If numStages is 2, we will have no stage available for indirect loads
543+
// with dist >= 1. In general, when dist is equal to numStages - 1, we
544+
// should not pipeline it.
545+
auto it = llvm::remove_if(loadOpToIndLevelAndUse, [=](auto op) {
546+
return std::get<1>(op) >= numStages - 1;
547+
});
548+
loadOpToIndLevelAndUse.erase(it, loadOpToIndLevelAndUse.end());
552549

553550
// Check which loads are good for pipelining, and assign them
554551
// memory layouts.

lib/Tools/LinearLayout.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,24 @@
1616
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
1717
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
1818

19+
#if defined(_MSC_VER) && !defined(__clang__)
20+
// from https://gist.github.com/pps83/3210a2f980fd02bb2ba2e5a1fc4a2ef0
21+
#include <intrin.h>
22+
23+
static int __builtin_ctz(unsigned x) {
24+
unsigned long r;
25+
_BitScanForward(&r, x);
26+
return static_cast<int>(r);
27+
}
28+
29+
static int __builtin_ctzll(unsigned long long x) {
30+
unsigned long r;
31+
_BitScanForward64(&r, x);
32+
return static_cast<int>(r);
33+
}
34+
35+
#endif
36+
1937
namespace mlir::triton {
2038

2139
namespace {

test/Triton/invalid.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,19 @@ tt.func public @fn(%v: tensor<4x128xf64>) {
108108

109109
// -----
110110

111+
tt.func @reduce_different_input_shapes(%arg0: tensor<32x32x64xf32>, %arg1: tensor<16x32x64xf32>) -> (tensor<32x64xf32>, tensor<16x64xf32>) {
112+
// expected-error @below {{op requires the same shape for all operands}}
113+
%0:2 = "tt.reduce" (%arg0, %arg1) <{axis = 1 : i32}> ({
114+
^bb0(%acc0: f32, %acc1: f32, %cur0: f32, %cur1: f32):
115+
%1 = arith.addf %acc0, %cur0 : f32
116+
%2 = arith.addf %acc1, %cur1 : f32
117+
tt.reduce.return %1, %2 : f32, f32
118+
}) : (tensor<32x32x64xf32>, tensor<16x32x64xf32>) -> (tensor<32x64xf32>, tensor<16x64xf32>)
119+
tt.return %0#0, %0#1 : tensor<32x64xf32>, tensor<16x64xf32>
120+
}
121+
122+
// -----
123+
111124
tt.func public @fn(%v: tensor<4x128xf32>) {
112125
// expected-error @+1 {{requires the same shape}}
113126
%a = "tt.scan" (%v) ({

0 commit comments

Comments
 (0)