Skip to content

Commit 4728fe0

Browse files
committed
Feat: hint about scalar range in thread intrinsics
1 parent f207e4e commit 4728fe0

File tree

2 files changed

+42
-12
lines changed

2 files changed

+42
-12
lines changed

crates/cuda_std/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Notable changes to this project will be documented in this file.
44

55
## Unreleased
66

7+
- Thread/Block/Grid index/dim intrinsics now hint to llvm that their range is in some bound declared by CUDA. Hopefully allowing for more optimizations.
8+
79
## 0.2.1 - 12/8/21
810

911
- Fixed `shared_array!` not using fully qualified MaybeUninit.

crates/cuda_std/src/thread.rs

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,76 +49,104 @@ extern "C" {
4949
fn __nvvm_system_fence();
5050
}
5151

52+
#[cfg(target_os = "cuda")]
53+
macro_rules! inbounds {
54+
// the bounds were taken mostly from the cuda C++ programming guide, i also
55+
// double-checked with what cuda clang does by checking its emitted llvm ir's scalar metadata
56+
($func_name:ident, $bound:expr) => {{
57+
let val = unsafe { $func_name() };
58+
if val > $bound {
59+
// SAFETY: this condition is declared unreachable by compute capability max bound
60+
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities
61+
// we do this to potentially allow for better optimizations by LLVM
62+
unsafe { core::hint::unreachable_unchecked() }
63+
} else {
64+
val
65+
}
66+
}};
67+
($func_name:ident, $lower_bound:expr, $upper_bound:expr) => {{
68+
let val = unsafe { $func_name() };
69+
if val < $lower_bound || val > $upper_bound {
70+
// SAFETY: this condition is declared unreachable by compute capability max bound
71+
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities
72+
// we do this to potentially allow for better optimizations by LLVM
73+
unsafe { core::hint::unreachable_unchecked() }
74+
} else {
75+
val
76+
}
77+
}};
78+
}
79+
5280
#[gpu_only]
5381
#[inline(always)]
5482
pub fn thread_idx_x() -> u32 {
55-
unsafe { __nvvm_thread_idx_x() }
83+
inbounds!(__nvvm_thread_idx_x, 1024)
5684
}
5785

5886
#[gpu_only]
5987
#[inline(always)]
6088
pub fn thread_idx_y() -> u32 {
61-
unsafe { __nvvm_thread_idx_y() }
89+
inbounds!(__nvvm_thread_idx_y, 1024)
6290
}
6391

6492
#[gpu_only]
6593
#[inline(always)]
6694
pub fn thread_idx_z() -> u32 {
67-
unsafe { __nvvm_thread_idx_z() }
95+
inbounds!(__nvvm_thread_idx_z, 64)
6896
}
6997

7098
#[gpu_only]
7199
#[inline(always)]
72100
pub fn block_idx_x() -> u32 {
73-
unsafe { __nvvm_block_idx_x() }
101+
inbounds!(__nvvm_block_idx_x, 2147483647)
74102
}
75103

76104
#[gpu_only]
77105
#[inline(always)]
78106
pub fn block_idx_y() -> u32 {
79-
unsafe { __nvvm_block_idx_y() }
107+
inbounds!(__nvvm_block_idx_y, 65535)
80108
}
81109

82110
#[gpu_only]
83111
#[inline(always)]
84112
pub fn block_idx_z() -> u32 {
85-
unsafe { __nvvm_block_idx_z() }
113+
inbounds!(__nvvm_block_idx_z, 65535)
86114
}
87115

88116
#[gpu_only]
89117
#[inline(always)]
90118
pub fn block_dim_x() -> u32 {
91-
unsafe { __nvvm_block_dim_x() }
119+
inbounds!(__nvvm_block_dim_x, 1, 1025)
92120
}
93121

94122
#[gpu_only]
95123
#[inline(always)]
96124
pub fn block_dim_y() -> u32 {
97-
unsafe { __nvvm_block_dim_y() }
125+
inbounds!(__nvvm_block_dim_y, 1, 1025)
98126
}
99127

100128
#[gpu_only]
101129
#[inline(always)]
102130
pub fn block_dim_z() -> u32 {
103-
unsafe { __nvvm_block_dim_z() }
131+
inbounds!(__nvvm_block_dim_z, 1, 65)
104132
}
105133

106134
#[gpu_only]
107135
#[inline(always)]
108136
pub fn grid_dim_x() -> u32 {
109-
unsafe { __nvvm_grid_dim_x() }
137+
inbounds!(__nvvm_grid_dim_x, 1, 2147483648)
110138
}
111139

112140
#[gpu_only]
113141
#[inline(always)]
114142
pub fn grid_dim_y() -> u32 {
115-
unsafe { __nvvm_grid_dim_y() }
143+
inbounds!(__nvvm_grid_dim_y, 1, 65536)
116144
}
117145

118146
#[gpu_only]
119147
#[inline(always)]
120148
pub fn grid_dim_z() -> u32 {
121-
unsafe { __nvvm_grid_dim_z() }
149+
inbounds!(__nvvm_grid_dim_z, 1, 65536)
122150
}
123151

124152
/// Gets the 3d index of the thread currently executing the kernel.

0 commit comments

Comments
 (0)