@@ -132,6 +132,7 @@ CudaRtcKernel::CudaRtcKernel(const std::string& name, const std::string& kernel_
132132
133133 std::vector<std::string> opts = {
134134 fmt::format (" --gpu-architecture=compute_{}" , cc),
135+ fmt::format (" -DTCNN_HALF_PRECISION={}" , TCNN_HALF_PRECISION),
135136 fmt::format (" -DTCNN_MIN_GPU_ARCH={}" , cc),
136137 " --std=c++14" ,
137138#ifdef TCNN_RTC_USE_FAST_MATH
@@ -178,8 +179,6 @@ CudaRtcKernel::CudaRtcKernel(const std::string& name, const std::string& kernel_
178179 {OPTS}
179180 */
180181
181- #define TCNN_HALF_PRECISION {TCNN_HALF_PRECISION}
182-
183182 // NVRTC does not come with the C++ standard library out of the box and
184183 // it would be troublesome to bundle it or require users to have it installed
185184 // in readily available paths. So we instead include a minimal custom
@@ -192,8 +191,7 @@ CudaRtcKernel::CudaRtcKernel(const std::string& name, const std::string& kernel_
192191 " KERNEL_NAME" _a = name,
193192 " PREAMBLE" _a = generate_device_code_preamble (),
194193 " OPTS" _a = join (opts, " \n " ),
195- " KERNEL_CODE" _a = kernel_code,
196- " TCNN_HALF_PRECISION" _a = TCNN_HALF_PRECISION
194+ " KERNEL_CODE" _a = kernel_code
197195 );
198196
199197 size_t code_hash = hash_combine (0 , complete_code);
0 commit comments