Skip to content

Commit 9a07e06

Browse files
committed
API for renderer to cache OptiX ptx
Signed-off-by: Chris Hellmuth <chellmuth@gmail.com>
1 parent df49a40 commit 9a07e06

File tree

6 files changed

+126
-27
lines changed

6 files changed

+126
-27
lines changed

src/include/OSL/rendererservices.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,19 @@ class OSLEXECPUBLIC RendererServices {
595595
}
596596
};
597597

598+
// Default no-op implementations of the caching api.
599+
// Currently used for caching optix ptx before llvm generation.
600+
virtual void cache_insert(string_view cachename, string_view key,
601+
string_view value) const
602+
{
603+
}
604+
605+
virtual bool cache_get(string_view cachename, string_view key,
606+
std::string& value) const
607+
{
608+
return false;
609+
}
610+
598611
/// A renderer may choose to support batched execution by providing pointers
599612
/// to objects satisfying the BatchedRendererServices<WidthOf<#>> interface
600613
/// for specific batch sizes.

src/liboslexec/instance.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -847,14 +847,40 @@ ShaderGroup::setup_interactive_arena(cspan<uint8_t> paramblock)
847847
}
848848

849849

850+
std::string
851+
ShaderGroup::generate_optix_cache_key()
852+
{
853+
// Call serialize_internal() because caller optimize_group already
854+
// has the lock.
855+
const uint64_t ir_key = Strutil::strhash(serialize_internal());
856+
857+
std::string safegroup;
858+
safegroup = Strutil::replace(name(), "/", "_", true);
859+
safegroup = Strutil::replace(safegroup, ":", "_", true);
860+
861+
// Cache key includes the groupname in addition to the serialized IR.
862+
// This is because the groupname makes its way into the ptx's direct callable name,
863+
// but isn't included in the serialization.
864+
std::string cache_key = fmtformat("cache-osl-ptx-{}-{}", safegroup, ir_key);
865+
866+
m_optix_cache_key = cache_key;
867+
return m_optix_cache_key;
868+
}
850869

851870
std::string
852871
ShaderGroup::serialize() const
872+
{
873+
lock_guard lock(m_mutex);
874+
return serialize_internal();
875+
}
876+
877+
std::string
878+
ShaderGroup::serialize_internal() const
853879
{
854880
std::ostringstream out;
855881
out.imbue(std::locale::classic()); // force C locale
856882
out.precision(9);
857-
lock_guard lock(m_mutex);
883+
858884
for (int i = 0, nl = nlayers(); i < nl; ++i) {
859885
const ShaderInstance* inst = m_layers[i].get();
860886

src/liboslexec/llvm_instance.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2281,6 +2281,14 @@ BackendLLVM::run()
22812281
group().llvm_compiled_layer(nlayers - 1));
22822282
}
22832283

2284+
if (use_optix() && renderer()->supports("optix_ptx_cache")) {
2285+
std::string cache_key = group().optix_cache_key();
2286+
renderer()->cache_insert(
2287+
"optix_ptx", cache_key,
2288+
optix_cache_wrap(group().m_llvm_ptx_compiled_version,
2289+
group().llvm_groupdata_size()));
2290+
}
2291+
22842292
// We are destroying the entire module below,
22852293
// no reason to bother destroying individual functions
22862294
#if 0

src/liboslexec/oslexec.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,29 @@ shadertype_from_name(string_view name)
5050
return ShaderType::Unknown;
5151
}
5252

53+
std::string
54+
optix_cache_wrap(const std::string& ptx, size_t groupdata_size)
55+
{
56+
// Cache string is the ptx file with groupdata size on top as a comment.
57+
// This way the cache string is a valid ptx program, which can be useful
58+
// for debugging.
59+
return fmtformat("// {}\n{}", groupdata_size, ptx);
60+
}
61+
62+
void
63+
optix_cache_unwrap(const std::string& cache_value, std::string& ptx,
64+
size_t& groupdata_size)
65+
{
66+
size_t groupdata_end_index = cache_value.find('\n');
67+
if (groupdata_end_index != std::string::npos) {
68+
constexpr int offset = 3; // Account for the "// " prefix
69+
std::string groupdata_string
70+
= cache_value.substr(offset, groupdata_end_index - offset);
71+
groupdata_size = std::stoll(groupdata_string);
72+
73+
ptx = cache_value.substr(groupdata_end_index + 1);
74+
}
75+
}
5376

5477
}; // namespace pvt
5578
OSL_NAMESPACE_EXIT

