Skip to content

Commit 480d4db

Browse files
authored
spv-in parsing Op::AtomicIIncrement (#5702)
Parse spirv::Op::AtomicIIncrement, add atomic_i_increment test.
1 parent 60a14c6 commit 480d4db

File tree

4 files changed

+150
-11
lines changed

4 files changed

+150
-11
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ By @stefnotch in [#5410](https://github.com/gfx-rs/wgpu/pull/5410)
8989
#### Naga
9090

9191
- Implement `WGSL`'s `unpack4xI8`,`unpack4xU8`,`pack4xI8` and `pack4xU8`. By @VlaDexa in [#5424](https://github.com/gfx-rs/wgpu/pull/5424)
92+
- Began work adding support for atomics to the SPIR-V frontend. Tracking issue is [here](https://github.com/gfx-rs/wgpu/issues/4489). By @schell in [#5702](https://github.com/gfx-rs/wgpu/pull/5702).
9293

9394
### Changes
9495

naga/src/front/spv/mod.rs

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,20 @@ enum SignAnchor {
564564
Operand,
565565
}
566566

567+
enum AtomicOpInst {
568+
AtomicIIncrement,
569+
}
570+
571+
#[allow(dead_code)]
572+
struct AtomicOp {
573+
instruction: AtomicOpInst,
574+
result_type_id: spirv::Word,
575+
result_id: spirv::Word,
576+
pointer_id: spirv::Word,
577+
scope_id: spirv::Word,
578+
memory_semantics_id: spirv::Word,
579+
}
580+
567581
pub struct Frontend<I> {
568582
data: I,
569583
data_offset: usize,
@@ -575,6 +589,8 @@ pub struct Frontend<I> {
575589
future_member_decor: FastHashMap<(spirv::Word, MemberIndex), Decoration>,
576590
lookup_member: FastHashMap<(Handle<crate::Type>, MemberIndex), LookupMember>,
577591
handle_sampling: FastHashMap<Handle<crate::GlobalVariable>, image::SamplingFlags>,
592+
// Used to upgrade types used in atomic ops to atomic types, keyed by pointer id
593+
lookup_atomic: FastHashMap<spirv::Word, AtomicOp>,
578594
lookup_type: FastHashMap<spirv::Word, LookupType>,
579595
lookup_void_type: Option<spirv::Word>,
580596
lookup_storage_buffer_types: FastHashMap<Handle<crate::Type>, crate::StorageAccess>,
@@ -630,6 +646,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
630646
future_member_decor: FastHashMap::default(),
631647
handle_sampling: FastHashMap::default(),
632648
lookup_member: FastHashMap::default(),
649+
lookup_atomic: FastHashMap::default(),
633650
lookup_type: FastHashMap::default(),
634651
lookup_void_type: None,
635652
lookup_storage_buffer_types: FastHashMap::default(),
@@ -3943,7 +3960,81 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
39433960
);
39443961
emitter.start(ctx.expressions);
39453962
}
3946-
_ => return Err(Error::UnsupportedInstruction(self.state, inst.op)),
3963+
Op::AtomicIIncrement => {
3964+
inst.expect(6)?;
3965+
let start = self.data_offset;
3966+
let span = self.span_from_with_op(start);
3967+
let result_type_id = self.next()?;
3968+
let result_id = self.next()?;
3969+
let pointer_id = self.next()?;
3970+
let scope_id = self.next()?;
3971+
let memory_semantics_id = self.next()?;
3972+
// Store the op for a later pass where we "upgrade" the pointer type
3973+
let atomic = AtomicOp {
3974+
instruction: AtomicOpInst::AtomicIIncrement,
3975+
result_type_id,
3976+
result_id,
3977+
pointer_id,
3978+
scope_id,
3979+
memory_semantics_id,
3980+
};
3981+
self.lookup_atomic.insert(pointer_id, atomic);
3982+
3983+
log::trace!("\t\t\tlooking up expr {:?}", pointer_id);
3984+
3985+
let (p_lexp_handle, p_lexp_ty_id) = {
3986+
let lexp = self.lookup_expression.lookup(pointer_id)?;
3987+
let handle = get_expr_handle!(pointer_id, &lexp);
3988+
(handle, lexp.type_id)
3989+
};
3990+
log::trace!("\t\t\tlooking up type {pointer_id:?}");
3991+
let p_ty = self.lookup_type.lookup(p_lexp_ty_id)?;
3992+
let p_ty_base_id =
3993+
p_ty.base_id.ok_or(Error::InvalidAccessType(p_lexp_ty_id))?;
3994+
log::trace!("\t\t\tlooking up base type {p_ty_base_id:?} of {p_ty:?}");
3995+
let p_base_ty = self.lookup_type.lookup(p_ty_base_id)?;
3996+
3997+
// Create an expression for our result
3998+
let r_lexp_handle = {
3999+
let expr = crate::Expression::AtomicResult {
4000+
ty: p_base_ty.handle,
4001+
comparison: false,
4002+
};
4003+
let handle = ctx.expressions.append(expr, span);
4004+
self.lookup_expression.insert(
4005+
result_id,
4006+
LookupExpression {
4007+
handle,
4008+
type_id: result_type_id,
4009+
block_id,
4010+
},
4011+
);
4012+
handle
4013+
};
4014+
4015+
// Create a literal "1" since WGSL lacks an increment operation
4016+
let one_lexp_handle = make_index_literal(
4017+
ctx,
4018+
1,
4019+
&mut block,
4020+
&mut emitter,
4021+
p_base_ty.handle,
4022+
p_lexp_ty_id,
4023+
span,
4024+
)?;
4025+
4026+
// Create a statement for the op itself
4027+
let stmt = crate::Statement::Atomic {
4028+
pointer: p_lexp_handle,
4029+
fun: crate::AtomicFunction::Add,
4030+
value: one_lexp_handle,
4031+
result: r_lexp_handle,
4032+
};
4033+
block.push(stmt, span);
4034+
}
4035+
_ => {
4036+
return Err(Error::UnsupportedInstruction(self.state, inst.op));
4037+
}
39474038
}
39484039
};
39494040

@@ -5593,4 +5684,38 @@ mod test {
55935684
];
55945685
let _ = super::parse_u8_slice(&bin, &Default::default()).unwrap();
55955686
}
5687+
5688+
#[test]
5689+
fn atomic_i_inc() {
5690+
let _ = env_logger::builder()
5691+
.is_test(true)
5692+
.filter_level(log::LevelFilter::Trace)
5693+
.try_init();
5694+
let bytes = include_bytes!("../../../tests/in/spv/atomic_i_increment.spv");
5695+
let m = super::parse_u8_slice(bytes, &Default::default()).unwrap();
5696+
let mut validator = crate::valid::Validator::new(
5697+
crate::valid::ValidationFlags::empty(),
5698+
Default::default(),
5699+
);
5700+
let info = validator.validate(&m).unwrap();
5701+
let wgsl =
5702+
crate::back::wgsl::write_string(&m, &info, crate::back::wgsl::WriterFlags::empty())
5703+
.unwrap();
5704+
log::info!("atomic_i_increment:\n{wgsl}");
5705+
5706+
let m = match crate::front::wgsl::parse_str(&wgsl) {
5707+
Ok(m) => m,
5708+
Err(e) => {
5709+
log::error!("{}", e.emit_to_string(&wgsl));
5710+
// at this point we know atomics create invalid modules
5711+
// so simply bail
5712+
return;
5713+
}
5714+
};
5715+
let mut validator =
5716+
crate::valid::Validator::new(crate::valid::ValidationFlags::all(), Default::default());
5717+
if let Err(e) = validator.validate(&m) {
5718+
log::error!("{}", e.emit_to_string(&wgsl));
5719+
}
5720+
}
55965721
}
392 Bytes
Binary file not shown.

