Skip to content

Commit 541ff05

Browse files
Merge commit '89c0b0abdfac05804b2cbcfb393c5efdb368b70b'
2 parents 7d0818a + 89c0b0a commit 541ff05

File tree

7 files changed

+116
-32
lines changed

7 files changed

+116
-32
lines changed

lib/Instrumentation/CMakeLists.txt

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,21 @@ foreach( plugin ${GPU_INSTRUMENTATION_PASSES} )
2222
LLVMTransformUtils
2323
"$<$<PLATFORM_ID:Darwin>:-undefined dynamic_lookup>"
2424
)
25-
# CMAKE_LIBRARY_OUTPUT_DIRECTORY is only set during the Python
26-
# build. It is empty if building directly from the root
27-
# CMakeLists.txt file. Therefore if not building from Python just
28-
# use the default CMake shared lib path otherwise this causes a hard
29-
# build error
30-
if(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY)
31-
set_target_properties(${plugin} PROPERTIES
32-
LIBRARY_OUTPUT_DIRECTORY
33-
"${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/../instrumentation")
34-
endif(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY)
25+
# CMAKE_LIBRARY_OUTPUT_DIRECTORY is only set during the Python
26+
# build. It is empty if building directly from the root
27+
# CMakeLists.txt file. Therefore if not building from Python just
28+
# use the default CMake shared lib path otherwise this causes a hard
29+
# build error
30+
if(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY)
31+
set_target_properties(${plugin} PROPERTIES
32+
LIBRARY_OUTPUT_DIRECTORY
33+
"${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/../instrumentation")
34+
endif(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY)
3535

36-
# This is set to -fvisibility=hidden in the top level CMake file
37-
# which causes the llvmGetPassPluginInfo symbol to be hidden and
38-
# an "entry point not found" error. Reset it just for this target
39-
if(NOT MSVC)
40-
target_compile_options(${plugin} PRIVATE -fvisibility=default -fno-rtti)
41-
endif()
36+
# This is set to -fvisibility=hidden in the top level CMake file
37+
# which causes the llvmGetPassPluginInfo symbol to be hidden and
38+
# an "entry point not found" error. Reset it just for this target
39+
if(NOT MSVC)
40+
target_compile_options(${plugin} PRIVATE -fvisibility=default -fno-rtti)
41+
endif()
4242
endforeach()

