Skip to content

Commit 85befbd

Browse files
authored
Improve cub::PtxVersionUncached (#7903)
1 parent de42fa9 commit 85befbd

File tree

1 file changed

+10
-33
lines changed

1 file changed

+10
-33
lines changed

cub/cub/util_device.cuh

Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ CUB_RUNTIME_FUNCTION inline int CurrentDevice()
8282

8383
//! @brief RAII helper which saves the current device and switches to the specified device on construction and switches
8484
//! to the saved device on destruction.
85-
class SwitchDevice
85+
class [[maybe_unused]] SwitchDevice
8686
{
8787
int target_device_;
8888
int original_device_;
@@ -278,39 +278,16 @@ CUB_RUNTIME_FUNCTION cudaError_t PtxVersionUncached(int& ptx_version)
278278
{
279279
// Instantiate `EmptyKernel<void>` in both host and device code to ensure
280280
// it can be called.
281-
using EmptyKernelPtr = void (*)();
282-
[[maybe_unused]] EmptyKernelPtr empty_kernel = detail::EmptyKernel<T>;
283-
284-
// Define a temporary macro that expands to the current target ptx version
285-
// in device code.
286-
// <nv/target> may provide an abstraction for this eventually. For now,
287-
// we have to keep this usage of __CUDA_ARCH__.
288-
# if _CCCL_CUDA_COMPILER(NVHPC)
289-
# define CUB_TEMP_GET_PTX __builtin_current_device_sm()
290-
# else // ^^^ _CCCL_CUDA_COMPILER(NVHPC) ^^^ / vvv !_CCCL_CUDA_COMPILER(NVHPC) vvv
291-
# define CUB_TEMP_GET_PTX _CCCL_PTX_ARCH()
292-
# endif // ^^^ !_CCCL_CUDA_COMPILER(NVHPC) ^^^
281+
[[maybe_unused]] const auto empty_kernel = detail::EmptyKernel<T>;
293282

294283
cudaError_t result = cudaSuccess;
295-
NV_IF_TARGET(
296-
NV_IS_HOST,
297-
(cudaFuncAttributes empty_kernel_attrs;
298-
299-
result = CubDebug(cudaFuncGetAttributes(&empty_kernel_attrs, reinterpret_cast<void*>(empty_kernel)));
300-
301-
ptx_version = empty_kernel_attrs.ptxVersion * 10;),
302-
// NV_IS_DEVICE
303-
(
304-
// This is necessary to ensure instantiation of EmptyKernel in device
305-
// code. The `reinterpret_cast` is necessary to suppress a
306-
// set-but-unused warnings. This is a meme now:
307-
// https://twitter.com/blelbach/status/1222391615576100864
308-
(void) reinterpret_cast<EmptyKernelPtr>(empty_kernel);
309-
310-
ptx_version = CUB_TEMP_GET_PTX;));
311-
312-
# undef CUB_TEMP_GET_PTX
313-
284+
NV_IF_ELSE_TARGET(NV_IS_HOST,
285+
({
286+
cudaFuncAttributes empty_kernel_attrs;
287+
result = CubDebug(cudaFuncGetAttributes(&empty_kernel_attrs, (const void*) empty_kernel));
288+
ptx_version = empty_kernel_attrs.ptxVersion * 10;
289+
}),
290+
({ ptx_version = ::cuda::device::current_compute_capability().get() * 10; }));
314291
return result;
315292
}
316293

@@ -320,7 +297,7 @@ CUB_RUNTIME_FUNCTION cudaError_t PtxVersionUncached(int& ptx_version)
320297
template <class T = void>
321298
_CCCL_HOST cudaError_t PtxVersionUncached(int& ptx_version, int device)
322299
{
323-
[[maybe_unused]] SwitchDevice sd(device);
300+
SwitchDevice sd(device);
324301
return PtxVersionUncached<T>(ptx_version);
325302
}
326303

0 commit comments

Comments
 (0)