Skip to content

Commit 0bd0bf0

Browse files
khyperiaeddyb
authored andcommitted
Fix issues with some raytracing functions and add tests
1 parent 2d5b8e6 commit 0bd0bf0

10 files changed

+146
-13
lines changed

crates/rustc_codegen_spirv/src/builder/spirv_asm.rs

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use super::Builder;
2-
use crate::builder_spirv::{BuilderCursor, SpirvValue};
2+
use crate::builder_spirv::{BuilderCursor, SpirvValue, SpirvValueExt};
33
use crate::codegen_cx::CodegenCx;
44
use crate::spirv_type::SpirvType;
55
use rspirv::dr;
@@ -304,8 +304,26 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
304304
}
305305
.def(self.span(), self),
306306
Op::TypeArray => {
307-
self.err("OpTypeArray in asm! is not supported yet");
308-
return;
307+
let count = inst.operands[1].unwrap_id_ref();
308+
let get_count_ty = || -> Option<Word> {
309+
let emit = self.emit();
310+
let func = &emit.module_ref().functions[emit.selected_function()?];
311+
let insts = &func.blocks[emit.selected_block()?].instructions;
312+
let inst = insts.iter().find(|i| i.result_id == Some(count))?;
313+
inst.result_type
314+
};
315+
let count_ty = match get_count_ty() {
316+
Some(ty) => ty,
317+
None => {
318+
self.err("Unable to find constant for OpTypeArray count");
319+
return;
320+
}
321+
};
322+
SpirvType::Array {
323+
element: inst.operands[0].unwrap_id_ref(),
324+
count: count.with_type(count_ty),
325+
}
326+
.def(self.span(), self)
309327
}
310328
Op::TypeRuntimeArray => SpirvType::RuntimeArray {
311329
element: inst.operands[0].unwrap_id_ref(),

crates/spirv-std/src/ray_tracing.rs

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -579,18 +579,16 @@ impl RayQuery {
579579
#[spirv_std_macros::gpu_only]
580580
#[doc(alias = "OpRayQueryGetWorldRayDirectionKHR")]
581581
#[inline]
582-
pub unsafe fn get_world_ray_direction<V: Vector<f32, 3>, const INTERSECTION: u32>(&self) -> V {
582+
pub unsafe fn get_world_ray_direction<V: Vector<f32, 3>>(&self) -> V {
583583
let mut result = Default::default();
584584

585585
asm! {
586586
"%u32 = OpTypeInt 32 0",
587587
"%f32 = OpTypeFloat 32",
588588
"%f32x3 = OpTypeVector %f32 3",
589-
"%intersection = OpConstant %u32 {intersection}",
590-
"%result = OpRayQueryGetWorldRayDirectionKHR %f32x3 {ray_query} %intersection",
589+
"%result = OpRayQueryGetWorldRayDirectionKHR %f32x3 {ray_query}",
591590
"OpStore {result} %result",
592591
ray_query = in(reg) self,
593-
intersection = const INTERSECTION,
594592
result = in(reg) &mut result,
595593
}
596594

@@ -601,18 +599,16 @@ impl RayQuery {
601599
#[spirv_std_macros::gpu_only]
602600
#[doc(alias = "OpRayQueryGetWorldRayOriginKHR")]
603601
#[inline]
604-
pub unsafe fn get_world_ray_origin<V: Vector<f32, 3>, const INTERSECTION: u32>(&self) -> V {
602+
pub unsafe fn get_world_ray_origin<V: Vector<f32, 3>>(&self) -> V {
605603
let mut result = Default::default();
606604

607605
asm! {
608606
"%u32 = OpTypeInt 32 0",
609607
"%f32 = OpTypeFloat 32",
610608
"%f32x3 = OpTypeVector %f32 3",
611-
"%intersection = OpConstant %u32 {intersection}",
612-
"%result = OpRayQueryGetWorldRayOriginKHR %f32x3 {ray_query} %intersection",
609+
"%result = OpRayQueryGetWorldRayOriginKHR %f32x3 {ray_query}",
613610
"OpStore {result} %result",
614611
ray_query = in(reg) self,
615-
intersection = const INTERSECTION,
616612
result = in(reg) &mut result,
617613
}
618614

@@ -626,15 +622,22 @@ impl RayQuery {
626622
#[inline]
627623
pub unsafe fn get_intersection_object_to_world<V: Vector<f32, 3>, const INTERSECTION: u32>(
628624
&self,
629-
) -> V {
625+
) -> [V; 4] {
630626
let mut result = Default::default();
631627

632628
asm! {
633629
"%u32 = OpTypeInt 32 0",
634630
"%f32 = OpTypeFloat 32",
631+
"%four = OpConstant %u32 4",
635632
"%f32x3 = OpTypeVector %f32 3",
633+
"%f32x3x4 = OpTypeMatrix %f32x3 4",
636634
"%intersection = OpConstant %u32 {intersection}",
637-
"%result = OpRayQueryGetWorldRayOriginKHR %f32x3 {ray_query} %intersection",
635+
"%matrix = OpRayQueryGetIntersectionObjectToWorldKHR %f32x3x4 {ray_query} %intersection",
636+
"%col0 = OpCompositeExtract %f32x3 %matrix 0",
637+
"%col1 = OpCompositeExtract %f32x3 %matrix 1",
638+
"%col2 = OpCompositeExtract %f32x3 %matrix 2",
639+
"%col3 = OpCompositeExtract %f32x3 %matrix 3",
640+
"%result = OpCompositeConstruct typeof*{result} %col0 %col1 %col2 %col3",
638641
"OpStore {result} %result",
639642
ray_query = in(reg) self,
640643
intersection = const INTERSECTION,
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// build-pass
2+
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query
3+
4+
use glam::Vec3;
5+
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
6+
7+
#[spirv(fragment)]
8+
pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStructure) {
9+
unsafe {
10+
spirv_std::ray_query!(let mut handle);
11+
handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0);
12+
handle.get_intersection_candidate_aabb_opaque();
13+
}
14+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// build-pass
2+
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query
3+
4+
use glam::Vec3;
5+
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
6+
7+
#[spirv(fragment)]
8+
pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStructure) {
9+
unsafe {
10+
spirv_std::ray_query!(let mut handle);
11+
handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0);
12+
let direction: glam::Vec3 = handle.get_intersection_object_ray_direction::<_, 5>();
13+
}
14+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// build-pass
2+
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query
3+
4+
use glam::Vec3;
5+
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
6+
7+
#[spirv(fragment)]
8+
pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStructure) {
9+
unsafe {
10+
spirv_std::ray_query!(let mut handle);
11+
handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0);
12+
let origin: glam::Vec3 = handle.get_intersection_object_ray_origin::<_, 5>();
13+
}
14+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// build-pass
2+
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query
3+
4+
use glam::Vec3;
5+
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
6+
7+
#[spirv(fragment)]
8+
pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStructure) {
9+
unsafe {
10+
spirv_std::ray_query!(let mut handle);
11+
handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0);
12+
let matrix: [glam::Vec3; 4] = handle.get_intersection_object_to_world::<_, 5>();
13+
}
14+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// build-pass
2+
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query
3+
4+
use glam::Vec3;
5+
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
6+
7+
#[spirv(fragment)]
8+
pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStructure) {
9+
unsafe {
10+
spirv_std::ray_query!(let mut handle);
11+
handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0);
12+
handle.get_intersection_primitive_index::<5>();
13+
}
14+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// build-pass
2+
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query
3+
4+
use glam::Vec3;
5+
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
6+
7+
#[spirv(fragment)]
8+
pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStructure) {
9+
unsafe {
10+
spirv_std::ray_query!(let mut handle);
11+
handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0);
12+
let flags = handle.get_ray_flags();
13+
}
14+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// build-pass
2+
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query
3+
4+
use glam::Vec3;
5+
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
6+
7+
#[spirv(fragment)]
8+
pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStructure) {
9+
unsafe {
10+
spirv_std::ray_query!(let mut handle);
11+
handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0);
12+
let direction: glam::Vec3 = handle.get_world_ray_direction();
13+
}
14+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// build-pass
2+
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query
3+
4+
use glam::Vec3;
5+
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
6+
7+
#[spirv(fragment)]
8+
pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStructure) {
9+
unsafe {
10+
spirv_std::ray_query!(let mut handle);
11+
handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0);
12+
let origin: glam::Vec3 = handle.get_world_ray_origin();
13+
}
14+
}

0 commit comments

Comments
 (0)