Skip to content

Commit 2208f7d

Browse files
authored
Add TCNN_HALF_PRECISION definition to kernel
1 parent 38e50ed commit 2208f7d

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/rtc_kernel.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ CudaRtcKernel::CudaRtcKernel(const std::string& name, const std::string& kernel_
178178
{OPTS}
179179
*/
180180
181+
#define TCNN_HALF_PRECISION {TCNN_HALF_PRECISION}
182+
181183
// NVRTC does not come with the C++ standard library out of the box and
182184
// it would be troublesome to bundle it or require users to have it installed
183185
// in readily available paths. So we instead include a minimal custom
@@ -190,7 +192,8 @@ CudaRtcKernel::CudaRtcKernel(const std::string& name, const std::string& kernel_
190192
"KERNEL_NAME"_a = name,
191193
"PREAMBLE"_a = generate_device_code_preamble(),
192194
"OPTS"_a = join(opts, "\n"),
193-
"KERNEL_CODE"_a = kernel_code
195+
"KERNEL_CODE"_a = kernel_code,
196+
"TCNN_HALF_PRECISION"_a = TCNN_HALF_PRECISION
194197
);
195198

196199
size_t code_hash = hash_combine(0, complete_code);

0 commit comments

Comments
 (0)