Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 129 additions & 7 deletions naga/src/proc/constant_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1730,17 +1730,106 @@ impl<'a> ConstantEvaluator<'a> {
self.packed_dot_product(arg, arg1.unwrap(), span, false)
}
crate::MathFunction::Cross => self.cross_product(arg, arg1.unwrap(), span),
crate::MathFunction::Dot => {
// https://www.w3.org/TR/WGSL/#dot-builtin
let e1 = self.extract_vec(arg, false)?;
let e2 = self.extract_vec(arg1.unwrap(), false)?;
if e1.len() != e2.len() {
return Err(ConstantEvaluatorError::InvalidMathArg);
}

fn int_dot<P>(a: &[P], b: &[P]) -> Result<P, ConstantEvaluatorError>
where
P: num_traits::PrimInt + num_traits::CheckedAdd + num_traits::CheckedMul,
{
a.iter()
.zip(b.iter())
.map(|(&aa, bb)| aa.checked_mul(bb))
.try_fold(P::zero(), |acc, x| {
if let Some(x) = x {
acc.checked_add(&x)
} else {
None
}
})
.ok_or(ConstantEvaluatorError::Overflow(
"in dot built-in".to_string(),
))
}

let result = match_literal_vector!(match (e1, e2) => Literal {
Float => |e1, e2| { e1.iter().zip(e2.iter()).map(|(&aa, &bb)| aa * bb).sum() },
Integer => |e1, e2| { int_dot(e1, e2)? },
})?;
self.register_evaluated_expr(Expression::Literal(result), span)
}
crate::MathFunction::Length => {
// https://www.w3.org/TR/WGSL/#length-builtin
let e1 = self.extract_vec(arg, true)?;

fn float_length<F>(e: &[F]) -> F
where
F: core::ops::Mul<F>,
F: num_traits::Float + iter::Sum,
{
e.iter().map(|&ei| ei * ei).sum::<F>().sqrt()
}

let result = match_literal_vector!(match e1 => Literal {
Float => |e1| { float_length(e1) },
})?;
self.register_evaluated_expr(Expression::Literal(result), span)
}
crate::MathFunction::Distance => {
// https://www.w3.org/TR/WGSL/#distance-builtin
let e1 = self.extract_vec(arg, true)?;
let e2 = self.extract_vec(arg1.unwrap(), true)?;
if e1.len() != e2.len() {
return Err(ConstantEvaluatorError::InvalidMathArg);
}

fn float_distance<F>(a: &[F], b: &[F]) -> F
where
F: core::ops::Mul<F>,
F: num_traits::Float + iter::Sum + core::ops::Sub,
{
a.iter()
.zip(b.iter())
.map(|(&aa, &bb)| aa - bb)
.map(|ei| ei * ei)
.sum::<F>()
.sqrt()
}
let result = match_literal_vector!(match (e1, e2) => Literal {
Float => |e1, e2| { float_distance(e1, e2) },
})?;
self.register_evaluated_expr(Expression::Literal(result), span)
}
crate::MathFunction::Normalize => {
// https://www.w3.org/TR/WGSL/#normalize-builtin
Comment on lines +1808 to +1809
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review TODO for myself: figure out what the "zero vector" domain exception listed in the standard is, and see if we need to validate it out.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Relevant spec parts https://www.w3.org/TR/WGSL/#normalize-builtin:

The domain is all vectors except the zero vector.

and

When evaluated outside its domain, the default exception handling rules of IEEE-754 require an implementation to generate an exception and yield a NaN value. In contrast, WGSL does not mandate floating point exceptions, and may instead yield an indeterminate value. See § 15.7.2 Differences from IEEE-754.

So we can do anything we want, current code will return NaN.

let e1 = self.extract_vec(arg, true)?;

fn float_normalize<F>(e: &[F]) -> ArrayVec<F, { crate::VectorSize::MAX }>
where
F: core::ops::Mul<F>,
F: num_traits::Float + iter::Sum,
{
let len = e.iter().map(|&ei| ei * ei).sum::<F>().sqrt();
e.iter().map(|&ei| ei / len).collect()
}

let result = match_literal_vector!(match e1 => LiteralVector {
Float => |e1| { float_normalize(e1) },
})?;
result.register_as_evaluated_expr(self, span)
}

// unimplemented
crate::MathFunction::Atan2
| crate::MathFunction::Modf
| crate::MathFunction::Frexp
| crate::MathFunction::Ldexp
| crate::MathFunction::Dot
| crate::MathFunction::Outer
| crate::MathFunction::Distance
| crate::MathFunction::Length
| crate::MathFunction::Normalize
| crate::MathFunction::FaceForward
| crate::MathFunction::Reflect
| crate::MathFunction::Refract
Expand Down Expand Up @@ -1816,8 +1905,8 @@ impl<'a> ConstantEvaluator<'a> {
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
use Literal as Li;

let (a, ty) = self.extract_vec::<3>(a)?;
let (b, _) = self.extract_vec::<3>(b)?;
let (a, ty) = self.extract_vec_with_size::<3>(a)?;
let (b, _) = self.extract_vec_with_size::<3>(b)?;
Comment on lines +1908 to +1909
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: IMO it'd have been better to have this refactor separated for the purpose of review, but 🤷🏻‍♂️ it's not too hard to follow.


let product = match (a, b) {
(
Expand Down Expand Up @@ -1884,7 +1973,7 @@ impl<'a> ConstantEvaluator<'a> {
/// values.
///
/// Also return the type handle from the `Compose` expression.
fn extract_vec<const N: usize>(
fn extract_vec_with_size<const N: usize>(
&mut self,
expr: Handle<Expression>,
) -> Result<([Literal; N], Handle<Type>), ConstantEvaluatorError> {
Expand All @@ -1908,6 +1997,39 @@ impl<'a> ConstantEvaluator<'a> {
Ok((value, ty))
}

/// Extract the values of a `vecN` from `expr`.
///
/// Return the value of `expr`, whose type is `vecN<S>` for some
/// vector size `N` and scalar `S`, as an array of `N` [`Literal`]
/// values.
///
/// Also return the type handle from the `Compose` expression.
fn extract_vec(
&mut self,
expr: Handle<Expression>,
allow_single: bool,
) -> Result<LiteralVector, ConstantEvaluatorError> {
let span = self.expressions.get_span(expr);
let expr = self.eval_zero_value_and_splat(expr, span)?;

match self.expressions[expr] {
Expression::Literal(literal) if allow_single => {
Ok(LiteralVector::from_literal(literal))
}
Expression::Compose { ty, ref components } => {
let components: ArrayVec<Literal, { crate::VectorSize::MAX }> =
crate::proc::flatten_compose(ty, components, self.expressions, self.types)
.map(|expr| match self.expressions[expr] {
Expression::Literal(l) => Ok(l),
_ => Err(ConstantEvaluatorError::InvalidMathArg),
})
.collect::<Result<_, ConstantEvaluatorError>>()?;
LiteralVector::from_literal_vec(components)
}
_ => Err(ConstantEvaluatorError::InvalidMathArg),
}
}

fn array_length(
&mut self,
array: Handle<Expression>,
Expand Down