Skip to content

Commit 627f1f2

Browse files
committed
use num-traits
Signed-off-by: sagudev <[email protected]>
1 parent 4f055ab commit 627f1f2

File tree

3 files changed

+62
-51
lines changed

3 files changed

+62
-51
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

naga/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ petgraph = { version = "0.6", optional = true }
8282
pp-rs = { version = "0.2.1", optional = true }
8383
hexf-parse = { version = "0.2.1", optional = true }
8484
unicode-xid = { version = "0.2.5", optional = true }
85+
num-traits = "0.2"
8586

8687
[build-dependencies]
8788
cfg_aliases.workspace = true

naga/src/proc/constant_evaluator.rs

Lines changed: 60 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ impl LiteralVector {
299299
}
300300
}
301301

302+
#[allow(dead_code)]
302303
/// Creates [`LiteralVector`] from Array of [`Literal`]s
303304
///
304305
/// Panics if vector is empty
@@ -1486,33 +1487,53 @@ impl<'a> ConstantEvaluator<'a> {
14861487
if e1.len() != e2.len() {
14871488
return Err(ConstantEvaluatorError::InvalidMathArg);
14881489
}
1490+
1491+
fn float_dot<F, const CAP: usize>(a: ArrayVec<F, CAP>, b: ArrayVec<F, CAP>) -> F
1492+
where
1493+
F: std::ops::Mul<F>,
1494+
F: num_traits::Float + std::iter::Sum,
1495+
{
1496+
a.iter().zip(b.iter()).map(|(&aa, &bb)| aa * bb).sum()
1497+
}
1498+
1499+
fn int_dot<P, const CAP: usize>(
1500+
a: ArrayVec<P, CAP>,
1501+
b: ArrayVec<P, CAP>,
1502+
) -> Result<P, ConstantEvaluatorError>
1503+
where
1504+
P: num_traits::PrimInt + num_traits::CheckedAdd + num_traits::CheckedMul,
1505+
{
1506+
a.iter()
1507+
.zip(b.iter())
1508+
.map(|(&aa, bb)| aa.checked_mul(bb))
1509+
.try_fold(P::zero(), |acc, x| {
1510+
if let Some(x) = x {
1511+
acc.checked_add(&x)
1512+
} else {
1513+
None
1514+
}
1515+
})
1516+
.ok_or(ConstantEvaluatorError::Overflow(
1517+
"in dot built-in".to_string(),
1518+
))
1519+
}
1520+
14891521
LiteralVector::from_literal(match (e1, e2) {
14901522
(LiteralVector::AbstractFloat(e1), LiteralVector::AbstractFloat(e2)) => {
1491-
Literal::AbstractFloat(
1492-
e1.iter().zip(e2.iter()).map(|(e1, e2)| e1 * e2).sum(),
1493-
)
1523+
Literal::AbstractFloat(float_dot(e1, e2))
14941524
}
14951525
(LiteralVector::F32(e1), LiteralVector::F32(e2)) => {
1496-
Literal::F32(e1.iter().zip(e2.iter()).map(|(e1, e2)| e1 * e2).sum())
1526+
Literal::F32(float_dot(e1, e2))
14971527
}
14981528
(LiteralVector::AbstractInt(e1), LiteralVector::AbstractInt(e2)) => {
1499-
Literal::AbstractInt(
1500-
e1.iter()
1501-
.zip(e2.iter())
1502-
.map(|(&e1, &e2)| e1.checked_mul(e2))
1503-
.try_fold(0_i64, |acc, x| {
1504-
if let Some(x) = x {
1505-
acc.checked_add(x)
1506-
} else {
1507-
None
1508-
}
1509-
})
1510-
.ok_or(ConstantEvaluatorError::Overflow(
1511-
"in dot built-in".to_string(),
1512-
))?,
1513-
)
1529+
Literal::AbstractInt(int_dot(e1, e2)?)
1530+
}
1531+
(LiteralVector::I32(e1), LiteralVector::I32(e2)) => {
1532+
Literal::I32(int_dot(e1, e2)?)
1533+
}
1534+
(LiteralVector::U32(e1), LiteralVector::U32(e2)) => {
1535+
Literal::U32(int_dot(e1, e2)?)
15141536
}
1515-
// TODO: more
15161537
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
15171538
})
15181539
.handle(self, span)
@@ -1521,41 +1542,29 @@ impl<'a> ConstantEvaluator<'a> {
15211542
let e1 = LiteralVector::from_expr(arg, self, span, false)?;
15221543
let e2 = LiteralVector::from_expr(arg1.unwrap(), self, span, false)?;
15231544
if e1.len() == 3 && e2.len() == 3 {
1545+
fn float_cross<F, const CAP: usize>(
1546+
a: ArrayVec<F, CAP>,
1547+
b: ArrayVec<F, CAP>,
1548+
) -> ArrayVec<F, CAP>
1549+
where
1550+
F: std::ops::Mul<F>,
1551+
F: num_traits::Float + std::iter::Sum,
1552+
{
1553+
[
1554+
a[1] * b[2] - a[2] * b[1],
1555+
a[2] * b[0] - a[0] * b[2],
1556+
a[0] * b[1] - a[1] * b[0],
1557+
]
1558+
.into_iter()
1559+
.collect()
1560+
}
15241561
match (e1, e2) {
15251562
(LiteralVector::AbstractFloat(a), LiteralVector::AbstractFloat(b)) => {
1526-
LiteralVector::AbstractFloat(
1527-
[
1528-
a[1] * b[2] - a[2] * b[1],
1529-
a[2] * b[0] - a[0] * b[2],
1530-
a[0] * b[1] - a[1] * b[0],
1531-
]
1532-
.into_iter()
1533-
.collect(),
1534-
)
1563+
LiteralVector::AbstractFloat(float_cross(a, b))
15351564
}
1536-
(LiteralVector::AbstractInt(a), LiteralVector::AbstractInt(b)) => {
1537-
LiteralVector::AbstractInt(
1538-
[
1539-
a[1].checked_mul(b[2])
1540-
.zip(a[2].checked_mul(b[1]))
1541-
.and_then(|(a, b)| a.checked_sub(b)),
1542-
a[2].checked_mul(b[0])
1543-
.zip(a[0].checked_mul(b[2]))
1544-
.and_then(|(a, b)| a.checked_sub(b)),
1545-
a[0].checked_mul(b[1])
1546-
.zip(a[1].checked_mul(b[0]))
1547-
.and_then(|(a, b)| a.checked_sub(b)),
1548-
]
1549-
.into_iter()
1550-
.collect::<Option<_>>()
1551-
.ok_or(
1552-
ConstantEvaluatorError::Overflow(
1553-
"in cross built-in".to_string(),
1554-
),
1555-
)?,
1556-
)
1565+
(LiteralVector::F32(a), LiteralVector::F32(b)) => {
1566+
LiteralVector::F32(float_cross(a, b))
15571567
}
1558-
// TODO: more
15591568
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
15601569
}
15611570
.handle(self, span)

0 commit comments

Comments
 (0)