Skip to content

Commit fd10713

Browse files
authored
Merge branch 'main' into cuf_int2float
2 parents 5d14255 + abe92a5 commit fd10713

33 files changed

+1289
-104
lines changed

flang/module/cudadevice.f90

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,13 @@ attributes(device) real(8) function rsqrt(x) bind(c,name='__nv_rsqrt')
325325
end function
326326
end interface
327327

328+
interface saturate
329+
attributes(device) real function __saturatef(r) bind(c, name='__nv_saturatef')
330+
!dir$ ignore_tkr (d) r
331+
real, value :: r
332+
end function
333+
end interface
334+
328335
interface __sad
329336
attributes(device) integer function __sad(i,j,k) bind(c, name='__nv_sad')
330337
!dir$ ignore_tkr (d) i, (d) j, (d) k
@@ -471,6 +478,90 @@ attributes(device) real(8) function sinpi(x) bind(c,name='__nv_sinpi')
471478
end function
472479
end interface
473480

481+
interface __float2int_rd
482+
attributes(device) integer function __float2int_rd(r) bind(c, name='__nv_float2int_rd')
483+
!dir$ ignore_tkr (d) r
484+
real, value :: r
485+
end function
486+
end interface
487+
488+
interface __float2int_rn
489+
attributes(device) integer function __float2int_rn(r) bind(c, name='__nv_float2int_rn')
490+
!dir$ ignore_tkr (d) r
491+
real, value :: r
492+
end function
493+
end interface
494+
495+
interface __float2int_ru
496+
attributes(device) integer function __float2int_ru(r) bind(c, name='__nv_float2int_ru')
497+
!dir$ ignore_tkr (d) r
498+
real, value :: r
499+
end function
500+
end interface
501+
502+
interface __float2int_rz
503+
attributes(device) integer function __float2int_rz(r) bind(c, name='__nv_float2int_rz')
504+
!dir$ ignore_tkr (d) r
505+
real, value :: r
506+
end function
507+
end interface
508+
509+
interface __float2uint_rd
510+
attributes(device) integer function __float2uint_rd(r) bind(c, name='__nv_float2uint_rd')
511+
!dir$ ignore_tkr (d) r
512+
real, value :: r
513+
end function
514+
end interface
515+
516+
interface __float2uint_rn
517+
attributes(device) integer function __float2uint_rn(r) bind(c, name='__nv_float2uint_rn')
518+
!dir$ ignore_tkr (d) r
519+
real, value :: r
520+
end function
521+
end interface
522+
523+
interface __float2uint_ru
524+
attributes(device) integer function __float2uint_ru(r) bind(c, name='__nv_float2uint_ru')
525+
!dir$ ignore_tkr (d) r
526+
real, value :: r
527+
end function
528+
end interface
529+
530+
interface __float2uint_rz
531+
attributes(device) integer function __float2uint_rz(r) bind(c, name='__nv_float2uint_rz')
532+
!dir$ ignore_tkr (d) r
533+
real, value :: r
534+
end function
535+
end interface
536+
537+
interface __float2ll_rd
538+
attributes(device) integer(8) function __float2ll_rd(r) bind(c, name='__nv_float2ll_rd')
539+
!dir$ ignore_tkr (d) r
540+
real, value :: r
541+
end function
542+
end interface
543+
544+
interface __float2ll_rn
545+
attributes(device) integer(8) function __float2ll_rn(r) bind(c, name='__nv_float2ll_rn')
546+
!dir$ ignore_tkr (d) r
547+
real, value :: r
548+
end function
549+
end interface
550+
551+
interface __float2ll_ru
552+
attributes(device) integer(8) function __float2ll_ru(r) bind(c, name='__nv_float2ll_ru')
553+
!dir$ ignore_tkr (d) r
554+
real, value :: r
555+
end function
556+
end interface
557+
558+
interface __float2ll_rz
559+
attributes(device) integer(8) function __float2ll_rz(r) bind(c, name='__nv_float2ll_rz')
560+
!dir$ ignore_tkr (d) r
561+
real, value :: r
562+
end function
563+
end interface
564+
474565
interface __half2float
475566
attributes(device) real function __half2float(i) bind(c, name='__nv_half2float')
476567
!dir$ ignore_tkr (d) i

flang/test/Lower/CUDA/cuda-libdevice.cuf

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ end subroutine
131131
! CHECK: %{{.*}} = fir.call @__nv_double2ll_ru(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f64) -> i64
132132
! CHECK: %{{.*}} = fir.call @__nv_double2ll_rz(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f64) -> i64
133133

134-
135134
attributes(global) subroutine test_drcp_rX()
136135
double precision :: res
137136
double precision :: r
@@ -162,6 +161,30 @@ end subroutine
162161
! CHECK: %{{.*}} = fir.call @__nv_double2ull_ru(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f64) -> i64
163162
! CHECK: %{{.*}} = fir.call @__nv_double2ull_rz(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f64) -> i64
164163