naga/tests/snapshots.rs

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@ const BASE_DIR_OUT: &str = "tests/out";
1414
bitflags::bitflags! {
1515
#[derive(Clone, Copy)]
1616
struct Targets: u32 {
17-
const IR = 0x1;
18-
const ANALYSIS = 0x2;
19-
const SPIRV = 0x4;
20-
const METAL = 0x8;
21-
const GLSL = 0x10;
22-
const DOT = 0x20;
23-
const HLSL = 0x40;
24-
const WGSL = 0x80;
17+
const IR = 1;
18+
const ANALYSIS = 1 << 1;
19+
const SPIRV = 1 << 2;
20+
const METAL = 1 << 3;
21+
const GLSL = 1 << 4;
22+
const DOT = 1 << 5;
23+
const HLSL = 1 << 6;
24+
const WGSL = 1 << 7;
25+
const NO_VALIDATION = 1 << 8;
2526
}
2627
}
2728

@@ -292,7 +293,13 @@ fn check_targets(
292293
}
293294
}
294295

295-
let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all(), capabilities)
296+
let validation_flags = if targets.contains(Targets::NO_VALIDATION) {
297+
naga::valid::ValidationFlags::empty()
298+
} else {
299+
naga::valid::ValidationFlags::all()
300+
};
301+
302+
let info = naga::valid::Validator::new(validation_flags, capabilities)
296303
.subgroup_stages(subgroup_stages)
297304
.subgroup_operations(subgroup_operations)
298305
.validate(module)
@@ -317,7 +324,7 @@ fn check_targets(
317324
}
318325
}
319326

320-
naga::valid::Validator::new(naga::valid::ValidationFlags::all(), capabilities)
327+
naga::valid::Validator::new(validation_flags, capabilities)
321328
.subgroup_stages(subgroup_stages)
322329
.subgroup_operations(subgroup_operations)
323330
.validate(module)
@@ -979,6 +986,12 @@ fn convert_spv_all() {
979986
false,
980987
Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
981988
);
989+
convert_spv(
990+
"atomic_i_increment",
991+
false,
992+
// TODO(@schell): remove Targets::NO_VALIDATION when OpAtomicIIncrement lands
993+
Targets::IR | Targets::NO_VALIDATION,
994+
);
982995
}
983996

984997
#[cfg(feature = "glsl-in")]

0 commit comments

Comments
 (0)