Skip to content

Commit da8d54d

Browse files
Automerge: [MLIR][NVVM] Add Permute Op (#169793)
This patch adds the `permute` op. Lit tests are added to verify the lowering to the intrinsics. Negative tests are also added to check the error-handling of invalid combinations. PTX spec reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prmt Signed-off-by: Dharuni R Acharya <[email protected]>
2 parents 926e8d9 + 63163b4 commit da8d54d

File tree

4 files changed

+284
-0
lines changed

4 files changed

+284
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1567,6 +1567,133 @@ def NVVM_ElectSyncOp : NVVM_Op<"elect.sync">
15671567
}];
15681568
}
15691569

1570+
//===----------------------------------------------------------------------===//
1571+
// Permute Bytes (Prmt)
1572+
//===----------------------------------------------------------------------===//
1573+
1574+
// Attributes for the permute operation modes supported by PTX.
1575+
def PermuteModeDefault : I32EnumAttrCase<"DEFAULT", 0, "default">;
1576+
def PermuteModeF4E : I32EnumAttrCase<"F4E", 1, "f4e">;
1577+
def PermuteModeB4E : I32EnumAttrCase<"B4E", 2, "b4e">;
1578+
def PermuteModeRC8 : I32EnumAttrCase<"RC8", 3, "rc8">;
1579+
def PermuteModeECL : I32EnumAttrCase<"ECL", 4, "ecl">;
1580+
def PermuteModeECR : I32EnumAttrCase<"ECR", 5, "ecr">;
1581+
def PermuteModeRC16 : I32EnumAttrCase<"RC16", 6, "rc16">;
1582+
1583+
def PermuteMode : I32EnumAttr<"PermuteMode", "NVVM permute mode",
1584+
[PermuteModeDefault, PermuteModeF4E,
1585+
PermuteModeB4E, PermuteModeRC8, PermuteModeECL,
1586+
PermuteModeECR, PermuteModeRC16]> {
1587+
let genSpecializedAttr = 0;
1588+
let cppNamespace = "::mlir::NVVM";
1589+
}
1590+
1591+
def PermuteModeAttr : EnumAttr<NVVM_Dialect, PermuteMode, "permute_mode"> {
1592+
let assemblyFormat = "`<` $value `>`";
1593+
}
1594+
1595+
def NVVM_PermuteOp : NVVM_Op<"prmt", [Pure]>,
1596+
Results<(outs I32:$res)>,
1597+
Arguments<(ins I32:$lo, Optional<I32>:$hi, I32:$selector,
1598+
PermuteModeAttr:$mode)> {
1599+
let summary = "Permute bytes from two 32-bit registers";
1600+
let description = [{
1601+
The `nvvm.prmt` operation constructs a permutation of the
1602+
bytes of the first one or two operands, selecting based on
1603+
the 2 least significant bits of the final operand.
1604+
1605+
The bytes in the first one or two source operands are numbered.
1606+
The first source operand (%lo) is numbered {b3, b2, b1, b0},
1607+
in the case of the '``default``', '``f4e``' and '``b4e``' variants,
1608+
the second source operand (%hi) is numbered {b7, b6, b5, b4}.
1609+
1610+
Modes:
1611+
- `default`: Index mode - each nibble in `selector` selects a byte from the 8-byte pool
1612+
- `f4e` : Forward 4 extract - extracts 4 contiguous bytes starting from position in `selector`
1613+
- `b4e` : Backward 4 extract - extracts 4 contiguous bytes in reverse order
1614+
- `rc8` : Replicate 8 - replicates the lower 8 bits across the 32-bit result
1615+
- `ecl` : Edge clamp left - clamps out-of-range indices to the leftmost valid byte
1616+
- `ecr` : Edge clamp right - clamps out-of-range indices to the rightmost valid byte
1617+
- `rc16` : Replicate 16 - replicates the lower 16 bits across the 32-bit result
1618+
1619+
Depending on the 2 least significant bits of the %selector operand, the result
1620+
of the permutation is defined as follows:
1621+
1622+
+------------+----------------+--------------+
1623+
| Mode | %selector[1:0] | Output |
1624+
+------------+----------------+--------------+
1625+
| '``f4e``' | 0 | {3, 2, 1, 0} |
1626+
| +----------------+--------------+
1627+
| | 1 | {4, 3, 2, 1} |
1628+
| +----------------+--------------+
1629+
| | 2 | {5, 4, 3, 2} |
1630+
| +----------------+--------------+
1631+
| | 3 | {6, 5, 4, 3} |
1632+
+------------+----------------+--------------+
1633+
| '``b4e``' | 0 | {5, 6, 7, 0} |
1634+
| +----------------+--------------+
1635+
| | 1 | {6, 7, 0, 1} |
1636+
| +----------------+--------------+
1637+
| | 2 | {7, 0, 1, 2} |
1638+
| +----------------+--------------+
1639+
| | 3 | {0, 1, 2, 3} |
1640+
+------------+----------------+--------------+
1641+
| '``rc8``' | 0 | {0, 0, 0, 0} |
1642+
| +----------------+--------------+
1643+
| | 1 | {1, 1, 1, 1} |
1644+
| +----------------+--------------+
1645+
| | 2 | {2, 2, 2, 2} |
1646+
| +----------------+--------------+
1647+
| | 3 | {3, 3, 3, 3} |
1648+
+------------+----------------+--------------+
1649+
| '``ecl``' | 0 | {3, 2, 1, 0} |
1650+
| +----------------+--------------+
1651+
| | 1 | {3, 2, 1, 1} |
1652+
| +----------------+--------------+
1653+
| | 2 | {3, 2, 2, 2} |
1654+
| +----------------+--------------+
1655+
| | 3 | {3, 3, 3, 3} |
1656+
+------------+----------------+--------------+
1657+
| '``ecr``' | 0 | {0, 0, 0, 0} |
1658+
| +----------------+--------------+
1659+
| | 1 | {1, 1, 1, 0} |
1660+
| +----------------+--------------+
1661+
| | 2 | {2, 2, 1, 0} |
1662+
| +----------------+--------------+
1663+
| | 3 | {3, 2, 1, 0} |
1664+
+------------+----------------+--------------+
1665+
| '``rc16``' | 0 | {1, 0, 1, 0} |
1666+
| +----------------+--------------+
1667+
| | 1 | {3, 2, 3, 2} |
1668+
| +----------------+--------------+
1669+
| | 2 | {1, 0, 1, 0} |
1670+
| +----------------+--------------+
1671+
| | 3 | {3, 2, 3, 2} |
1672+
+------------+----------------+--------------+
1673+
1674+
[For more information, see PTX ISA]
1675+
(https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prmt)
1676+
}];
1677+
1678+
let assemblyFormat = [{
1679+
$mode $selector `,` $lo (`,` $hi^)? attr-dict `:` type($res)
1680+
}];
1681+
1682+
let hasVerifier = 1;
1683+
1684+
let extraClassDeclaration = [{
1685+
static mlir::NVVM::IDArgPair
1686+
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
1687+
llvm::IRBuilderBase &builder);
1688+
}];
1689+
1690+
string llvmBuilder = [{
1691+
auto [id, args] = NVVM::PermuteOp::getIntrinsicIDAndArgs(
1692+
*op, moduleTranslation, builder);
1693+
$res = createIntrinsicCall(builder, id, args);
1694+
}];
1695+
}
1696+
15701697
def LoadCacheModifierCA : I32EnumAttrCase<"CA", 0, "ca">;
15711698
def LoadCacheModifierCG : I32EnumAttrCase<"CG", 1, "cg">;
15721699
def LoadCacheModifierCS : I32EnumAttrCase<"CS", 2, "cs">;

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,31 @@ LogicalResult ConvertF4x2ToF16x2Op::verify() {
448448
return success();
449449
}
450450

