Skip to content

Commit 206adb3

Browse files
committed
Address PR feedback
Signed-off-by: Lukas Sommer <[email protected]>
1 parent cd47c3e commit 206adb3

File tree

3 files changed

+105
-62
lines changed

3 files changed

+105
-62
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
//===- SPIRVSubgroupOps.h - Mapping for SPIR-V Reduction --------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines mapping from operations in the 'arith' dialect to the
10+
// corresponding SPIR-V Subgroup Reduction Operation.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef TRITONINTELGPUTOLLVM_SPIRVSUBGROUPOPS_H
15+
#define TRITONINTELGPUTOLLVM_SPIRVSUBGROUPOPS_H
16+
17+
#include "mlir/Dialect/Arith/IR/Arith.h"
18+
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
19+
20+
using namespace mlir;
21+
22+
namespace mlir::triton::intel {
23+
24+
template <typename OpTy> struct SPIRVArithmeticGroupOp {};
25+
26+
template <> struct SPIRVArithmeticGroupOp<arith::AddFOp> {
27+
using type = spirv::GroupNonUniformFAddOp;
28+
};
29+
template <> struct SPIRVArithmeticGroupOp<arith::AddIOp> {
30+
using type = spirv::GroupNonUniformIAddOp;
31+
};
32+
template <> struct SPIRVArithmeticGroupOp<arith::MulFOp> {
33+
using type = spirv::GroupNonUniformFMulOp;
34+
};
35+
template <> struct SPIRVArithmeticGroupOp<arith::MulIOp> {
36+
using type = spirv::GroupNonUniformIMulOp;
37+
};
38+
template <> struct SPIRVArithmeticGroupOp<arith::MaxNumFOp> {
39+
using type = spirv::GroupNonUniformFMaxOp;
40+
};
41+
template <> struct SPIRVArithmeticGroupOp<arith::MinNumFOp> {
42+
using type = spirv::GroupNonUniformFMinOp;
43+
};
44+
45+
template <typename OpTy>
46+
using SPIRVArithmeticGroupOpTy = typename SPIRVArithmeticGroupOp<OpTy>::type;
47+
48+
template <typename OpTy> struct SPIRVBitwiseGroupOp {};
49+
50+
template <> struct SPIRVBitwiseGroupOp<arith::AndIOp> {
51+
using type = spirv::GroupNonUniformBitwiseAndOp;
52+
};
53+
template <> struct SPIRVBitwiseGroupOp<arith::OrIOp> {
54+
using type = spirv::GroupNonUniformBitwiseOrOp;
55+
};
56+
template <> struct SPIRVBitwiseGroupOp<arith::XOrIOp> {
57+
using type = spirv::GroupNonUniformBitwiseXorOp;
58+
};
59+
60+
template <typename OpTy>
61+
using SPIRVBitwiseGroupOpTy = typename SPIRVBitwiseGroupOp<OpTy>::type;
62+
63+
template <typename OpTy> struct SPIRVLogicalGroupOp {};
64+
65+
template <> struct SPIRVLogicalGroupOp<arith::AndIOp> {
66+
using type = spirv::GroupNonUniformLogicalAndOp;
67+
};
68+
template <> struct SPIRVLogicalGroupOp<arith::OrIOp> {
69+
using type = spirv::GroupNonUniformLogicalOrOp;
70+
};
71+
template <> struct SPIRVLogicalGroupOp<arith::XOrIOp> {
72+
using type = spirv::GroupNonUniformLogicalXorOp;
73+
};
74+
75+
template <typename OpTy>
76+
using SPIRVLogicalGroupOpTy = typename SPIRVLogicalGroupOp<OpTy>::type;
77+
78+
} // namespace mlir::triton::intel
79+
80+
#endif // TRITONINTELGPUTOLLVM_SPIRVSUBGROUPOPS_H

third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp

