Skip to content

Commit e4e6b69

Browse files
committed
[NVPTX] Add intrinsics for the szext instruction
1 parent bbafa52 commit e4e6b69

File tree

5 files changed

+217
-41
lines changed

5 files changed

+217
-41
lines changed

llvm/docs/NVPTXUsage.rst

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,99 @@ to left-shift the found bit into the most-significant bit position, otherwise
568568
the result is the shift amount needed to right-shift the found bit into the
569569
least-significant bit position. 0xffffffff is returned if no 1 bit is found.
570570

571+
'``llvm.nvvm.zext.inreg.clamp``' Intrinsic
572+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
573+
574+
Syntax:
575+
"""""""
576+
577+
.. code-block:: llvm
578+
579+
declare i32 @llvm.nvvm.zext.inreg.clamp(i32 %a, i32 %b)
580+
581+
Overview:
582+
"""""""""
583+
584+
The '``llvm.nvvm.zext.inreg.clamp``' intrinsic extracts the low bits of the
585+
input value, and zero-extends them back to the original width.
586+
587+
Semantics:
588+
""""""""""
589+
590+
The '``llvm.nvvm.zext.inreg.clamp``' returns the zero-extension of N lowest bits
591+
of operand %a. N is the value of operand %b clamped to the range [0, 32]. If N
592+
is 0, the result is 0.
593+
594+
'``llvm.nvvm.zext.inreg.wrap``' Intrinsic
595+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
596+
597+
Syntax:
598+
"""""""
599+
600+
.. code-block:: llvm
601+
602+
declare i32 @llvm.nvvm.zext.inreg.wrap(i32 %a, i32 %b)
603+
604+
Overview:
605+
"""""""""
606+
607+
The '``llvm.nvvm.zext.inreg.wrap``' intrinsic extracts the low bits of the
608+
input value, and zero-extends them back to the original width.
609+
610+
Semantics:
611+
""""""""""
612+
613+
The '``llvm.nvvm.zext.inreg.wrap``' returns the zero-extension of N lowest bits
614+
of operand %a. N is the value of operand %b modulo 32. If N is 0, the result
615+
is 0.
616+
617+
'``llvm.nvvm.sext.inreg.clamp``' Intrinsic
618+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
619+
620+
Syntax:
621+
"""""""
622+
623+
.. code-block:: llvm
624+
625+
declare i32 @llvm.nvvm.sext.inreg.clamp(i32 %a, i32 %b)
626+
627+
Overview:
628+
"""""""""
629+
630+
The '``llvm.nvvm.sext.inreg.clamp``' intrinsic extracts the low bits of the
631+
input value, and sign-extends them back to the original width.
632+
633+
Semantics:
634+
""""""""""
635+
636+
The '``llvm.nvvm.sext.inreg.clamp``' returns the sign-extension of N lowest bits
637+
of operand %a. N is the value of operand %b clamped to the range [0, 32]. If N
638+
is 0, the result is 0.
639+
640+
641+
'``llvm.nvvm.sext.inreg.wrap``' Intrinsic
642+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
643+
644+
Syntax:
645+
"""""""
646+
647+
.. code-block:: llvm
648+
649+
declare i32 @llvm.nvvm.sext.inreg.wrap(i32 %a, i32 %b)
650+
651+
Overview:
652+
"""""""""
653+
654+
The '``llvm.nvvm.sext.inreg.wrap``' intrinsic extracts the low bits of the
655+
input value, and sign-extends them back to the original width.
656+
657+
Semantics:
658+
""""""""""
659+
660+
The '``llvm.nvvm.sext.inreg.wrap``' returns the sign-extension of N lowest bits
661+
of operand %a. N is the value of operand %b modulo 32. If N is 0, the result
662+
is 0.
663+
571664
TMA family of Intrinsics
572665
------------------------
573666

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,6 +1356,17 @@ let TargetPrefix = "nvvm" in {
13561356
[llvm_anyint_ty, llvm_i1_ty],
13571357
[IntrNoMem, IntrSpeculatable, IntrWillReturn, ImmArg<ArgIndex<1>>]>;
13581358

1359+
1360+
//
1361+
// szext
1362+
//
1363+
foreach ext = ["sext", "zext"] in
1364+
foreach mode = ["wrap", "clamp"] in
1365+
def int_nvvm_ # ext # _inreg_ # mode :
1366+
DefaultAttrsIntrinsic<[llvm_i32_ty],
1367+
[llvm_i32_ty, llvm_i32_ty],
1368+
[IntrNoMem, IntrSpeculatable]>;
1369+
13591370
//
13601371
// Convert
13611372
//

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 33 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -240,26 +240,33 @@ def F16X2RT : RegTyInfo<v2f16, Int32Regs, ?, ?, supports_imm = 0>;
240240
def BF16X2RT : RegTyInfo<v2bf16, Int32Regs, ?, ?, supports_imm = 0>;
241241

242242

243+
multiclass I3Inst<string op_str, SDPatternOperator op_node, RegTyInfo t,
244+
bit commutative, list<Predicate> requires = []> {
245+
defvar asmstr = op_str # " \t$dst, $a, $b;";
246+
247+
def rr :
248+
NVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, t.RC:$b),
249+
asmstr,
250+
[(set t.Ty:$dst, (op_node t.Ty:$a, t.Ty:$b))]>,
251+
Requires<requires>;
252+
def ri :
253+
NVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, t.Imm:$b),
254+
asmstr,
255+
[(set t.Ty:$dst, (op_node t.RC:$a, imm:$b))]>,
256+
Requires<requires>;
257+
if !not(commutative) then
258+
def ir :
259+
NVPTXInst<(outs t.RC:$dst), (ins t.Imm:$a, t.RC:$b),
260+
asmstr,
261+
[(set t.Ty:$dst, (op_node imm:$a, t.RC:$b))]>,
262+
Requires<requires>;
263+
}
264+
243265
// Template for instructions which take three int64, int32, or int16 args.
244266
// The instructions are named "<OpcStr><Width>" (e.g. "add.s64").
245-
multiclass I3<string OpcStr, SDNode OpNode, bit commutative> {
246-
foreach t = [I16RT, I32RT, I64RT] in {
247-
defvar asmstr = OpcStr # t.Size # " \t$dst, $a, $b;";
248-
249-
def t.Ty # rr :
250-
NVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, t.RC:$b),
251-
asmstr,
252-
[(set t.Ty:$dst, (OpNode t.Ty:$a, t.Ty:$b))]>;
253-
def t.Ty # ri :
254-
NVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, t.Imm:$b),
255-
asmstr,
256-
[(set t.Ty:$dst, (OpNode t.RC:$a, imm:$b))]>;
257-
if !not(commutative) then
258-
def t.Ty # ir :
259-
NVPTXInst<(outs t.RC:$dst), (ins t.Imm:$a, t.RC:$b),
260-
asmstr,
261-
[(set t.Ty:$dst, (OpNode imm:$a, t.RC:$b))]>;
262-
}
267+
multiclass I3<string op_str, SDPatternOperator op_node, bit commutative> {
268+
foreach t = [I16RT, I32RT, I64RT] in
269+
defm t.Ty# : I3Inst<op_str # t.Size, op_node, t, commutative>;
263270
}
264271

