Skip to content

Commit d0d4a06

Browse files
authored
API for renderer to cache OptiX ptx (#1938)
This change adds a caching layer to the OptiX PTX pipeline that can skip LLVM generation and optimization. This helps improve scene build times for re-renders of scenes with large shader counts. We generate a hash key from the optimized shadergroup, and depend on the renderer to provide a cache backend implementation. There is no overhead if the renderer doesn't explicitly opt-in to ptx caching. A simple backend will be added to testshade/testrender in a follow-up PR. As background information: The sequence at runtime is group oso ->1-> runtime optimize by liboslexec ->2-> JIT to PTX via LLVM ->3-> driver converts PTX to actual executable GPU code on that hardware. The "OptiX Cache" (part of OptiX & driver) speed up step (3) by not having the last step for optimized/JITed shaders it's encountered before. This PR adds another cache to step (2), managed by OSL and/or the renderer internals, to allow you to skip the bulk of the work for that step for optimized shaders you've encountered already. You still pay full price the very first time a shader is encountered, and that leads to terrible TTFP (time to first pixel). But this should take a big bite out of that in practice since it's very common to have encountered most shader configuration before. If the implementation the renderer provides is to store the cache persistently on disk or in a real database, it will be shared from run to run and possibly from user to user. --------- Signed-off-by: Chris Hellmuth <[email protected]>
1 parent 4ce094f commit d0d4a06

File tree

10 files changed

+160
-37
lines changed

10 files changed

+160
-37
lines changed

src/include/OSL/rendererservices.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,19 @@ class OSLEXECPUBLIC RendererServices {
580580
}
581581
};
582582

583+
// Default no-op implementations of the caching api.
584+
// Currently used for caching optix ptx before llvm generation.
585+
virtual void cache_insert(string_view cachename, string_view key,
586+
string_view value) const
587+
{
588+
}
589+
590+
virtual bool cache_get(string_view cachename, string_view key,
591+
std::string& value) const
592+
{
593+
return false;
594+
}
595+
583596
/// A renderer may choose to support batched execution by providing pointers
584597
/// to objects satisfying the BatchedRendererServices<WidthOf<#>> interface
585598
/// for specific batch sizes.

src/liboslexec/backendllvm.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,7 @@ class BackendLLVM final : public OSOProcessorBase {
524524

525525
/// Return whether or not we are compiling for an OptiX-based renderer.
526526
bool use_optix() { return m_use_optix; }
527+
bool use_optix_cache() { return shadingsys().use_optix_cache(); }
527528

528529
/// Return if we should compile against free function versions of Renderer Service.
529530
bool use_rs_bitcode() { return m_use_rs_bitcode; }

src/liboslexec/instance.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,25 @@ ShaderGroup::setup_interactive_arena(cspan<uint8_t> paramblock)
848848

849849

850850

851+
void
852+
ShaderGroup::generate_optix_cache_key(string_view code)
853+
{
854+
const uint64_t ir_key = Strutil::strhash(code);
855+
856+
std::string safegroup;
857+
safegroup = Strutil::replace(name(), "/", "_", true);
858+
safegroup = Strutil::replace(safegroup, ":", "_", true);
859+
860+
// Cache key includes the groupname in addition to the serialized IR.
861+
// This is because the groupname makes its way into the ptx's direct callable name,
862+
// but isn't included in the serialization.
863+
std::string cache_key = fmtformat("cache-osl-ptx-{}-{}", safegroup, ir_key);
864+
865+
m_optix_cache_key = cache_key;
866+
}
867+
868+
869+
851870
std::string
852871
ShaderGroup::serialize() const
853872
{

src/liboslexec/llvm_gen.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -426,10 +426,13 @@ LLVMGEN(llvm_gen_printf_legacy)
426426
}
427427
#endif
428428

429-
// Some ops prepend things
430-
if (op.opname() == op_error || op.opname() == op_warning) {
431-
s = fmtformat("Shader {} [{}]: {}", op.opname(),
432-
rop.inst()->shadername(), s);
429+
// TODO: optix cache should handle ustrings generated during llvm-gen
430+
if (!rop.use_optix_cache()) {
431+
// Some ops prepend things
432+
if (op.opname() == op_error || op.opname() == op_warning) {
433+
s = fmtformat("Shader {} [{}]: {}", op.opname(),
434+
rop.inst()->shadername(), s);
435+
}
433436
}
434437

435438
// Now go back and put the new format string in its place
@@ -709,10 +712,12 @@ LLVMGEN(llvm_gen_print_fmt)
709712
}
710713
}
711714
}
712-
// Some ops prepend things
713-
if (op.opname() == op_error || op.opname() == op_warning) {
714-
s = fmtformat("Shader {} [{}]: {}", op.opname(),
715-
rop.inst()->shadername(), s);
715+
if (!rop.use_optix_cache()) {
716+
// Some ops prepend things
717+
if (op.opname() == op_error || op.opname() == op_warning) {
718+
s = fmtformat("Shader {} [{}]: {}", op.opname(),
719+
rop.inst()->shadername(), s);
720+
}
716721
}
717722
ustring s_ustring(s.c_str());
718723
call_args.push_back(rop.llvm_const_hash(s_ustring));

src/liboslexec/llvm_instance.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2481,6 +2481,14 @@ BackendLLVM::run()
24812481
group().llvm_compiled_layer(nlayers - 1));
24822482
}
24832483

