Skip to content

Commit df6894a

Browse files
committed
Merge remote-tracking branch 'origin/main' into vplan-runtime-checks
2 parents e5b8af3 + 30af6fb commit df6894a

File tree

18 files changed

+616
-147
lines changed

18 files changed

+616
-147
lines changed

flang/include/flang/Runtime/CUDA/allocatable.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,16 @@ int RTDECL(CUFAllocatableAllocate)(Descriptor &, bool hasStat = false,
2525
/// Perform allocation of the descriptor without synchronization. Assign data
2626
/// from source.
2727
int RTDEF(CUFAllocatableAllocateSource)(Descriptor &alloc,
28-
const Descriptor &source, bool hasStat, const Descriptor *errMsg,
29-
const char *sourceFile, int sourceLine);
28+
const Descriptor &source, bool hasStat = false,
29+
const Descriptor *errMsg = nullptr, const char *sourceFile = nullptr,
30+
int sourceLine = 0);
3031

3132
/// Perform allocation of the descriptor with synchronization of it when
3233
/// necessary. Assign data from source.
3334
int RTDEF(CUFAllocatableAllocateSourceSync)(Descriptor &alloc,
34-
const Descriptor &source, bool hasStat, const Descriptor *errMsg,
35-
const char *sourceFile, int sourceLine);
35+
const Descriptor &source, bool hasStat = false,
36+
const Descriptor *errMsg = nullptr, const char *sourceFile = nullptr,
37+
int sourceLine = 0);
3638

3739
/// Perform deallocation of the descriptor with synchronization of it when
3840
/// necessary.

llvm/lib/Target/AMDGPU/VOP3Instructions.td

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -899,7 +899,7 @@ def VOP3_CVT_SCALE_FP4FP8BF8_F32_Profile : VOP3_Profile<VOPProfile<[i32, f32, f3
899899
let HasOMod = 0;
900900
}
901901