164+
attributes(global) subroutine test_saturatef()
165+
real :: res
166+
real :: r
167+
res = __saturatef(r)
168+
end subroutine
169+
170+
! CHECK-LABEL: _QPtest_saturatef
171+
! CHECK: %{{.*}} = fir.call @__nv_saturatef(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32
172+
173+
attributes(global) subroutine test_float2ll_rX()
174+
integer(8) :: res
175+
real :: r
176+
res = __float2ll_rd(r)
177+
res = __float2ll_rn(r)
178+
res = __float2ll_ru(r)
179+
res = __float2ll_rz(r)
180+
end subroutine
181+
182+
! CHECK-LABEL: _QPtest_float2ll_rx
183+
! CHECK: %{{.*}} = fir.call @__nv_float2ll_rd(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> i64
184+
! CHECK: %{{.*}} = fir.call @__nv_float2ll_rn(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> i64
185+
! CHECK: %{{.*}} = fir.call @__nv_float2ll_ru(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> i64
186+
! CHECK: %{{.*}} = fir.call @__nv_float2ll_rz(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> i64
187+
165188
attributes(global) subroutine test_ll2float_rX()
166189
real :: res
167190
integer(8) :: i
@@ -191,3 +214,33 @@ end subroutine
191214
! CHECK: %{{.*}} = fir.call @__nv_int2float_rn(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> f32
192215
! CHECK: %{{.*}} = fir.call @__nv_int2float_ru(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> f32
193216
! CHECK: %{{.*}} = fir.call @__nv_int2float_rz(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> f32
217+
218+
attributes(global) subroutine test_float2int_rX()
219+
integer :: res
220+
real :: r
221+
res = __float2int_rd(r)
222+
res = __float2int_rn(r)
223+
res = __float2int_ru(r)
224+
res = __float2int_rz(r)
225+
end subroutine
226+
227+
! CHECK-LABEL: _QPtest_float2int_rx
228+
! CHECK: %{{.*}} = fir.call @__nv_float2int_rd(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> i32
229+
! CHECK: %{{.*}} = fir.call @__nv_float2int_rn(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> i32
230+
! CHECK: %{{.*}} = fir.call @__nv_float2int_ru(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> i32
231+
! CHECK: %{{.*}} = fir.call @__nv_float2int_rz(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> i32
232+
233+
attributes(global) subroutine test_float2uint_rX()
234+
integer :: res
235+
real :: r
236+
res = __float2uint_rd(r)
237+
res = __float2uint_rn(r)
238+
res = __float2uint_ru(r)
239+
res = __float2uint_rz(r)
240+
end subroutine
241+
242+
! CHECK-LABEL: _QPtest_float2uint_rx
243+
! CHECK: %{{.*}} = fir.call @__nv_float2uint_rd(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> i32
244+
! CHECK: %{{.*}} = fir.call @__nv_float2uint_rn(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> i32
245+
! CHECK: %{{.*}} = fir.call @__nv_float2uint_ru(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> i32
246+
! CHECK: %{{.*}} = fir.call @__nv_float2uint_rz(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> i32

llvm/include/llvm/BinaryFormat/DXContainer.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,6 @@ enum class FeatureFlags : uint64_t {
158158
static_assert((uint64_t)FeatureFlags::NextUnusedBit <= 1ull << 63,
159159
"Shader flag bits exceed enum size.");
160160

161-
LLVM_ABI ArrayRef<EnumEntry<llvm::dxil::ResourceClass>> getResourceClasses();
162-
163161
#define ROOT_SIGNATURE_FLAG(Num, Val) Val = Num,
164162
enum class RootFlags : uint32_t {
165163
#include "DXContainerConstants.def"

llvm/include/llvm/Support/DXILABI.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
#ifndef LLVM_SUPPORT_DXILABI_H
1818
#define LLVM_SUPPORT_DXILABI_H
1919

20+
#include "llvm/ADT/StringRef.h"
21+
#include "llvm/Support/ScopedPrinter.h"
2022
#include <cstdint>
2123

2224
namespace llvm {
@@ -99,6 +101,10 @@ enum class SamplerFeedbackType : uint32_t {
99101
const unsigned MinWaveSize = 4;
100102
const unsigned MaxWaveSize = 128;
101103

104+
LLVM_ABI ArrayRef<EnumEntry<ResourceClass>> getResourceClasses();
105+
106+
LLVM_ABI StringRef getResourceClassName(ResourceClass RC);
107+
102108
} // namespace dxil
103109
} // namespace llvm
104110

llvm/lib/Analysis/DXILResource.cpp

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "llvm/IR/Metadata.h"
2121
#include "llvm/IR/Module.h"
2222
#include "llvm/InitializePasses.h"
23+
#include "llvm/Support/DXILABI.h"
2324
#include "llvm/Support/FormatVariadic.h"
2425
#include <cstdint>
2526
#include <optional>
@@ -29,20 +30,6 @@
2930
using namespace llvm;
3031
using namespace dxil;
3132

32-
static StringRef getResourceClassName(ResourceClass RC) {
33-
switch (RC) {
34-
case ResourceClass::SRV:
35-
return "SRV";
36-
case ResourceClass::UAV:
37-
return "UAV";
38-
case ResourceClass::CBuffer:
39-
return "CBuffer";
40-
case ResourceClass::Sampler:
41-
return "Sampler";
42-
}
43-
llvm_unreachable("Unhandled ResourceClass");
44-
}
45-
4633
static StringRef getResourceKindName(ResourceKind RK) {
4734
switch (RK) {
4835
case ResourceKind::Texture1D:

llvm/lib/BinaryFormat/DXContainer.cpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,6 @@ ArrayRef<EnumEntry<SigComponentType>> dxbc::getSigComponentTypes() {
6060
return ArrayRef(SigComponentTypes);
6161
}
6262

63-
static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = {
64-
{"SRV", llvm::dxil::ResourceClass::SRV},
65-
{"UAV", llvm::dxil::ResourceClass::UAV},
66-
{"CBV", llvm::dxil::ResourceClass::CBuffer},
67-
{"Sampler", llvm::dxil::ResourceClass::Sampler},
68-
};
69-
70-
ArrayRef<EnumEntry<llvm::dxil::ResourceClass>> dxbc::getResourceClasses() {
71-
return ArrayRef(ResourceClassNames);
72-
}
73-
7463
static const EnumEntry<RootFlags> RootFlagNames[] = {
7564
#define ROOT_SIGNATURE_FLAG(Val, Enum) {#Enum, RootFlags::Enum},
7665
#include "llvm/BinaryFormat/DXContainerConstants.def"

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26018,7 +26018,10 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
2601826018
// Combine an extract of an extract into a single extract_subvector.
2601926019
// ext (ext X, C), 0 --> ext X, C
2602026020
if (ExtIdx == 0 && V.getOpcode() == ISD::EXTRACT_SUBVECTOR && V.hasOneUse()) {
26021-
if (TLI.isExtractSubvectorCheap(NVT, V.getOperand(0).getValueType(),
26021+
// The index has to be a multiple of the new result type's known minimum
26022+
// vector length.
26023+
if (V.getConstantOperandVal(1) % NVT.getVectorMinNumElements() == 0 &&
26024+
TLI.isExtractSubvectorCheap(NVT, V.getOperand(0).getValueType(),
2602226025
V.getConstantOperandVal(1)) &&
2602326026
TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NVT)) {
2602426027
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NVT, V.getOperand(0),

llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ static raw_ostream &operator<<(raw_ostream &OS,
9494

9595
static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) {
9696
OS << enumToStringRef(dxil::ResourceClass(llvm::to_underlying(Type)),
97-
dxbc::getResourceClasses());
97+
dxil::getResourceClasses());
9898

9999
return OS;
100100
}

llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ MDNode *MetadataBuilder::BuildRootDescriptor(const RootDescriptor &Descriptor) {
121121
IRBuilder<> Builder(Ctx);
122122
StringRef ResName =
123123
enumToStringRef(dxil::ResourceClass(to_underlying(Descriptor.Type)),
124-
dxbc::getResourceClasses());
124+
dxil::getResourceClasses());
125125
assert(!ResName.empty() && "Provided an invalid Resource Class");
126126
SmallString<7> Name({"Root", ResName});
127127
Metadata *Operands[] = {
@@ -163,7 +163,7 @@ MDNode *MetadataBuilder::BuildDescriptorTableClause(
163163
IRBuilder<> Builder(Ctx);
164164
StringRef ResName =
165165
enumToStringRef(dxil::ResourceClass(to_underlying(Clause.Type)),
166-
dxbc::getResourceClasses());
166+
dxil::getResourceClasses());
167167
assert(!ResName.empty() && "Provided an invalid Resource Class");
168168
Metadata *Operands[] = {
169169
MDString::get(Ctx, ResName),

llvm/lib/Support/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ add_llvm_component_library(LLVMSupport
182182
DivisionByConstantInfo.cpp
183183
DAGDeltaAlgorithm.cpp
184184
DJB.cpp
185+
DXILABI.cpp
185186
DynamicAPInt.cpp
186187
ELFAttributes.cpp
187188
ELFAttrParserCompact.cpp

0 commit comments

Comments
 (0)