Skip to content

Commit e659ef3

Browse files
committed
LiteralVector and some demo
Signed-off-by: sagudev <[email protected]>
1 parent 4e9a2a5 commit e659ef3

File tree

1 file changed

+368
-3
lines changed

1 file changed

+368
-3
lines changed

naga/src/proc/constant_evaluator.rs

Lines changed: 368 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,254 @@ gen_component_wise_extractor! {
254254
],
255255
}
256256

257+
/// Vector for each [`Literal`] type
258+
///
259+
/// This type ensures that all elements have same type
260+
enum LiteralVector {
261+
F64(ArrayVec<f64, { crate::VectorSize::MAX }>),
262+
F32(ArrayVec<f32, { crate::VectorSize::MAX }>),
263+
U32(ArrayVec<u32, { crate::VectorSize::MAX }>),
264+
I32(ArrayVec<i32, { crate::VectorSize::MAX }>),
265+
U64(ArrayVec<u64, { crate::VectorSize::MAX }>),
266+
I64(ArrayVec<i64, { crate::VectorSize::MAX }>),
267+
Bool(ArrayVec<bool, { crate::VectorSize::MAX }>),
268+
AbstractInt(ArrayVec<i64, { crate::VectorSize::MAX }>),
269+
AbstractFloat(ArrayVec<f64, { crate::VectorSize::MAX }>),
270+
}
271+
272+
impl LiteralVector {
273+
#[allow(clippy::pattern_type_mismatch)]
274+
const fn len(&self) -> usize {
275+
match self {
276+
LiteralVector::F64(v) => v.len(),
277+
LiteralVector::F32(v) => v.len(),
278+
LiteralVector::U32(v) => v.len(),
279+
LiteralVector::I32(v) => v.len(),
280+
LiteralVector::U64(v) => v.len(),
281+
LiteralVector::I64(v) => v.len(),
282+
LiteralVector::Bool(v) => v.len(),
283+
LiteralVector::AbstractInt(v) => v.len(),
284+
LiteralVector::AbstractFloat(v) => v.len(),
285+
}
286+
}
287+
/// Creates [`LiteralVector`] of size 1 from single [`Literal`]
288+
fn from_literal(literal: Literal) -> Self {
289+
match literal {
290+
Literal::F64(e) => Self::F64(ArrayVec::from_iter(iter::once(e))),
291+
Literal::F32(e) => Self::F32(ArrayVec::from_iter(iter::once(e))),
292+
Literal::U32(e) => Self::U32(ArrayVec::from_iter(iter::once(e))),
293+
Literal::I32(e) => Self::I32(ArrayVec::from_iter(iter::once(e))),
294+
Literal::U64(e) => Self::U64(ArrayVec::from_iter(iter::once(e))),
295+
Literal::I64(e) => Self::I64(ArrayVec::from_iter(iter::once(e))),
296+
Literal::Bool(e) => Self::Bool(ArrayVec::from_iter(iter::once(e))),
297+
Literal::AbstractInt(e) => Self::AbstractInt(ArrayVec::from_iter(iter::once(e))),
298+
Literal::AbstractFloat(e) => Self::AbstractFloat(ArrayVec::from_iter(iter::once(e))),
299+
}
300+
}
301+
302+
/// Creates [`LiteralVector`] from Array of [`Literal`]s
303+
///
304+
/// Panics if vector is empty
305+
fn from_literal_vec(
306+
components: ArrayVec<Literal, { crate::VectorSize::MAX }>,
307+
) -> Result<Self, ConstantEvaluatorError> {
308+
let scalar = components[0].scalar();
309+
Self::from_literal_vec_with_scalar_type(components, scalar)
310+
}
311+
312+
/// Creates [`LiteralVector`] of type provided by scalar from Array of [`Literal`]s
313+
///
314+
/// Panics if vector is empty, returns error if types do not match
315+
fn from_literal_vec_with_scalar_type(
316+
components: ArrayVec<Literal, { crate::VectorSize::MAX }>,
317+
scalar: crate::Scalar,
318+
) -> Result<Self, ConstantEvaluatorError> {
319+
assert!(!components.is_empty());
320+
Ok(match scalar {
321+
crate::Scalar::I32 => Self::I32(
322+
components
323+
.iter()
324+
.map(|l| match l {
325+
&Literal::I32(v) => Ok(v),
326+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
327+
})
328+
.collect::<Result<_, _>>()?,
329+
),
330+
crate::Scalar::U32 => Self::U32(
331+
components
332+
.iter()
333+
.map(|l| match l {
334+
&Literal::U32(v) => Ok(v),
335+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
336+
})
337+
.collect::<Result<_, _>>()?,
338+
),
339+
crate::Scalar::I64 => Self::I64(
340+
components
341+
.iter()
342+
.map(|l| match l {
343+
&Literal::I64(v) => Ok(v),
344+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
345+
})
346+
.collect::<Result<_, _>>()?,
347+
),
348+
crate::Scalar::U64 => Self::U64(
349+
components
350+
.iter()
351+
.map(|l| match l {
352+
&Literal::U64(v) => Ok(v),
353+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
354+
})
355+
.collect::<Result<_, _>>()?,
356+
),
357+
crate::Scalar::F32 => Self::F32(
358+
components
359+
.iter()
360+
.map(|l| match l {
361+
&Literal::F32(v) => Ok(v),
362+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
363+
})
364+
.collect::<Result<_, _>>()?,
365+
),
366+
crate::Scalar::F64 => Self::F64(
367+
components
368+
.iter()
369+
.map(|l| match l {
370+
&Literal::F64(v) => Ok(v),
371+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
372+
})
373+
.collect::<Result<_, _>>()?,
374+
),
375+
crate::Scalar::BOOL => Self::Bool(
376+
components
377+
.iter()
378+
.map(|l| match l {
379+
&Literal::Bool(v) => Ok(v),
380+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
381+
})
382+
.collect::<Result<_, _>>()?,
383+
),
384+
crate::Scalar::ABSTRACT_INT => Self::AbstractInt(
385+
components
386+
.iter()
387+
.map(|l| match l {
388+
&Literal::AbstractInt(v) => Ok(v),
389+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
390+
})
391+
.collect::<Result<_, _>>()?,
392+
),
393+
crate::Scalar::ABSTRACT_FLOAT => Self::AbstractFloat(
394+
components
395+
.iter()
396+
.map(|l| match l {
397+
&Literal::AbstractFloat(v) => Ok(v),
398+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
399+
})
400+
.collect::<Result<_, _>>()?,
401+
),
402+
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
403+
})
404+
}
405+
406+
fn from_expr(
407+
expr: Handle<Expression>,
408+
eval: &mut ConstantEvaluator<'_>,
409+
span: Span,
410+
allow_single: bool,
411+
) -> Result<Self, ConstantEvaluatorError> {
412+
let expr = eval
413+
.eval_zero_value_and_splat(expr, span)
414+
.map(|expr| &eval.expressions[expr])?;
415+
match *expr {
416+
Expression::Literal(literal) => {
417+
if allow_single {
418+
Ok(Self::from_literal(literal))
419+
} else {
420+
Err(ConstantEvaluatorError::InvalidMathArg)
421+
}
422+
}
423+
Expression::Compose { ty, ref components } => match eval.types[ty].inner {
424+
TypeInner::Vector { scalar, .. } => {
425+
if components.len() > crate::VectorSize::MAX {
426+
return Err(ConstantEvaluatorError::InvalidMathArg);
427+
}
428+
let components: ArrayVec<Literal, { crate::VectorSize::MAX }> =
429+
crate::proc::flatten_compose(ty, components, eval.expressions, eval.types)
430+
.map(|expr| match eval.expressions[expr] {
431+
Expression::Literal(l) => Ok(l),
432+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
433+
})
434+
.collect::<Result<_, ConstantEvaluatorError>>()?;
435+
Self::from_literal_vec_with_scalar_type(components, scalar)
436+
}
437+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
438+
},
439+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
440+
}
441+
}
442+
443+
/// Returns [`ArrayVec`] of [`Literals`]
444+
fn to_literal_vec(&self) -> ArrayVec<Literal, { crate::VectorSize::MAX }> {
445+
#[allow(clippy::pattern_type_mismatch)]
446+
match self {
447+
LiteralVector::F64(v) => v.iter().map(|e| (Literal::F64(*e))).collect(),
448+
LiteralVector::F32(v) => v.iter().map(|e| (Literal::F32(*e))).collect(),
449+
LiteralVector::U32(v) => v.iter().map(|e| (Literal::U32(*e))).collect(),
450+
LiteralVector::I32(v) => v.iter().map(|e| (Literal::I32(*e))).collect(),
451+
LiteralVector::U64(v) => v.iter().map(|e| (Literal::U64(*e))).collect(),
452+
LiteralVector::I64(v) => v.iter().map(|e| (Literal::I64(*e))).collect(),
453+
LiteralVector::Bool(v) => v.iter().map(|e| (Literal::Bool(*e))).collect(),
454+
LiteralVector::AbstractInt(v) => v.iter().map(|e| (Literal::AbstractInt(*e))).collect(),
455+
LiteralVector::AbstractFloat(v) => {
456+
v.iter().map(|e| (Literal::AbstractFloat(*e))).collect()
457+
}
458+
}
459+
}
460+
461+
fn to_expr(&self, eval: &mut ConstantEvaluator<'_>) -> Expression {
462+
let lit_vec = self.to_literal_vec();
463+
assert!(!lit_vec.is_empty());
464+
if lit_vec.len() == 1 {
465+
Expression::Literal(lit_vec[0])
466+
} else {
467+
Expression::Compose {
468+
ty: eval.types.insert(
469+
Type {
470+
name: None,
471+
inner: TypeInner::Vector {
472+
size: match lit_vec.len() {
473+
2 => crate::VectorSize::Bi,
474+
3 => crate::VectorSize::Tri,
475+
4 => crate::VectorSize::Quad,
476+
_ => unreachable!(),
477+
},
478+
scalar: lit_vec[0].scalar(),
479+
},
480+
},
481+
Span::UNDEFINED,
482+
),
483+
components: lit_vec
484+
.iter()
485+
.map(|&l| {
486+
eval.expressions
487+
.append(Expression::Literal(l), Span::UNDEFINED)
488+
})
489+
.collect(),
490+
}
491+
}
492+
}
493+
494+
/// Puts self into eval's expressions arena and returns handle to it
495+
fn handle(
496+
&self,
497+
eval: &mut ConstantEvaluator<'_>,
498+
span: Span,
499+
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
500+
let expr = self.to_expr(eval);
501+
eval.register_evaluated_expr(expr, span)
502+
}
503+
}
504+
257505
#[derive(Debug)]
258506
enum Behavior<'a> {
259507
Wgsl(WgslRestrictions<'a>),
@@ -917,9 +1165,10 @@ impl<'a> ConstantEvaluator<'a> {
9171165
Expression::Select { .. } => Err(ConstantEvaluatorError::NotImplemented(
9181166
"select built-in function".into(),
9191167
)),
920-
Expression::Relational { fun, .. } => Err(ConstantEvaluatorError::NotImplemented(
921-
format!("{fun:?} built-in function"),
922-
)),
1168+
Expression::Relational { fun, argument } => {
1169+
let arg = self.check_and_get(argument)?;
1170+
self.relational_op(fun, arg, span)
1171+
}
9231172
Expression::ArrayLength(expr) => match self.behavior {
9241173
Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
9251174
Behavior::Glsl(_) => {
@@ -1230,6 +1479,90 @@ impl<'a> ConstantEvaluator<'a> {
12301479
})
12311480
}
12321481

1482+
// geometry
1483+
crate::MathFunction::Dot => {
1484+
let e1 = LiteralVector::from_expr(arg, self, span, false)?;
1485+
let e2 = LiteralVector::from_expr(arg1.unwrap(), self, span, false)?;
1486+
if e1.len() != e2.len() {
1487+
return Err(ConstantEvaluatorError::InvalidMathArg);
1488+
}
1489+
LiteralVector::from_literal(match (e1, e2) {
1490+
(LiteralVector::AbstractFloat(e1), LiteralVector::AbstractFloat(e2)) => {
1491+
Literal::AbstractFloat(
1492+
e1.iter().zip(e2.iter()).map(|(e1, e2)| e1 * e2).sum(),
1493+
)
1494+
}
1495+
(LiteralVector::F32(e1), LiteralVector::F32(e2)) => {
1496+
Literal::F32(e1.iter().zip(e2.iter()).map(|(e1, e2)| e1 * e2).sum())
1497+
}
1498+
(LiteralVector::AbstractInt(e1), LiteralVector::AbstractInt(e2)) => {
1499+
Literal::AbstractInt(
1500+
e1.iter()
1501+
.zip(e2.iter())
1502+
.map(|(&e1, &e2)| e1.checked_mul(e2))
1503+
.try_fold(0_i64, |acc, x| {
1504+
if let Some(x) = x {
1505+
acc.checked_add(x)
1506+
} else {
1507+
None
1508+
}
1509+
})
1510+
.ok_or(ConstantEvaluatorError::Overflow(
1511+
"in dot built-in".to_string(),
1512+
))?,
1513+
)
1514+
}
1515+
// TODO: more
1516+
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
1517+
})
1518+
.handle(self, span)
1519+
}
1520+
crate::MathFunction::Cross => {
1521+
let e1 = LiteralVector::from_expr(arg, self, span, false)?;
1522+
let e2 = LiteralVector::from_expr(arg1.unwrap(), self, span, false)?;
1523+
if e1.len() == 3 && e2.len() == 3 {
1524+
match (e1, e2) {
1525+
(LiteralVector::AbstractFloat(a), LiteralVector::AbstractFloat(b)) => {
1526+
LiteralVector::AbstractFloat(
1527+
[
1528+
a[1] * b[2] - a[2] * b[1],
1529+
a[2] * b[0] - a[0] * b[2],
1530+
a[0] * b[1] - a[1] * b[0],
1531+
]
1532+
.into_iter()
1533+
.collect(),
1534+
)
1535+
}
1536+
(LiteralVector::AbstractInt(a), LiteralVector::AbstractInt(b)) => {
1537+
LiteralVector::AbstractInt(
1538+
[
1539+
a[1].checked_mul(b[2])
1540+
.zip(a[2].checked_mul(b[1]))
1541+
.and_then(|(a, b)| a.checked_sub(b)),
1542+
a[2].checked_mul(b[0])
1543+
.zip(a[0].checked_mul(b[2]))
1544+
.and_then(|(a, b)| a.checked_sub(b)),
1545+
a[0].checked_mul(b[1])
1546+
.zip(a[1].checked_mul(b[0]))
1547+
.and_then(|(a, b)| a.checked_sub(b)),
1548+
]
1549+
.into_iter()
1550+
.collect::<Option<_>>()
1551+
.ok_or(
1552+
ConstantEvaluatorError::Overflow(
1553+
"in cross built-in".to_string(),
1554+
),
1555+
)?,
1556+
)
1557+
}
1558+
// TODO: more
1559+
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
1560+
}
1561+
.handle(self, span)
1562+
} else {
1563+
Err(ConstantEvaluatorError::InvalidMathArg)
1564+
}
1565+
}
12331566
// computational
12341567
crate::MathFunction::Sign => {
12351568
component_wise_signed!(self, span, [arg], |e| { Ok([e.signum()]) })
@@ -2059,6 +2392,38 @@ impl<'a> ConstantEvaluator<'a> {
20592392
Ok(Expression::Compose { ty, components })
20602393
}
20612394

2395+
fn relational_op(
2396+
&mut self,
2397+
fun: crate::RelationalFunction,
2398+
arg: Handle<Expression>,
2399+
span: Span,
2400+
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2401+
let arg = LiteralVector::from_expr(arg, self, span, true)?;
2402+
let res = LiteralVector::Bool(match fun {
2403+
crate::RelationalFunction::IsNan => match arg {
2404+
LiteralVector::F64(f) => f.iter().map(|e| e.is_nan()).collect(),
2405+
LiteralVector::F32(f) => f.iter().map(|e| e.is_nan()).collect(),
2406+
LiteralVector::AbstractFloat(f) => f.iter().map(|e| e.is_nan()).collect(),
2407+
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
2408+
},
2409+
crate::RelationalFunction::IsInf => match arg {
2410+
LiteralVector::F64(f) => f.iter().map(|e| e.is_infinite()).collect(),
2411+
LiteralVector::F32(f) => f.iter().map(|e| e.is_infinite()).collect(),
2412+
LiteralVector::AbstractFloat(f) => f.iter().map(|e| e.is_infinite()).collect(),
2413+
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
2414+
},
2415+
crate::RelationalFunction::All => match arg {
2416+
LiteralVector::Bool(bools) => iter::once(bools.iter().all(|b| *b)).collect(),
2417+
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
2418+
},
2419+
crate::RelationalFunction::Any => match arg {
2420+
LiteralVector::Bool(bools) => iter::once(bools.iter().any(|b| *b)).collect(),
2421+
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
2422+
},
2423+
});
2424+
res.handle(self, span)
2425+
}
2426+
20622427
/// Deep copy `expr` from `expressions` into `self.expressions`.
20632428
///
20642429
/// Return the root of the new copy.

0 commit comments

Comments
 (0)