python/triton/language/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,8 +1647,10 @@ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_i
16471647
def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, out_dtype=float32, _builder=None):
16481648
"""
16491649
Returns the matrix product of two blocks in microscaling format.
1650+
16501651
lhs and rhs use microscaling formats described here:
16511652
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
1653+
16521654
:param lhs: The first tensor to be multiplied.
16531655
:type lhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type.
16541656
:param lhs_scale: Scale factor for lhs tensor.
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: triton-opt %s --convert-triton-amdgpu-to-llvm='arch=gfx942' | FileCheck %s
2+
3+
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} {
4+
tt.func @conditional_barrier() {
5+
// CHECK-LABEL: llvm.func @conditional_barrier
6+
7+
// CHECK: %[[CMP0:.+]] = llvm.icmp "ne" %3, %1 : i32
8+
// CHECK: %[[CMP1:.+]] = llvm.icmp "eq" %3, %1 : i32
9+
// CHECK: llvm.cond_br %[[CMP0]], ^bb1, ^bb2
10+
// CHECK: ^bb1:
11+
// CHECK: rocdl.s.barrier
12+
// CHECK: llvm.br ^bb2
13+
// CHECK: ^bb2:
14+
// CHECK: llvm.add
15+
// CHECK: llvm.cond_br %[[CMP1]], ^bb3, ^bb4
16+
// CHECK: ^bb3:
17+
// CHECK: rocdl.s.barrier
18+
// CHECK: llvm.br ^bb4
19+
// CHECK: ^bb4:
20+
// CHECK: llvm.return
21+
22+
%c256_i32 = arith.constant 256 : i32
23+
%c0_i32 = arith.constant 0 : i32
24+
%0 = rocdl.workitem.id.x : i32
25+
%1 = arith.divsi %0, %c256_i32 : i32
26+
%2 = arith.cmpi ne, %1, %c0_i32 : i32
27+
%3 = arith.cmpi eq, %1, %c0_i32 : i32
28+
amdgpu.cond_barrier %2
29+
%4 = arith.addi %0, %c256_i32 : i32
30+
amdgpu.cond_barrier %3
31+
tt.return
32+
}
33+
}

test/lib/Instrumentation/CMakeLists.txt

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,21 @@ foreach( plugin ${GPU_INSTRUMENTATION_PASSES} )
2020
LLVMCore
2121
"$<$<PLATFORM_ID:Darwin>:-undefined dynamic_lookup>"
2222
)
23-
# CMAKE_LIBRARY_OUTPUT_DIRECTORY is only set during the Python
24-
# build. It is empty if building directly from the root
25-
# CMakeLists.txt file. Therefore if not building from Python just
26-
# use the default CMake shared lib path otherwise this causes a hard
27-
# build error
28-
if(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY)
29-
set_target_properties(${plugin} PROPERTIES
30-
LIBRARY_OUTPUT_DIRECTORY
31-
"${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/../instrumentation")
32-
endif(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY)
23+
# CMAKE_LIBRARY_OUTPUT_DIRECTORY is only set during the Python
24+
# build. It is empty if building directly from the root
25+
# CMakeLists.txt file. Therefore if not building from Python just
26+
# use the default CMake shared lib path otherwise this causes a hard
27+
# build error
28+
if(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY)
29+
set_target_properties(${plugin} PROPERTIES
30+
LIBRARY_OUTPUT_DIRECTORY
31+
"${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/../instrumentation")
32+
endif(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY)
3333

34-
# This is set to -fvisibility=hidden in the top level CMake file
35-
# which causes the llvmGetPassPluginInfo symbol to be hidden and
36-
# an "entry point not found" error. Reset it just for this target
37-
if(NOT MSVC)
38-
target_compile_options(${plugin} PRIVATE -fvisibility=default)
39-
endif()
34+
# This is set to -fvisibility=hidden in the top level CMake file
35+
# which causes the llvmGetPassPluginInfo symbol to be hidden and
36+
# an "entry point not found" error. Reset it just for this target
37+
if(NOT MSVC)
38+
target_compile_options(${plugin} PRIVATE -fvisibility=default)
39+
endif()
4040
endforeach()

third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,23 @@ def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> {
152152
let assemblyFormat = [{ attr-dict }];
153153
}
154154

155+
def CondBarrierOp : TT_AMDGPU_Op<"cond_barrier">,
156+
Arguments<(ins I1:$pred)> {
157+
let summary = "Conditionally set barriers to synchronize partial threads in a block";
158+
159+
let description = [{
160+
condBarrierOp sets barrier instruction only when the given argument is true.
161+
This provides a way to synchronize partial threads in a block, deliberately
162+
diverges the execution sequences. However, user should guarantee all threads
163+
converge at the end by calling condBarrierOp(true) with the remaining threads.
164+
Conceptually, this is similar to having an execution barrier inside an if statement.
165+
This op allows us to avoid blocking the whole block when suitable to help scheduling.
166+
NB. This doesn't set any memory fence.
167+
}];
168+
169+
let assemblyFormat = "$pred attr-dict";
170+
}
171+
155172
//
156173
// AMD Buffer operations.
157174
//

third_party/amd/lib/TritonAMDGPUToLLVM/SPMDOpToLLVM.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
#include "Dialect/TritonAMDGPU/IR/Dialect.h"
12
#include "PatternTritonGPUOpToLLVM.h"
23
#include "Utility.h"
4+
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
35

46
using namespace mlir;
57

@@ -25,10 +27,37 @@ struct GetNumProgramsOpConversion
2527
}
2628
};
2729

30+
struct CondBarrierOpConversion
31+
: public ConvertOpToLLVMPattern<triton::amdgpu::CondBarrierOp> {
32+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
33+
34+
LogicalResult
35+
matchAndRewrite(triton::amdgpu::CondBarrierOp op, OpAdaptor adaptor,
36+
ConversionPatternRewriter &rewriter) const override {
37+
Location loc = op->getLoc();
38+
Block *currentBlock = rewriter.getInsertionBlock();
39+
Block *afterCondBarBlock =
40+
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
41+
Block *trueBlock = rewriter.createBlock(afterCondBarBlock);
42+
rewriter.setInsertionPointToEnd(currentBlock);
43+
rewriter.create<LLVM::CondBrOp>(loc, adaptor.getPred(), trueBlock,
44+
afterCondBarBlock);
45+
46+
// conditional barrier
47+
rewriter.setInsertionPointToStart(trueBlock);
48+
rewriter.create<ROCDL::SBarrierOp>(loc);
49+
rewriter.create<LLVM::BrOp>(loc, afterCondBarBlock);
50+
51+
rewriter.eraseOp(op);
52+
return success();
53+
}
54+
};
55+
2856
} // namespace
2957

3058
void mlir::triton::AMD::populateSPMDOpToLLVMPattern(
3159
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
3260
PatternBenefit benefit) {
3361
patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit);
62+
patterns.add<CondBarrierOpConversion>(typeConverter, benefit);
3463
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,9 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter,
516516

517517
const auto &mmaInstructions =
518518
isTuring ? mmaInstrPtxTuring : mmaInstrPtxAmpere;
519+
if (mmaInstructions.find(mmaType) == mmaInstructions.end()) {
520+
return emitError(loc, "Unsupported MMA instruction for the given mma type");
521+
}
519522
auto rank = dTensorTy.getRank();
520523
auto elemsPerThread = triton::gpu::getElemsPerThread(dTensorTy);
521524
auto batchOffset =

0 commit comments

Comments
 (0)