Skip to content

Commit 4bf5d79

Browse files
committed
(Naga) Implement OpSpecConstantOp for the SPIR-V frontend
1 parent 7902357 commit 4bf5d79

12 files changed

+860
-0
lines changed

naga/src/front/spv/error.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ pub enum Error {
2626
UnknownCapability(spirv::Word),
2727
#[error("unsupported instruction {1:?} at {0:?}")]
2828
UnsupportedInstruction(ModuleState, spirv::Op),
29+
#[error("unsupported opcode in specialization constant operation {0:?}")]
30+
UnsupportedSpecConstantOp(spirv::Op),
31+
#[error("invalid opcode in specialization constant operation {0:?}")]
32+
InvalidSpecConstantOp(spirv::Op),
2933
#[error("unsupported capability {0:?}")]
3034
UnsupportedCapability(spirv::Capability),
3135
#[error("unsupported extension {0}")]

naga/src/front/spv/mod.rs

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4803,6 +4803,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
48034803
Op::ConstantFalse | Op::SpecConstantFalse => {
48044804
self.parse_bool_constant(inst, false, &mut module)
48054805
}
4806+
Op::SpecConstantOp => self.parse_spec_constant_op(inst, &mut module),
48064807
Op::Variable => self.parse_global_variable(inst, &mut module),
48074808
Op::Function => {
48084809
self.switch(ModuleState::Function, inst.op)?;
@@ -5897,6 +5898,276 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
58975898
self.insert_parsed_constant(module, id, type_id, ty, init, span)
58985899
}
58995900

5901+
fn parse_spec_constant_op(
5902+
&mut self,
5903+
inst: Instruction,
5904+
module: &mut crate::Module,
5905+
) -> Result<(), Error> {
5906+
use spirv::Op;
5907+
5908+
let start = self.data_offset;
5909+
self.switch(ModuleState::Type, inst.op)?;
5910+
inst.expect_at_least(4)?;
5911+
5912+
let result_type_id = self.next()?;
5913+
let result_id = self.next()?;
5914+
let opcode_word = self.next()?;
5915+
5916+
let type_lookup = self.lookup_type.lookup(result_type_id)?;
5917+
let ty = type_lookup.handle;
5918+
let span = self.span_from_with_op(start);
5919+
5920+
let opcode = Op::from_u32(opcode_word).ok_or(Error::UnsupportedInstruction(
5921+
self.state,
5922+
Op::SpecConstantOp,
5923+
))?;
5924+
5925+
let mut get_const_expr =
5926+
|frontend: &Self, const_id: spirv::Word| -> Result<Handle<crate::Expression>, Error> {
5927+
let lookup = frontend.lookup_constant.lookup(const_id)?;
5928+
Ok(module
5929+
.global_expressions
5930+
.append(lookup.inner.to_expr(), span))
5931+
};
5932+
5933+
let init = match opcode {
5934+
Op::SConvert | Op::UConvert | Op::FConvert => {
5935+
let value_id = self.next()?;
5936+
let value_expr = get_const_expr(self, value_id)?;
5937+
5938+
let scalar = match module.types[ty].inner {
5939+
crate::TypeInner::Scalar(scalar)
5940+
| crate::TypeInner::Vector { scalar, .. }
5941+
| crate::TypeInner::Matrix { scalar, .. } => scalar,
5942+
_ => return Err(Error::InvalidAsType(ty)),
5943+
};
5944+
5945+
module.global_expressions.append(
5946+
crate::Expression::As {
5947+
expr: value_expr,
5948+
kind: scalar.kind,
5949+
convert: Some(scalar.width),
5950+
},
5951+
span,
5952+
)
5953+
}
5954+
5955+
Op::SNegate | Op::Not | Op::LogicalNot => {
5956+
let value_id = self.next()?;
5957+
let value_expr = get_const_expr(self, value_id)?;
5958+
5959+
let op = match opcode {
5960+
Op::SNegate => crate::UnaryOperator::Negate,
5961+
Op::Not => crate::UnaryOperator::BitwiseNot,
5962+
Op::LogicalNot => crate::UnaryOperator::LogicalNot,
5963+
_ => unreachable!(),
5964+
};
5965+
5966+
module.global_expressions.append(
5967+
crate::Expression::Unary {
5968+
op,
5969+
expr: value_expr,
5970+
},
5971+
span,
5972+
)
5973+
}
5974+
5975+
Op::IAdd
5976+
| Op::ISub
5977+
| Op::IMul
5978+
| Op::UDiv
5979+
| Op::SDiv
5980+
| Op::SRem
5981+
| Op::UMod
5982+
| Op::BitwiseOr
5983+
| Op::BitwiseXor
5984+
| Op::BitwiseAnd
5985+
| Op::ShiftLeftLogical
5986+
| Op::ShiftRightLogical
5987+
| Op::ShiftRightArithmetic
5988+
| Op::LogicalOr
5989+
| Op::LogicalAnd
5990+
| Op::LogicalEqual
5991+
| Op::LogicalNotEqual
5992+
| Op::IEqual
5993+
| Op::INotEqual
5994+
| Op::ULessThan
5995+
| Op::SLessThan
5996+
| Op::UGreaterThan
5997+
| Op::SGreaterThan
5998+
| Op::ULessThanEqual
5999+
| Op::SLessThanEqual
6000+
| Op::UGreaterThanEqual
6001+
| Op::SGreaterThanEqual => {
6002+
let left_id = self.next()?;
6003+
let right_id = self.next()?;
6004+
let left_expr = get_const_expr(self, left_id)?;
6005+
let right_expr = get_const_expr(self, right_id)?;
6006+
6007+
let op = match opcode {
6008+
Op::IAdd => crate::BinaryOperator::Add,
6009+
Op::ISub => crate::BinaryOperator::Subtract,
6010+
Op::IMul => crate::BinaryOperator::Multiply,
6011+
Op::UDiv | Op::SDiv => crate::BinaryOperator::Divide,
6012+
Op::SRem | Op::UMod => crate::BinaryOperator::Modulo,
6013+
Op::BitwiseOr => crate::BinaryOperator::InclusiveOr,
6014+
Op::BitwiseXor => crate::BinaryOperator::ExclusiveOr,
6015+
Op::BitwiseAnd => crate::BinaryOperator::And,
6016+
Op::ShiftLeftLogical => crate::BinaryOperator::ShiftLeft,
6017+
Op::ShiftRightLogical | Op::ShiftRightArithmetic => {
6018+
crate::BinaryOperator::ShiftRight
6019+
}
6020+
Op::LogicalOr => crate::BinaryOperator::LogicalOr,
6021+
Op::LogicalAnd => crate::BinaryOperator::LogicalAnd,
6022+
Op::LogicalEqual => crate::BinaryOperator::Equal,
6023+
Op::LogicalNotEqual => crate::BinaryOperator::NotEqual,
6024+
Op::IEqual => crate::BinaryOperator::Equal,
6025+
Op::INotEqual => crate::BinaryOperator::NotEqual,
6026+
Op::ULessThan | Op::SLessThan => crate::BinaryOperator::Less,
6027+
Op::UGreaterThan | Op::SGreaterThan => crate::BinaryOperator::Greater,
6028+
Op::ULessThanEqual | Op::SLessThanEqual => crate::BinaryOperator::LessEqual,
6029+
Op::UGreaterThanEqual | Op::SGreaterThanEqual => {
6030+
crate::BinaryOperator::GreaterEqual
6031+
}
6032+
_ => unreachable!(),
6033+
};
6034+
6035+
module.global_expressions.append(
6036+
crate::Expression::Binary {
6037+
op,
6038+
left: left_expr,
6039+
right: right_expr,
6040+
},
6041+
span,
6042+
)
6043+
}
6044+
6045+
Op::SMod => {
6046+
// x - y * int(floor(float(x) / float(y)))
6047+
6048+
let left_id = self.next()?;
6049+
let right_id = self.next()?;
6050+
let left = get_const_expr(self, left_id)?;
6051+
let right = get_const_expr(self, right_id)?;
6052+
6053+
let scalar = match module.types[ty].inner {
6054+
crate::TypeInner::Scalar(scalar) => scalar,
6055+
crate::TypeInner::Vector { scalar, .. } => scalar,
6056+
_ => return Err(Error::InvalidAsType(ty)),
6057+
};
6058+
6059+
let left_cast = module.global_expressions.append(
6060+
crate::Expression::As {
6061+
expr: left,
6062+
kind: crate::ScalarKind::Float,
6063+
convert: Some(scalar.width),
6064+
},
6065+
span,
6066+
);
6067+
let right_cast = module.global_expressions.append(
6068+
crate::Expression::As {
6069+
expr: right,
6070+
kind: crate::ScalarKind::Float,
6071+
convert: Some(scalar.width),
6072+
},
6073+
span,
6074+
);
6075+
let div = module.global_expressions.append(
6076+
crate::Expression::Binary {
6077+
op: crate::BinaryOperator::Divide,
6078+
left: left_cast,
6079+
right: right_cast,
6080+
},
6081+
span,
6082+
);
6083+
let floor = module.global_expressions.append(
6084+
crate::Expression::Math {
6085+
fun: crate::MathFunction::Floor,
6086+
arg: div,
6087+
arg1: None,
6088+
arg2: None,
6089+
arg3: None,
6090+
},
6091+
span,
6092+
);
6093+
let cast = module.global_expressions.append(
6094+
crate::Expression::As {
6095+
expr: floor,
6096+
kind: scalar.kind,
6097+
convert: Some(scalar.width),
6098+
},
6099+
span,
6100+
);
6101+
let mult = module.global_expressions.append(
6102+
crate::Expression::Binary {
6103+
op: crate::BinaryOperator::Multiply,
6104+
left: cast,
6105+
right,
6106+
},
6107+
span,
6108+
);
6109+
module.global_expressions.append(
6110+
crate::Expression::Binary {
6111+
op: crate::BinaryOperator::Subtract,
6112+
left,
6113+
right: mult,
6114+
},
6115+
span,
6116+
)
6117+
}
6118+
6119+
Op::Select => {
6120+
let condition_id = self.next()?;
6121+
let o1_id = self.next()?;
6122+
let o2_id = self.next()?;
6123+
6124+
let cond = get_const_expr(self, condition_id)?;
6125+
let o1 = get_const_expr(self, o1_id)?;
6126+
let o2 = get_const_expr(self, o2_id)?;
6127+
6128+
module.global_expressions.append(
6129+
crate::Expression::Select {
6130+
condition: cond,
6131+
accept: o1,
6132+
reject: o2,
6133+
},
6134+
span,
6135+
)
6136+
}
6137+
6138+
Op::VectorShuffle
6139+
| Op::CompositeExtract
6140+
| Op::CompositeInsert
6141+
| Op::QuantizeToF16 => {
6142+
// Nothing stops us from implementing these cases in general.
6143+
// I just couldn't get them to work properly.
6144+
return Err(Error::UnsupportedSpecConstantOp(opcode))
6145+
}
6146+
6147+
_ => return Err(Error::InvalidSpecConstantOp(opcode)),
6148+
};
6149+
6150+
// IMPORTANT: Overrides must have either a name or an id to be processed correctly
6151+
// by process_overrides(). OpSpecConstantOp results don't have a SpecId (they're
6152+
// not user-overridable), so we assign them a name based on the result_id.
6153+
let op_override = crate::Override {
6154+
name: Some(format!("_spec_const_op_{result_id}")),
6155+
id: None,
6156+
ty,
6157+
init: Some(init),
6158+
};
6159+
6160+
self.lookup_constant.insert(
6161+
result_id,
6162+
LookupConstant {
6163+
inner: Constant::Override(module.overrides.append(op_override, span)),
6164+
type_id: result_type_id,
6165+
},
6166+
);
6167+
6168+
Ok(())
6169+
}
6170+
59006171
fn insert_parsed_constant(
59016172
&mut self,
59026173
module: &mut crate::Module,
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
; SPIR-V
2+
; Version: 1.4
3+
; Generator: Manual
4+
; Bound: 50
5+
; Schema: 0
6+
OpCapability Shader
7+
OpMemoryModel Logical GLSL450
8+
OpEntryPoint GLCompute %main "main" %output_buffer
9+
OpExecutionMode %main LocalSize 1 1 1
10+
11+
; Decorations for spec constants
12+
OpDecorate %spec_a SpecId 0
13+
OpDecorate %spec_b SpecId 1
14+
15+
; Decorations for storage buffer
16+
OpDecorate %OutputStruct Block
17+
OpMemberDecorate %OutputStruct 0 Offset 0
18+
OpMemberDecorate %OutputStruct 1 Offset 4
19+
OpMemberDecorate %OutputStruct 2 Offset 8
20+
OpDecorate %output_buffer DescriptorSet 0
21+
OpDecorate %output_buffer Binding 0
22+
23+
; Types
24+
%void = OpTypeVoid
25+
%3 = OpTypeFunction %void
26+
%int = OpTypeInt 32 1
27+
%uint = OpTypeInt 32 0
28+
29+
; Storage buffer types
30+
%OutputStruct = OpTypeStruct %uint %uint %uint
31+
%_ptr_StorageBuffer_OutputStruct = OpTypePointer StorageBuffer %OutputStruct
32+
%output_buffer = OpVariable %_ptr_StorageBuffer_OutputStruct StorageBuffer
33+
%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint
34+
35+
; Base spec constants
36+
%spec_a = OpSpecConstant %int 0
37+
%spec_b = OpSpecConstant %int 0
38+
39+
; Basic chaining
40+
%add_result = OpSpecConstantOp %int IAdd %spec_a %spec_b ; 10 + 3 = 13
41+
%mul_result = OpSpecConstantOp %int IMul %add_result %spec_b ; 13 * 3 = 39
42+
%sub_result = OpSpecConstantOp %int ISub %mul_result %add_result ; 39 - 13 = 26
43+
44+
; Deeper chaining
45+
%step1 = OpSpecConstantOp %int IAdd %spec_a %spec_b ; 10 + 3 = 13
46+
%step2 = OpSpecConstantOp %int IMul %step1 %spec_a ; 13 * 10 = 130
47+
%const_one = OpConstant %int 1
48+
%step3 = OpSpecConstantOp %int ISub %step2 %const_one ; 130 - 1 = 129
49+
%step4 = OpSpecConstantOp %int IMul %step3 %spec_b ; 129 * 3 = 387
50+
51+
; Constants for indices
52+
%idx_0 = OpConstant %uint 0
53+
%idx_1 = OpConstant %uint 1
54+
%idx_2 = OpConstant %uint 2
55+
56+
; Main function
57+
%main = OpFunction %void None %3
58+
%5 = OpLabel
59+
60+
; Store results
61+
%mul_uint = OpBitcast %uint %mul_result
62+
%ptr_0 = OpAccessChain %_ptr_StorageBuffer_uint %output_buffer %idx_0
63+
OpStore %ptr_0 %mul_uint
64+
65+
%sub_uint = OpBitcast %uint %sub_result
66+
%ptr_1 = OpAccessChain %_ptr_StorageBuffer_uint %output_buffer %idx_1
67+
OpStore %ptr_1 %sub_uint
68+
69+
%step4_uint = OpBitcast %uint %step4
70+
%ptr_2 = OpAccessChain %_ptr_StorageBuffer_uint %output_buffer %idx_2
71+
OpStore %ptr_2 %step4_uint
72+
73+
OpReturn
74+
OpFunctionEnd
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
pipeline_constants = { 0 = 10, 1 = 3 }
2+
targets = "HLSL"
3+
4+
[spv-in]
5+
adjust_coordinate_space = false
6+
7+
[spv]
8+
separate_entry_points = true

naga/tests/in/spv/spec-constant-op-stress.glsl

Whitespace-only changes.
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
pipeline_constants = { 0 = 10, 1 = 3, 2 = 20, 3 = 4, 4 = 1.0, 5 = 0.0, 6 = 10.5 }
2+
targets = "HLSL"
3+
4+
[spv-in]
5+
adjust_coordinate_space = false
6+
7+
[spv]
8+
separate_entry_points = true

0 commit comments

Comments
 (0)