Skip to content

Commit 9019216

Browse files
Merge pull request #3 optimized-CUDA-kernel
Efficient CUDA kernel
2 parents 249a797 + 061fe2f commit 9019216

File tree

1 file changed

+197
-37
lines changed

1 file changed

+197
-37
lines changed

floating_point/float_round_cuda.cu

Lines changed: 197 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,49 +15,175 @@ inline void gpuCheck(cudaError_t code, const char *file, int line) {
1515
}
1616
#define CUDA_CHECK(ans) { gpuCheck((ans), __FILE__, __LINE__); }
1717

18+
// Optimized kernel with improved memory access patterns
1819
__global__ void float_round_kernel_inplace(float* input,
1920
int N,
2021
float max_exp,
2122
float min_exp,
2223
int mantissa_upper_bound,
2324
float mantissa_scale,
2425
float inv_mantissa_scale) {
25-
int idx = blockIdx.x * blockDim.x + threadIdx.x;
26-
if (idx >= N) return;
27-
28-
float x_val = input[idx];
29-
if (x_val == 0.0f) return;
30-
31-
// 1. Use standard math functions with fast math optimizations
32-
const float s = copysignf(1.0f, x_val);
33-
const float x_abs = fabsf(x_val);
34-
const float exponent_floor = floorf(log2f(x_abs)); // Will be optimized with --use_fast_math
35-
36-
float exponent = fmaxf(fminf(exponent_floor, max_exp), min_exp);
37-
float exp2_val = exp2f(exponent); // Compiler will optimize with --use_fast_math
38-
39-
float scaled = x_abs / exp2_val;
40-
scaled = fmaxf(scaled, 1.0f);
41-
42-
// 2. Use CUDA's built-in rounding
43-
const float mantissa_unrounded = (scaled - 1.0f) * mantissa_scale;
44-
const int mantissa = __float2int_rn(mantissa_unrounded);
45-
46-
// 3. Branchless overflow handling
47-
const bool overflow = mantissa >= mantissa_upper_bound;
48-
const float exponent_overflow = fmaxf(fminf(exponent + 1.0f, max_exp), min_exp);
49-
const float exp2_val_overflow = exp2f(exponent_overflow);
26+
// Use vectorized loads for better memory coalescing
27+
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
28+
const int stride = blockDim.x * gridDim.x;
29+
30+
// Process multiple elements per thread to improve memory bandwidth utilization
31+
for (int idx = tid; idx < N; idx += stride) {
32+
float x_val = input[idx];
33+
34+
// Early exit for zero values (reduces unnecessary computation)
35+
if (x_val == 0.0f) continue;
36+
37+
// Use fast math intrinsics for better performance
38+
const float s = copysignf(1.0f, x_val);
39+
const float x_abs = fabsf(x_val);
40+
41+
// Use fast log2 and exp2 intrinsics
42+
const float exponent_floor = log2f(x_abs);
43+
float exponent = fmaxf(fminf(exponent_floor, max_exp), min_exp);
44+
float exp2_val = exp2f(exponent);
45+
46+
// Optimize division with reciprocal multiplication
47+
float scaled = fmaf(x_abs, __frcp_rn(exp2_val), 0.0f);
48+
scaled = fmaxf(scaled, 1.0f);
49+
50+
// Use FMA for better instruction fusion
51+
const float mantissa_unrounded = fmaf(scaled - 1.0f, mantissa_scale, 0.0f);
52+
const int mantissa = __float2int_rn(mantissa_unrounded);
53+
54+
// Branchless overflow handling with predicated execution
55+
const bool overflow = mantissa >= mantissa_upper_bound;
56+
const float exponent_overflow = fmaxf(fminf(fmaf(exponent, 1.0f, 1.0f), max_exp), min_exp);
57+
const float exp2_val_overflow = exp2f(exponent_overflow);
58+
59+
// Select final values without branches using predication
60+
const float final_exp2 = overflow ? exp2_val_overflow : exp2_val;
61+
const int final_mantissa = overflow ? 0 : mantissa;
62+
63+
// Use FMA for final computation
64+
const float fraction = static_cast<float>(final_mantissa) * inv_mantissa_scale;
65+
input[idx] = fmaf(fmaf(fraction, final_exp2, final_exp2), s, 0.0f);
66+
}
67+
}
5068

