From b6dfe53cf7476cf64f2edf21361e43e4a3f3a9ff Mon Sep 17 00:00:00 2001 From: "S. Bharadwaj Yadavalli" Date: Wed, 27 Nov 2024 11:25:34 -0500 Subject: [PATCH 1/3] [Rebase] Propagate shader flags mask of callees to callers Add tests to verify propagation of shader flags --- llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 104 ++++++++--- llvm/lib/Target/DirectX/DXILShaderFlags.h | 20 +-- .../DirectX/ShaderFlags/double-extensions.ll | 7 + .../propagate-function-flags-test.ll | 167 ++++++++++++++++++ 4 files changed, 259 insertions(+), 39 deletions(-) create mode 100644 llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp index 2edfc707ce6c7..e956189f8ecd4 100644 --- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp +++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp @@ -13,9 +13,13 @@ #include "DXILShaderFlags.h" #include "DirectX.h" +#include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/DXILResource.h" #include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsDirectX.h" @@ -27,15 +31,24 @@ using namespace llvm; using namespace llvm::dxil; -static void updateFunctionFlags(ComputedShaderFlags &CSF, const Instruction &I, - DXILResourceTypeMap &DRTM) { +/// Update the shader flags mask based on the given instruction. +/// \param CSF Shader flags mask to update. +/// \param I Instruction to check. +void ModuleShaderFlags::updateFunctionFlags(ComputedShaderFlags &CSF, + const Instruction &I, + DXILResourceTypeMap &DRTM) { if (!CSF.Doubles) CSF.Doubles = I.getType()->isDoubleTy(); if (!CSF.Doubles) { - for (Value *Op : I.operands()) - CSF.Doubles |= Op->getType()->isDoubleTy(); + for (const Value *Op : I.operands()) { + if (Op->getType()->isDoubleTy()) { + CSF.Doubles = true; + break; + } + } } + if (CSF.Doubles) { switch (I.getOpcode()) { case Instruction::FDiv: @@ -43,8 +56,6 @@ static void updateFunctionFlags(ComputedShaderFlags &CSF, const Instruction &I, case Instruction::SIToFP: case Instruction::FPToUI: case Instruction::FPToSI: - // TODO: To be set if I is a call to DXIL intrinsic DXIL::Opcode::Fma - // https://github.com/llvm/llvm-project/issues/114554 CSF.DX11_1_DoubleExtensions = true; break; } @@ -62,27 +73,65 @@ static void updateFunctionFlags(ComputedShaderFlags &CSF, const Instruction &I, } } } + // Handle call instructions + if (auto *CI = dyn_cast(&I)) { + const Function *CF = CI->getCalledFunction(); + // Merge-in shader flags mask of the called function in the current module + if (FunctionFlags.contains(CF)) { + CSF.merge(FunctionFlags[CF]); + } + // TODO: Set DX11_1_DoubleExtensions if I is a call to DXIL intrinsic + // DXIL::Opcode::Fma https://github.com/llvm/llvm-project/issues/114554 + } } -void ModuleShaderFlags::initialize(const Module &M, DXILResourceTypeMap &DRTM) { - - // Collect shader flags for each of the functions - for (const auto &F : M.getFunctionList()) { - if (F.isDeclaration()) { - assert(!F.getName().starts_with("dx.op.") && - "DXIL Shader Flag analysis should not be run post-lowering."); - continue; +/// Construct ModuleShaderFlags for module Module M +void ModuleShaderFlags::initialize(Module &M, DXILResourceTypeMap &DRTM) { + CallGraph CG(M); + + // Compute Shader Flags Mask for all functions using post-order visit of SCC + // of the call graph. + for (scc_iterator SCCI = scc_begin(&CG); !SCCI.isAtEnd(); + ++SCCI) { + const std::vector &CurSCC = *SCCI; + + // Union of shader masks of all functions in CurSCC + ComputedShaderFlags SCCSF; + // List of functions in CurSCC that are neither external nor declarations + // and hence whose flags are collected + SmallVector CurSCCFuncs; + for (CallGraphNode *CGN : CurSCC) { + Function *F = CGN->getFunction(); + if (!F) + continue; + + if (F->isDeclaration()) { + assert(!F->getName().starts_with("dx.op.") && + "DXIL Shader Flag analysis should not be run post-lowering."); + continue; + } + + ComputedShaderFlags CSF; + for (const auto &BB : *F) + for (const auto &I : BB) + updateFunctionFlags(CSF, I, DRTM); + // Update combined shader flags mask for all functions in this SCC + SCCSF.merge(CSF); + + CurSCCFuncs.push_back(F); } - ComputedShaderFlags CSF; - for (const auto &BB : F) - for (const auto &I : BB) - updateFunctionFlags(CSF, I, DRTM); - // Insert shader flag mask for function F - FunctionFlags.push_back({&F, CSF}); - // Update combined shader flags mask - CombinedSFMask.merge(CSF); + + // Update combined shader flags mask for all functions of the module + CombinedSFMask.merge(SCCSF); + + // Shader flags mask of each of the functions in an SCC of the call graph is + // the union of all functions in the SCC. Update shader flags masks of + // functions in CurSCC accordingly. This is trivially true if SCC contains + // one function. + for (Function *F : CurSCCFuncs) + // Merge SCCSF with that of F + FunctionFlags[F].merge(SCCSF); } - llvm::sort(FunctionFlags); } void ComputedShaderFlags::print(raw_ostream &OS) const { @@ -106,12 +155,9 @@ void ComputedShaderFlags::print(raw_ostream &OS) const { /// Return the shader flags mask of the specified function Func. const ComputedShaderFlags & ModuleShaderFlags::getFunctionFlags(const Function *Func) const { - const auto Iter = llvm::lower_bound( - FunctionFlags, Func, - [](const std::pair FSM, - const Function *FindFunc) { return (FSM.first < FindFunc); }); + auto Iter = FunctionFlags.find(Func); assert((Iter != FunctionFlags.end() && Iter->first == Func) && - "No Shader Flags Mask exists for function"); + "Get Shader Flags : No Shader Flags Mask exists for function"); return Iter->second; } @@ -142,7 +188,7 @@ PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M, for (const auto &F : M.getFunctionList()) { if (F.isDeclaration()) continue; - auto SFMask = FlagsInfo.getFunctionFlags(&F); + const ComputedShaderFlags &SFMask = FlagsInfo.getFunctionFlags(&F); OS << formatv("; Function {0} : {1:x8}\n;\n", F.getName(), (uint64_t)(SFMask)); } diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h index 67ddab39d0f34..e6c6d56402c1a 100644 --- a/llvm/lib/Target/DirectX/DXILShaderFlags.h +++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h @@ -71,13 +71,11 @@ struct ComputedShaderFlags { return FeatureFlags; } - void merge(const uint64_t IVal) { + void merge(const ComputedShaderFlags CSF) { #define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str) \ - FlagName |= (IVal & getMask(DxilModuleBit)); -#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \ - FlagName |= (IVal & getMask(DxilModuleBit)); + FlagName |= CSF.FlagName; +#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) FlagName |= CSF.FlagName; #include "llvm/BinaryFormat/DXContainerConstants.def" - return; } void print(raw_ostream &OS = dbgs()) const; @@ -85,17 +83,19 @@ struct ComputedShaderFlags { }; struct ModuleShaderFlags { - void initialize(const Module &, DXILResourceTypeMap &DRTM); + void initialize(Module &, DXILResourceTypeMap &DRTM); const ComputedShaderFlags &getFunctionFlags(const Function *) const; const ComputedShaderFlags &getCombinedFlags() const { return CombinedSFMask; } private: - /// Vector of sorted Function-Shader Flag mask pairs representing properties - /// of each of the functions in the module. Shader Flags of each function - /// represent both module-level and function-level flags - SmallVector> FunctionFlags; + /// Map of Function-Shader Flag Mask pairs representing properties of each of + /// the functions in the module. Shader Flags of each function represent both + /// module-level and function-level flags + DenseMap FunctionFlags; /// Combined Shader Flag Mask of all functions of the module ComputedShaderFlags CombinedSFMask{}; + void updateFunctionFlags(ComputedShaderFlags &, const Instruction &, + DXILResourceTypeMap &); }; class ShaderFlagsAnalysis : public AnalysisInfoMixin { diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll index 6332ef806a0d8..d6df67626be5a 100644 --- a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll +++ b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll @@ -12,6 +12,13 @@ target triple = "dxil-pc-shadermodel6.7-library" ; CHECK-NEXT: ; ; CHECK-NEXT: ; Shader Flags for Module Functions +;CHECK: ; Function top_level : 0x00000044 +define double @top_level() #0 { + %r = call double @test_uitofp_i64(i64 5) + ret double %r +} + + ; CHECK: ; Function test_fdiv_double : 0x00000044 define double @test_fdiv_double(double %a, double %b) #0 { %res = fdiv double %a, %b diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll new file mode 100644 index 0000000000000..e7a2cf4d5b20f --- /dev/null +++ b/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll @@ -0,0 +1,167 @@ +; RUN: opt -S --passes="print-dx-shader-flags" 2>&1 %s | FileCheck %s + +target triple = "dxil-pc-shadermodel6.7-library" + +; CHECK: ; Combined Shader Flags for Module +; CHECK-NEXT: ; Shader Flags Value: 0x00000044 +; CHECK-NEXT: ; +; CHECK-NEXT: ; Note: shader requires additional functionality: +; CHECK-NEXT: ; Double-precision floating point +; CHECK-NEXT: ; Double-precision extensions for 11.1 +; CHECK-NEXT: ; Note: extra DXIL module flags: +; CHECK-NEXT: ; +; CHECK-NEXT: ; Shader Flags for Module Functions + +; Call Graph of test source +; main -> [get_fptoui_flag, get_sitofp_fdiv_flag] +; get_fptoui_flag -> [get_sitofp_uitofp_flag, call_get_uitofp_flag] +; get_sitofp_uitofp_flag -> [call_get_fptoui_flag, call_get_sitofp_flag] +; call_get_fptoui_flag -> [get_fptoui_flag] +; get_sitofp_fdiv_flag -> [get_no_flags, get_all_doubles_flags] +; get_all_doubles_flags -> [call_get_sitofp_fdiv_flag] +; call_get_sitofp_fdiv_flag -> [get_sitofp_fdiv_flag] +; call_get_sitofp_flag -> [get_sitofp_flag] +; call_get_uitofp_flag -> [get_uitofp_flag] +; get_sitofp_flag -> [] +; get_uitofp_flag -> [] +; get_no_flags -> [] +; +; Strongly Connected Component in the CG +; [get_fptoui_flag, get_sitofp_uitofp_flag, call_get_fptoui_flag] +; [get_sitofp_fdiv_flag, get_all_doubles_flags, call_get_sitofp_fdiv_flag] + +; +; CHECK: ; Function get_sitofp_flag : 0x00000044 +define double @get_sitofp_flag(i32 noundef %0) local_unnamed_addr #0 { + %2 = sitofp i32 %0 to double + ret double %2 +} + +; CHECK: ; Function call_get_sitofp_flag : 0x00000044 +define double @call_get_sitofp_flag(i32 noundef %0) local_unnamed_addr #0 { + %2 = tail call double @get_sitofp_flag(i32 noundef %0) + ret double %2 +} + +; CHECK: ; Function get_uitofp_flag : 0x00000044 +define double @get_uitofp_flag(i32 noundef %0) local_unnamed_addr #0 { + %2 = uitofp i32 %0 to double + ret double %2 +} + +; CHECK: ; Function call_get_uitofp_flag : 0x00000044 +define double @call_get_uitofp_flag(i32 noundef %0) local_unnamed_addr #0 { + %2 = tail call double @get_uitofp_flag(i32 noundef %0) + ret double %2 +} + +; CHECK: ; Function call_get_fptoui_flag : 0x00000044 +define double @call_get_fptoui_flag(double noundef %0) local_unnamed_addr #0 { + %2 = tail call double @get_fptoui_flag(double noundef %0) + ret double %2 +} + +; CHECK: ; Function get_fptoui_flag : 0x00000044 +define double @get_fptoui_flag(double noundef %0) local_unnamed_addr #0 { + %2 = fcmp ugt double %0, 5.000000e+00 + br i1 %2, label %6, label %3 + +3: ; preds = %1 + %4 = fptoui double %0 to i64 + %5 = tail call double @get_sitofp_uitofp_flag(i64 noundef %4) + br label %9 + +6: ; preds = %1 + %7 = fptoui double %0 to i32 + %8 = tail call double @call_get_uitofp_flag(i32 noundef %7) + br label %9 + +9: ; preds = %6, %3 + %10 = phi double [ %5, %3 ], [ %8, %6 ] + ret double %10 +} + +; CHECK: ; Function get_sitofp_uitofp_flag : 0x00000044 +define double @get_sitofp_uitofp_flag(i64 noundef %0) local_unnamed_addr #0 { + %2 = icmp ult i64 %0, 6 + br i1 %2, label %3, label %7 + +3: ; preds = %1 + %4 = add nuw nsw i64 %0, 1 + %5 = uitofp i64 %4 to double + %6 = tail call double @call_get_fptoui_flag(double noundef %5) + br label %10 + +7: ; preds = %1 + %8 = trunc i64 %0 to i32 + %9 = tail call double @call_get_sitofp_flag(i32 noundef %8) + br label %10 + +10: ; preds = %7, %3 + %11 = phi double [ %6, %3 ], [ %9, %7 ] + ret double %11 +} + +; CHECK: ; Function get_no_flags : 0x00000000 +define i32 @get_no_flags(i32 noundef %0) local_unnamed_addr #0 { + %2 = mul nsw i32 %0, %0 + ret i32 %2 +} + +; CHECK: ; Function call_get_sitofp_fdiv_flag : 0x00000044 +define i32 @call_get_sitofp_fdiv_flag(i32 noundef %0) local_unnamed_addr #0 { + %2 = icmp eq i32 %0, 0 + br i1 %2, label %5, label %3 + +3: ; preds = %1 + %4 = mul nsw i32 %0, %0 + br label %7 + +5: ; preds = %1 + %6 = tail call double @get_sitofp_fdiv_flag(i32 noundef 0) + br label %7 + +7: ; preds = %5, %3 + %8 = phi i32 [ %4, %3 ], [ 0, %5 ] + ret i32 %8 +} + +; CHECK: ; Function get_sitofp_fdiv_flag : 0x00000044 +define double @get_sitofp_fdiv_flag(i32 noundef %0) local_unnamed_addr #0 { + %2 = icmp sgt i32 %0, 5 + br i1 %2, label %3, label %6 + +3: ; preds = %1 + %4 = tail call i32 @get_no_flags(i32 noundef %0) + %5 = sitofp i32 %4 to double + br label %9 + +6: ; preds = %1 + %7 = tail call double @get_all_doubles_flags(i32 noundef %0) + %8 = fdiv double %7, 3.000000e+00 + br label %9 + +9: ; preds = %6, %3 + %10 = phi double [ %5, %3 ], [ %8, %6 ] + ret double %10 +} + +; CHECK: ; Function get_all_doubles_flags : 0x00000044 +define double @get_all_doubles_flags(i32 noundef %0) local_unnamed_addr #0 { + %2 = tail call i32 @call_get_sitofp_fdiv_flag(i32 noundef %0) + %3 = icmp eq i32 %2, 0 + %4 = select i1 %3, double 1.000000e+01, double 1.000000e+02 + ret double %4 +} + +; CHECK: ; Function main : 0x00000044 +define i32 @main() local_unnamed_addr #0 { + %1 = tail call double @get_fptoui_flag(double noundef 1.000000e+00) + %2 = tail call double @get_sitofp_fdiv_flag(i32 noundef 4) + %3 = fadd double %1, %2 + %4 = fcmp ogt double %3, 0.000000e+00 + %5 = zext i1 %4 to i32 + ret i32 %5 +} + +attributes #0 = { convergent norecurse nounwind "hlsl.export"} From e53cd26d830fe57c9760ad18144760265d82c72a Mon Sep 17 00:00:00 2001 From: "S. Bharadwaj Yadavalli" Date: Tue, 14 Jan 2025 11:10:00 -0500 Subject: [PATCH 2/3] Delete unnecessary #include --- llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp index e956189f8ecd4..4bcc01a90b170 100644 --- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp +++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp @@ -14,7 +14,6 @@ #include "DXILShaderFlags.h" #include "DirectX.h" #include "llvm/ADT/SCCIterator.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/DXILResource.h" From c1e134321ada333d908fb1f9c20b201d5b8798d6 Mon Sep 17 00:00:00 2001 From: "S. Bharadwaj Yadavalli" Date: Tue, 14 Jan 2025 12:10:47 -0500 Subject: [PATCH 3/3] Delete braces around single-statement if expression --- llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp index 4bcc01a90b170..b1ff975d4dae9 100644 --- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp +++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp @@ -76,9 +76,9 @@ void ModuleShaderFlags::updateFunctionFlags(ComputedShaderFlags &CSF, if (auto *CI = dyn_cast(&I)) { const Function *CF = CI->getCalledFunction(); // Merge-in shader flags mask of the called function in the current module - if (FunctionFlags.contains(CF)) { + if (FunctionFlags.contains(CF)) CSF.merge(FunctionFlags[CF]); - } + // TODO: Set DX11_1_DoubleExtensions if I is a call to DXIL intrinsic // DXIL::Opcode::Fma https://github.com/llvm/llvm-project/issues/114554 }