451+
LogicalResult PermuteOp::verify() {
452+
using Mode = NVVM::PermuteMode;
453+
bool hasHi = static_cast<bool>(getHi());
454+
455+
switch (getMode()) {
456+
case Mode::DEFAULT:
457+
case Mode::F4E:
458+
case Mode::B4E:
459+
if (!hasHi)
460+
return emitError("mode '")
461+
<< stringifyPermuteMode(getMode()) << "' requires 'hi' operand.";
462+
break;
463+
case Mode::RC8:
464+
case Mode::ECL:
465+
case Mode::ECR:
466+
case Mode::RC16:
467+
if (hasHi)
468+
return emitError("mode '") << stringifyPermuteMode(getMode())
469+
<< "' does not accept 'hi' operand.";
470+
break;
471+
}
472+
473+
return success();
474+
}
475+
451476
//===----------------------------------------------------------------------===//
452477
// Stochastic Rounding Conversion Ops
453478
//===----------------------------------------------------------------------===//
@@ -3855,6 +3880,31 @@ NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
38553880
return {intrinsicID, args};
38563881
}
38573882

3883+
mlir::NVVM::IDArgPair
3884+
PermuteOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3885+
llvm::IRBuilderBase &builder) {
3886+
auto thisOp = cast<NVVM::PermuteOp>(op);
3887+
NVVM::PermuteMode mode = thisOp.getMode();
3888+
3889+
static constexpr llvm::Intrinsic::ID IDs[] = {
3890+
llvm::Intrinsic::nvvm_prmt, llvm::Intrinsic::nvvm_prmt_f4e,
3891+
llvm::Intrinsic::nvvm_prmt_b4e, llvm::Intrinsic::nvvm_prmt_rc8,
3892+
llvm::Intrinsic::nvvm_prmt_ecl, llvm::Intrinsic::nvvm_prmt_ecr,
3893+
llvm::Intrinsic::nvvm_prmt_rc16};
3894+
3895+
unsigned modeIndex = static_cast<unsigned>(mode);
3896+
llvm::SmallVector<llvm::Value *> args;
3897+
args.push_back(mt.lookupValue(thisOp.getLo()));
3898+
3899+
// Only first 3 modes (Default, f4e, b4e) need the hi operand.
3900+
if (modeIndex < 3)
3901+
args.push_back(mt.lookupValue(thisOp.getHi()));
3902+
3903+
args.push_back(mt.lookupValue(thisOp.getSelector()));
3904+
3905+
return {IDs[modeIndex], args};
3906+
}
3907+
38583908
//===----------------------------------------------------------------------===//
38593909
// NVVM tcgen05.mma functions
38603910
//===----------------------------------------------------------------------===//
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
2+
3+
llvm.func @invalid_default_missing_hi(%sel: i32, %lo: i32) -> i32 {
4+
// expected-error @below {{mode 'default' requires 'hi' operand.}}
5+
%r = nvvm.prmt #nvvm.permute_mode<default> %sel, %lo : i32
6+
llvm.return %r : i32
7+
}
8+
9+
llvm.func @invalid_f4e_missing_hi(%sel: i32, %lo: i32) -> i32 {
10+
// expected-error @below {{mode 'f4e' requires 'hi' operand.}}
11+
%r = nvvm.prmt #nvvm.permute_mode<f4e> %sel, %lo : i32
12+
llvm.return %r : i32
13+
}
14+
15+
llvm.func @invalid_b4e_missing_hi(%sel: i32, %lo: i32) -> i32 {
16+
// expected-error @below {{mode 'b4e' requires 'hi' operand.}}
17+
%r = nvvm.prmt #nvvm.permute_mode<b4e> %sel, %lo : i32
18+
llvm.return %r : i32
19+
}
20+
21+
llvm.func @invalid_rc8_with_hi(%sel: i32, %lo: i32, %hi: i32) -> i32 {
22+
// expected-error @below {{mode 'rc8' does not accept 'hi' operand.}}
23+
%r = nvvm.prmt #nvvm.permute_mode<rc8> %sel, %lo, %hi : i32
24+
llvm.return %r : i32
25+
}
26+
27+
llvm.func @invalid_ecl_with_hi(%sel: i32, %lo: i32, %hi: i32) -> i32 {
28+
// expected-error @below {{mode 'ecl' does not accept 'hi' operand.}}
29+
%r = nvvm.prmt #nvvm.permute_mode<ecl> %sel, %lo, %hi : i32
30+
llvm.return %r : i32
31+
}
32+
33+
llvm.func @invalid_ecr_with_hi(%sel: i32, %lo: i32, %hi: i32) -> i32 {
34+
// expected-error @below {{mode 'ecr' does not accept 'hi' operand.}}
35+
%r = nvvm.prmt #nvvm.permute_mode<ecr> %sel, %lo, %hi : i32
36+
llvm.return %r : i32
37+
}
38+
39+
llvm.func @invalid_rc16_with_hi(%sel: i32, %lo: i32, %hi: i32) -> i32 {
40+
// expected-error @below {{mode 'rc16' does not accept 'hi' operand.}}
41+
%r = nvvm.prmt #nvvm.permute_mode<rc16> %sel, %lo, %hi : i32
42+
llvm.return %r : i32
43+
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// CHECK-LABEL: @test_prmt_default
4+
llvm.func @test_prmt_default(%sel: i32, %lo: i32, %hi: i32) -> i32 {
5+
// CHECK: call i32 @llvm.nvvm.prmt(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
6+
%result = nvvm.prmt #nvvm.permute_mode<default> %sel, %lo, %hi : i32
7+
llvm.return %result : i32
8+
}
9+
10+
// CHECK-LABEL: @test_prmt_f4e
11+
llvm.func @test_prmt_f4e(%pos: i32, %lo: i32, %hi: i32) -> i32 {
12+
// CHECK: call i32 @llvm.nvvm.prmt.f4e(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
13+
%result = nvvm.prmt #nvvm.permute_mode<f4e> %pos, %lo, %hi : i32
14+
llvm.return %result : i32
15+
}
16+
17+
// CHECK-LABEL: @test_prmt_b4e
18+
llvm.func @test_prmt_b4e(%pos: i32, %lo: i32, %hi: i32) -> i32 {
19+
// CHECK: call i32 @llvm.nvvm.prmt.b4e(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
20+
%result = nvvm.prmt #nvvm.permute_mode<b4e> %pos, %lo, %hi : i32
21+
llvm.return %result : i32
22+
}
23+
24+
// CHECK-LABEL: @test_prmt_rc8
25+
llvm.func @test_prmt_rc8(%sel: i32, %val: i32) -> i32 {
26+
// CHECK: call i32 @llvm.nvvm.prmt.rc8(i32 %{{.*}}, i32 %{{.*}})
27+
%result = nvvm.prmt #nvvm.permute_mode<rc8> %sel, %val : i32
28+
llvm.return %result : i32
29+
}
30+
31+
// CHECK-LABEL: @test_prmt_ecl
32+
llvm.func @test_prmt_ecl(%sel: i32, %val: i32) -> i32 {
33+
// CHECK: call i32 @llvm.nvvm.prmt.ecl(i32 %{{.*}}, i32 %{{.*}})
34+
%result = nvvm.prmt #nvvm.permute_mode<ecl> %sel, %val : i32
35+
llvm.return %result : i32
36+
}
37+
38+
// CHECK-LABEL: @test_prmt_ecr
39+
llvm.func @test_prmt_ecr(%sel: i32, %val: i32) -> i32 {
40+
// CHECK: call i32 @llvm.nvvm.prmt.ecr(i32 %{{.*}}, i32 %{{.*}})
41+
%result = nvvm.prmt #nvvm.permute_mode<ecr> %sel, %val : i32
42+
llvm.return %result : i32
43+
}
44+
45+
// CHECK-LABEL: @test_prmt_rc16
46+
llvm.func @test_prmt_rc16(%sel: i32, %val: i32) -> i32 {
47+
// CHECK: call i32 @llvm.nvvm.prmt.rc16(i32 %{{.*}}, i32 %{{.*}})
48+
%result = nvvm.prmt #nvvm.permute_mode<rc16> %sel, %val : i32
49+
llvm.return %result : i32
50+
}
51+
52+
// CHECK-LABEL: @test_prmt_mixed
53+
llvm.func @test_prmt_mixed(%sel: i32, %lo: i32, %hi: i32) -> i32 {
54+
// CHECK: call i32 @llvm.nvvm.prmt(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
55+
%r1 = nvvm.prmt #nvvm.permute_mode<default> %sel, %lo, %hi : i32
56+
57+
// CHECK: call i32 @llvm.nvvm.prmt.rc8(i32 %{{.*}}, i32 %{{.*}})
58+
%r2 = nvvm.prmt #nvvm.permute_mode<rc8> %sel, %r1 : i32
59+
60+
// CHECK: call i32 @llvm.nvvm.prmt.f4e(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
61+
%r3 = nvvm.prmt #nvvm.permute_mode<f4e> %lo, %r2, %sel : i32
62+
63+
llvm.return %r3 : i32
64+
}

0 commit comments

Comments
 (0)