src/liboslexec/oslexec_pvt.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,12 @@ struct PerThreadInfo {
7878

7979
namespace pvt {
8080

81+
void
82+
optix_cache_unwrap(const std::string& cache_value, std::string& ptx,
83+
size_t& groupdata_size);
84+
std::string
85+
optix_cache_wrap(const std::string& ptx, size_t groupdata_size);
86+
8187
// forward definitions
8288
class ShadingSystemImpl;
8389
class ShaderInstance;
@@ -1829,6 +1835,10 @@ class ShaderGroup {
18291835
void name(ustring name) { m_name = name; }
18301836
ustring name() const { return m_name; }
18311837

1838+
// Generate and memoize the cache key so we don't calculate it twice
1839+
std::string generate_optix_cache_key();
1840+
std::string optix_cache_key() const { return m_optix_cache_key; }
1841+
18321842
std::string serialize() const;
18331843

18341844
void lock() const { m_mutex.lock(); }
@@ -1965,6 +1975,8 @@ class ShaderGroup {
19651975
}
19661976

19671977
private:
1978+
std::string serialize_internal() const;
1979+
19681980
// Put all the things that are read-only (after optimization) and
19691981
// needed on every shade execution at the front of the struct, as much
19701982
// together on one cache line as possible.
@@ -2016,6 +2028,8 @@ class ShaderGroup {
20162028
atomic_ll m_executions { 0 }; ///< Number of times the group executed
20172029
atomic_ll m_stat_total_shading_time_ticks { 0 }; // Shading time (ticks)
20182030

2031+
std::string m_optix_cache_key;
2032+
20192033
// PTX assembly for compiled ShaderGroup
20202034
std::string m_llvm_ptx_compiled_version;
20212035

src/liboslexec/shadingsys.cpp

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3823,34 +3823,49 @@ ShadingSystemImpl::optimize_group(ShaderGroup& group, ShadingContext* ctx,
38233823
}
38243824

38253825
if (need_jit) {
3826-
BackendLLVM lljitter(*this, group, ctx);
3827-
lljitter.run();
3828-
3829-
// NOTE: it is now possible to optimize and not JIT
3830-
// which would leave the cleanup to happen
3831-
// when the ShadingSystem is destroyed
3832-
3833-
// Only cleanup when are not batching or if
3834-
// the batch jit has already happened,
3835-
// as it requires the ops so we can't delete them yet!
3836-
if (((renderer()->batched(WidthOf<16>()) == nullptr)
3837-
&& (renderer()->batched(WidthOf<8>()) == nullptr)
3838-
&& (renderer()->batched(WidthOf<4>()) == nullptr))
3839-
|| group.batch_jitted()) {
3840-
group_post_jit_cleanup(group);
3826+
bool cached = false;
3827+
if (use_optix() && renderer()->supports("optix_ptx_cache")) {
3828+
std::string cache_key = group.generate_optix_cache_key();
3829+
3830+
std::string cache_value;
3831+
if (renderer()->cache_get("optix_ptx", cache_key, cache_value)) {
3832+
cached = true;
3833+
optix_cache_unwrap(cache_value,
3834+
group.m_llvm_ptx_compiled_version,
3835+
group.m_llvm_groupdata_size);
3836+
}
38413837
}
38423838

3843-
group.m_jitted = true;
3844-
spin_lock stat_lock(m_stat_mutex);
3845-
m_stat_opt_locking_time += locking_time;
3846-
m_stat_optimization_time += timer();
3847-
m_stat_total_llvm_time += lljitter.m_stat_total_llvm_time;
3848-
m_stat_llvm_setup_time += lljitter.m_stat_llvm_setup_time;
3849-
m_stat_llvm_irgen_time += lljitter.m_stat_llvm_irgen_time;
3850-
m_stat_llvm_opt_time += lljitter.m_stat_llvm_opt_time;
3851-
m_stat_llvm_jit_time += lljitter.m_stat_llvm_jit_time;
3852-
m_stat_max_llvm_local_mem = std::max(m_stat_max_llvm_local_mem,
3853-
lljitter.m_llvm_local_mem);
3839+
if (!cached) {
3840+
BackendLLVM lljitter(*this, group, ctx);
3841+
lljitter.run();
3842+
3843+
// NOTE: it is now possible to optimize and not JIT
3844+
// which would leave the cleanup to happen
3845+
// when the ShadingSystem is destroyed
3846+
3847+
// Only cleanup when are not batching or if
3848+
// the batch jit has already happened,
3849+
// as it requires the ops so we can't delete them yet!
3850+
if (((renderer()->batched(WidthOf<16>()) == nullptr)
3851+
&& (renderer()->batched(WidthOf<8>()) == nullptr)
3852+
&& (renderer()->batched(WidthOf<4>()) == nullptr))
3853+
|| group.batch_jitted()) {
3854+
group_post_jit_cleanup(group);
3855+
}
3856+
3857+
group.m_jitted = true;
3858+
spin_lock stat_lock(m_stat_mutex);
3859+
m_stat_opt_locking_time += locking_time;
3860+
m_stat_optimization_time += timer();
3861+
m_stat_total_llvm_time += lljitter.m_stat_total_llvm_time;
3862+
m_stat_llvm_setup_time += lljitter.m_stat_llvm_setup_time;
3863+
m_stat_llvm_irgen_time += lljitter.m_stat_llvm_irgen_time;
3864+
m_stat_llvm_opt_time += lljitter.m_stat_llvm_opt_time;
3865+
m_stat_llvm_jit_time += lljitter.m_stat_llvm_jit_time;
3866+
m_stat_max_llvm_local_mem = std::max(m_stat_max_llvm_local_mem,
3867+
lljitter.m_llvm_local_mem);
3868+
}
38543869
}
38553870

38563871
if (ctx_allocated) {

0 commit comments

Comments
 (0)