265272
class I16x2<string OpcStr, SDNode OpNode> :
@@ -270,26 +277,11 @@ class I16x2<string OpcStr, SDNode OpNode> :
270277

271278
// Template for instructions which take 3 int args. The instructions are
272279
// named "<OpcStr>.s32" (e.g. "addc.cc.s32").
273-
multiclass ADD_SUB_INT_CARRY<string OpcStr, SDNode OpNode> {
280+
multiclass ADD_SUB_INT_CARRY<string op_str, SDNode op_node, bit commutative> {
274281
let hasSideEffects = 1 in {
275-
def i32rr :
276-
NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, Int32Regs:$b),
277-
!strconcat(OpcStr, ".s32 \t$dst, $a, $b;"),
278-
[(set i32:$dst, (OpNode i32:$a, i32:$b))]>;
279-
def i32ri :
280-
NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, i32imm:$b),
281-
!strconcat(OpcStr, ".s32 \t$dst, $a, $b;"),
282-
[(set i32:$dst, (OpNode i32:$a, imm:$b))]>;
283-
def i64rr :
284-
NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$a, Int64Regs:$b),
285-
!strconcat(OpcStr, ".s64 \t$dst, $a, $b;"),
286-
[(set i64:$dst, (OpNode i64:$a, i64:$b))]>,
287-
Requires<[hasPTX<43>]>;
288-
def i64ri :
289-
NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$a, i64imm:$b),
290-
!strconcat(OpcStr, ".s64 \t$dst, $a, $b;"),
291-
[(set i64:$dst, (OpNode i64:$a, imm:$b))]>,
292-
Requires<[hasPTX<43>]>;
282+
defm i32 : I3Inst<op_str # ".s32", op_node, I32RT, commutative>;
283+
defm i64 : I3Inst<op_str # ".s64", op_node, I64RT, commutative,
284+
requires = [hasPTX<43>]>;
293285
}
294286
}
295287

@@ -847,12 +839,12 @@ defm SUB : I3<"sub.s", sub, /*commutative=*/ false>;
847839
def ADD16x2 : I16x2<"add.s", add>;
848840

849841
// in32 and int64 addition and subtraction with carry-out.
850-
defm ADDCC : ADD_SUB_INT_CARRY<"add.cc", addc>;
851-
defm SUBCC : ADD_SUB_INT_CARRY<"sub.cc", subc>;
842+
defm ADDCC : ADD_SUB_INT_CARRY<"add.cc", addc, commutative = true>;
843+
defm SUBCC : ADD_SUB_INT_CARRY<"sub.cc", subc, commutative = false>;
852844

