Skip to content

Commit b184e9c

Browse files
authored
Merge branch 'main' into triton-gen-sub-group-block-memaccess
2 parents a0ea9f6 + 413e738 commit b184e9c

File tree

46 files changed

+831
-787
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+831
-787
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
name: Build on Windows
2+
3+
on:
4+
workflow_dispatch:
5+
6+
pull_request:
7+
branches:
8+
- main
9+
push:
10+
branches:
11+
- main
12+
13+
permissions: read-all
14+
15+
env:
16+
NEW_WORKSPACE: C:\gh${{ github.run_id }}
17+
18+
jobs:
19+
build:
20+
name: Build
21+
runs-on: avc336
22+
steps:
23+
- name: Enable long paths
24+
run: |
25+
git config --system core.longPaths true
26+
27+
- name: Checkout repository
28+
uses: actions/checkout@v4
29+
30+
- name: Install Python
31+
uses: actions/setup-python@v5
32+
with:
33+
python-version: '3.9'
34+
35+
# Copy workspace to a temporary location with a shorter name.
36+
- name: Copy workspace
37+
run: |
38+
Copy-Item -Path ${{ github.workspace }} -Destination ${{ env.NEW_WORKSPACE }} -Recurse
39+
40+
# We need ninja >= 1.12.0 to support long names on Windows. At the moment there is no required
41+
# version in pypi, so instead of installing ninja with pip we use a preinstalled 1.12.1 on the
42+
# runner.
43+
- name: Build Triton
44+
run: |
45+
cd ${{ env.NEW_WORKSPACE }}
46+
cd python
47+
pip install -U wheel pybind11 certifi cython cmake
48+
python -m certifi
49+
pip install --no-build-isolation '.[build]'
50+
51+
- name: Clean
52+
if: ${{ always() }}
53+
run: |
54+
Remove-Item -LiteralPath ${{ env.NEW_WORKSPACE }} -Force -Recurse -ErrorAction Ignore

CMakeLists.txt

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,24 @@ endif()
88

99
include(ExternalProject)
1010

11-
set(CMAKE_CXX_STANDARD 17)
12-
1311
set(CMAKE_INCLUDE_CURRENT_DIR ON)
1412

1513
project(triton CXX)
1614
include(CTest)
1715

18-
if(NOT WIN32)
19-
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
20-
endif()
21-
22-
16+
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
2317

2418
# Options
19+
if(WIN32)
20+
set(DEFAULT_BUILD_PROTON OFF)
21+
else()
22+
set(DEFAULT_BUILD_PROTON ON)
23+
endif()
24+
25+
# Define the option with the determined default value
26+
option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ${DEFAULT_BUILD_PROTON})
2527
option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
2628
option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
27-
option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ON)
2829
option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" ON)
2930
option(TRITON_BUILD_WITH_CCACHE "Build with ccache (if available)" ON)
3031
set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends")
@@ -49,10 +50,21 @@ endif()
4950
# used conditionally in this file and by lit tests
5051

