diff --git a/Cargo.lock b/Cargo.lock index c472cdb86b5..ecc91f84bc0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1893,6 +1893,7 @@ dependencies = [ "indexmap", "itertools", "log", + "num-traits", "petgraph", "pp-rs", "ron", diff --git a/naga/Cargo.toml b/naga/Cargo.toml index 0d54c964cc6..4fe7129fd28 100644 --- a/naga/Cargo.toml +++ b/naga/Cargo.toml @@ -81,7 +81,8 @@ serde = { version = "1.0.214", features = ["derive"], optional = true } petgraph = { version = "0.6", optional = true } pp-rs = { version = "0.2.1", optional = true } hexf-parse = { version = "0.2.1", optional = true } -unicode-xid = { version = "0.2.6", optional = true } +unicode-xid = { version = "0.2.5", optional = true } +num-traits = "0.2" [build-dependencies] cfg_aliases.workspace = true diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index 5fdf4815164..34353dc18ce 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -254,6 +254,285 @@ gen_component_wise_extractor! { ], } +macro_rules! match_literal_vector { + ($($x:expr),* => $( $( ($($mat:ident($arg:ident)),*) -> $ret:path )|+ => $body:expr ),*,_ => $body2:expr) => { + match $($x),* { + $( + $( + ($( LiteralVector::$mat($arg) ),*) => $ret($body), + )* + )* + _ => $body2 + } + }; + ($($x:expr),* => $( $( ($($mat:ident($arg:ident)),*) )|+ => $body:expr ),*,_ => $body2:expr) => { + match $($x),* { + $( + $( + ($( LiteralVector::$mat($arg) ),*) => $body, + )* + )* + _ => $body2 + } + }; + ($($x:expr),* => $( $( $mat:ident$arg:tt -> $ret:path )|+ => $body:expr ),*,_ => $body2:expr) => { + match $($x),* { + $( + $( + LiteralVector::$mat$arg => $ret($body), + )* + )* + _ => $body2 + } + }; + ($($x:expr),* => $( $( $mat:ident$arg:tt )|+ => $body:expr ),*) => { + match $($x),* { + $( + $( + LiteralVector::$mat$arg => $body, + )* + )* + } + }; +} + +/// Vectors with a concrete element type. +#[derive(Debug)] +enum LiteralVector { + F64(ArrayVec), + F32(ArrayVec), + U32(ArrayVec), + I32(ArrayVec), + U64(ArrayVec), + I64(ArrayVec), + Bool(ArrayVec), + AbstractInt(ArrayVec), + AbstractFloat(ArrayVec), +} + +impl LiteralVector { + const fn len(&self) -> usize { + match_literal_vector!(*self => + F64(ref v) + | F32(ref v) + | U32(ref v) + | I32(ref v) + | U64(ref v) + | I64(ref v) + | Bool(ref v) + | AbstractInt(ref v) + | AbstractFloat(ref v) => v.len() + ) + } + + /// Creates [`LiteralVector`] of size 1 from single [`Literal`] + fn from_literal(literal: Literal) -> Self { + match literal { + Literal::F64(e) => Self::F64(ArrayVec::from_iter(iter::once(e))), + Literal::F32(e) => Self::F32(ArrayVec::from_iter(iter::once(e))), + Literal::U32(e) => Self::U32(ArrayVec::from_iter(iter::once(e))), + Literal::I32(e) => Self::I32(ArrayVec::from_iter(iter::once(e))), + Literal::U64(e) => Self::U64(ArrayVec::from_iter(iter::once(e))), + Literal::I64(e) => Self::I64(ArrayVec::from_iter(iter::once(e))), + Literal::Bool(e) => Self::Bool(ArrayVec::from_iter(iter::once(e))), + Literal::AbstractInt(e) => Self::AbstractInt(ArrayVec::from_iter(iter::once(e))), + Literal::AbstractFloat(e) => Self::AbstractFloat(ArrayVec::from_iter(iter::once(e))), + } + } + + /// # Panics + /// Panics if vector is empty, returns error if types do not match + fn from_literal_vec_with_scalar_type( + components: ArrayVec, + scalar: crate::Scalar, + ) -> Result { + assert!(!components.is_empty()); + Ok(match scalar { + crate::Scalar::I32 => Self::I32( + components + .iter() + .map(|l| match l { + &Literal::I32(v) => Ok(v), + _ => Err(ConstantEvaluatorError::InvalidMathArg), + }) + .collect::>()?, + ), + crate::Scalar::U32 => Self::U32( + components + .iter() + .map(|l| match l { + &Literal::U32(v) => Ok(v), + _ => Err(ConstantEvaluatorError::InvalidMathArg), + }) + .collect::>()?, + ), + crate::Scalar::I64 => Self::I64( + components + .iter() + .map(|l| match l { + &Literal::I64(v) => Ok(v), + _ => Err(ConstantEvaluatorError::InvalidMathArg), + }) + .collect::>()?, + ), + crate::Scalar::U64 => Self::U64( + components + .iter() + .map(|l| match l { + &Literal::U64(v) => Ok(v), + _ => Err(ConstantEvaluatorError::InvalidMathArg), + }) + .collect::>()?, + ), + crate::Scalar::F32 => Self::F32( + components + .iter() + .map(|l| match l { + &Literal::F32(v) => Ok(v), + _ => Err(ConstantEvaluatorError::InvalidMathArg), + }) + .collect::>()?, + ), + crate::Scalar::F64 => Self::F64( + components + .iter() + .map(|l| match l { + &Literal::F64(v) => Ok(v), + _ => Err(ConstantEvaluatorError::InvalidMathArg), + }) + .collect::>()?, + ), + crate::Scalar::BOOL => Self::Bool( + components + .iter() + .map(|l| match l { + &Literal::Bool(v) => Ok(v), + _ => Err(ConstantEvaluatorError::InvalidMathArg), + }) + .collect::>()?, + ), + crate::Scalar::ABSTRACT_INT => Self::AbstractInt( + components + .iter() + .map(|l| match l { + &Literal::AbstractInt(v) => Ok(v), + _ => Err(ConstantEvaluatorError::InvalidMathArg), + }) + .collect::>()?, + ), + crate::Scalar::ABSTRACT_FLOAT => Self::AbstractFloat( + components + .iter() + .map(|l| match l { + &Literal::AbstractFloat(v) => Ok(v), + _ => Err(ConstantEvaluatorError::InvalidMathArg), + }) + .collect::>()?, + ), + _ => return Err(ConstantEvaluatorError::InvalidMathArg), + }) + } + + fn from_expr( + expr: Handle, + eval: &mut ConstantEvaluator<'_>, + span: Span, + allow_single: bool, + ) -> Result { + let expr = eval + .eval_zero_value_and_splat(expr, span) + .map(|expr| &eval.expressions[expr])?; + match *expr { + Expression::Literal(literal) => { + if allow_single { + Ok(Self::from_literal(literal)) + } else { + Err(ConstantEvaluatorError::InvalidMathArg) + } + } + Expression::Compose { ty, ref components } => match eval.types[ty].inner { + TypeInner::Vector { scalar, .. } => { + if components.len() > crate::VectorSize::MAX { + return Err(ConstantEvaluatorError::InvalidMathArg); + } + let components: ArrayVec = + crate::proc::flatten_compose(ty, components, eval.expressions, eval.types) + .map(|expr| match eval.expressions[expr] { + Expression::Literal(l) => Ok(l), + _ => Err(ConstantEvaluatorError::InvalidMathArg), + }) + .collect::>()?; + Self::from_literal_vec_with_scalar_type(components, scalar) + } + _ => Err(ConstantEvaluatorError::InvalidMathArg), + }, + _ => Err(ConstantEvaluatorError::InvalidMathArg), + } + } + + /// Returns [`ArrayVec`] of [`Literal`]s + fn to_literal_vec(&self) -> ArrayVec { + match *self { + LiteralVector::F64(ref v) => v.iter().map(|e| (Literal::F64(*e))).collect(), + LiteralVector::F32(ref v) => v.iter().map(|e| (Literal::F32(*e))).collect(), + LiteralVector::U32(ref v) => v.iter().map(|e| (Literal::U32(*e))).collect(), + LiteralVector::I32(ref v) => v.iter().map(|e| (Literal::I32(*e))).collect(), + LiteralVector::U64(ref v) => v.iter().map(|e| (Literal::U64(*e))).collect(), + LiteralVector::I64(ref v) => v.iter().map(|e| (Literal::I64(*e))).collect(), + LiteralVector::Bool(ref v) => v.iter().map(|e| (Literal::Bool(*e))).collect(), + LiteralVector::AbstractInt(ref v) => { + v.iter().map(|e| (Literal::AbstractInt(*e))).collect() + } + LiteralVector::AbstractFloat(ref v) => { + v.iter().map(|e| (Literal::AbstractFloat(*e))).collect() + } + } + } + + fn to_expr( + &self, + eval: &mut ConstantEvaluator<'_>, + ) -> Result { + let lit_vec = self.to_literal_vec(); + assert!(!lit_vec.is_empty()); + if lit_vec.len() == 1 { + Ok(Expression::Literal(lit_vec[0])) + } else { + Ok(Expression::Compose { + ty: eval.types.insert( + Type { + name: None, + inner: TypeInner::Vector { + size: match lit_vec.len() { + 2 => crate::VectorSize::Bi, + 3 => crate::VectorSize::Tri, + 4 => crate::VectorSize::Quad, + _ => unreachable!(), + }, + scalar: lit_vec[0].scalar(), + }, + }, + Span::UNDEFINED, + ), + components: lit_vec + .iter() + .map(|&l| eval.register_evaluated_expr(Expression::Literal(l), Span::UNDEFINED)) + .collect::>()?, + }) + } + } + + /// Puts self into eval's expressions arena and returns handle to it + fn handle( + &self, + eval: &mut ConstantEvaluator<'_>, + span: Span, + ) -> Result, ConstantEvaluatorError> { + let expr = self.to_expr(eval)?; + eval.register_evaluated_expr(expr, span) + } +} + #[derive(Debug)] enum Behavior<'a> { Wgsl(WgslRestrictions<'a>), @@ -917,9 +1196,10 @@ impl<'a> ConstantEvaluator<'a> { Expression::Select { .. } => Err(ConstantEvaluatorError::NotImplemented( "select built-in function".into(), )), - Expression::Relational { fun, .. } => Err(ConstantEvaluatorError::NotImplemented( - format!("{fun:?} built-in function"), - )), + Expression::Relational { fun, argument } => { + let arg = self.check_and_get(argument)?; + self.relational_op(fun, arg, span) + } Expression::ArrayLength(expr) => match self.behavior { Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength), Behavior::Glsl(_) => { @@ -1230,6 +1510,136 @@ impl<'a> ConstantEvaluator<'a> { }) } + // geometry + crate::MathFunction::Dot => { + let e1 = LiteralVector::from_expr(arg, self, span, false)?; + let e2 = LiteralVector::from_expr(arg1.unwrap(), self, span, false)?; + if e1.len() != e2.len() { + return Err(ConstantEvaluatorError::InvalidMathArg); + } + + fn float_dot(a: ArrayVec, b: ArrayVec) -> F + where + F: std::ops::Mul, + F: num_traits::Float + std::iter::Sum, + { + a.iter().zip(b.iter()).map(|(&aa, &bb)| aa * bb).sum() + } + + fn int_dot( + a: ArrayVec, + b: ArrayVec, + ) -> Result + where + P: num_traits::PrimInt + num_traits::CheckedAdd + num_traits::CheckedMul, + { + a.iter() + .zip(b.iter()) + .map(|(&aa, bb)| aa.checked_mul(bb)) + .try_fold(P::zero(), |acc, x| { + if let Some(x) = x { + acc.checked_add(&x) + } else { + None + } + }) + .ok_or(ConstantEvaluatorError::Overflow( + "in dot built-in".to_string(), + )) + } + + LiteralVector::from_literal(match_literal_vector! {(e1, e2) => + (AbstractFloat(e1), AbstractFloat(e2)) -> Literal::AbstractFloat + | (F32(e1), F32(e2)) -> Literal::F32 + => float_dot(e1, e2), + (AbstractInt(e1), AbstractInt(e2)) -> Literal::AbstractInt + | (I32(e1), I32(e2)) -> Literal::I32 + | (U32(e1), U32(e2)) -> Literal::U32 + => int_dot(e1, e2)?, + _ => return Err(ConstantEvaluatorError::InvalidMathArg) + }) + .handle(self, span) + } + crate::MathFunction::Cross => { + let e1 = LiteralVector::from_expr(arg, self, span, false)?; + let e2 = LiteralVector::from_expr(arg1.unwrap(), self, span, false)?; + if e1.len() == 3 && e2.len() == 3 { + fn float_cross( + a: ArrayVec, + b: ArrayVec, + ) -> ArrayVec + where + F: std::ops::Mul, + F: num_traits::Float + std::iter::Sum, + { + [ + a[1] * b[2] - a[2] * b[1], + a[2] * b[0] - a[0] * b[2], + a[0] * b[1] - a[1] * b[0], + ] + .into_iter() + .collect() + } + match_literal_vector! {(e1, e2) => + (AbstractFloat(e1), AbstractFloat(e2)) -> LiteralVector::AbstractFloat + | (F32(e1), F32(e2)) -> LiteralVector::F32 + => float_cross(e1, e2), + _ => return Err(ConstantEvaluatorError::InvalidMathArg) + } + .handle(self, span) + } else { + Err(ConstantEvaluatorError::InvalidMathArg) + } + } + crate::MathFunction::Length => { + let e1 = LiteralVector::from_expr(arg, self, span, true)?; + + fn float_length(e: ArrayVec) -> F + where + F: std::ops::Mul, + F: num_traits::Float + std::iter::Sum, + { + e.iter().map(|&ei| ei * ei).sum::().sqrt() + } + + LiteralVector::from_literal(match_literal_vector! {e1 => + AbstractFloat(e1) -> Literal::AbstractFloat + | F32(e1) -> Literal::F32 + => float_length(e1), + _ => return Err(ConstantEvaluatorError::InvalidMathArg) + }) + .handle(self, span) + } + crate::MathFunction::Distance => { + let e1 = LiteralVector::from_expr(arg, self, span, true)?; + let e2 = LiteralVector::from_expr(arg1.unwrap(), self, span, true)?; + if e1.len() != e2.len() { + return Err(ConstantEvaluatorError::InvalidMathArg); + } + + fn float_distance( + a: ArrayVec, + b: ArrayVec, + ) -> F + where + F: std::ops::Mul, + F: num_traits::Float + std::iter::Sum + std::ops::Sub, + { + a.iter() + .zip(b.iter()) + .map(|(&aa, &bb)| aa - bb) + .map(|ei| ei * ei) + .sum::() + .sqrt() + } + LiteralVector::from_literal(match_literal_vector! {(e1, e2) => + (AbstractFloat(e1), AbstractFloat(e2)) -> Literal::AbstractFloat + | (F32(e1), F32(e2)) -> Literal::F32 + => float_distance(e1, e2), + _ => return Err(ConstantEvaluatorError::InvalidMathArg) + }) + .handle(self, span) + } // computational crate::MathFunction::Sign => { component_wise_signed!(self, span, [arg], |e| { Ok([e.signum()]) }) @@ -2059,6 +2469,38 @@ impl<'a> ConstantEvaluator<'a> { Ok(Expression::Compose { ty, components }) } + fn relational_op( + &mut self, + fun: crate::RelationalFunction, + arg: Handle, + span: Span, + ) -> Result, ConstantEvaluatorError> { + let arg = LiteralVector::from_expr(arg, self, span, true)?; + let res = LiteralVector::Bool(match fun { + crate::RelationalFunction::IsNan => match arg { + LiteralVector::F64(f) => f.iter().map(|e| e.is_nan()).collect(), + LiteralVector::F32(f) => f.iter().map(|e| e.is_nan()).collect(), + LiteralVector::AbstractFloat(f) => f.iter().map(|e| e.is_nan()).collect(), + _ => return Err(ConstantEvaluatorError::InvalidMathArg), + }, + crate::RelationalFunction::IsInf => match arg { + LiteralVector::F64(f) => f.iter().map(|e| e.is_infinite()).collect(), + LiteralVector::F32(f) => f.iter().map(|e| e.is_infinite()).collect(), + LiteralVector::AbstractFloat(f) => f.iter().map(|e| e.is_infinite()).collect(), + _ => return Err(ConstantEvaluatorError::InvalidMathArg), + }, + crate::RelationalFunction::All => match arg { + LiteralVector::Bool(bools) => iter::once(bools.iter().all(|b| *b)).collect(), + _ => return Err(ConstantEvaluatorError::InvalidMathArg), + }, + crate::RelationalFunction::Any => match arg { + LiteralVector::Bool(bools) => iter::once(bools.iter().any(|b| *b)).collect(), + _ => return Err(ConstantEvaluatorError::InvalidMathArg), + }, + }); + res.handle(self, span) + } + /// Deep copy `expr` from `expressions` into `self.expressions`. /// /// Return the root of the new copy.