902-
def VOP3_CVT_SCALE_PK_F32_FP4FP8BF8_Profile : VOP3_Profile<VOPProfile<[v2f32, i32, f32, untyped]>,
902+
class VOP3_CVT_SCALE_PK_F16BF16F32_FP4FP8BF8_Profile<ValueType DstTy> : VOP3_Profile<VOPProfile<[DstTy, i32, f32, untyped]>,
903903
VOP3_OPSEL> {
904904
let InsVOP3OpSel = (ins FP32InputMods:$src0_modifiers, Src0RC64:$src0,
905905
FP32InputMods:$src1_modifiers, Src1RC64:$src1,
@@ -929,7 +929,7 @@ let SubtargetPredicate = HasFP8ConversionScaleInsts, mayRaiseFPException = 0 in
929929
defm V_CVT_SCALEF32_F16_FP8 : VOP3Inst<"v_cvt_scalef32_f16_fp8", VOP3_CVT_SCALE_F1632_FP8BF8_Profile<f16>>;
930930
defm V_CVT_SCALEF32_F32_FP8 : VOP3Inst<"v_cvt_scalef32_f32_fp8", VOP3_CVT_SCALE_F1632_FP8BF8_Profile<f32>>;
931931
defm V_CVT_SCALEF32_PK_FP8_F32 : VOP3Inst<"v_cvt_scalef32_pk_fp8_f32", VOP3_CVT_SCALE_FP4FP8BF8_F32_Profile>;
932-
defm V_CVT_SCALEF32_PK_F32_FP8 : VOP3Inst<"v_cvt_scalef32_pk_f32_fp8", VOP3_CVT_SCALE_PK_F32_FP4FP8BF8_Profile>;
932+
defm V_CVT_SCALEF32_PK_F32_FP8 : VOP3Inst<"v_cvt_scalef32_pk_f32_fp8", VOP3_CVT_SCALE_PK_F16BF16F32_FP4FP8BF8_Profile<v2f32>>;
933933
defm V_CVT_SCALEF32_PK_FP8_F16 : VOP3Inst<"v_cvt_scalef32_pk_fp8_f16", VOP3_CVT_SCALE_PK_FP8BF8_F16BF16_Profile>;
934934
defm V_CVT_SCALEF32_PK_FP8_BF16 : VOP3Inst<"v_cvt_scalef32_pk_fp8_bf16", VOP3_CVT_SCALE_PK_FP8BF8_F16BF16_Profile>;
935935
}
@@ -938,14 +938,16 @@ let SubtargetPredicate = HasBF8ConversionScaleInsts, mayRaiseFPException = 0 in
938938
defm V_CVT_SCALEF32_F16_BF8 : VOP3Inst<"v_cvt_scalef32_f16_bf8", VOP3_CVT_SCALE_F1632_FP8BF8_Profile<f16>>;
939939
defm V_CVT_SCALEF32_F32_BF8 : VOP3Inst<"v_cvt_scalef32_f32_bf8", VOP3_CVT_SCALE_F1632_FP8BF8_Profile<f32>>;
940940
defm V_CVT_SCALEF32_PK_BF8_F32 : VOP3Inst<"v_cvt_scalef32_pk_bf8_f32", VOP3_CVT_SCALE_FP4FP8BF8_F32_Profile>;
941-
defm V_CVT_SCALEF32_PK_F32_BF8 : VOP3Inst<"v_cvt_scalef32_pk_f32_bf8", VOP3_CVT_SCALE_PK_F32_FP4FP8BF8_Profile>;
941+
defm V_CVT_SCALEF32_PK_F32_BF8 : VOP3Inst<"v_cvt_scalef32_pk_f32_bf8", VOP3_CVT_SCALE_PK_F16BF16F32_FP4FP8BF8_Profile<v2f32>>;
942942
defm V_CVT_SCALEF32_PK_BF8_F16 : VOP3Inst<"v_cvt_scalef32_pk_bf8_f16", VOP3_CVT_SCALE_PK_FP8BF8_F16BF16_Profile>;
943943
defm V_CVT_SCALEF32_PK_BF8_BF16 : VOP3Inst<"v_cvt_scalef32_pk_bf8_bf16", VOP3_CVT_SCALE_PK_FP8BF8_F16BF16_Profile>;
944944
}
945945

946946
let SubtargetPredicate = HasFP4ConversionScaleInsts, mayRaiseFPException = 0 in {
947-
defm V_CVT_SCALEF32_PK_F32_FP4 : VOP3Inst<"v_cvt_scalef32_pk_f32_fp4", VOP3_CVT_SCALE_PK_F32_FP4FP8BF8_Profile>;
947+
defm V_CVT_SCALEF32_PK_F32_FP4 : VOP3Inst<"v_cvt_scalef32_pk_f32_fp4", VOP3_CVT_SCALE_PK_F16BF16F32_FP4FP8BF8_Profile<v2f32>>;
948948
defm V_CVT_SCALEF32_PK_FP4_F32 : VOP3Inst<"v_cvt_scalef32_pk_fp4_f32", VOP3_CVT_SCALE_FP4FP8BF8_F32_Profile>;
949+
defm V_CVT_SCALEF32_PK_F16_FP4 : VOP3Inst<"v_cvt_scalef32_pk_f16_fp4", VOP3_CVT_SCALE_PK_F16BF16F32_FP4FP8BF8_Profile<v2f16>>;
950+
defm V_CVT_SCALEF32_PK_BF16_FP4 : VOP3Inst<"v_cvt_scalef32_pk_bf16_fp4", VOP3_CVT_SCALE_PK_F16BF16F32_FP4FP8BF8_Profile<v2bf16>>;
949951
}
950952