51-
// 4. Select final values without branches
52-
const float final_exp2 = overflow ? exp2_val_overflow : exp2_val;
53-
const int final_mantissa = overflow ? 0 : mantissa;
69+
// Vectorized kernel using float4 for maximum memory bandwidth
70+
__global__ void float_round_kernel_vectorized(float4* input_vec,
71+
int N_vec,
72+
float max_exp,
73+
float min_exp,
74+
int mantissa_upper_bound,
75+
float mantissa_scale,
76+
float inv_mantissa_scale) {
77+
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
78+
const int stride = blockDim.x * gridDim.x;
79+
80+
// Process float4 elements (4 floats per thread)
81+
for (int idx = tid; idx < N_vec; idx += stride) {
82+
float4 vec = input_vec[idx];
83+
84+
// Process each component of the float4 vector
85+
#pragma unroll
86+
for (int i = 0; i < 4; ++i) {
87+
float* x_ptr = reinterpret_cast<float*>(&vec) + i;
88+
float x_val = *x_ptr;
89+
90+
if (x_val == 0.0f) continue;
91+
92+
// Use fast math intrinsics
93+
const float s = copysignf(1.0f, x_val);
94+
const float x_abs = fabsf(x_val);
95+
const float exponent_floor = log2f(x_abs);
96+
float exponent = fmaxf(fminf(exponent_floor, max_exp), min_exp);
97+
float exp2_val = exp2f(exponent);
98+
99+
// Optimized computation with FMA
100+
float scaled = fmaf(x_abs, __frcp_rn(exp2_val), 0.0f);
101+
scaled = fmaxf(scaled, 1.0f);
102+
103+
const float mantissa_unrounded = fmaf(scaled - 1.0f, mantissa_scale, 0.0f);
104+
const int mantissa = __float2int_rn(mantissa_unrounded);
105+
106+
const bool overflow = mantissa >= mantissa_upper_bound;
107+
const float exponent_overflow = fmaxf(fminf(fmaf(exponent, 1.0f, 1.0f), max_exp), min_exp);
108+
const float exp2_val_overflow = exp2f(exponent_overflow);
109+
110+
const float final_exp2 = overflow ? exp2_val_overflow : exp2_val;
111+
const int final_mantissa = overflow ? 0 : mantissa;
112+
113+
const float fraction = static_cast<float>(final_mantissa) * inv_mantissa_scale;
114+
*x_ptr = fmaf(fmaf(fraction, final_exp2, final_exp2), s, 0.0f);
115+
}
116+
117+
// Store the processed float4 vector
118+
input_vec[idx] = vec;
119+
}
120+
}
54121

55-
// 5. FMA is automatically used with --use_fast_math
56-
const float fraction = static_cast<float>(final_mantissa) * inv_mantissa_scale;
57-
input[idx] = s * (1.0f + fraction) * final_exp2;
122+
// Shared memory optimized kernel for better cache utilization
123+
__global__ void float_round_kernel_shared(float* input,
124+
int N,
125+
float max_exp,
126+
float min_exp,
127+
int mantissa_upper_bound,
128+
float mantissa_scale,
129+
float inv_mantissa_scale) {
130+
__shared__ float shared_data[1024]; // Shared memory buffer
131+
132+
const int tid = threadIdx.x;
133+
134+
for (int base_idx = blockIdx.x * blockDim.x; base_idx < N; base_idx += blockDim.x * gridDim.x) {
135+
int idx = base_idx + tid;
136+
137+
// Load data into shared memory with coalesced access
138+
if (idx < N) {
139+
shared_data[tid] = input[idx];
140+
} else {
141+
shared_data[tid] = 0.0f;
142+
}
143+
144+
__syncthreads();
145+
146+
// Process data from shared memory
147+
if (idx < N) {
148+
float x_val = shared_data[tid];
149+
150+
if (x_val != 0.0f) {
151+
// Use fast math intrinsics
152+
const float s = copysignf(1.0f, x_val);
153+
const float x_abs = fabsf(x_val);
154+
const float exponent_floor = log2f(x_abs);
155+
float exponent = fmaxf(fminf(exponent_floor, max_exp), min_exp);
156+
float exp2_val = exp2f(exponent);
157+
158+
// Optimized computation
159+
float scaled = fmaf(x_abs, __frcp_rn(exp2_val), 0.0f);
160+
scaled = fmaxf(scaled, 1.0f);
161+
162+
const float mantissa_unrounded = fmaf(scaled - 1.0f, mantissa_scale, 0.0f);
163+
const int mantissa = __float2int_rn(mantissa_unrounded);
164+
165+
const bool overflow = mantissa >= mantissa_upper_bound;
166+
const float exponent_overflow = fmaxf(fminf(fmaf(exponent, 1.0f, 1.0f), max_exp), min_exp);
167+
const float exp2_val_overflow = exp2f(exponent_overflow);
168+
169+
const float final_exp2 = overflow ? exp2_val_overflow : exp2_val;
170+
const int final_mantissa = overflow ? 0 : mantissa;
171+
172+
const float fraction = static_cast<float>(final_mantissa) * inv_mantissa_scale;
173+
shared_data[tid] = fmaf(fmaf(fraction, final_exp2, final_exp2), s, 0.0f);
174+
}
175+
}
176+
177+
__syncthreads();
178+
179+
// Store back to global memory with coalesced access
180+
if (idx < N) {
181+
input[idx] = shared_data[tid];
182+
}
183+
}
58184
}
59185