853845
// int32 and int64 addition and subtraction with carry-in and carry-out.
854-
defm ADDCCC : ADD_SUB_INT_CARRY<"addc.cc", adde>;
855-
defm SUBCCC : ADD_SUB_INT_CARRY<"subc.cc", sube>;
846+
defm ADDCCC : ADD_SUB_INT_CARRY<"addc.cc", adde, commutative = true>;
847+
defm SUBCCC : ADD_SUB_INT_CARRY<"subc.cc", sube, commutative = false>;
856848

857849
defm MULT : I3<"mul.lo.s", mul, /*commutative=*/ true>;
858850

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1678,6 +1678,21 @@ foreach t = [I32RT, I64RT] in {
16781678
}
16791679
}
16801680

1681+
//
1682+
// szext
1683+
//
1684+
1685+
foreach sign = ["s", "u"] in {
1686+
foreach mode = ["wrap", "clamp"] in {
1687+
defvar ext = !if(!eq(sign, "s"), "sext", "zext");
1688+
defvar intrin = !cast<Intrinsic>("int_nvvm_" # ext # "_inreg_" # mode);
1689+
defm SZEXT_ # sign # _ # mode
1690+
: I3Inst<"szext." # mode # "." # sign # "32",
1691+
intrin, I32RT, commutative = false,
1692+
requires = [hasSM<70>, hasPTX<76>]>;
1693+
}
1694+
}
1695+
16811696
//
16821697
// Convert
16831698
//

llvm/test/CodeGen/NVPTX/szext.ll

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc -o - < %s -mcpu=sm_70 -mattr=+ptx76 | FileCheck %s
3+
4+
target triple = "nvptx-unknown-cuda"
5+
6+
define i32 @szext_wrap_u32(i32 %a, i32 %b) {
7+
; CHECK-LABEL: szext_wrap_u32(
8+
; CHECK: {
9+
; CHECK-NEXT: .reg .b32 %r<4>;
10+
; CHECK-EMPTY:
11+
; CHECK-NEXT: // %bb.0:
12+
; CHECK-NEXT: ld.param.u32 %r1, [szext_wrap_u32_param_0];
13+
; CHECK-NEXT: ld.param.u32 %r2, [szext_wrap_u32_param_1];
14+
; CHECK-NEXT: szext.wrap.u32 %r3, %r1, %r2;
15+
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
16+
; CHECK-NEXT: ret;
17+
%c = call i32 @llvm.nvvm.zext.inreg.wrap(i32 %a, i32 %b)
18+
ret i32 %c
19+
}
20+
21+
define i32 @szext_clamp_u32(i32 %a, i32 %b) {
22+
; CHECK-LABEL: szext_clamp_u32(
23+
; CHECK: {
24+
; CHECK-NEXT: .reg .b32 %r<4>;
25+
; CHECK-EMPTY:
26+
; CHECK-NEXT: // %bb.0:
27+
; CHECK-NEXT: ld.param.u32 %r1, [szext_clamp_u32_param_0];
28+
; CHECK-NEXT: ld.param.u32 %r2, [szext_clamp_u32_param_1];
29+
; CHECK-NEXT: szext.clamp.u32 %r3, %r1, %r2;
30+
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
31+
; CHECK-NEXT: ret;
32+
%c = call i32 @llvm.nvvm.zext.inreg.clamp(i32 %a, i32 %b)
33+
ret i32 %c
34+
}
35+
36+
define i32 @szext_wrap_s32(i32 %a, i32 %b) {
37+
; CHECK-LABEL: szext_wrap_s32(
38+
; CHECK: {
39+
; CHECK-NEXT: .reg .b32 %r<4>;
40+
; CHECK-EMPTY:
41+
; CHECK-NEXT: // %bb.0:
42+
; CHECK-NEXT: ld.param.u32 %r1, [szext_wrap_s32_param_0];
43+
; CHECK-NEXT: ld.param.u32 %r2, [szext_wrap_s32_param_1];
44+
; CHECK-NEXT: szext.wrap.s32 %r3, %r1, %r2;
45+
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
46+
; CHECK-NEXT: ret;
47+
%c = call i32 @llvm.nvvm.sext.inreg.wrap(i32 %a, i32 %b)
48+
ret i32 %c
49+
}
50+
51+
define i32 @szext_clamp_s32(i32 %a, i32 %b) {
52+
; CHECK-LABEL: szext_clamp_s32(
53+
; CHECK: {
54+
; CHECK-NEXT: .reg .b32 %r<4>;
55+
; CHECK-EMPTY:
56+
; CHECK-NEXT: // %bb.0:
57+
; CHECK-NEXT: ld.param.u32 %r1, [szext_clamp_s32_param_0];
58+
; CHECK-NEXT: ld.param.u32 %r2, [szext_clamp_s32_param_1];
59+
; CHECK-NEXT: szext.clamp.s32 %r3, %r1, %r2;
60+
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
61+
; CHECK-NEXT: ret;
62+
%c = call i32 @llvm.nvvm.sext.inreg.clamp(i32 %a, i32 %b)
63+
ret i32 %c
64+
}
65+

0 commit comments

Comments
 (0)