Skip to content

Commit ac2674f

Browse files
committed
Fix warp shuffle codegen
This must have changed / broken in a rustc update.
1 parent 2dbfc5e commit ac2674f

File tree

2 files changed

+262
-3
lines changed

2 files changed

+262
-3
lines changed

crates/cuda_std/src/warp.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,16 @@ pub enum WarpShuffleMode {
733733
Xor,
734734
}
735735

736+
// C-compatible struct to match LLVM IR's {i32, i8} return type
737+
// This fixes an ABI mismatch where Rust would represent (u32, bool) as [2 x i32]
738+
// but the LLVM intrinsic returns {i32, i8} (a struct, not an array)
739+
#[doc(hidden)]
740+
#[repr(C)]
741+
pub struct WarpShuffleResult {
742+
value: u32,
743+
predicate: u8,
744+
}
745+
736746
#[gpu_only]
737747
unsafe fn warp_shuffle_32(
738748
mode: WarpShuffleMode,
@@ -743,8 +753,8 @@ unsafe fn warp_shuffle_32(
743753
) -> (u32, bool) {
744754
extern "C" {
745755
// see libintrinsics.ll
746-
#[allow(improper_ctypes)]
747-
fn __nvvm_warp_shuffle(mask: u32, mode: u32, a: u32, b: u32, c: u32) -> (u32, bool);
756+
// Returns {i32, i8} in LLVM IR, which maps to our WarpShuffleResult struct
757+
fn __nvvm_warp_shuffle(mask: u32, mode: u32, a: u32, b: u32, c: u32) -> WarpShuffleResult;
748758
}
749759

750760
assert!(
@@ -757,7 +767,8 @@ unsafe fn warp_shuffle_32(
757767
c |= 0b11111;
758768
c |= (32 - width) << 8;
759769

760-
__nvvm_warp_shuffle(mask, mode as u32, value, b, c)
770+
let result = __nvvm_warp_shuffle(mask, mode as u32, value, b, c);
771+
(result.value, result.predicate != 0)
761772
}
762773

763774
unsafe fn warp_shuffle_128(
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
// Test CUDA warp shuffle functions compile correctly
2+
// build-pass
3+
4+
use cuda_std::kernel;
5+
use cuda_std::warp;
6+
7+
#[kernel]
8+
pub unsafe fn test_warp_shuffle_functions() {
9+
let mask = 0xFFFFFFFF_u32; // Full warp mask
10+
let width = 32_u32; // Full warp width
11+
12+
// Test warp_shuffle_xor with various types
13+
{
14+
// 8-bit types
15+
let val_i8: i8 = 42;
16+
let (res_i8, pred_i8) = warp::warp_shuffle_xor(mask, val_i8, 1, width);
17+
18+
let val_u8: u8 = 42;
19+
let (res_u8, pred_u8) = warp::warp_shuffle_xor(mask, val_u8, 1, width);
20+
21+
// 16-bit types
22+
let val_i16: i16 = 1234;
23+
let (res_i16, pred_i16) = warp::warp_shuffle_xor(mask, val_i16, 2, width);
24+
25+
let val_u16: u16 = 1234;
26+
let (res_u16, pred_u16) = warp::warp_shuffle_xor(mask, val_u16, 2, width);
27+
28+
// 32-bit types
29+
let val_i32: i32 = 123456;
30+
let (res_i32, pred_i32) = warp::warp_shuffle_xor(mask, val_i32, 4, width);
31+
32+
let val_u32: u32 = 123456;
33+
let (res_u32, pred_u32) = warp::warp_shuffle_xor(mask, val_u32, 4, width);
34+
35+
let val_f32: f32 = 3.14159;
36+
let (res_f32, pred_f32) = warp::warp_shuffle_xor(mask, val_f32, 8, width);
37+
38+
// 64-bit types
39+
let val_i64: i64 = 1234567890;
40+
let (res_i64, pred_i64) = warp::warp_shuffle_xor(mask, val_i64, 16, width);
41+
42+
let val_u64: u64 = 1234567890;
43+
let (res_u64, pred_u64) = warp::warp_shuffle_xor(mask, val_u64, 16, width);
44+
45+
let val_f64: f64 = 2.718281828;
46+
let (res_f64, pred_f64) = warp::warp_shuffle_xor(mask, val_f64, 16, width);
47+
48+
// 128-bit types
49+
let val_i128: i128 = 12345678901234567890;
50+
let (res_i128, pred_i128) = warp::warp_shuffle_xor(mask, val_i128, 1, width);
51+
52+
let val_u128: u128 = 12345678901234567890;
53+
let (res_u128, pred_u128) = warp::warp_shuffle_xor(mask, val_u128, 1, width);
54+
}
55+
56+
// Test warp_shuffle_down with various types
57+
{
58+
let delta = 1_u32;
59+
60+
let val_i32: i32 = 42;
61+
let (res_i32, pred_i32) = warp::warp_shuffle_down(mask, val_i32, delta, width);
62+
63+
let val_u32: u32 = 42;
64+
let (res_u32, pred_u32) = warp::warp_shuffle_down(mask, val_u32, delta, width);
65+
66+
let val_f32: f32 = 1.0;
67+
let (res_f32, pred_f32) = warp::warp_shuffle_down(mask, val_f32, delta, width);
68+
69+
let val_i64: i64 = 100;
70+
let (res_i64, pred_i64) = warp::warp_shuffle_down(mask, val_i64, delta, width);
71+
72+
let val_u64: u64 = 100;
73+
let (res_u64, pred_u64) = warp::warp_shuffle_down(mask, val_u64, delta, width);
74+
75+
let val_f64: f64 = 1.0;
76+
let (res_f64, pred_f64) = warp::warp_shuffle_down(mask, val_f64, delta, width);
77+
}
78+
79+
// Test warp_shuffle_up with various types
80+
{
81+
let delta = 1_u32;
82+
83+
let val_i32: i32 = 42;
84+
let (res_i32, pred_i32) = warp::warp_shuffle_up(mask, val_i32, delta, width);
85+
86+
let val_u32: u32 = 42;
87+
let (res_u32, pred_u32) = warp::warp_shuffle_up(mask, val_u32, delta, width);
88+
89+
let val_f32: f32 = 1.0;
90+
let (res_f32, pred_f32) = warp::warp_shuffle_up(mask, val_f32, delta, width);
91+
92+
let val_i64: i64 = 100;
93+
let (res_i64, pred_i64) = warp::warp_shuffle_up(mask, val_i64, delta, width);
94+
95+
let val_u64: u64 = 100;
96+
let (res_u64, pred_u64) = warp::warp_shuffle_up(mask, val_u64, delta, width);
97+
98+
let val_f64: f64 = 1.0;
99+
let (res_f64, pred_f64) = warp::warp_shuffle_up(mask, val_f64, delta, width);
100+
}
101+
102+
// Test warp_shuffle_idx with various types
103+
{
104+
let idx = 5_u32;
105+
106+
let val_i32: i32 = 42;
107+
let (res_i32, pred_i32) = warp::warp_shuffle_idx(mask, val_i32, idx, width);
108+
109+
let val_u32: u32 = 42;
110+
let (res_u32, pred_u32) = warp::warp_shuffle_idx(mask, val_u32, idx, width);
111+
112+
let val_f32: f32 = 1.0;
113+
let (res_f32, pred_f32) = warp::warp_shuffle_idx(mask, val_f32, idx, width);
114+
115+
let val_i64: i64 = 100;
116+
let (res_i64, pred_i64) = warp::warp_shuffle_idx(mask, val_i64, idx, width);
117+
118+
let val_u64: u64 = 100;
119+
let (res_u64, pred_u64) = warp::warp_shuffle_idx(mask, val_u64, idx, width);
120+
121+
let val_f64: f64 = 1.0;
122+
let (res_f64, pred_f64) = warp::warp_shuffle_idx(mask, val_f64, idx, width);
123+
}
124+
125+
// Test with different mask values
126+
{
127+
let partial_mask = 0x0000FFFF_u32; // Lower 16 lanes
128+
let val: i32 = 123;
129+
let (res, pred) = warp::warp_shuffle_xor(partial_mask, val, 1, width);
130+
}
131+
132+
// Test with different width values (must be power of 2 and <= 32)
133+
{
134+
let val: i32 = 456;
135+
let lane_mask = 1_u32;
136+
137+
// Width 16
138+
let (res16, pred16) = warp::warp_shuffle_xor(mask, val, lane_mask, 16);
139+
140+
// Width 8
141+
let (res8, pred8) = warp::warp_shuffle_xor(mask, val, lane_mask, 8);
142+
143+
// Width 4
144+
let (res4, pred4) = warp::warp_shuffle_xor(mask, val, lane_mask, 4);
145+
146+
// Width 2
147+
let (res2, pred2) = warp::warp_shuffle_xor(mask, val, lane_mask, 2);
148+
}
149+
150+
// Test with half-precision floating point types (if available)
151+
#[cfg(feature = "half")]
152+
{
153+
use half::{bf16, f16};
154+
155+
let val_f16 = f16::from_f32(1.5);
156+
let (res_f16, pred_f16) = warp::warp_shuffle_xor(mask, val_f16, 1, width);
157+
158+
let val_bf16 = bf16::from_f32(2.5);
159+
let (res_bf16, pred_bf16) = warp::warp_shuffle_xor(mask, val_bf16, 1, width);
160+
}
161+
}
162+
163+
// Test edge cases and boundary conditions
164+
#[kernel]
165+
pub unsafe fn test_warp_shuffle_edge_cases() {
166+
let mask = 0xFFFFFFFF_u32;
167+
168+
// Test with lane_mask = 0 (should shuffle with same lane)
169+
{
170+
let val: i32 = 999;
171+
let (res, pred) = warp::warp_shuffle_xor(mask, val, 0, 32);
172+
}
173+
174+
// Test with maximum lane_mask
175+
{
176+
let val: i32 = 888;
177+
let (res, pred) = warp::warp_shuffle_xor(mask, val, 31, 32);
178+
}
179+
180+
// Test shuffle_down with delta = 0
181+
{
182+
let val: i32 = 777;
183+
let (res, pred) = warp::warp_shuffle_down(mask, val, 0, 32);
184+
}
185+
186+
// Test shuffle_up with delta = 0
187+
{
188+
let val: i32 = 666;
189+
let (res, pred) = warp::warp_shuffle_up(mask, val, 0, 32);
190+
}
191+
192+
// Test shuffle_idx with idx = 0 and idx = 31
193+
{
194+
let val: i32 = 555;
195+
let (res0, pred0) = warp::warp_shuffle_idx(mask, val, 0, 32);
196+
let (res31, pred31) = warp::warp_shuffle_idx(mask, val, 31, 32);
197+
}
198+
}
199+
200+
// Test that the functions work in practical scenarios
201+
#[kernel]
202+
pub unsafe fn test_warp_shuffle_practical() {
203+
let lane_id = warp::lane_id();
204+
let mask = 0xFFFFFFFF_u32;
205+
206+
// Butterfly reduction pattern using XOR shuffle
207+
{
208+
let mut val = lane_id as i32;
209+
210+
// Stage 1: XOR with distance 16
211+
let (shuffled, _) = warp::warp_shuffle_xor(mask, val, 16, 32);
212+
val += shuffled;
213+
214+
// Stage 2: XOR with distance 8
215+
let (shuffled, _) = warp::warp_shuffle_xor(mask, val, 8, 32);
216+
val += shuffled;
217+
218+
// Stage 3: XOR with distance 4
219+
let (shuffled, _) = warp::warp_shuffle_xor(mask, val, 4, 32);
220+
val += shuffled;
221+
222+
// Stage 4: XOR with distance 2
223+
let (shuffled, _) = warp::warp_shuffle_xor(mask, val, 2, 32);
224+
val += shuffled;
225+
226+
// Stage 5: XOR with distance 1
227+
let (shuffled, _) = warp::warp_shuffle_xor(mask, val, 1, 32);
228+
val += shuffled;
229+
}
230+
231+
// Broadcast from lane 0 using shuffle_idx
232+
{
233+
let my_val = lane_id * 10;
234+
let (broadcast_val, is_valid) = warp::warp_shuffle_idx(mask, my_val, 0, 32);
235+
}
236+
237+
// Shift pattern using shuffle_down
238+
{
239+
let my_val = lane_id as f32;
240+
let (shifted_val, is_valid) = warp::warp_shuffle_down(mask, my_val, 1, 32);
241+
}
242+
243+
// Reverse shift using shuffle_up
244+
{
245+
let my_val = (31 - lane_id) as f32;
246+
let (shifted_val, is_valid) = warp::warp_shuffle_up(mask, my_val, 1, 32);
247+
}
248+
}

0 commit comments

Comments
 (0)