@@ -15,49 +15,175 @@ inline void gpuCheck(cudaError_t code, const char *file, int line) {
1515 }
1616#define CUDA_CHECK (ans ) { gpuCheck ((ans), __FILE__, __LINE__); }
1717
18+ // Optimized kernel with improved memory access patterns
1819__global__ void float_round_kernel_inplace (float * input,
1920 int N,
2021 float max_exp,
2122 float min_exp,
2223 int mantissa_upper_bound,
2324 float mantissa_scale,
2425 float inv_mantissa_scale) {
25- int idx = blockIdx .x * blockDim .x + threadIdx .x ;
26- if (idx >= N) return ;
27-
28- float x_val = input[idx];
29- if (x_val == 0 .0f ) return ;
30-
31- // 1. Use standard math functions with fast math optimizations
32- const float s = copysignf (1 .0f , x_val);
33- const float x_abs = fabsf (x_val);
34- const float exponent_floor = floorf (log2f (x_abs)); // Will be optimized with --use_fast_math
35-
36- float exponent = fmaxf (fminf (exponent_floor, max_exp), min_exp);
37- float exp2_val = exp2f (exponent); // Compiler will optimize with --use_fast_math
38-
39- float scaled = x_abs / exp2_val;
40- scaled = fmaxf (scaled, 1 .0f );
41-
42- // 2. Use CUDA's built-in rounding
43- const float mantissa_unrounded = (scaled - 1 .0f ) * mantissa_scale;
44- const int mantissa = __float2int_rn (mantissa_unrounded);
45-
46- // 3. Branchless overflow handling
47- const bool overflow = mantissa >= mantissa_upper_bound;
48- const float exponent_overflow = fmaxf (fminf (exponent + 1 .0f , max_exp), min_exp);
49- const float exp2_val_overflow = exp2f (exponent_overflow);
26+ // Use vectorized loads for better memory coalescing
27+ const int tid = blockIdx .x * blockDim .x + threadIdx .x ;
28+ const int stride = blockDim .x * gridDim .x ;
29+
30+ // Process multiple elements per thread to improve memory bandwidth utilization
31+ for (int idx = tid; idx < N; idx += stride) {
32+ float x_val = input[idx];
33+
34+ // Early exit for zero values (reduces unnecessary computation)
35+ if (x_val == 0 .0f ) continue ;
36+
37+ // Use fast math intrinsics for better performance
38+ const float s = copysignf (1 .0f , x_val);
39+ const float x_abs = fabsf (x_val);
40+
41+ // Use fast log2 and exp2 intrinsics
42+ const float exponent_floor = log2f (x_abs);
43+ float exponent = fmaxf (fminf (exponent_floor, max_exp), min_exp);
44+ float exp2_val = exp2f (exponent);
45+
46+ // Optimize division with reciprocal multiplication
47+ float scaled = fmaf (x_abs, __frcp_rn (exp2_val), 0 .0f );
48+ scaled = fmaxf (scaled, 1 .0f );
49+
50+ // Use FMA for better instruction fusion
51+ const float mantissa_unrounded = fmaf (scaled - 1 .0f , mantissa_scale, 0 .0f );
52+ const int mantissa = __float2int_rn (mantissa_unrounded);
53+
54+ // Branchless overflow handling with predicated execution
55+ const bool overflow = mantissa >= mantissa_upper_bound;
56+ const float exponent_overflow = fmaxf (fminf (fmaf (exponent, 1 .0f , 1 .0f ), max_exp), min_exp);
57+ const float exp2_val_overflow = exp2f (exponent_overflow);
58+
59+ // Select final values without branches using predication
60+ const float final_exp2 = overflow ? exp2_val_overflow : exp2_val;
61+ const int final_mantissa = overflow ? 0 : mantissa;
62+
63+ // Use FMA for final computation
64+ const float fraction = static_cast <float >(final_mantissa) * inv_mantissa_scale;
65+ input[idx] = fmaf (fmaf (fraction, final_exp2, final_exp2), s, 0 .0f );
66+ }
67+ }
5068
51- // 4. Select final values without branches
52- const float final_exp2 = overflow ? exp2_val_overflow : exp2_val;
53- const int final_mantissa = overflow ? 0 : mantissa;
69+ // Vectorized kernel using float4 for maximum memory bandwidth
70+ __global__ void float_round_kernel_vectorized (float4 * input_vec,
71+ int N_vec,
72+ float max_exp,
73+ float min_exp,
74+ int mantissa_upper_bound,
75+ float mantissa_scale,
76+ float inv_mantissa_scale) {
77+ const int tid = blockIdx .x * blockDim .x + threadIdx .x ;
78+ const int stride = blockDim .x * gridDim .x ;
79+
80+ // Process float4 elements (4 floats per thread)
81+ for (int idx = tid; idx < N_vec; idx += stride) {
82+ float4 vec = input_vec[idx];
83+
84+ // Process each component of the float4 vector
85+ #pragma unroll
86+ for (int i = 0 ; i < 4 ; ++i) {
87+ float * x_ptr = reinterpret_cast <float *>(&vec) + i;
88+ float x_val = *x_ptr;
89+
90+ if (x_val == 0 .0f ) continue ;
91+
92+ // Use fast math intrinsics
93+ const float s = copysignf (1 .0f , x_val);
94+ const float x_abs = fabsf (x_val);
95+ const float exponent_floor = log2f (x_abs);
96+ float exponent = fmaxf (fminf (exponent_floor, max_exp), min_exp);
97+ float exp2_val = exp2f (exponent);
98+
99+ // Optimized computation with FMA
100+ float scaled = fmaf (x_abs, __frcp_rn (exp2_val), 0 .0f );
101+ scaled = fmaxf (scaled, 1 .0f );
102+
103+ const float mantissa_unrounded = fmaf (scaled - 1 .0f , mantissa_scale, 0 .0f );
104+ const int mantissa = __float2int_rn (mantissa_unrounded);
105+
106+ const bool overflow = mantissa >= mantissa_upper_bound;
107+ const float exponent_overflow = fmaxf (fminf (fmaf (exponent, 1 .0f , 1 .0f ), max_exp), min_exp);
108+ const float exp2_val_overflow = exp2f (exponent_overflow);
109+
110+ const float final_exp2 = overflow ? exp2_val_overflow : exp2_val;
111+ const int final_mantissa = overflow ? 0 : mantissa;
112+
113+ const float fraction = static_cast <float >(final_mantissa) * inv_mantissa_scale;
114+ *x_ptr = fmaf (fmaf (fraction, final_exp2, final_exp2), s, 0 .0f );
115+ }
116+
117+ // Store the processed float4 vector
118+ input_vec[idx] = vec;
119+ }
120+ }
54121
55- // 5. FMA is automatically used with --use_fast_math
56- const float fraction = static_cast <float >(final_mantissa) * inv_mantissa_scale;
57- input[idx] = s * (1 .0f + fraction) * final_exp2;
122+ // Shared memory optimized kernel for better cache utilization
123+ __global__ void float_round_kernel_shared (float * input,
124+ int N,
125+ float max_exp,
126+ float min_exp,
127+ int mantissa_upper_bound,
128+ float mantissa_scale,
129+ float inv_mantissa_scale) {
130+ __shared__ float shared_data[1024 ]; // Shared memory buffer
131+
132+ const int tid = threadIdx .x ;
133+
134+ for (int base_idx = blockIdx .x * blockDim .x ; base_idx < N; base_idx += blockDim .x * gridDim .x ) {
135+ int idx = base_idx + tid;
136+
137+ // Load data into shared memory with coalesced access
138+ if (idx < N) {
139+ shared_data[tid] = input[idx];
140+ } else {
141+ shared_data[tid] = 0 .0f ;
142+ }
143+
144+ __syncthreads ();
145+
146+ // Process data from shared memory
147+ if (idx < N) {
148+ float x_val = shared_data[tid];
149+
150+ if (x_val != 0 .0f ) {
151+ // Use fast math intrinsics
152+ const float s = copysignf (1 .0f , x_val);
153+ const float x_abs = fabsf (x_val);
154+ const float exponent_floor = log2f (x_abs);
155+ float exponent = fmaxf (fminf (exponent_floor, max_exp), min_exp);
156+ float exp2_val = exp2f (exponent);
157+
158+ // Optimized computation
159+ float scaled = fmaf (x_abs, __frcp_rn (exp2_val), 0 .0f );
160+ scaled = fmaxf (scaled, 1 .0f );
161+
162+ const float mantissa_unrounded = fmaf (scaled - 1 .0f , mantissa_scale, 0 .0f );
163+ const int mantissa = __float2int_rn (mantissa_unrounded);
164+
165+ const bool overflow = mantissa >= mantissa_upper_bound;
166+ const float exponent_overflow = fmaxf (fminf (fmaf (exponent, 1 .0f , 1 .0f ), max_exp), min_exp);
167+ const float exp2_val_overflow = exp2f (exponent_overflow);
168+
169+ const float final_exp2 = overflow ? exp2_val_overflow : exp2_val;
170+ const int final_mantissa = overflow ? 0 : mantissa;
171+
172+ const float fraction = static_cast <float >(final_mantissa) * inv_mantissa_scale;
173+ shared_data[tid] = fmaf (fmaf (fraction, final_exp2, final_exp2), s, 0 .0f );
174+ }
175+ }
176+
177+ __syncthreads ();
178+
179+ // Store back to global memory with coalesced access
180+ if (idx < N) {
181+ input[idx] = shared_data[tid];
182+ }
183+ }
58184}
59185
60- // Function that launches the kernel
186+ // Function that launches the optimized kernel
61187torch::Tensor float_round_cuda_inplace (torch::Tensor input, int exponent_bits, int mantissa_bits, int bias) {
62188 CHECK_CUDA (input);
63189
@@ -73,14 +199,48 @@ torch::Tensor float_round_cuda_inplace(torch::Tensor input, int exponent_bits, i
73199 float inv_mantissa_scale = 1 .0f / mantissa_scale;
74200
75201 float * input_ptr = input.data_ptr <float >();
76- int threads = 1024 ;
202+
203+ // Optimize block and grid size for better occupancy
204+ int device_id = input.device ().index ();
205+ cudaDeviceProp prop;
206+ cudaGetDeviceProperties (&prop, device_id);
207+
208+ // Calculate optimal block size based on register usage and shared memory
209+ int threads = 256 ; // Reduced from 1024 to improve occupancy
77210 int blocks = (numel + threads - 1 ) / threads;
78211
212+ // Ensure we don't exceed maximum blocks per SM
213+ int max_blocks_per_sm = prop.maxBlocksPerMultiProcessor ;
214+ int max_blocks = prop.multiProcessorCount * max_blocks_per_sm;
215+ blocks = min (blocks, max_blocks);
216+
79217 cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
80- float_round_kernel_inplace<<<blocks, threads, 0 , stream>>> (
81- input_ptr, numel, max_exp, min_exp,
82- mantissa_upper_bound, mantissa_scale, inv_mantissa_scale
83- );
218+
219+ // Choose kernel based on input size and optimization strategy
220+ if (numel >= 1000000 ) {
221+ // For large inputs, use vectorized kernel if possible
222+ if (numel % 4 == 0 ) {
223+ float4 * input_vec = reinterpret_cast <float4 *>(input_ptr);
224+ int N_vec = numel / 4 ;
225+ float_round_kernel_vectorized<<<blocks, threads, 0 , stream>>> (
226+ input_vec, N_vec, max_exp, min_exp,
227+ mantissa_upper_bound, mantissa_scale, inv_mantissa_scale
228+ );
229+ } else {
230+ // Use shared memory kernel for better cache utilization
231+ float_round_kernel_shared<<<blocks, threads, 0 , stream>>> (
232+ input_ptr, numel, max_exp, min_exp,
233+ mantissa_upper_bound, mantissa_scale, inv_mantissa_scale
234+ );
235+ }
236+ } else {
237+ // For smaller inputs, use optimized kernel
238+ float_round_kernel_inplace<<<blocks, threads, 0 , stream>>> (
239+ input_ptr, numel, max_exp, min_exp,
240+ mantissa_upper_bound, mantissa_scale, inv_mantissa_scale
241+ );
242+ }
243+
84244 CUDA_CHECK (cudaGetLastError ());
85245
86246 return input;
0 commit comments