2484+
if (shadingsys().use_optix_cache()) {
2485+
std::string cache_key = group().optix_cache_key();
2486+
renderer()->cache_insert(
2487+
"optix_ptx", cache_key,
2488+
optix_cache_wrap(group().m_llvm_ptx_compiled_version,
2489+
group().llvm_groupdata_size()));
2490+
}
2491+
24842492
// We are destroying the entire module below,
24852493
// no reason to bother destroying individual functions
24862494
#if 0

src/liboslexec/oslexec.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,32 @@ shadertype_from_name(string_view name)
5151
}
5252

5353

54+
55+
std::string
56+
optix_cache_wrap(string_view ptx, size_t groupdata_size)
57+
{
58+
// Cache string is the ptx file with groupdata size on top as a comment.
59+
// This way the cache string is a valid ptx program, which can be useful
60+
// for debugging.
61+
return fmtformat("// {}\n{}", groupdata_size, ptx);
62+
}
63+
64+
65+
66+
void
67+
optix_cache_unwrap(string_view cache_value, std::string& ptx,
68+
size_t& groupdata_size)
69+
{
70+
size_t groupdata_end_index = cache_value.find('\n');
71+
if (groupdata_end_index != std::string::npos) {
72+
constexpr int offset = 3; // Account for the "// " prefix
73+
std::string groupdata_string
74+
= cache_value.substr(offset, groupdata_end_index - offset);
75+
groupdata_size = std::stoll(groupdata_string);
76+
77+
ptx = cache_value.substr(groupdata_end_index + 1);
78+
}
79+
}
80+
5481
}; // namespace pvt
5582
OSL_NAMESPACE_END

