Skip to content

Commit 0e679c1

Browse files
[Intel] Add permute to TargetInfo
Signed-off-by: Whitney Tsang <[email protected]>
1 parent a7b1123 commit 0e679c1

File tree

4 files changed

+31
-0
lines changed

4 files changed

+31
-0
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,14 @@ Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val,
8585
return LLVM::intel::shuffleIdx(loc, rewriter, val, i);
8686
}
8787

88+
Value TargetInfo::permute(RewriterBase &rewriter, Location loc, Value a,
89+
Value b, Value selector) const {
90+
// Warning: The `a` and `b` operands are ordered to align with Nvidia's `prmt`
91+
// Both use little-endian ordering, but AMD puts the MSBs of the data in the
92+
// 0-th operand.
93+
return LLVM::intel::permute(loc, rewriter, b, a, selector);
94+
}
95+
8896
Value TargetInfo::programId(RewriterBase &rewriter, Location loc,
8997
ModuleOp moduleOp, ProgramIDDim axis) const {
9098
Value blockId =

third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
4141
Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val,
4242
Value i) const override;
4343

44+
Value permute(RewriterBase &rewriter, Location loc, Value a, Value b,
45+
Value selector) const override;
46+
4447
Value programId(RewriterBase &rewriter, Location loc, ModuleOp moduleOp,
4548
ProgramIDDim axis) const override;
4649

third_party/intel/lib/TritonIntelGPUToLLVM/Utility.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,23 @@ Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i) {
101101
return shuffleCommon(loc, rewriter, val, i, mlir::gpu::ShuffleMode::IDX);
102102
}
103103

104+
Value permute(Location loc, RewriterBase &rewriter, Value x, Value y,
105+
Value selector) {
106+
auto b = TritonLLVMOpBuilder(loc, rewriter);
107+
Value prmt_mask = selector;
108+
// convert from nybble mask to byte mask:
109+
prmt_mask =
110+
b.or_(b.and_(prmt_mask, b.i32_val(0x000000ff)),
111+
b.shl(b.and_(prmt_mask, b.i32_val(0x0000ff00)), b.i32_val(8)));
112+
prmt_mask =
113+
b.or_(b.and_(prmt_mask, b.i32_val(0x000f000f)),
114+
b.shl(b.and_(prmt_mask, b.i32_val(0x00f000f0)), b.i32_val(4)));
115+
Value args[] = {x, y, prmt_mask};
116+
auto op = createLLVMIntrinsicCallOp(rewriter, loc, "llvm.amdgcn.perm", i32_ty,
117+
args);
118+
return op.getResult(0);
119+
}
120+
104121
LLVM::RoundingMode
105122
convertTritonRoundingModeToLLVM(const triton::RoundingMode rounding) {
106123
LLVM::RoundingMode roundingMode;

third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i);
1919
Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i);
2020
Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i);
2121

22+
Value permute(Location loc, RewriterBase &rewriter, Value a, Value b,
23+
Value selector);
24+
2225
/// Create a predicated block, using \p cond as the condition and \p ops for the
2326
/// values supplied by the conditional branch to the exit block. The \p
2427
/// thenOpsFn function is used to inject operations in the 'then' branch:

0 commit comments

Comments
 (0)