5152
# Customized release build type with assertions: TritonRelBuildWithAsserts
52-
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
53-
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
54-
set(CMAKE_C_FLAGS_TRITONBUILDWITHO1 "-O1")
55-
set(CMAKE_CXX_FLAGS_TRITONBUILDWITHO1 "-O1")
53+
if(NOT MSVC)
54+
set(CMAKE_CXX_STANDARD 17)
55+
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
56+
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
57+
set(CMAKE_C_FLAGS_TRITONBUILDWITHO1 "-O1")
58+
set(CMAKE_CXX_FLAGS_TRITONBUILDWITHO1 "-O1")
59+
else()
60+
set(CMAKE_CXX_STANDARD 20)
61+
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /Ob0 /Od /RTC1 /bigobj /Zc:preprocessor")
62+
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /Ob0 /Od /RTC1 /bigobj /Zc:preprocessor")
63+
set(CMAKE_EXE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
64+
set(CMAKE_MODULE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
65+
set(CMAKE_SHARED_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
66+
set(CMAKE_STATIC_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
67+
endif()
5668

5769
# Default build type
5870
if(NOT CMAKE_BUILD_TYPE)
@@ -70,7 +82,15 @@ endif()
7082

7183
# Compiler flags
7284
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
73-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17")
85+
if(NOT MSVC)
86+
if(NOT WIN32)
87+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fPIC")
88+
else()
89+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -Wno-deprecated")
90+
endif()
91+
else()
92+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS /wd4244 /wd4624 /wd4715 /wd4530")
93+
endif()
7494

7595

7696
# #########
@@ -124,7 +144,11 @@ endfunction()
124144

125145

126146
# Disable warnings that show up in external code (gtest;pybind11)
127-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden")
147+
if(NOT MSVC)
148+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden")
149+
else()
150+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /WX-")
151+
endif()
128152

129153
include_directories(".")
130154
include_directories(${MLIR_INCLUDE_DIRS})
@@ -134,7 +158,8 @@ include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files
134158
include_directories(${PROJECT_SOURCE_DIR}/third_party)
135159
include_directories(${PROJECT_BINARY_DIR}/third_party) # Tablegen'd files
136160

137-
# link_directories(${LLVM_LIBRARY_DIR})
161+
link_directories(${LLVM_LIBRARY_DIR})
162+
138163
add_subdirectory(include)
139164
add_subdirectory(lib)
140165

@@ -163,6 +188,8 @@ if(TRITON_BUILD_PYTHON_MODULE)
163188
# using pip install.
164189
include_directories(${PYTHON_INCLUDE_DIRS})
165190
include_directories(${PYBIND11_INCLUDE_DIR})
191+
message(STATUS "PYTHON_LIB_DIRS ${PYTHON_LIB_DIRS}")
192+
link_directories(${PYTHON_LIB_DIRS})
166193
else()
167194
# Otherwise, we might be building from top CMakeLists.txt directly.
168195
# Try to find Python and pybind11 packages.
@@ -245,7 +272,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
245272
LLVMAArch64CodeGen
246273
LLVMAArch64AsmParser
247274
)
248-
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64")
275+
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64" OR CMAKE_SYSTEM_PROCESSOR MATCHES "AMD64")
249276
list(APPEND TRITON_LIBRARIES
250277
LLVMX86CodeGen
251278
LLVMX86AsmParser
@@ -280,6 +307,8 @@ if(TRITON_BUILD_PYTHON_MODULE)
280307
target_link_libraries(triton PUBLIC ${TRITON_LIBRARIES})
281308
if(WIN32)
282309
target_link_libraries(triton PRIVATE ${CMAKE_DL_LIBS})
310+
set_target_properties(triton PROPERTIES SUFFIX ".pyd")
311+
set_target_properties(triton PROPERTIES PREFIX "lib")
283312
else()
284313
target_link_libraries(triton PRIVATE z)
285314
endif()
@@ -306,6 +335,10 @@ if(NOT TRITON_BUILD_PYTHON_MODULE)
306335
add_subdirectory(third_party/${CODEGEN_BACKEND})
307336
endforeach()
308337
endif()
338+
if(WIN32)
339+
option(CMAKE_USE_WIN32_THREADS_INIT "using WIN32 threads" ON)
340+
option(gtest_disable_pthreads "Disable uses of pthreads in gtest." ON)
341+
endif()
309342

310343
add_subdirectory(third_party/f2reduce)
311344
add_subdirectory(bin)

include/triton/Analysis/Allocation.h

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ namespace mlir {
1818
namespace triton {
1919
class AllocationAnalysis;
2020

21+
/// Callback to allow backends to specify target-specific scratch sizes for
22+
/// some operations.
23+
using AllocationAnalysisScratchSizeFn = std::function<unsigned(Operation *)>;
24+
25+
unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op);
26+
2127
// To convert a tensor from one layout to another, we need to allocate a
2228
// temporary buffer (i.e., scratch buffer) in shared memory. The conversion may
2329
// require multiple iterations, with each iteration involving multiple
@@ -141,7 +147,8 @@ class Allocation {
141147
explicit Allocation(Operation *operation) : operation(operation) {}
142148

143149
/// Runs allocation analysis on the given top-level operation.
144-
template <typename AllocationAnalysis> void run(FuncAllocMapT &funcAllocMap);
150+
void run(FuncAllocMapT &funcAllocMap,
151+
triton::AllocationAnalysisScratchSizeFn scratchSizeGetter);
145152

146153
/// Returns the operation this analysis was constructed from.
147154
Operation *getOperation() const { return operation; }
@@ -242,9 +249,6 @@ class Allocation {
242249
size_t sharedMemorySize = 0;
243250
};
244251

245-
template <>
246-
void Allocation::run<triton::AllocationAnalysis>(FuncAllocMapT &funcAllocMap);
247-
248252
/// Static analysis that computes the allocation of shared memory buffers
249253
/// of the entire call graph.
250254
/// The allocation is performed in a post-order walk of the call graph.
@@ -255,19 +259,19 @@ class ModuleAllocation : public CallGraph<Allocation> {
255259
public:
256260
using FuncOffsetMapT = DenseMap<FunctionOpInterface, Value>;
257261

258-
template <typename AllocationAnalysis = triton::AllocationAnalysis>
259-
static ModuleAllocation get(ModuleOp moduleOp) {
260-
ModuleAllocation res(moduleOp);
261-
res.walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
262+
ModuleAllocation(ModuleOp moduleOp,
263+
triton::AllocationAnalysisScratchSizeFn scratchSizeGetter =
264+
triton::defaultAllocationAnalysisScratchSizeFn)
265+
: CallGraph<Allocation>(moduleOp) {
266+
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
262267
// Pre-order edge walk callback
263268
[](CallOpInterface callOp, FunctionOpInterface funcOp) {},
264269
// Post-order node walk callback
265270
[&](FunctionOpInterface funcOp) {
266-
auto [iter, inserted] = res.funcMap.try_emplace(funcOp, funcOp);
271+
auto [iter, inserted] = funcMap.try_emplace(funcOp, funcOp);
267272
if (inserted)
268-
iter->second.template run<AllocationAnalysis>(res.funcMap);
273+
iter->second.run(funcMap, scratchSizeGetter);
269274
});
270-
return res;
271275
}
272276

273277
size_t getSharedMemorySize() {
@@ -292,9 +296,6 @@ class ModuleAllocation : public CallGraph<Allocation> {
292296
}
293297

294298
private:
295-
explicit ModuleAllocation(ModuleOp moduleOp)
296-
: CallGraph<Allocation>(moduleOp) {}
297-
298299
FuncOffsetMapT sharedMemoryValue;
299300
};
300301

include/triton/Analysis/AxisInfo.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,12 @@ class AxisInfo {
2727
public:
2828
AxisInfo() : AxisInfo({}, {}, {}) {}
2929

30-
AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy)
30+
AxisInfo(ArrayRef<int64_t> contiguity, ArrayRef<int64_t> divisibility,
31+
ArrayRef<int64_t> constancy)
3132
: AxisInfo(contiguity, divisibility, constancy, std::nullopt) {}
3233

33-
AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy,
34-
std::optional<int64_t> constantValue)
34+
AxisInfo(ArrayRef<int64_t> contiguity, ArrayRef<int64_t> divisibility,
35+
ArrayRef<int64_t> constancy, std::optional<int64_t> constantValue)
3536
: contiguity(contiguity), divisibility(divisibility),
3637
constancy(constancy), constantValue(constantValue) {
3738
assert(divisibility.size() == contiguity.size());

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,11 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> {
778778
(ins "int":$opIdx,
779779
"int":$kWidth)>,
780780

781+
InterfaceMethod<"Get the order of reps (tiles of this layout that tile the whole tensor). The fastest-changing axis first",
782+
"SmallVector<unsigned>",
783+
"getRepOrderForOperand",
784+
(ins "int":$opIdx)>,
785+
781786
InterfaceMethod<"Return element sizes per thread for dot operands.", "SmallVector<unsigned>",
782787
"getElemsPerThreadForOperands", (ins "ArrayRef<int64_t>":$tensorShape,
783788
"Type":$eltTy,

0 commit comments

Comments
 (0)