Skip to content

Commit 7902357

Browse files
authored
Make include_spirv actually work statically (#8250)
1 parent 334170b commit 7902357

File tree

9 files changed

+125
-24
lines changed

9 files changed

+125
-24
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ By @wumpf in [#8282](https://github.com/gfx-rs/wgpu/pull/8282), [#8285](https://
241241
- Require new `F16_IN_F32` downlevel flag for `quantizeToF16`, `pack2x16float`, and `unpack2x16float` in WGSL input. By @aleiserson in [#8130](https://github.com/gfx-rs/wgpu/pull/8130).
242242
- The error message for non-copyable depth/stencil formats no longer mentions the aspect when it is not relevant. By @reima in [#8156](https://github.com/gfx-rs/wgpu/pull/8156).
243243
- Track the initialization status of buffer memory correctly when `copy_texture_to_buffer` skips over padding space between rows or layers, or when the start/end of a texture-buffer transfer is not 4B aligned. By @andyleiserson in [#8099](https://github.com/gfx-rs/wgpu/pull/8099).
244+
- Allow `include_spirv!` and `include_spirv_raw!` macros to be used in constants and statics. By @clarfonthey in [#8250](https://github.com/gfx-rs/wgpu/pull/8250).
244245

245246
#### naga
246247

wgpu-types/src/lib.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7835,7 +7835,7 @@ pub struct ShaderRuntimeChecks {
78357835
impl ShaderRuntimeChecks {
78367836
/// Creates a new configuration where the shader is fully checked.
78377837
#[must_use]
7838-
pub fn checked() -> Self {
7838+
pub const fn checked() -> Self {
78397839
unsafe { Self::all(true) }
78407840
}
78417841

@@ -7846,7 +7846,7 @@ impl ShaderRuntimeChecks {
78467846
/// See the documentation for the `set_*` methods for the safety requirements
78477847
/// of each sub-configuration.
78487848
#[must_use]
7849-
pub fn unchecked() -> Self {
7849+
pub const fn unchecked() -> Self {
78507850
unsafe { Self::all(false) }
78517851
}
78527852

@@ -7858,7 +7858,7 @@ impl ShaderRuntimeChecks {
78587858
/// See the documentation for the `set_*` methods for the safety requirements
78597859
/// of each sub-configuration.
78607860
#[must_use]
7861-
pub unsafe fn all(all_checks: bool) -> Self {
7861+
pub const unsafe fn all(all_checks: bool) -> Self {
78627862
Self {
78637863
bounds_checks: all_checks,
78647864
force_loop_bounding: all_checks,

wgpu/src/macros/be-aligned.spv

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#"3D

wgpu/src/macros/le-aligned.spv

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#D3"

wgpu/src/macros.rs renamed to wgpu/src/macros/mod.rs

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,35 @@ fn test_vertex_attr_array() {
7777
assert_eq!(attrs[1].shader_location, 3);
7878
}
7979

80+
#[macro_export]
81+
#[doc(hidden)]
82+
macro_rules! include_spirv_source {
83+
($($token:tt)*) => {
84+
{
85+
// FIXME(MSRV): when bumping to 1.89, use [u8; _] here
86+
const SPIRV_SOURCE: [
87+
u8;
88+
$crate::__macro_helpers::include_bytes!($($token)*).len()
89+
] = *$crate::__macro_helpers::include_bytes!($($token)*);
90+
const SPIRV_LEN: usize = SPIRV_SOURCE.len() / 4;
91+
const SPIRV_WORDS: [u32; SPIRV_LEN] = $crate::util::make_spirv_const(SPIRV_SOURCE);
92+
&SPIRV_WORDS
93+
}
94+
}
95+
}
96+
97+
#[test]
98+
fn make_spirv_le_pass() {
99+
static SPIRV: &[u32] = include_spirv_source!("le-aligned.spv");
100+
assert_eq!(SPIRV, &[0x07230203, 0x11223344]);
101+
}
102+
103+
#[test]
104+
fn make_spirv_be_pass() {
105+
static SPIRV: &[u32] = include_spirv_source!("be-aligned.spv");
106+
assert_eq!(SPIRV, &[0x07230203, 0x11223344]);
107+
}
108+
80109
/// Macro to load a SPIR-V module statically.
81110
///
82111
/// It ensures the word alignment as well as the magic number.
@@ -90,12 +119,18 @@ macro_rules! include_spirv {
90119
//log::info!("including '{}'", $($token)*);
91120
$crate::ShaderModuleDescriptor {
92121
label: Some($($token)*),
93-
source: $crate::util::make_spirv(include_bytes!($($token)*)),
122+
source: $crate::ShaderSource::SpirV(
123+
$crate::__macro_helpers::Cow::Borrowed($crate::include_spirv_source!($($token)*))
124+
),
94125
}
95126
}
96127
};
97128
}
98129

130+
#[cfg(all(feature = "spirv", test))]
131+
#[expect(dead_code)]
132+
static SPIRV: crate::ShaderModuleDescriptor<'_> = include_spirv!("le-aligned.spv");
133+
99134
/// Macro to load raw SPIR-V data statically, for use with [`Features::EXPERIMENTAL_PASSTHROUGH_SHADERS`].
100135
///
101136
/// It ensures the word alignment as well as the magic number.
@@ -108,13 +143,11 @@ macro_rules! include_spirv_raw {
108143
//log::info!("including '{}'", $($token)*);
109144
$crate::ShaderModuleDescriptorPassthrough {
110145
label: $crate::__macro_helpers::Some($($token)*),
111-
spirv: Some($crate::util::make_spirv_raw($crate::__macro_helpers::include_bytes!($($token)*))),
112-
113-
entry_point: "".to_owned(),
146+
spirv: Some($crate::__macro_helpers::Cow::Borrowed($crate::include_spirv_source!($($token)*))),
147+
entry_point: $crate::__macro_helpers::String::new(),
114148
// This is unused for SPIR-V
115149
num_workgroups: (0, 0, 0),
116-
reflection: None,
117-
shader_runtime_checks: $crate::ShaderRuntimeChecks::unchecked(),
150+
runtime_checks: $crate::ShaderRuntimeChecks::unchecked(),
118151
dxil: None,
119152
msl: None,
120153
hlsl: None,
@@ -125,6 +158,11 @@ macro_rules! include_spirv_raw {
125158
};
126159
}
127160

161+
#[cfg(test)]
162+
#[expect(dead_code)]
163+
static SPIRV_RAW: crate::ShaderModuleDescriptorPassthrough<'_> =
164+
include_spirv_raw!("le-aligned.spv");
165+
128166
/// Load WGSL source code from a file at compile time.
129167
///
130168
/// The loaded path is relative to the path of the file containing the macro call, in the same way
@@ -232,7 +270,7 @@ macro_rules! hal_type_gles {
232270

233271
#[doc(hidden)]
234272
pub mod helpers {
235-
pub use alloc::borrow::Cow;
273+
pub use alloc::{borrow::Cow, string::String};
236274
pub use core::{include_bytes, include_str};
237275
pub use Some;
238276
}

wgpu/src/util/be-unaligned.spv

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#"

wgpu/src/util/empty.spv

Whitespace-only changes.

wgpu/src/util/le-unaligned.spv

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#"

wgpu/src/util/mod.rs

Lines changed: 72 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ mod mutex;
1414
mod texture_blitter;
1515

1616
use alloc::{borrow::Cow, format, string::String, vec};
17-
use core::ptr::copy_nonoverlapping;
17+
use core::{mem, ptr::copy_nonoverlapping};
1818

1919
#[cfg(std)]
2020
pub use belt::StagingBelt;
@@ -46,18 +46,29 @@ pub fn make_spirv(data: &[u8]) -> super::ShaderSource<'_> {
4646
super::ShaderSource::SpirV(make_spirv_raw(data))
4747
}
4848

49+
const SPIRV_MAGIC_NUMBER: u32 = 0x0723_0203;
50+
51+
const fn check_spirv_len(data: &[u8]) {
52+
assert!(
53+
data.len() % size_of::<u32>() == 0,
54+
"SPIRV data size must be a multiple of 4."
55+
);
56+
assert!(!data.is_empty(), "SPIRV data must not be empty.");
57+
}
58+
59+
const fn verify_spirv_magic(words: &[u32]) {
60+
assert!(
61+
words[0] == SPIRV_MAGIC_NUMBER,
62+
"Wrong magic word in data. Make sure you are using a binary SPIRV file.",
63+
);
64+
}
65+
4966
/// Version of `make_spirv` intended for use with [`Device::create_shader_module_passthrough`].
5067
/// Returns a raw slice instead of [`ShaderSource`](super::ShaderSource).
5168
///
5269
/// [`Device::create_shader_module_passthrough`]: crate::Device::create_shader_module_passthrough
5370
pub fn make_spirv_raw(data: &[u8]) -> Cow<'_, [u32]> {
54-
const MAGIC_NUMBER: u32 = 0x0723_0203;
55-
assert_eq!(
56-
data.len() % size_of::<u32>(),
57-
0,
58-
"data size is not a multiple of 4"
59-
);
60-
assert_ne!(data.len(), 0, "data size must be larger than zero");
71+
check_spirv_len(data);
6172

6273
// If the data happens to be aligned, directly use the byte array,
6374
// otherwise copy the byte array in an owned vector and use that instead.
@@ -76,21 +87,68 @@ pub fn make_spirv_raw(data: &[u8]) -> Cow<'_, [u32]> {
7687

7788
// Before checking if the data starts with the magic, check if it starts
7889
// with the magic in non-native endianness, own & swap the data if so.
79-
if words[0] == MAGIC_NUMBER.swap_bytes() {
90+
if words[0] == SPIRV_MAGIC_NUMBER.swap_bytes() {
8091
for word in Cow::to_mut(&mut words) {
8192
*word = word.swap_bytes();
8293
}
8394
}
8495

85-
assert_eq!(
86-
words[0], MAGIC_NUMBER,
87-
"wrong magic word {:x}. Make sure you are using a binary SPIRV file.",
88-
words[0]
89-
);
96+
verify_spirv_magic(&words);
97+
98+
words
99+
}
100+
101+
/// Version of `make_spirv_raw` used for implementing [`include_spirv!`] and [`include_spirv_raw!`] macros.
102+
///
103+
/// Not public API. Also, don't even try calling at runtime; you'll get a stack overflow.
104+
///
105+
/// [`include_spirv!`]: crate::include_spirv
106+
#[doc(hidden)]
107+
pub const fn make_spirv_const<const IN: usize, const OUT: usize>(data: [u8; IN]) -> [u32; OUT] {
108+
#[repr(align(4))]
109+
struct Aligned<T: ?Sized>(T);
110+
111+
check_spirv_len(&data);
112+
113+
// NOTE: to get around lack of generic const expressions
114+
assert!(IN / 4 == OUT);
115+
116+
let aligned = Aligned(data);
117+
let mut words: [u32; OUT] = unsafe { mem::transmute_copy(&aligned) };
118+
119+
// Before checking if the data starts with the magic, check if it starts
120+
// with the magic in non-native endianness, own & swap the data if so.
121+
if words[0] == SPIRV_MAGIC_NUMBER.swap_bytes() {
122+
let mut idx = 0;
123+
while idx < words.len() {
124+
words[idx] = words[idx].swap_bytes();
125+
idx += 1;
126+
}
127+
}
128+
129+
verify_spirv_magic(&words);
90130

91131
words
92132
}
93133

134+
#[should_panic = "multiple of 4"]
135+
#[test]
136+
fn make_spirv_le_fail() {
137+
let _: [u32; 1] = make_spirv_const([0x03, 0x02, 0x23, 0x07, 0x44, 0x33]);
138+
}
139+
140+
#[should_panic = "multiple of 4"]
141+
#[test]
142+
fn make_spirv_be_fail() {
143+
let _: [u32; 1] = make_spirv_const([0x07, 0x23, 0x02, 0x03, 0x11, 0x22]);
144+
}
145+
146+
#[should_panic = "empty"]
147+
#[test]
148+
fn make_spirv_empty() {
149+
let _: [u32; 0] = make_spirv_const([]);
150+
}
151+
94152
/// CPU accessible buffer used to download data back from the GPU.
95153
pub struct DownloadBuffer {
96154
_gpu_buffer: super::Buffer,

0 commit comments

Comments
 (0)