Skip to content

Commit bc5bce4

Browse files
committed
[naga] Implement builtins dot, length, distance in const using LiteralVector
Signed-off-by: sagudev <[email protected]>
1 parent 4779b36 commit bc5bce4

File tree

1 file changed

+111
-6
lines changed

1 file changed

+111
-6
lines changed

naga/src/proc/constant_evaluator.rs

Lines changed: 111 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1725,16 +1725,88 @@ impl<'a> ConstantEvaluator<'a> {
17251725
self.packed_dot_product(arg, arg1.unwrap(), span, false)
17261726
}
17271727
crate::MathFunction::Cross => self.cross_product(arg, arg1.unwrap(), span),
1728+
crate::MathFunction::Dot => {
1729+
// https://www.w3.org/TR/WGSL/#dot-builtin
1730+
let e1 = self.extract_vec(arg, false)?;
1731+
let e2 = self.extract_vec(arg1.unwrap(), false)?;
1732+
if e1.len() != e2.len() {
1733+
return Err(ConstantEvaluatorError::InvalidMathArg);
1734+
}
1735+
1736+
fn int_dot<P>(a: &[P], b: &[P]) -> Result<P, ConstantEvaluatorError>
1737+
where
1738+
P: num_traits::PrimInt + num_traits::CheckedAdd + num_traits::CheckedMul,
1739+
{
1740+
a.iter()
1741+
.zip(b.iter())
1742+
.map(|(&aa, bb)| aa.checked_mul(bb))
1743+
.try_fold(P::zero(), |acc, x| {
1744+
if let Some(x) = x {
1745+
acc.checked_add(&x)
1746+
} else {
1747+
None
1748+
}
1749+
})
1750+
.ok_or(ConstantEvaluatorError::Overflow(
1751+
"in dot built-in".to_string(),
1752+
))
1753+
}
1754+
1755+
let result = match_literal_vector!(match (e1, e2) => Literal {
1756+
Float => |e1, e2| { e1.iter().zip(e2.iter()).map(|(&aa, &bb)| aa * bb).sum() },
1757+
Integer => |e1, e2| { int_dot(e1, e2)? },
1758+
})?;
1759+
self.register_evaluated_expr(Expression::Literal(result), span)
1760+
}
1761+
crate::MathFunction::Length => {
1762+
// https://www.w3.org/TR/WGSL/#length-builtin
1763+
let e1 = self.extract_vec(arg, true)?;
1764+
1765+
fn float_length<F>(e: &[F]) -> F
1766+
where
1767+
F: core::ops::Mul<F>,
1768+
F: num_traits::Float + iter::Sum,
1769+
{
1770+
e.iter().map(|&ei| ei * ei).sum::<F>().sqrt()
1771+
}
1772+
1773+
let result = match_literal_vector!(match e1 => Literal {
1774+
Float => |e1| { float_length(e1) },
1775+
})?;
1776+
self.register_evaluated_expr(Expression::Literal(result), span)
1777+
}
1778+
crate::MathFunction::Distance => {
1779+
// https://www.w3.org/TR/WGSL/#distance-builtin
1780+
let e1 = self.extract_vec(arg, true)?;
1781+
let e2 = self.extract_vec(arg1.unwrap(), true)?;
1782+
if e1.len() != e2.len() {
1783+
return Err(ConstantEvaluatorError::InvalidMathArg);
1784+
}
1785+
1786+
fn float_distance<F>(a: &[F], b: &[F]) -> F
1787+
where
1788+
F: core::ops::Mul<F>,
1789+
F: num_traits::Float + iter::Sum + core::ops::Sub,
1790+
{
1791+
a.iter()
1792+
.zip(b.iter())
1793+
.map(|(&aa, &bb)| aa - bb)
1794+
.map(|ei| ei * ei)
1795+
.sum::<F>()
1796+
.sqrt()
1797+
}
1798+
let result = match_literal_vector!(match (e1, e2) => Literal {
1799+
Float => |e1, e2| { float_distance(e1, e2) },
1800+
})?;
1801+
self.register_evaluated_expr(Expression::Literal(result), span)
1802+
}
17281803

17291804
// unimplemented
17301805
crate::MathFunction::Atan2
17311806
| crate::MathFunction::Modf
17321807
| crate::MathFunction::Frexp
17331808
| crate::MathFunction::Ldexp
1734-
| crate::MathFunction::Dot
17351809
| crate::MathFunction::Outer
1736-
| crate::MathFunction::Distance
1737-
| crate::MathFunction::Length
17381810
| crate::MathFunction::Normalize
17391811
| crate::MathFunction::FaceForward
17401812
| crate::MathFunction::Reflect
@@ -1811,8 +1883,8 @@ impl<'a> ConstantEvaluator<'a> {
18111883
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
18121884
use Literal as Li;
18131885

1814-
let (a, ty) = self.extract_vec::<3>(a)?;
1815-
let (b, _) = self.extract_vec::<3>(b)?;
1886+
let (a, ty) = self.extract_vec_with_size::<3>(a)?;
1887+
let (b, _) = self.extract_vec_with_size::<3>(b)?;
18161888

18171889
let product = match (a, b) {
18181890
(
@@ -1879,7 +1951,7 @@ impl<'a> ConstantEvaluator<'a> {
18791951
/// values.
18801952
///
18811953
/// Also return the type handle from the `Compose` expression.
1882-
fn extract_vec<const N: usize>(
1954+
fn extract_vec_with_size<const N: usize>(
18831955
&mut self,
18841956
expr: Handle<Expression>,
18851957
) -> Result<([Literal; N], Handle<Type>), ConstantEvaluatorError> {
@@ -1903,6 +1975,39 @@ impl<'a> ConstantEvaluator<'a> {
19031975
Ok((value, ty))
19041976
}
19051977

1978+
/// Extract the values of a `vecN` from `expr`.
1979+
///
1980+
/// Return the value of `expr`, whose type is `vecN<S>` for some
1981+
/// vector size `N` and scalar `S`, as an array of `N` [`Literal`]
1982+
/// values.
1983+
///
1984+
/// Also return the type handle from the `Compose` expression.
1985+
fn extract_vec(
1986+
&mut self,
1987+
expr: Handle<Expression>,
1988+
allow_single: bool,
1989+
) -> Result<LiteralVector, ConstantEvaluatorError> {
1990+
let span = self.expressions.get_span(expr);
1991+
let expr = self.eval_zero_value_and_splat(expr, span)?;
1992+
1993+
match self.expressions[expr] {
1994+
Expression::Literal(literal) if allow_single => {
1995+
Ok(LiteralVector::from_literal(literal))
1996+
}
1997+
Expression::Compose { ty, ref components } => {
1998+
let components: ArrayVec<Literal, { crate::VectorSize::MAX }> =
1999+
crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2000+
.map(|expr| match self.expressions[expr] {
2001+
Expression::Literal(l) => Ok(l),
2002+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
2003+
})
2004+
.collect::<Result<_, ConstantEvaluatorError>>()?;
2005+
LiteralVector::from_literal_vec(components)
2006+
}
2007+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
2008+
}
2009+
}
2010+
19062011
fn array_length(
19072012
&mut self,
19082013
array: Handle<Expression>,

0 commit comments

Comments
 (0)