Skip to content

Commit 108a65e

Browse files
chellmuthlgritz
authored andcommitted
API for renderer to cache OptiX ptx (#1938)
This change adds a caching layer to the OptiX PTX pipeline that can skip LLVM generation and optimization. This helps improve scene build times for re-renders of scenes with large shader counts. We generate a hash key from the optimized shadergroup, and depend on the renderer to provide a cache backend implementation. There is no overhead if the renderer doesn't explicitly opt-in to ptx caching. A simple backend will be added to testshade/testrender in a follow-up PR. As background information: The sequence at runtime is group oso ->1-> runtime optimize by liboslexec ->2-> JIT to PTX via LLVM ->3-> driver converts PTX to actual executable GPU code on that hardware. The "OptiX Cache" (part of OptiX & driver) speed up step (3) by not having the last step for optimized/JITed shaders it's encountered before. This PR adds another cache to step (2), managed by OSL and/or the renderer internals, to allow you to skip the bulk of the work for that step for optimized shaders you've encountered already. You still pay full price the very first time a shader is encountered, and that leads to terrible TTFP (time to first pixel). But this should take a big bite out of that in practice since it's very common to have encountered most shader configuration before. If the implementation the renderer provides is to store the cache persistently on disk or in a real database, it will be shared from run to run and possibly from user to user. --------- Signed-off-by: Chris Hellmuth <chellmuth@gmail.com>
1 parent c34f6b0 commit 108a65e

File tree

10 files changed

+160
-37
lines changed

10 files changed

+160
-37
lines changed

src/include/OSL/rendererservices.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,19 @@ class OSLEXECPUBLIC RendererServices {
595595
}
596596
};
597597

598+
// Default no-op implementations of the caching api.
599+
// Currently used for caching optix ptx before llvm generation.
600+
virtual void cache_insert(string_view cachename, string_view key,
601+
string_view value) const
602+
{
603+
}
604+
605+
virtual bool cache_get(string_view cachename, string_view key,
606+
std::string& value) const
607+
{
608+
return false;
609+
}
610+
598611
/// A renderer may choose to support batched execution by providing pointers
599612
/// to objects satisfying the BatchedRendererServices<WidthOf<#>> interface
600613
/// for specific batch sizes.

src/liboslexec/backendllvm.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,7 @@ class BackendLLVM final : public OSOProcessorBase {
500500

501501
/// Return whether or not we are compiling for an OptiX-based renderer.
502502
bool use_optix() { return m_use_optix; }
503+
bool use_optix_cache() { return shadingsys().use_optix_cache(); }
503504

504505
/// Return if we should compile against free function versions of Renderer Service.
505506
bool use_rs_bitcode() { return m_use_rs_bitcode; }

src/liboslexec/instance.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,25 @@ ShaderGroup::setup_interactive_arena(cspan<uint8_t> paramblock)
848848

849849

850850

851+
void
852+
ShaderGroup::generate_optix_cache_key(string_view code)
853+
{
854+
const uint64_t ir_key = Strutil::strhash(code);
855+
856+
std::string safegroup;
857+
safegroup = Strutil::replace(name(), "/", "_", true);
858+
safegroup = Strutil::replace(safegroup, ":", "_", true);
859+
860+
// Cache key includes the groupname in addition to the serialized IR.
861+
// This is because the groupname makes its way into the ptx's direct callable name,
862+
// but isn't included in the serialization.
863+
std::string cache_key = fmtformat("cache-osl-ptx-{}-{}", safegroup, ir_key);
864+
865+
m_optix_cache_key = cache_key;
866+
}
867+
868+
869+
851870
std::string
852871
ShaderGroup::serialize() const
853872
{

src/liboslexec/llvm_gen.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -426,10 +426,13 @@ LLVMGEN(llvm_gen_printf_legacy)
426426
}
427427
#endif
428428

429-
// Some ops prepend things
430-
if (op.opname() == op_error || op.opname() == op_warning) {
431-
s = fmtformat("Shader {} [{}]: {}", op.opname(),
432-
rop.inst()->shadername(), s);
429+
// TODO: optix cache should handle ustrings generated during llvm-gen
430+
if (!rop.use_optix_cache()) {
431+
// Some ops prepend things
432+
if (op.opname() == op_error || op.opname() == op_warning) {
433+
s = fmtformat("Shader {} [{}]: {}", op.opname(),
434+
rop.inst()->shadername(), s);
435+
}
433436
}
434437

435438
// Now go back and put the new format string in its place
@@ -709,10 +712,12 @@ LLVMGEN(llvm_gen_print_fmt)
709712
}
710713
}
711714
}
712-
// Some ops prepend things
713-
if (op.opname() == op_error || op.opname() == op_warning) {
714-
s = fmtformat("Shader {} [{}]: {}", op.opname(),
715-
rop.inst()->shadername(), s);
715+
if (!rop.use_optix_cache()) {
716+
// Some ops prepend things
717+
if (op.opname() == op_error || op.opname() == op_warning) {
718+
s = fmtformat("Shader {} [{}]: {}", op.opname(),
719+
rop.inst()->shadername(), s);
720+
}
716721
}
717722
ustring s_ustring(s.c_str());
718723
call_args.push_back(rop.llvm_const_hash(s_ustring));

src/liboslexec/llvm_instance.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2281,6 +2281,14 @@ BackendLLVM::run()
22812281
group().llvm_compiled_layer(nlayers - 1));
22822282
}
22832283

