Skip to content

Commit 09ee3d6

Browse files
authored
fix(rtc): move TCNN_HALF_PRECISION definition to compiler opts
1 parent 2208f7d commit 09ee3d6

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

src/rtc_kernel.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ CudaRtcKernel::CudaRtcKernel(const std::string& name, const std::string& kernel_
132132

133133
std::vector<std::string> opts = {
134134
fmt::format("--gpu-architecture=compute_{}", cc),
135+
fmt::format("-DTCNN_HALF_PRECISION={}", TCNN_HALF_PRECISION),
135136
fmt::format("-DTCNN_MIN_GPU_ARCH={}", cc),
136137
"--std=c++14",
137138
#ifdef TCNN_RTC_USE_FAST_MATH
@@ -178,8 +179,6 @@ CudaRtcKernel::CudaRtcKernel(const std::string& name, const std::string& kernel_
178179
{OPTS}
179180
*/
180181
181-
#define TCNN_HALF_PRECISION {TCNN_HALF_PRECISION}
182-
183182
// NVRTC does not come with the C++ standard library out of the box and
184183
// it would be troublesome to bundle it or require users to have it installed
185184
// in readily available paths. So we instead include a minimal custom
@@ -192,8 +191,7 @@ CudaRtcKernel::CudaRtcKernel(const std::string& name, const std::string& kernel_
192191
"KERNEL_NAME"_a = name,
193192
"PREAMBLE"_a = generate_device_code_preamble(),
194193
"OPTS"_a = join(opts, "\n"),
195-
"KERNEL_CODE"_a = kernel_code,
196-
"TCNN_HALF_PRECISION"_a = TCNN_HALF_PRECISION
194+
"KERNEL_CODE"_a = kernel_code
197195
);
198196

199197
size_t code_hash = hash_combine(0, complete_code);

0 commit comments

Comments
 (0)