src/liboslexec/oslexec_pvt.h

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ struct PerThreadInfo {
8181

8282
namespace pvt {
8383

84+
void
85+
optix_cache_unwrap(string_view cache_value, std::string& ptx,
86+
size_t& groupdata_size);
87+
std::string
88+
optix_cache_wrap(string_view ptx, size_t groupdata_size);
89+
8490
// forward definitions
8591
class ShadingSystemImpl;
8692
class ShaderInstance;
@@ -632,6 +638,7 @@ class ShadingSystemImpl {
632638
TextureSystem* texturesys() const { return m_texturesys; }
633639

634640
bool use_optix() const { return m_use_optix; }
641+
bool use_optix_cache() const { return m_use_optix_cache; }
635642
bool debug_nan() const { return m_debugnan; }
636643
bool debug_uninit() const { return m_debug_uninit; }
637644
bool lockgeom_default() const { return m_lockgeom_default; }
@@ -954,9 +961,10 @@ class ShadingSystemImpl {
954961
std::vector<ustring> m_raytypes; ///< Names of ray types
955962
std::vector<ustring> m_renderer_outputs; ///< Names of renderer outputs
956963
std::vector<SymLocationDesc> m_symlocs;
957-
int m_max_local_mem_KB; ///< Local storage can a shader use
958-
int m_compile_report; ///< Print compilation report?
959-
bool m_use_optix; ///< This is an OptiX-based renderer
964+
int m_max_local_mem_KB; ///< Local storage can a shader use
965+
int m_compile_report; ///< Print compilation report?
966+
bool m_use_optix; ///< This is an OptiX-based renderer
967+
bool m_use_optix_cache; ///< Renderer-enabled caching for OptiX ptx
960968
int m_max_optix_groupdata_alloc; ///< Maximum OptiX groupdata buffer allocation
961969
bool m_buffer_printf; ///< Buffer/batch printf output?
962970
bool m_no_noise; ///< Substitute trivial noise calls
@@ -1843,6 +1851,10 @@ class ShaderGroup {
18431851
void name(ustring name) { m_name = name; }
18441852
ustring name() const { return m_name; }
18451853

1854+
// Generate and memoize the cache key so we don't calculate it twice
1855+
void generate_optix_cache_key(string_view code);
1856+
std::string optix_cache_key() const { return m_optix_cache_key; }
1857+
18461858
std::string serialize() const;
18471859

18481860
void lock() const { m_mutex.lock(); }
@@ -2046,6 +2058,8 @@ class ShaderGroup {
20462058
atomic_ll m_executions { 0 }; ///< Number of times the group executed
20472059
atomic_ll m_stat_total_shading_time_ticks { 0 }; // Shading time (ticks)
20482060

2061+
std::string m_optix_cache_key;
2062+
20492063
// PTX assembly for compiled ShaderGroup
20502064
std::string m_llvm_ptx_compiled_version;
20512065

src/liboslexec/runtimeoptimize.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3124,6 +3124,21 @@ RuntimeOptimizer::printinst(std::ostream& out) const
31243124

31253125

31263126

3127+
std::string
3128+
RuntimeOptimizer::serialize()
3129+
{
3130+
std::ostringstream ss {};
3131+
int nlayers = (int)group().nlayers();
3132+
for (int layer = 0; layer < nlayers; ++layer) {
3133+
set_inst(layer);
3134+
printinst(ss);
3135+
}
3136+
3137+
return ss.str();
3138+
}
3139+
3140+
3141+
31273142
void
31283143
RuntimeOptimizer::run()
31293144
{

src/liboslexec/runtimeoptimize.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,8 @@ class RuntimeOptimizer final : public OSOProcessorBase {
460460
fmtformat(fmt, std::forward<Args>(args)...));
461461
}
462462

463+
std::string serialize();
464+
463465
private:
464466
int m_optimize; ///< Current optimization level
465467
bool m_opt_simplify_param; ///< Turn instance params into const?

src/liboslexec/shadingsys.cpp

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,6 +1129,7 @@ ShadingSystemImpl::ShadingSystemImpl(RendererServices* renderer,
11291129
, m_max_local_mem_KB(2048)
11301130
, m_compile_report(0)
11311131
, m_use_optix(renderer->supports("OptiX"))
1132+
, m_use_optix_cache(m_use_optix && renderer->supports("optix_ptx_cache"))
11321133
, m_max_optix_groupdata_alloc(0)
11331134
, m_buffer_printf(true)
11341135
, m_no_noise(false)
@@ -3801,6 +3802,9 @@ ShadingSystemImpl::optimize_group(ShaderGroup& group, ShadingContext* ctx,
38013802
}
38023803
group.m_optimized = true;
38033804

3805+
if (use_optix_cache())
3806+
group.generate_optix_cache_key(rop.serialize());
3807+
38043808
spin_lock stat_lock(m_stat_mutex);
38053809
if (!need_jit) {
38063810
m_stat_opt_locking_time += locking_time;
@@ -3812,34 +3816,49 @@ ShadingSystemImpl::optimize_group(ShaderGroup& group, ShadingContext* ctx,
38123816
}
38133817

38143818
if (need_jit) {
3815-
BackendLLVM lljitter(*this, group, ctx);
3816-
lljitter.run();
3817-
3818-
// NOTE: it is now possible to optimize and not JIT
3819-
// which would leave the cleanup to happen
3820-
// when the ShadingSystem is destroyed
3821-
3822-
// Only cleanup when are not batching or if
3823-
// the batch jit has already happened,
3824-
// as it requires the ops so we can't delete them yet!
3825-
if (((renderer()->batched(WidthOf<16>()) == nullptr)
3826-
&& (renderer()->batched(WidthOf<8>()) == nullptr)
3827-
&& (renderer()->batched(WidthOf<4>()) == nullptr))
3828-
|| group.batch_jitted()) {
3829-
group_post_jit_cleanup(group);
3819+
bool cached = false;
3820+
if (use_optix_cache()) {
3821+
std::string cache_key = group.optix_cache_key();
3822+
3823+
std::string cache_value;
3824+
if (renderer()->cache_get("optix_ptx", cache_key, cache_value)) {
3825+
cached = true;
3826+
optix_cache_unwrap(cache_value,
3827+
group.m_llvm_ptx_compiled_version,
3828+
group.m_llvm_groupdata_size);
3829+
}
38303830
}
38313831

3832-
group.m_jitted = true;
3833-
spin_lock stat_lock(m_stat_mutex);
3834-
m_stat_opt_locking_time += locking_time;
3835-
m_stat_optimization_time += timer();
3836-
m_stat_total_llvm_time += lljitter.m_stat_total_llvm_time;
3837-
m_stat_llvm_setup_time += lljitter.m_stat_llvm_setup_time;
3838-
m_stat_llvm_irgen_time += lljitter.m_stat_llvm_irgen_time;
3839-
m_stat_llvm_opt_time += lljitter.m_stat_llvm_opt_time;
3840-
m_stat_llvm_jit_time += lljitter.m_stat_llvm_jit_time;
3841-
m_stat_max_llvm_local_mem = std::max(m_stat_max_llvm_local_mem,
3842-
lljitter.m_llvm_local_mem);
3832+
if (!cached) {
3833+
BackendLLVM lljitter(*this, group, ctx);
3834+
lljitter.run();
3835+
3836+
// NOTE: it is now possible to optimize and not JIT
3837+
// which would leave the cleanup to happen
3838+
// when the ShadingSystem is destroyed
3839+
3840+
// Only cleanup when are not batching or if
3841+
// the batch jit has already happened,
3842+
// as it requires the ops so we can't delete them yet!
3843+
if (((renderer()->batched(WidthOf<16>()) == nullptr)
3844+
&& (renderer()->batched(WidthOf<8>()) == nullptr)
3845+
&& (renderer()->batched(WidthOf<4>()) == nullptr))
3846+
|| group.batch_jitted()) {
3847+
group_post_jit_cleanup(group);
3848+
}
3849+
3850+
group.m_jitted = true;
3851+
spin_lock stat_lock(m_stat_mutex);
3852+
m_stat_opt_locking_time += locking_time;
3853+
m_stat_optimization_time += timer();
3854+
m_stat_total_llvm_time += lljitter.m_stat_total_llvm_time;
3855+
m_stat_llvm_setup_time += lljitter.m_stat_llvm_setup_time;
3856+
m_stat_llvm_irgen_time += lljitter.m_stat_llvm_irgen_time;
3857+
m_stat_llvm_opt_time += lljitter.m_stat_llvm_opt_time;
3858+
m_stat_llvm_jit_time += lljitter.m_stat_llvm_jit_time;
3859+
m_stat_max_llvm_local_mem = std::max(m_stat_max_llvm_local_mem,
3860+
lljitter.m_llvm_local_mem);
3861+
}
38433862
}
38443863

38453864
if (ctx_allocated) {

0 commit comments

Comments
 (0)