2284+
if (shadingsys().use_optix_cache()) {
2285+
std::string cache_key = group().optix_cache_key();
2286+
renderer()->cache_insert(
2287+
"optix_ptx", cache_key,
2288+
optix_cache_wrap(group().m_llvm_ptx_compiled_version,
2289+
group().llvm_groupdata_size()));
2290+
}
2291+
22842292
// We are destroying the entire module below,
22852293
// no reason to bother destroying individual functions
22862294
#if 0

src/liboslexec/oslexec.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,32 @@ shadertype_from_name(string_view name)
5151
}
5252

5353

54+
55+
std::string
56+
optix_cache_wrap(string_view ptx, size_t groupdata_size)
57+
{
58+
// Cache string is the ptx file with groupdata size on top as a comment.
59+
// This way the cache string is a valid ptx program, which can be useful
60+
// for debugging.
61+
return fmtformat("// {}\n{}", groupdata_size, ptx);
62+
}
63+
64+
65+
66+
void
67+
optix_cache_unwrap(string_view cache_value, std::string& ptx,
68+
size_t& groupdata_size)
69+
{
70+
size_t groupdata_end_index = cache_value.find('\n');
71+
if (groupdata_end_index != std::string::npos) {
72+
constexpr int offset = 3; // Account for the "// " prefix
73+
std::string groupdata_string
74+
= cache_value.substr(offset, groupdata_end_index - offset);
75+
groupdata_size = std::stoll(groupdata_string);
76+
77+
ptx = cache_value.substr(groupdata_end_index + 1);
78+
}
79+
}
80+
5481
}; // namespace pvt
5582
OSL_NAMESPACE_EXIT