Lines changed: 16 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "TargetInfo.h"
10+
#include "SPIRVSubgroupOps.h"
1011
#include "Utility.h"
1112
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
1213
#include "llvm/ADT/TypeSwitch.h"
@@ -133,46 +134,21 @@ Value warpReduceHelper(RewriterBase &rewriter, Location loc, Value acc,
133134
unsigned warpSize) {
134135
auto resultType = reduceOp->getResult(0).getType();
135136
Value warpReduce =
136-
llvm::TypeSwitch<mlir::Operation *, Value>(reduceOp)
137-
.Case<arith::AddFOp>([&](auto) {
138-
return createSPIRVGroupOp<spirv::GroupNonUniformFAddOp>(
139-
rewriter, loc, resultType, acc, numLanesToReduce, warpSize);
140-
})
141-
.Case<arith::AddIOp>([&](auto) {
142-
return createSPIRVGroupOp<spirv::GroupNonUniformIAddOp>(
143-
rewriter, loc, resultType, acc, numLanesToReduce, warpSize);
144-
})
145-
.Case<arith::MulFOp>([&](auto) {
146-
return createSPIRVGroupOp<spirv::GroupNonUniformFMulOp>(
147-
rewriter, loc, resultType, acc, numLanesToReduce, warpSize);
148-
})
149-
.Case<arith::MulIOp>([&](auto) {
150-
return createSPIRVGroupOp<spirv::GroupNonUniformIMulOp>(
151-
rewriter, loc, resultType, acc, numLanesToReduce, warpSize);
152-
})
153-
.Case<arith::MaxNumFOp>([&](auto) {
154-
return createSPIRVGroupOp<spirv::GroupNonUniformFMaxOp>(
155-
rewriter, loc, resultType, acc, numLanesToReduce, warpSize);
156-
})
157-
.Case<arith::MinNumFOp>([&](auto) {
158-
return createSPIRVGroupOp<spirv::GroupNonUniformFMinOp>(
159-
rewriter, loc, resultType, acc, numLanesToReduce, warpSize);
160-
})
161-
.Case<arith::AndIOp>([&](auto) {
162-
return createSPIRVGroupOp<spirv::GroupNonUniformBitwiseAndOp>(
163-
rewriter, loc, resultType, acc, numLanesToReduce, warpSize);
164-
})
165-
.Case<arith::OrIOp>([&](auto) {
166-
return createSPIRVGroupOp<spirv::GroupNonUniformBitwiseOrOp>(
137+
TypeSwitch<mlir::Operation *, Value>(reduceOp)
138+
.Case<arith::AddFOp, arith::AddIOp, arith::MulFOp, arith::MulIOp,
139+
arith::MaxNumFOp, arith::MinNumFOp>([&](auto groupOp) {
140+
return createSPIRVGroupOp<
141+
SPIRVArithmeticGroupOpTy<decltype(groupOp)>>(
167142
rewriter, loc, resultType, acc, numLanesToReduce, warpSize);
168143
})
169-
.Case<arith::XOrIOp>([&](auto) {
170-
return createSPIRVGroupOp<spirv::GroupNonUniformBitwiseXorOp>(
144+
.Case<arith::AndIOp, arith::OrIOp, arith::XOrIOp>([&](auto groupOp) {
145+
if (resultType.isInteger(1)) {
146+
return createSPIRVGroupOp<
147+
SPIRVLogicalGroupOpTy<decltype(groupOp)>>(
148+
rewriter, loc, resultType, acc, numLanesToReduce, warpSize);
149+
}
150+
return createSPIRVGroupOp<SPIRVBitwiseGroupOpTy<decltype(groupOp)>>(
171151
rewriter, loc, resultType, acc, numLanesToReduce, warpSize);
172-
})
173-
.Default([](auto) {
174-
llvm_unreachable("Unsupported reduction");
175-
return Value();
176152
});
177153
return warpReduce;
178154
}
@@ -206,12 +182,9 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
206182
reduceOp->getOperand(1) != block.getArgument(1))
207183
return false;
208184

209-
auto supportedOp =
210-
llvm::TypeSwitch<mlir::Operation *, bool>(reduceOp)
211-
.Case<arith::AddFOp, arith::AddIOp, arith::MulFOp, arith::MulIOp,
212-
arith::MaxNumFOp, arith::MinNumFOp, arith::AndIOp, arith::OrIOp,
213-
arith::XOrIOp>([&](auto) { return true; })
214-
.Default([](auto) { return false; });
185+
auto supportedOp = isa<arith::AddFOp, arith::AddIOp, arith::MulFOp,
186+
arith::MulIOp, arith::MaxNumFOp, arith::MinNumFOp,
187+
arith::AndIOp, arith::OrIOp, arith::XOrIOp>(reduceOp);
215188

216189
if (!supportedOp)
217190
return false;

third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "Dialect/TritonIntelGPU/IR/Utils.h"
22
#include "PatternTritonGPUOpToLLVM.h"
3+
#include "SPIRVSubgroupOps.h"
34

45
#include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h"
56
#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
@@ -583,28 +584,17 @@ class ReduceOpConversion : public ConvertTritonGPUOpToLLVMPattern<ReduceOp> {
583584
Operation *combine = &*combineOp.front().getOperations().begin();
584585

585586
// FIXME: support all possible reduction modes
586-
using AllReduceOperation = mlir::gpu::AllReduceOperation;
587-
AllReduceOperation redKind;
588-
if (isa<arith::AddFOp>(combine))
589-
replaceWithSPIRVOp<mlir::spirv::GroupNonUniformFAddOp>(op, adaptor,
590-
rewriter);
591-
else if (isa<arith::MaxNumFOp>(combine))
592-
replaceWithSPIRVOp<mlir::spirv::GroupNonUniformFMaxOp>(op, adaptor,
593-
rewriter);
594-
else
595-
llvm_unreachable("Unhandled reduction kind");
587+
TypeSwitch<Operation *>(combine).Case<arith::AddFOp, arith::MaxNumFOp>(
588+
[&](auto reduce) {
589+
rewriter.replaceOpWithNewOp<
590+
intel::SPIRVArithmeticGroupOpTy<decltype(reduce)>>(
591+
op, typeConverter->convertType(op.getType(0)),
592+
spirv::Scope::Subgroup, spirv::GroupOperation::Reduce,
593+
adaptor.getSrcs()[0], Value());
594+
});
596595

597596
return success();
598597
}
599-
600-
private:
601-
template <typename ReplaceOp>
602-
void replaceWithSPIRVOp(ReduceOp op, ReduceOpAdaptor adaptor,
603-
ConversionPatternRewriter &rewriter) const {
604-
rewriter.replaceOpWithNewOp<ReplaceOp>(
605-
op, typeConverter->convertType(op.getType(0)), spirv::Scope::Subgroup,
606-
spirv::GroupOperation::Reduce, adaptor.getSrcs()[0], Value());
607-
}
608598
};
609599

610600
class TransposedReduceOpConversion

0 commit comments

Comments
 (0)