Skip to content

Commit b9bad54

Browse files
committed
[naga] Implement builtins dot, length, distance, normalize in const using LiteralVector
Signed-off-by: sagudev <[email protected]>
1 parent 2368eef commit b9bad54

File tree

1 file changed

+129
-7
lines changed

1 file changed

+129
-7
lines changed

naga/src/proc/constant_evaluator.rs

Lines changed: 129 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1725,17 +1725,106 @@ impl<'a> ConstantEvaluator<'a> {
17251725
self.packed_dot_product(arg, arg1.unwrap(), span, false)
17261726
}
17271727
crate::MathFunction::Cross => self.cross_product(arg, arg1.unwrap(), span),
1728+
crate::MathFunction::Dot => {
1729+
// https://www.w3.org/TR/WGSL/#dot-builtin
1730+
let e1 = self.extract_vec(arg, false)?;
1731+
let e2 = self.extract_vec(arg1.unwrap(), false)?;
1732+
if e1.len() != e2.len() {
1733+
return Err(ConstantEvaluatorError::InvalidMathArg);
1734+
}
1735+
1736+
fn int_dot<P>(a: &[P], b: &[P]) -> Result<P, ConstantEvaluatorError>
1737+
where
1738+
P: num_traits::PrimInt + num_traits::CheckedAdd + num_traits::CheckedMul,
1739+
{
1740+
a.iter()
1741+
.zip(b.iter())
1742+
.map(|(&aa, bb)| aa.checked_mul(bb))
1743+
.try_fold(P::zero(), |acc, x| {
1744+
if let Some(x) = x {
1745+
acc.checked_add(&x)
1746+
} else {
1747+
None
1748+
}
1749+
})
1750+
.ok_or(ConstantEvaluatorError::Overflow(
1751+
"in dot built-in".to_string(),
1752+
))
1753+
}
1754+
1755+
let result = match_literal_vector!(match (e1, e2) => Literal {
1756+
Float => |e1, e2| { e1.iter().zip(e2.iter()).map(|(&aa, &bb)| aa * bb).sum() },
1757+
Integer => |e1, e2| { int_dot(e1, e2)? },
1758+
})?;
1759+
self.register_evaluated_expr(Expression::Literal(result), span)
1760+
}
1761+
crate::MathFunction::Length => {
1762+
// https://www.w3.org/TR/WGSL/#length-builtin
1763+
let e1 = self.extract_vec(arg, true)?;
1764+
1765+
fn float_length<F>(e: &[F]) -> F
1766+
where
1767+
F: core::ops::Mul<F>,
1768+
F: num_traits::Float + iter::Sum,
1769+
{
1770+
e.iter().map(|&ei| ei * ei).sum::<F>().sqrt()
1771+
}
1772+
1773+
let result = match_literal_vector!(match e1 => Literal {
1774+
Float => |e1| { float_length(e1) },
1775+
})?;
1776+
self.register_evaluated_expr(Expression::Literal(result), span)
1777+
}
1778+
crate::MathFunction::Distance => {
1779+
// https://www.w3.org/TR/WGSL/#distance-builtin
1780+
let e1 = self.extract_vec(arg, true)?;
1781+
let e2 = self.extract_vec(arg1.unwrap(), true)?;
1782+
if e1.len() != e2.len() {
1783+
return Err(ConstantEvaluatorError::InvalidMathArg);
1784+
}
1785+
1786+
fn float_distance<F>(a: &[F], b: &[F]) -> F
1787+
where
1788+
F: core::ops::Mul<F>,
1789+
F: num_traits::Float + iter::Sum + core::ops::Sub,
1790+
{
1791+
a.iter()
1792+
.zip(b.iter())
1793+
.map(|(&aa, &bb)| aa - bb)
1794+
.map(|ei| ei * ei)
1795+
.sum::<F>()
1796+
.sqrt()
1797+
}
1798+
let result = match_literal_vector!(match (e1, e2) => Literal {
1799+
Float => |e1, e2| { float_distance(e1, e2) },
1800+
})?;
1801+
self.register_evaluated_expr(Expression::Literal(result), span)
1802+
}
1803+
crate::MathFunction::Normalize => {
1804+
// https://www.w3.org/TR/WGSL/#normalize-builtin
1805+
let e1 = self.extract_vec(arg, true)?;
1806+
1807+
fn float_normalize<F>(e: &[F]) -> ArrayVec<F, { crate::VectorSize::MAX }>
1808+
where
1809+
F: core::ops::Mul<F>,
1810+
F: num_traits::Float + iter::Sum,
1811+
{
1812+
let len = e.iter().map(|&ei| ei * ei).sum::<F>().sqrt();
1813+
e.iter().map(|&ei| ei / len).collect()
1814+
}
1815+
1816+
let result = match_literal_vector!(match e1 => LiteralVector {
1817+
Float => |e1| { float_normalize(e1) },
1818+
})?;
1819+
result.register_as_evaluated_expr(self, span)
1820+
}
17281821