951953
let SubtargetPredicate = isGFX10Plus in {
@@ -1889,4 +1891,6 @@ defm V_CVT_SCALEF32_PK_BF8_BF16: VOP3OpSel_Real_gfx9 <0x245>;
18891891
let OtherPredicates = [HasFP4ConversionScaleInsts] in {
18901892
defm V_CVT_SCALEF32_PK_F32_FP4 : VOP3OpSel_Real_gfx9 <0x23f>;
18911893
defm V_CVT_SCALEF32_PK_FP4_F32 : VOP3OpSel_Real_gfx9 <0x23d>;
1894+
defm V_CVT_SCALEF32_PK_F16_FP4 : VOP3OpSel_Real_gfx9 <0x250>;
1895+
defm V_CVT_SCALEF32_PK_BF16_FP4 : VOP3OpSel_Real_gfx9 <0x251>;
18921896
}

llvm/lib/Target/DirectX/DXContainerGlobals.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,13 @@ bool DXContainerGlobals::runOnModule(Module &M) {
7878
}
7979

8080
GlobalVariable *DXContainerGlobals::getFeatureFlags(Module &M) {
81-
const uint64_t FeatureFlags =
82-
static_cast<uint64_t>(getAnalysis<ShaderFlagsAnalysisWrapper>()
83-
.getShaderFlags()
84-
.getFeatureFlags());
81+
uint64_t CombinedFeatureFlags = getAnalysis<ShaderFlagsAnalysisWrapper>()
82+
.getShaderFlags()
83+
.getCombinedFlags()
84+
.getFeatureFlags();
8585

8686
Constant *FeatureFlagsConstant =
87-
ConstantInt::get(M.getContext(), APInt(64, FeatureFlags));
87+
ConstantInt::get(M.getContext(), APInt(64, CombinedFeatureFlags));
8888
return buildContainerGlobal(M, FeatureFlagsConstant, "dx.sfi0", "SFI0");
8989
}
9090

llvm/lib/Target/DirectX/DXILShaderFlags.cpp

Lines changed: 71 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,36 +13,54 @@
1313

1414
#include "DXILShaderFlags.h"
1515
#include "DirectX.h"
16+
#include "llvm/ADT/STLExtras.h"
1617
#include "llvm/IR/Instruction.h"
1718
#include "llvm/IR/Module.h"
1819
#include "llvm/Support/FormatVariadic.h"
20+
#include "llvm/Support/raw_ostream.h"
1921

2022
using namespace llvm;
2123
using namespace llvm::dxil;
2224

23-
static void updateFlags(ComputedShaderFlags &Flags, const Instruction &I) {
24-
Type *Ty = I.getType();
25-
if (Ty->isDoubleTy()) {
26-
Flags.Doubles = true;
25+
static void updateFunctionFlags(ComputedShaderFlags &CSF,
26+
const Instruction &I) {
27+
if (!CSF.Doubles)
28+
CSF.Doubles = I.getType()->isDoubleTy();
29+
30+
if (!CSF.Doubles) {
31+
for (Value *Op : I.operands())
32+
CSF.Doubles |= Op->getType()->isDoubleTy();
33+
}
34+
if (CSF.Doubles) {
2735
switch (I.getOpcode()) {
2836
case Instruction::FDiv:
2937
case Instruction::UIToFP:
3038
case Instruction::SIToFP:
3139
case Instruction::FPToUI:
3240
case Instruction::FPToSI:
33-
Flags.DX11_1_DoubleExtensions = true;
41+
// TODO: To be set if I is a call to DXIL intrinsic DXIL::Opcode::Fma
42+
// https://github.com/llvm/llvm-project/issues/114554
43+
CSF.DX11_1_DoubleExtensions = true;
3444
break;
3545
}
3646
}
3747
}
3848

39-
ComputedShaderFlags ComputedShaderFlags::computeFlags(Module &M) {
40-
ComputedShaderFlags Flags;
41-
for (const auto &F : M)
49+
void ModuleShaderFlags::initialize(const Module &M) {
50+
// Collect shader flags for each of the functions
51+
for (const auto &F : M.getFunctionList()) {
52+
if (F.isDeclaration())
53+
continue;
54+
ComputedShaderFlags CSF;
4255
for (const auto &BB : F)
4356
for (const auto &I : BB)
44-
updateFlags(Flags, I);
45-
return Flags;
57+
updateFunctionFlags(CSF, I);
58+
// Insert shader flag mask for function F
59+
FunctionFlags.push_back({&F, CSF});
60+
// Update combined shader flags mask
61+
CombinedSFMask.merge(CSF);
62+
}
63+
llvm::sort(FunctionFlags);
4664
}
4765

4866
void ComputedShaderFlags::print(raw_ostream &OS) const {
@@ -63,20 +81,58 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
6381
OS << ";\n";
6482
}
6583

84+
/// Return the shader flags mask of the specified function Func.
85+
const ComputedShaderFlags &
86+
ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
87+
const auto Iter = llvm::lower_bound(
88+
FunctionFlags, Func,
89+
[](const std::pair<const Function *, ComputedShaderFlags> FSM,
90+
const Function *FindFunc) { return (FSM.first < FindFunc); });
91+
assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
92+
"No Shader Flags Mask exists for function");
93+
return Iter->second;
94+
}
95+
96+
//===----------------------------------------------------------------------===//
97+
// ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass
98+
99+
// Provide an explicit template instantiation for the static ID.
66100
AnalysisKey ShaderFlagsAnalysis::Key;
67101

68-
ComputedShaderFlags ShaderFlagsAnalysis::run(Module &M,
69-
ModuleAnalysisManager &AM) {
70-
return ComputedShaderFlags::computeFlags(M);
102+
ModuleShaderFlags ShaderFlagsAnalysis::run(Module &M,
103+
ModuleAnalysisManager &AM) {
104+
ModuleShaderFlags MSFI;
105+
MSFI.initialize(M);
106+
return MSFI;
71107
}
72108

73109
PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
74110
ModuleAnalysisManager &AM) {
75-
ComputedShaderFlags Flags = AM.getResult<ShaderFlagsAnalysis>(M);
76-
Flags.print(OS);
111+
const ModuleShaderFlags &FlagsInfo = AM.getResult<ShaderFlagsAnalysis>(M);
112+
// Print description of combined shader flags for all module functions
113+
OS << "; Combined Shader Flags for Module\n";
114+
FlagsInfo.getCombinedFlags().print(OS);
115+
// Print shader flags mask for each of the module functions
116+
OS << "; Shader Flags for Module Functions\n";
117+
for (const auto &F : M.getFunctionList()) {
118+
if (F.isDeclaration())
119+
continue;
120+
auto SFMask = FlagsInfo.getFunctionFlags(&F);
121+
OS << formatv("; Function {0} : {1:x8}\n;\n", F.getName(),
122+
(uint64_t)(SFMask));
123+
}
124+
77125
return PreservedAnalyses::all();
78126
}
79127

128+
//===----------------------------------------------------------------------===//
129+
// ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass
130+
131+
bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) {
132+
MSFI.initialize(M);
133+
return false;
134+
}
135+
80136
char ShaderFlagsAnalysisWrapper::ID = 0;
81137

82138
INITIALIZE_PASS(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis",

llvm/lib/Target/DirectX/DXILShaderFlags.h

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414
#ifndef LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H
1515
#define LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H
1616

17+
#include "llvm/IR/Function.h"
1718
#include "llvm/IR/PassManager.h"
1819
#include "llvm/Pass.h"
1920
#include "llvm/Support/Compiler.h"
2021
#include "llvm/Support/Debug.h"
2122
#include "llvm/Support/raw_ostream.h"
2223
#include <cstdint>
24+
#include <memory>
2325

2426
namespace llvm {
2527
class Module;
@@ -43,15 +45,23 @@ struct ComputedShaderFlags {
4345
constexpr uint64_t getMask(int Bit) const {
4446
return Bit != -1 ? 1ull << Bit : 0;
4547
}
48+
49+
uint64_t getModuleFlags() const {
50+
uint64_t ModuleFlags = 0;
51+
#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \
52+
ModuleFlags |= FlagName ? getMask(DxilModuleBit) : 0ull;
53+
#include "llvm/BinaryFormat/DXContainerConstants.def"
54+
return ModuleFlags;
55+
}
56+
4657
operator uint64_t() const {
47-
uint64_t FlagValue = 0;
58+
uint64_t FlagValue = getModuleFlags();
4859
#define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str) \
4960
FlagValue |= FlagName ? getMask(DxilModuleBit) : 0ull;
50-
#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \
51-
FlagValue |= FlagName ? getMask(DxilModuleBit) : 0ull;
5261
#include "llvm/BinaryFormat/DXContainerConstants.def"
5362
return FlagValue;
5463
}
64+
5565
uint64_t getFeatureFlags() const {
5666
uint64_t FeatureFlags = 0;
5767
#define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str) \
@@ -60,21 +70,43 @@ struct ComputedShaderFlags {
6070
return FeatureFlags;
6171
}
6272

63-
static ComputedShaderFlags computeFlags(Module &M);
73+
void merge(const uint64_t IVal) {
74+
#define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str) \
75+
FlagName |= (IVal & getMask(DxilModuleBit));
76+
#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \
77+
FlagName |= (IVal & getMask(DxilModuleBit));
78+
#include "llvm/BinaryFormat/DXContainerConstants.def"
79+
return;
80+
}
81+
6482
void print(raw_ostream &OS = dbgs()) const;
6583
LLVM_DUMP_METHOD void dump() const { print(); }
6684
};
6785

86+
struct ModuleShaderFlags {
87+
void initialize(const Module &);
88+
const ComputedShaderFlags &getFunctionFlags(const Function *) const;
89+
const ComputedShaderFlags &getCombinedFlags() const { return CombinedSFMask; }
90+
91+
private:
92+
/// Vector of sorted Function-Shader Flag mask pairs representing properties
93+
/// of each of the functions in the module. Shader Flags of each function
94+
/// represent both module-level and function-level flags
95+
SmallVector<std::pair<Function const *, ComputedShaderFlags>> FunctionFlags;
96+
/// Combined Shader Flag Mask of all functions of the module
97+
ComputedShaderFlags CombinedSFMask{};
98+
};
99+
68100
class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
69101
friend AnalysisInfoMixin<ShaderFlagsAnalysis>;
70102
static AnalysisKey Key;
71103

72104
public:
73105
ShaderFlagsAnalysis() = default;
74106

75-
using Result = ComputedShaderFlags;
107+
using Result = ModuleShaderFlags;
76108

77-
ComputedShaderFlags run(Module &M, ModuleAnalysisManager &AM);
109+
ModuleShaderFlags run(Module &M, ModuleAnalysisManager &AM);
78110
};
79111

80112
/// Printer pass for ShaderFlagsAnalysis results.
@@ -92,19 +124,16 @@ class ShaderFlagsAnalysisPrinter
92124
/// This is required because the passes that will depend on this are codegen
93125
/// passes which run through the legacy pass manager.
94126
class ShaderFlagsAnalysisWrapper : public ModulePass {
95-
ComputedShaderFlags Flags;
127+
ModuleShaderFlags MSFI;
96128

97129
public:
98130
static char ID;
99131

100132
ShaderFlagsAnalysisWrapper() : ModulePass(ID) {}
101133

102-
const ComputedShaderFlags &getShaderFlags() { return Flags; }
134+
const ModuleShaderFlags &getShaderFlags() { return MSFI; }
103135

104-
bool runOnModule(Module &M) override {
105-
Flags = ComputedShaderFlags::computeFlags(M);
106-
return false;
107-
}
136+
bool runOnModule(Module &M) override;
108137

109138
void getAnalysisUsage(AnalysisUsage &AU) const override {
110139
AU.setPreservesAll();

0 commit comments

Comments
 (0)