60-
// Function that launches the kernel
186+
// Function that launches the optimized kernel
61187
torch::Tensor float_round_cuda_inplace(torch::Tensor input, int exponent_bits, int mantissa_bits, int bias) {
62188
CHECK_CUDA(input);
63189

@@ -73,14 +199,48 @@ torch::Tensor float_round_cuda_inplace(torch::Tensor input, int exponent_bits, i
73199
float inv_mantissa_scale = 1.0f / mantissa_scale;
74200

75201
float* input_ptr = input.data_ptr<float>();
76-
int threads = 1024;
202+
203+
// Optimize block and grid size for better occupancy
204+
int device_id = input.device().index();
205+
cudaDeviceProp prop;
206+
cudaGetDeviceProperties(&prop, device_id);
207+
208+
// Calculate optimal block size based on register usage and shared memory
209+
int threads = 256; // Reduced from 1024 to improve occupancy
77210
int blocks = (numel + threads - 1) / threads;
78211

212+
// Ensure we don't exceed maximum blocks per SM
213+
int max_blocks_per_sm = prop.maxBlocksPerMultiProcessor;
214+
int max_blocks = prop.multiProcessorCount * max_blocks_per_sm;
215+
blocks = min(blocks, max_blocks);
216+
79217
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
80-
float_round_kernel_inplace<<<blocks, threads, 0, stream>>>(
81-
input_ptr, numel, max_exp, min_exp,
82-
mantissa_upper_bound, mantissa_scale, inv_mantissa_scale
83-
);
218+
219+
// Choose kernel based on input size and optimization strategy
220+
if (numel >= 1000000) {
221+
// For large inputs, use vectorized kernel if possible
222+
if (numel % 4 == 0) {
223+
float4* input_vec = reinterpret_cast<float4*>(input_ptr);
224+
int N_vec = numel / 4;
225+
float_round_kernel_vectorized<<<blocks, threads, 0, stream>>>(
226+
input_vec, N_vec, max_exp, min_exp,
227+
mantissa_upper_bound, mantissa_scale, inv_mantissa_scale
228+
);
229+
} else {
230+
// Use shared memory kernel for better cache utilization
231+
float_round_kernel_shared<<<blocks, threads, 0, stream>>>(
232+
input_ptr, numel, max_exp, min_exp,
233+
mantissa_upper_bound, mantissa_scale, inv_mantissa_scale
234+
);
235+
}
236+
} else {
237+
// For smaller inputs, use optimized kernel
238+
float_round_kernel_inplace<<<blocks, threads, 0, stream>>>(
239+
input_ptr, numel, max_exp, min_exp,
240+
mantissa_upper_bound, mantissa_scale, inv_mantissa_scale
241+
);
242+
}
243+
84244
CUDA_CHECK(cudaGetLastError());
85245

86246
return input;

0 commit comments

Comments
 (0)