src/liboslexec/oslexec_pvt.h

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,12 @@ struct PerThreadInfo {
7878

7979
namespace pvt {
8080

81+
void
82+
optix_cache_unwrap(string_view cache_value, std::string& ptx,
83+
size_t& groupdata_size);
84+
std::string
85+
optix_cache_wrap(string_view ptx, size_t groupdata_size);
86+
8187
// forward definitions
8288
class ShadingSystemImpl;
8389
class ShaderInstance;
@@ -627,6 +633,7 @@ class ShadingSystemImpl {
627633
TextureSystem* texturesys() const { return m_texturesys; }
628634

629635
bool use_optix() const { return m_use_optix; }
636+
bool use_optix_cache() const { return m_use_optix_cache; }
630637
bool debug_nan() const { return m_debugnan; }
631638
bool debug_uninit() const { return m_debug_uninit; }
632639
bool lockgeom_default() const { return m_lockgeom_default; }
@@ -940,9 +947,10 @@ class ShadingSystemImpl {
940947
std::vector<ustring> m_raytypes; ///< Names of ray types
941948
std::vector<ustring> m_renderer_outputs; ///< Names of renderer outputs
942949
std::vector<SymLocationDesc> m_symlocs;
943-
int m_max_local_mem_KB; ///< Local storage can a shader use
944-
int m_compile_report; ///< Print compilation report?
945-
bool m_use_optix; ///< This is an OptiX-based renderer
950+
int m_max_local_mem_KB; ///< Local storage can a shader use
951+
int m_compile_report; ///< Print compilation report?
952+
bool m_use_optix; ///< This is an OptiX-based renderer
953+
bool m_use_optix_cache; ///< Renderer-enabled caching for OptiX ptx
946954
int m_max_optix_groupdata_alloc; ///< Maximum OptiX groupdata buffer allocation
947955
bool m_buffer_printf; ///< Buffer/batch printf output?
948956
bool m_no_noise; ///< Substitute trivial noise calls
@@ -1829,6 +1837,10 @@ class ShaderGroup {
18291837
void name(ustring name) { m_name = name; }
18301838
ustring name() const { return m_name; }
18311839

1840+
// Generate and memoize the cache key so we don't calculate it twice
1841+
void generate_optix_cache_key(string_view code);
1842+
std::string optix_cache_key() const { return m_optix_cache_key; }
1843+
18321844
std::string serialize() const;
18331845

18341846
void lock() const { m_mutex.lock(); }
@@ -2016,6 +2028,8 @@ class ShaderGroup {
20162028
atomic_ll m_executions { 0 }; ///< Number of times the group executed
20172029
atomic_ll m_stat_total_shading_time_ticks { 0 }; // Shading time (ticks)
20182030

2031+
std::string m_optix_cache_key;
2032+
20192033
// PTX assembly for compiled ShaderGroup
20202034
std::string m_llvm_ptx_compiled_version;
20212035

src/liboslexec/runtimeoptimize.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3116,6 +3116,21 @@ RuntimeOptimizer::printinst(std::ostream& out) const
31163116

31173117

31183118

3119+
std::string
3120+
RuntimeOptimizer::serialize()
3121+
{
3122+
std::ostringstream ss {};
3123+
int nlayers = (int)group().nlayers();
3124+
for (int layer = 0; layer < nlayers; ++layer) {
3125+
set_inst(layer);
3126+
printinst(ss);
3127+
}
3128+
3129+
return ss.str();
3130+
}
3131+
3132+
3133+
31193134
void
31203135
RuntimeOptimizer::run()
31213136
{

src/liboslexec/runtimeoptimize.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,8 @@ class RuntimeOptimizer final : public OSOProcessorBase {
460460
fmtformat(fmt, std::forward<Args>(args)...));
461461
}
462462

463+
std::string serialize();
464+
463465
private:
464466
int m_optimize; ///< Current optimization level
465467
bool m_opt_simplify_param; ///< Turn instance params into const?

src/liboslexec/shadingsys.cpp

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,6 +1129,7 @@ ShadingSystemImpl::ShadingSystemImpl(RendererServices* renderer,
11291129
, m_max_local_mem_KB(2048)
11301130
, m_compile_report(0)
11311131
, m_use_optix(renderer->supports("OptiX"))
1132+
, m_use_optix_cache(m_use_optix && renderer->supports("optix_ptx_cache"))
11321133
, m_max_optix_groupdata_alloc(0)
11331134
, m_buffer_printf(true)
11341135
, m_no_noise(false)
@@ -3812,6 +3813,9 @@ ShadingSystemImpl::optimize_group(ShaderGroup& group, ShadingContext* ctx,
38123813
}
38133814
group.m_optimized = true;
38143815

3816+
if (use_optix_cache())
3817+
group.generate_optix_cache_key(rop.serialize());
3818+
38153819
spin_lock stat_lock(m_stat_mutex);
38163820
if (!need_jit) {
38173821
m_stat_opt_locking_time += locking_time;
@@ -3823,34 +3827,49 @@ ShadingSystemImpl::optimize_group(ShaderGroup& group, ShadingContext* ctx,
38233827
}
38243828

38253829
if (need_jit) {
3826-
BackendLLVM lljitter(*this, group, ctx);
3827-
lljitter.run();
3828-
3829-
// NOTE: it is now possible to optimize and not JIT
3830-
// which would leave the cleanup to happen
3831-
// when the ShadingSystem is destroyed
3832-
3833-
// Only cleanup when are not batching or if
3834-
// the batch jit has already happened,
3835-
// as it requires the ops so we can't delete them yet!
3836-
if (((renderer()->batched(WidthOf<16>()) == nullptr)
3837-
&& (renderer()->batched(WidthOf<8>()) == nullptr)
3838-
&& (renderer()->batched(WidthOf<4>()) == nullptr))
3839-
|| group.batch_jitted()) {
3840-
group_post_jit_cleanup(group);
3830+
bool cached = false;
3831+
if (use_optix_cache()) {
3832+
std::string cache_key = group.optix_cache_key();
3833+
3834+
std::string cache_value;
3835+
if (renderer()->cache_get("optix_ptx", cache_key, cache_value)) {
3836+
cached = true;
3837+
optix_cache_unwrap(cache_value,
3838+
group.m_llvm_ptx_compiled_version,
3839+
group.m_llvm_groupdata_size);
3840+
}
38413841
}
38423842

3843-
group.m_jitted = true;
3844-
spin_lock stat_lock(m_stat_mutex);
3845-
m_stat_opt_locking_time += locking_time;
3846-
m_stat_optimization_time += timer();
3847-
m_stat_total_llvm_time += lljitter.m_stat_total_llvm_time;
3848-
m_stat_llvm_setup_time += lljitter.m_stat_llvm_setup_time;
3849-
m_stat_llvm_irgen_time += lljitter.m_stat_llvm_irgen_time;
3850-
m_stat_llvm_opt_time += lljitter.m_stat_llvm_opt_time;
3851-
m_stat_llvm_jit_time += lljitter.m_stat_llvm_jit_time;
3852-
m_stat_max_llvm_local_mem = std::max(m_stat_max_llvm_local_mem,
3853-
lljitter.m_llvm_local_mem);
3843+
if (!cached) {
3844+
BackendLLVM lljitter(*this, group, ctx);
3845+
lljitter.run();
3846+
3847+
// NOTE: it is now possible to optimize and not JIT
3848+
// which would leave the cleanup to happen
3849+
// when the ShadingSystem is destroyed
3850+
3851+
// Only cleanup when are not batching or if
3852+
// the batch jit has already happened,
3853+
// as it requires the ops so we can't delete them yet!
3854+
if (((renderer()->batched(WidthOf<16>()) == nullptr)
3855+
&& (renderer()->batched(WidthOf<8>()) == nullptr)
3856+
&& (renderer()->batched(WidthOf<4>()) == nullptr))
3857+
|| group.batch_jitted()) {
3858+
group_post_jit_cleanup(group);
3859+
}
3860+
3861+
group.m_jitted = true;
3862+
spin_lock stat_lock(m_stat_mutex);
3863+
m_stat_opt_locking_time += locking_time;
3864+
m_stat_optimization_time += timer();
3865+
m_stat_total_llvm_time += lljitter.m_stat_total_llvm_time;
3866+
m_stat_llvm_setup_time += lljitter.m_stat_llvm_setup_time;
3867+
m_stat_llvm_irgen_time += lljitter.m_stat_llvm_irgen_time;
3868+
m_stat_llvm_opt_time += lljitter.m_stat_llvm_opt_time;
3869+
m_stat_llvm_jit_time += lljitter.m_stat_llvm_jit_time;
3870+
m_stat_max_llvm_local_mem = std::max(m_stat_max_llvm_local_mem,
3871+
lljitter.m_llvm_local_mem);
3872+
}
38543873
}
38553874

38563875
if (ctx_allocated) {

0 commit comments

Comments
 (0)