17291822
// unimplemented
17301823
crate::MathFunction::Atan2
17311824
| crate::MathFunction::Modf
17321825
| crate::MathFunction::Frexp
17331826
| crate::MathFunction::Ldexp
1734-
| crate::MathFunction::Dot
17351827
| crate::MathFunction::Outer
1736-
| crate::MathFunction::Distance
1737-
| crate::MathFunction::Length
1738-
| crate::MathFunction::Normalize
17391828
| crate::MathFunction::FaceForward
17401829
| crate::MathFunction::Reflect
17411830
| crate::MathFunction::Refract
@@ -1811,8 +1900,8 @@ impl<'a> ConstantEvaluator<'a> {
18111900
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
18121901
use Literal as Li;
18131902

1814-
let (a, ty) = self.extract_vec::<3>(a)?;
1815-
let (b, _) = self.extract_vec::<3>(b)?;
1903+
let (a, ty) = self.extract_vec_with_size::<3>(a)?;
1904+
let (b, _) = self.extract_vec_with_size::<3>(b)?;
18161905

18171906
let product = match (a, b) {
18181907
(
@@ -1879,7 +1968,7 @@ impl<'a> ConstantEvaluator<'a> {
18791968
/// values.
18801969
///
18811970
/// Also return the type handle from the `Compose` expression.
1882-
fn extract_vec<const N: usize>(
1971+
fn extract_vec_with_size<const N: usize>(
18831972
&mut self,
18841973
expr: Handle<Expression>,
18851974
) -> Result<([Literal; N], Handle<Type>), ConstantEvaluatorError> {
@@ -1903,6 +1992,39 @@ impl<'a> ConstantEvaluator<'a> {
19031992
Ok((value, ty))
19041993
}
19051994

1995+
/// Extract the values of a `vecN` from `expr`.
1996+
///
1997+
/// Return the value of `expr`, whose type is `vecN<S>` for some
1998+
/// vector size `N` and scalar `S`, as an array of `N` [`Literal`]
1999+
/// values.
2000+
///
2001+
/// Also return the type handle from the `Compose` expression.
2002+
fn extract_vec(
2003+
&mut self,
2004+
expr: Handle<Expression>,
2005+
allow_single: bool,
2006+
) -> Result<LiteralVector, ConstantEvaluatorError> {
2007+
let span = self.expressions.get_span(expr);
2008+
let expr = self.eval_zero_value_and_splat(expr, span)?;
2009+
2010+
match self.expressions[expr] {
2011+
Expression::Literal(literal) if allow_single => {
2012+
Ok(LiteralVector::from_literal(literal))
2013+
}
2014+
Expression::Compose { ty, ref components } => {
2015+
let components: ArrayVec<Literal, { crate::VectorSize::MAX }> =
2016+
crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2017+
.map(|expr| match self.expressions[expr] {
2018+
Expression::Literal(l) => Ok(l),
2019+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
2020+
})
2021+
.collect::<Result<_, ConstantEvaluatorError>>()?;
2022+
LiteralVector::from_literal_vec(components)
2023+
}
2024+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
2025+
}
2026+
}
2027+
19062028
fn array_length(
19072029
&mut self,
19082030
array: Handle<Expression>,

0 commit comments

Comments
 (0)