Skip to content

Commit d625d08

Browse files
authored
[naga wgsl-in] Implement any() and all() during const evaluation (#7166)
1 parent c03176f commit d625d08

File tree

7 files changed

+180
-52
lines changed

7 files changed

+180
-52
lines changed

naga/src/proc/constant_evaluator.rs

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ use arrayvec::ArrayVec;
44

55
use crate::{
66
arena::{Arena, Handle, HandleVec, UniqueArena},
7-
ArraySize, BinaryOperator, Constant, Expression, Literal, Override, ScalarKind, Span, Type,
8-
TypeInner, UnaryOperator,
7+
ArraySize, BinaryOperator, Constant, Expression, Literal, Override, RelationalFunction,
8+
ScalarKind, Span, Type, TypeInner, UnaryOperator,
99
};
1010

1111
/// A macro that allows dollar signs (`$`) to be emitted by other macros. Useful for generating
@@ -547,6 +547,8 @@ pub enum ConstantEvaluatorError {
547547
InvalidMathArg,
548548
#[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
549549
InvalidMathArgCount(crate::MathFunction, usize, usize),
550+
#[error("Cannot apply relational function to type")]
551+
InvalidRelationalArg(RelationalFunction),
550552
#[error("value of `low` is greater than `high` for clamp built-in function")]
551553
InvalidClamp,
552554
#[error("Splat is defined only on scalar values")]
@@ -931,9 +933,10 @@ impl<'a> ConstantEvaluator<'a> {
931933
Expression::Select { .. } => Err(ConstantEvaluatorError::NotImplemented(
932934
"select built-in function".into(),
933935
)),
934-
Expression::Relational { fun, .. } => Err(ConstantEvaluatorError::NotImplemented(
935-
format!("{fun:?} built-in function"),
936-
)),
936+
Expression::Relational { fun, argument } => {
937+
let argument = self.check_and_get(argument)?;
938+
self.relational(fun, argument, span)
939+
}
937940
Expression::ArrayLength(expr) => match self.behavior {
938941
Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
939942
Behavior::Glsl(_) => {
@@ -2103,6 +2106,41 @@ impl<'a> ConstantEvaluator<'a> {
21032106
Ok(Expression::Compose { ty, components })
21042107
}
21052108

2109+
fn relational(
2110+
&mut self,
2111+
fun: RelationalFunction,
2112+
arg: Handle<Expression>,
2113+
span: Span,
2114+
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2115+
let arg = self.eval_zero_value_and_splat(arg, span)?;
2116+
match fun {
2117+
RelationalFunction::All | RelationalFunction::Any => match self.expressions[arg] {
2118+
Expression::Literal(Literal::Bool(_)) => Ok(arg),
2119+
Expression::Compose { ty, ref components }
2120+
if matches!(self.types[ty].inner, TypeInner::Vector { .. }) =>
2121+
{
2122+
let components =
2123+
crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2124+
.map(|component| match self.expressions[component] {
2125+
Expression::Literal(Literal::Bool(val)) => Ok(val),
2126+
_ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
2127+
})
2128+
.collect::<Result<ArrayVec<bool, { crate::VectorSize::MAX }>, _>>()?;
2129+
let result = match fun {
2130+
RelationalFunction::All => components.iter().all(|c| *c),
2131+
RelationalFunction::Any => components.iter().any(|c| *c),
2132+
_ => unreachable!(),
2133+
};
2134+
self.register_evaluated_expr(Expression::Literal(Literal::Bool(result)), span)
2135+
}
2136+
_ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
2137+
},
2138+
_ => Err(ConstantEvaluatorError::NotImplemented(format!(
2139+
"{fun:?} built-in function"
2140+
))),
2141+
}
2142+
}
2143+
21062144
/// Deep copy `expr` from `expressions` into `self.expressions`.
21072145
///
21082146
/// Return the root of the new copy.

naga/tests/in/const-exprs.wgsl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
const TWO: u32 = 2u;
22
const THREE: i32 = 3i;
3+
const TRUE = true;
4+
const FALSE = false;
35

46
@compute @workgroup_size(TWO, THREE, TWO - 1u)
57
fn main() {
@@ -94,3 +96,16 @@ fn compose_vector_zero_val_binop() {
9496
var b = vec3(vec2i(), 0) + vec3(0, 1, 2);
9597
var c = vec3(vec2i(), 2) + vec3(1, vec2i());
9698
}
99+
100+
fn relational() {
101+
// Test scalar and vector forms of any() and all(), with a mixture of
102+
// consts, literals, zero-values, composes, and splats.
103+
var scalar_any_false = any(false);
104+
var scalar_any_true = any(true);
105+
var scalar_all_false = all(false);
106+
var scalar_all_true = all(true);
107+
var vec_any_false = any(vec4<bool>());
108+
var vec_any_true = any(vec4(bool(), true, vec2(FALSE)));
109+
var vec_all_false = all(vec4(vec3(vec2<bool>(), TRUE), false));
110+
var vec_all_true = all(vec4(true));
111+
}

naga/tests/out/glsl/const-exprs.main.Compute.glsl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ layout(local_size_x = 2, local_size_y = 3, local_size_z = 1) in;
77

88
const uint TWO = 2u;
99
const int THREE = 3;
10+
const bool TRUE = true;
11+
const bool FALSE = false;
1012
const int FOUR = 4;
1113
const int FOUR_ALIAS = 4;
1214
const int TEST_CONSTANT_ADDITION = 8;
@@ -93,6 +95,18 @@ void compose_vector_zero_val_binop() {
9395
return;
9496
}
9597

98+
void relational() {
99+
bool scalar_any_false = false;
100+
bool scalar_any_true = true;
101+
bool scalar_all_false = false;
102+
bool scalar_all_true = true;
103+
bool vec_any_false = false;
104+
bool vec_any_true = true;
105+
bool vec_all_false = false;
106+
bool vec_all_true = true;
107+
return;
108+
}
109+
96110
void main() {
97111
swizzle_of_compose();
98112
index_of_compose();

naga/tests/out/hlsl/const-exprs.hlsl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
static const uint TWO = 2u;
22
static const int THREE = int(3);
3+
static const bool TRUE = true;
4+
static const bool FALSE = false;
35
static const int FOUR = int(4);
46
static const int FOUR_ALIAS = int(4);
57
static const int TEST_CONSTANT_ADDITION = int(8);
@@ -102,6 +104,20 @@ void compose_vector_zero_val_binop()
102104
return;
103105
}
104106

107+
void relational()
108+
{
109+
bool scalar_any_false = false;
110+
bool scalar_any_true = true;
111+
bool scalar_all_false = false;
112+
bool scalar_all_true = true;
113+
bool vec_any_false = false;
114+
bool vec_any_true = true;
115+
bool vec_all_false = false;
116+
bool vec_all_true = true;
117+
118+
return;
119+
}
120+
105121
[numthreads(2, 3, 1)]
106122
void main()
107123
{

naga/tests/out/msl/const-exprs.msl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ using metal::uint;
66

77
constant uint TWO = 2u;
88
constant int THREE = 3;
9+
constant bool TRUE = true;
10+
constant bool FALSE = false;
911
constant int FOUR = 4;
1012
constant int FOUR_ALIAS = 4;
1113
constant int TEST_CONSTANT_ADDITION = 8;
@@ -101,6 +103,19 @@ void compose_vector_zero_val_binop(
101103
return;
102104
}
103105

106+
void relational(
107+
) {
108+
bool scalar_any_false = false;
109+
bool scalar_any_true = true;
110+
bool scalar_all_false = false;
111+
bool scalar_all_true = true;
112+
bool vec_any_false = false;
113+
bool vec_any_true = true;
114+
bool vec_all_false = false;
115+
bool vec_all_true = true;
116+
return;
117+
}
118+
104119
kernel void main_(
105120
) {
106121
swizzle_of_compose();

naga/tests/out/spv/const-exprs.spvasm

Lines changed: 62 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,67 @@
11
; SPIR-V
22
; Version: 1.1
33
; Generator: rspirv
4-
; Bound: 120
4+
; Bound: 132
55
OpCapability Shader
66
%1 = OpExtInstImport "GLSL.std.450"
77
OpMemoryModel Logical GLSL450
8-
OpEntryPoint GLCompute %111 "main"
9-
OpExecutionMode %111 LocalSize 2 3 1
8+
OpEntryPoint GLCompute %123 "main"
9+
OpExecutionMode %123 LocalSize 2 3 1
1010
%2 = OpTypeVoid
1111
%3 = OpTypeInt 32 0
1212
%4 = OpTypeInt 32 1
13-
%5 = OpTypeVector %4 4
14-
%6 = OpTypeFloat 32
15-
%7 = OpTypeVector %6 4
16-
%8 = OpTypeVector %6 2
17-
%10 = OpTypeBool
18-
%9 = OpTypeVector %10 2
13+
%5 = OpTypeBool
14+
%6 = OpTypeVector %4 4
15+
%7 = OpTypeFloat 32
16+
%8 = OpTypeVector %7 4
17+
%9 = OpTypeVector %7 2
18+
%10 = OpTypeVector %5 2
1919
%11 = OpTypeVector %4 3
2020
%12 = OpConstant %3 2
2121
%13 = OpConstant %4 3
22-
%14 = OpConstant %4 4
23-
%15 = OpConstant %4 8
24-
%16 = OpConstant %6 3.141
25-
%17 = OpConstant %6 6.282
26-
%18 = OpConstant %6 0.44444445
27-
%19 = OpConstant %6 0.0
28-
%20 = OpConstantComposite %7 %18 %19 %19 %19
29-
%21 = OpConstant %4 0
30-
%22 = OpConstant %4 1
31-
%23 = OpConstant %4 2
32-
%24 = OpConstant %6 4.0
33-
%25 = OpConstant %6 5.0
34-
%26 = OpConstantComposite %8 %24 %25
35-
%27 = OpConstantTrue %10
36-
%28 = OpConstantFalse %10
37-
%29 = OpConstantComposite %9 %27 %28
22+
%14 = OpConstantTrue %5
23+
%15 = OpConstantFalse %5
24+
%16 = OpConstant %4 4
25+
%17 = OpConstant %4 8
26+
%18 = OpConstant %7 3.141
27+
%19 = OpConstant %7 6.282
28+
%20 = OpConstant %7 0.44444445
29+
%21 = OpConstant %7 0.0
30+
%22 = OpConstantComposite %8 %20 %21 %21 %21
31+
%23 = OpConstant %4 0
32+
%24 = OpConstant %4 1
33+
%25 = OpConstant %4 2
34+
%26 = OpConstant %7 4.0
35+
%27 = OpConstant %7 5.0
36+
%28 = OpConstantComposite %9 %26 %27
37+
%29 = OpConstantComposite %10 %14 %15
3838
%32 = OpTypeFunction %2
39-
%33 = OpConstantComposite %5 %14 %13 %23 %22
40-
%35 = OpTypePointer Function %5
39+
%33 = OpConstantComposite %6 %16 %13 %25 %24
40+
%35 = OpTypePointer Function %6
4141
%40 = OpTypePointer Function %4
4242
%44 = OpConstant %4 6
4343
%49 = OpConstant %4 30
4444
%50 = OpConstant %4 70
4545
%53 = OpConstantNull %4
4646
%55 = OpConstantNull %4
47-
%58 = OpConstantNull %5
47+
%58 = OpConstantNull %6
4848
%69 = OpConstant %4 -4
49-
%70 = OpConstantComposite %5 %69 %69 %69 %69
50-
%79 = OpConstant %6 1.0
51-
%80 = OpConstant %6 2.0
52-
%81 = OpConstantComposite %7 %80 %79 %79 %79
53-
%83 = OpTypePointer Function %7
49+
%70 = OpConstantComposite %6 %69 %69 %69 %69
50+
%79 = OpConstant %7 1.0
51+
%80 = OpConstant %7 2.0
52+
%81 = OpConstantComposite %8 %80 %79 %79 %79
53+
%83 = OpTypePointer Function %8
5454
%88 = OpTypeFunction %3 %4
5555
%89 = OpConstant %3 10
5656
%90 = OpConstant %3 20
5757
%91 = OpConstant %3 30
5858
%92 = OpConstant %3 0
5959
%99 = OpConstantNull %3
60-
%102 = OpConstantComposite %11 %22 %22 %22
61-
%103 = OpConstantComposite %11 %21 %22 %23
62-
%104 = OpConstantComposite %11 %22 %21 %23
60+
%102 = OpConstantComposite %11 %24 %24 %24
61+
%103 = OpConstantComposite %11 %23 %24 %25
62+
%104 = OpConstantComposite %11 %24 %23 %25
6363
%106 = OpTypePointer Function %11
64+
%113 = OpTypePointer Function %5
6465
%31 = OpFunction %2 None %32
6566
%30 = OpLabel
6667
%34 = OpVariable %35 Function %33
@@ -70,7 +71,7 @@ OpReturn
7071
OpFunctionEnd
7172
%38 = OpFunction %2 None %32
7273
%37 = OpLabel
73-
%39 = OpVariable %40 Function %23
74+
%39 = OpVariable %40 Function %25
7475
OpBranch %41
7576
%41 = OpLabel
7677
OpReturn
@@ -99,7 +100,7 @@ OpStore %54 %61
99100
%63 = OpLoad %4 %52
100101
%64 = OpLoad %4 %54
101102
%65 = OpLoad %4 %56
102-
%66 = OpCompositeConstruct %5 %62 %63 %64 %65
103+
%66 = OpCompositeConstruct %6 %62 %63 %64 %65
103104
OpStore %57 %66
104105
OpReturn
105106
OpFunctionEnd
@@ -153,14 +154,28 @@ OpReturn
153154
OpFunctionEnd
154155
%111 = OpFunction %2 None %32
155156
%110 = OpLabel
156-
OpBranch %112
157-
%112 = OpLabel
158-
%113 = OpFunctionCall %2 %31
159-
%114 = OpFunctionCall %2 %38
160-
%115 = OpFunctionCall %2 %43
161-
%116 = OpFunctionCall %2 %48
162-
%117 = OpFunctionCall %2 %68
163-
%118 = OpFunctionCall %2 %74
164-
%119 = OpFunctionCall %2 %78
157+
%119 = OpVariable %113 Function %15
158+
%116 = OpVariable %113 Function %14
159+
%112 = OpVariable %113 Function %15
160+
%120 = OpVariable %113 Function %14
161+
%117 = OpVariable %113 Function %15
162+
%114 = OpVariable %113 Function %14
163+
%118 = OpVariable %113 Function %14
164+
%115 = OpVariable %113 Function %15
165+
OpBranch %121
166+
%121 = OpLabel
167+
OpReturn
168+
OpFunctionEnd
169+
%123 = OpFunction %2 None %32
170+
%122 = OpLabel
171+
OpBranch %124
172+
%124 = OpLabel
173+
%125 = OpFunctionCall %2 %31
174+
%126 = OpFunctionCall %2 %38
175+
%127 = OpFunctionCall %2 %43
176+
%128 = OpFunctionCall %2 %48
177+
%129 = OpFunctionCall %2 %68
178+
%130 = OpFunctionCall %2 %74
179+
%131 = OpFunctionCall %2 %78
165180
OpReturn
166181
OpFunctionEnd

naga/tests/out/wgsl/const-exprs.wgsl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
const TWO: u32 = 2u;
22
const THREE: i32 = 3i;
3+
const TRUE: bool = true;
4+
const FALSE: bool = false;
35
const FOUR: i32 = 4i;
46
const FOUR_ALIAS: i32 = 4i;
57
const TEST_CONSTANT_ADDITION: i32 = 8i;
@@ -93,6 +95,19 @@ fn compose_vector_zero_val_binop() {
9395
return;
9496
}
9597

98+
fn relational() {
99+
var scalar_any_false: bool = false;
100+
var scalar_any_true: bool = true;
101+
var scalar_all_false: bool = false;
102+
var scalar_all_true: bool = true;
103+
var vec_any_false: bool = false;
104+
var vec_any_true: bool = true;
105+
var vec_all_false: bool = false;
106+
var vec_all_true: bool = true;
107+
108+
return;
109+
}
110+
96111
@compute @workgroup_size(2, 3, 1)
97112
fn main() {
98113
swizzle_of_compose();

0 commit comments

Comments
 (0)