Skip to content

Commit 0a23fcd

Browse files
Merge commit '83cf4362bd87c74bf57e79a7e213b4301fa3f25c'
2 parents 3fcbdc8 + 83cf436 commit 0a23fcd

File tree

13 files changed

+701
-65
lines changed

13 files changed

+701
-65
lines changed

.github/workflows/integration-tests.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ jobs:
236236
- name: Install pip dependencies
237237
run: |
238238
python3 -m pip install --upgrade pip
239-
python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-forked pytest-xdist lit
239+
python3 -m pip install cython setuptools wheel cmake==3.24 ninja lit
240240
- name: Install Triton
241241
env:
242242
CUDA_HOME: "/usr/local/cuda"
@@ -569,7 +569,7 @@ jobs:
569569
python3 -m venv ~/.venv
570570
source ~/.venv/bin/activate
571571
python3 -m pip install --upgrade pip
572-
python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-xdist lit pybind11
572+
python3 -m pip install cython setuptools wheel cmake==3.24 ninja lit pybind11
573573
- name: Install Triton
574574
env:
575575
TRITON_BUILD_WITH_O1: "true"

.github/workflows/integration-tests.yml.in

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ jobs:
268268
- name: Install pip dependencies
269269
run: |
270270
python3 -m pip install --upgrade pip
271-
python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-forked pytest-xdist lit
271+
python3 -m pip install cython setuptools wheel cmake==3.24 ninja lit
272272

273273
- name: Install Triton
274274
env:
@@ -481,7 +481,7 @@ jobs:
481481
python3 -m venv ~/.venv
482482
source ~/.venv/bin/activate
483483
python3 -m pip install --upgrade pip
484-
python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-xdist lit pybind11
484+
python3 -m pip install cython setuptools wheel cmake==3.24 ninja lit pybind11
485485
- name: Install Triton
486486
env:
487487
TRITON_BUILD_WITH_O1: "true"

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
185185
include_directories(${PYTHON_SRC_PATH})
186186

187187
# Python Interpreter is used to run lit tests
188-
find_package(Python3 REQUIRED COMPONENTS Development Interpreter)
188+
find_package(Python3 REQUIRED COMPONENTS Development.Module Interpreter)
189189
find_package(pybind11 CONFIG REQUIRED HINTS "${Python3_SITELIB}")
190190

191191
if (DEFINED TRITON_PLUGIN_DIRS)

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
3535
// Return the minClusterId and maxClusterId for the given ForOp.
3636
std::pair<int, int> getMinMaxCluster(scf::ForOp &forOp);
3737
std::pair<int, int> getStageCluster(Operation *op);
38-
void setStageCluster(scf::ForOp &forOp, Operation *op, int stage, int cluster);
38+
void setStageCluster(Operation *op, int stage, int cluster);
3939
} // namespace triton
4040
} // namespace mlir
4141

lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,7 @@ class OpBuilderWithStage : public OpBuilder {
6464
OpTy createWithStage(Location location, int stage, int cluster,
6565
Args &&...args) {
6666
OpTy op = OpBuilder::create<OpTy>(location, std::forward<Args>(args)...);
67-
auto ctx = getContext();
68-
op->setAttr(mlir::triton::kLoopStageAttrName,
69-
IntegerAttr::get(IntegerType::get(ctx, 32), stage));
70-
op->setAttr(mlir::triton::kLoopClusterAttrName,
71-
IntegerAttr::get(IntegerType::get(ctx, 32), cluster));
67+
tt::setStageCluster(op, stage, cluster);
7268
return op;
7369
}
7470
using OpBuilder::create;
@@ -204,9 +200,8 @@ static int createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc,
204200
// Prefetch load if is not MMAV3 and is used by the dot.
205201
if (loadToInfo[loadOp].usedByDot) {
206202
assert(stageForFirstUse >= 1);
207-
tt::setStageCluster(forOp, wait, stageForFirstUse - 1, maxClusterId + 1);
208-
tt::setStageCluster(forOp, viewLoad, stageForFirstUse - 1,
209-
maxClusterId + 1);
203+
tt::setStageCluster(wait, stageForFirstUse - 1, maxClusterId + 1);
204+
tt::setStageCluster(viewLoad, stageForFirstUse - 1, maxClusterId + 1);
210205
retCode = stageForFirstUse - 1;
211206
}
212207
}

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,8 @@ std::pair<int, int> mlir::triton::getStageCluster(Operation *op) {
188188
return std::make_pair(stage, clusterId);
189189
}
190190

191-
void mlir::triton::setStageCluster(scf::ForOp &forOp, Operation *op, int stage,
192-
int cluster) {
193-
auto ctx = forOp.getContext();
191+
void mlir::triton::setStageCluster(Operation *op, int stage, int cluster) {
192+
auto ctx = op->getContext();
194193
op->setAttr(mlir::triton::kLoopStageAttrName,
195194
IntegerAttr::get(IntegerType::get(ctx, 32), stage));
196195
op->setAttr(mlir::triton::kLoopClusterAttrName,

lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ void tt::CoarseSchedule::dump() {
9494
// Set <stage, cluster> based on CoarseSchedule.
9595
void tt::CoarseSchedule::serialize(scf::ForOp &forOp) {
9696
for (auto [op, stage, cluster] : getOpsInOrder(forOp)) {
97-
tt::setStageCluster(forOp, op, stage, *cluster);
97+
tt::setStageCluster(op, stage, *cluster);
9898
}
9999
}
100100

python/setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,8 @@ def get_git_commit_hash(length=8):
806806
"isort",
807807
"numpy",
808808
"pytest",
809+
"pytest-forked",
810+
"pytest-xdist",
809811
"scipy>=1.7.1",
810812
"llnl-hatchet",
811813
],

0 commit comments

Comments
 (0)