Skip to content

Commit 8f6592f

Browse files
Fix build failures from cfe3dd0
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 699ff73 commit 8f6592f

File tree

3 files changed

+8
-16
lines changed

3 files changed

+8
-16
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/PrintOpToLLVM.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,10 @@ struct PrintOpConversion
2929
ConversionPatternRewriter &rewriter) const override {
3030
auto loc = op->getLoc();
3131

32-
auto getPid = [&](int axis) {
33-
return targetInfo.programId(rewriter, loc,
34-
op->getParentOfType<ModuleOp>(), axis);
35-
};
36-
std::array<Value, 3> pid = {getPid(0), getPid(1), getPid(2)};
32+
std::array<Value, 3> pid;
33+
auto module = op->getParentOfType<ModuleOp>();
34+
for (auto axis : {ProgramIDDim::X, ProgramIDDim::Y, ProgramIDDim::Z})
35+
pid[(int)axis] = targetInfo.programId(rewriter, loc, module, axis);
3736

3837
// Simple printf of a string without any tensors.
3938
if (op.getNumOperands() == 0) {

third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,9 @@ Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val,
8686
}
8787

8888
Value TargetInfo::programId(RewriterBase &rewriter, Location loc,
89-
ModuleOp moduleOp, int axis) const {
90-
assert(axis >= 0);
91-
assert(axis < 3);
92-
assert(moduleOp);
93-
94-
constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x,
95-
mlir::gpu::Dimension::y,
96-
mlir::gpu::Dimension::z};
97-
98-
Value blockId = rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[axis]);
89+
ModuleOp moduleOp, ProgramIDDim axis) const {
90+
Value blockId =
91+
rewriter.create<::mlir::gpu::BlockIdOp>(loc, mlir::gpu::Dimension(axis));
9992
return rewriter.create<arith::IndexCastOp>(loc, i32_ty, blockId);
10093
}
10194

third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
4242
Value i) const override;
4343

4444
Value programId(RewriterBase &rewriter, Location loc, ModuleOp moduleOp,
45-
int axis) const override;
45+
ProgramIDDim axis) const override;
4646

4747
bool warpReduce(RewriterBase &rewriter, Location loc, SmallVector<Value> &acc,
4848
triton::ReduceOp op, unsigned numLaneToReduce,

0 commit comments

Comments
 (0)