Skip to content

Commit 4624182

Browse files
sagudevSelkoSays
andcommitted
[naga] Add LiteralVector and match_literal_vector!
Co-authored-by: SelkoSays <[email protected]> Signed-off-by: sagudev <[email protected]>
1 parent 04a3401 commit 4624182

File tree

1 file changed

+386
-0
lines changed

1 file changed

+386
-0
lines changed

naga/src/proc/constant_evaluator.rs

Lines changed: 386 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,392 @@ gen_component_wise_extractor! {
268268
],
269269
}
270270

271+
/// Vectors with a concrete element type.
272+
#[derive(Debug)]
273+
enum LiteralVector {
274+
F64(ArrayVec<f64, { crate::VectorSize::MAX }>),
275+
F32(ArrayVec<f32, { crate::VectorSize::MAX }>),
276+
F16(ArrayVec<f16, { crate::VectorSize::MAX }>),
277+
U32(ArrayVec<u32, { crate::VectorSize::MAX }>),
278+
I32(ArrayVec<i32, { crate::VectorSize::MAX }>),
279+
U64(ArrayVec<u64, { crate::VectorSize::MAX }>),
280+
I64(ArrayVec<i64, { crate::VectorSize::MAX }>),
281+
Bool(ArrayVec<bool, { crate::VectorSize::MAX }>),
282+
AbstractInt(ArrayVec<i64, { crate::VectorSize::MAX }>),
283+
AbstractFloat(ArrayVec<f64, { crate::VectorSize::MAX }>),
284+
}
285+
286+
impl LiteralVector {
287+
#[allow(clippy::missing_const_for_fn, reason = "MSRV")]
288+
fn len(&self) -> usize {
289+
match *self {
290+
LiteralVector::F64(ref v) => v.len(),
291+
LiteralVector::F32(ref v) => v.len(),
292+
LiteralVector::F16(ref v) => v.len(),
293+
LiteralVector::U32(ref v) => v.len(),
294+
LiteralVector::I32(ref v) => v.len(),
295+
LiteralVector::U64(ref v) => v.len(),
296+
LiteralVector::I64(ref v) => v.len(),
297+
LiteralVector::Bool(ref v) => v.len(),
298+
LiteralVector::AbstractInt(ref v) => v.len(),
299+
LiteralVector::AbstractFloat(ref v) => v.len(),
300+
}
301+
}
302+
303+
/// Creates [`LiteralVector`] of size 1 from single [`Literal`]
304+
fn from_literal(literal: Literal) -> Self {
305+
match literal {
306+
Literal::F64(e) => Self::F64(ArrayVec::from_iter(iter::once(e))),
307+
Literal::F32(e) => Self::F32(ArrayVec::from_iter(iter::once(e))),
308+
Literal::U32(e) => Self::U32(ArrayVec::from_iter(iter::once(e))),
309+
Literal::I32(e) => Self::I32(ArrayVec::from_iter(iter::once(e))),
310+
Literal::U64(e) => Self::U64(ArrayVec::from_iter(iter::once(e))),
311+
Literal::I64(e) => Self::I64(ArrayVec::from_iter(iter::once(e))),
312+
Literal::Bool(e) => Self::Bool(ArrayVec::from_iter(iter::once(e))),
313+
Literal::AbstractInt(e) => Self::AbstractInt(ArrayVec::from_iter(iter::once(e))),
314+
Literal::AbstractFloat(e) => Self::AbstractFloat(ArrayVec::from_iter(iter::once(e))),
315+
Literal::F16(e) => Self::F16(ArrayVec::from_iter(iter::once(e))),
316+
}
317+
}
318+
319+
/// Creates [`LiteralVector`] from [`ArrayVec`] of [`Literal`]s.
320+
/// Returns error if components types do not match.
321+
/// # Panics
322+
/// Panics if vector is empty
323+
fn from_literal_vec(
324+
components: ArrayVec<Literal, { crate::VectorSize::MAX }>,
325+
) -> Result<Self, ConstantEvaluatorError> {
326+
assert!(!components.is_empty());
327+
Ok(match components[0] {
328+
Literal::I32(_) => Self::I32(
329+
components
330+
.iter()
331+
.map(|l| match l {
332+
&Literal::I32(v) => Ok(v),
333+
// TODO: should we handle abstract int here?
334+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
335+
})
336+
.collect::<Result<_, _>>()?,
337+
),
338+
Literal::U32(_) => Self::U32(
339+
components
340+
.iter()
341+
.map(|l| match l {
342+
&Literal::U32(v) => Ok(v),
343+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
344+
})
345+
.collect::<Result<_, _>>()?,
346+
),
347+
Literal::I64(_) => Self::I64(
348+
components
349+
.iter()
350+
.map(|l| match l {
351+
&Literal::I64(v) => Ok(v),
352+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
353+
})
354+
.collect::<Result<_, _>>()?,
355+
),
356+
Literal::U64(_) => Self::U64(
357+
components
358+
.iter()
359+
.map(|l| match l {
360+
&Literal::U64(v) => Ok(v),
361+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
362+
})
363+
.collect::<Result<_, _>>()?,
364+
),
365+
Literal::F32(_) => Self::F32(
366+
components
367+
.iter()
368+
.map(|l| match l {
369+
&Literal::F32(v) => Ok(v),
370+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
371+
})
372+
.collect::<Result<_, _>>()?,
373+
),
374+
Literal::F64(_) => Self::F64(
375+
components
376+
.iter()
377+
.map(|l| match l {
378+
&Literal::F64(v) => Ok(v),
379+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
380+
})
381+
.collect::<Result<_, _>>()?,
382+
),
383+
Literal::Bool(_) => Self::Bool(
384+
components
385+
.iter()
386+
.map(|l| match l {
387+
&Literal::Bool(v) => Ok(v),
388+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
389+
})
390+
.collect::<Result<_, _>>()?,
391+
),
392+
Literal::AbstractInt(_) => Self::AbstractInt(
393+
components
394+
.iter()
395+
.map(|l| match l {
396+
&Literal::AbstractInt(v) => Ok(v),
397+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
398+
})
399+
.collect::<Result<_, _>>()?,
400+
),
401+
Literal::AbstractFloat(_) => Self::AbstractFloat(
402+
components
403+
.iter()
404+
.map(|l| match l {
405+
&Literal::AbstractFloat(v) => Ok(v),
406+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
407+
})
408+
.collect::<Result<_, _>>()?,
409+
),
410+
Literal::F16(_) => Self::F16(
411+
components
412+
.iter()
413+
.map(|l| match l {
414+
&Literal::F16(v) => Ok(v),
415+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
416+
})
417+
.collect::<Result<_, _>>()?,
418+
),
419+
})
420+
}
421+
422+
#[allow(dead_code)]
423+
/// Returns [`ArrayVec`] of [`Literal`]s
424+
fn to_literal_vec(&self) -> ArrayVec<Literal, { crate::VectorSize::MAX }> {
425+
match *self {
426+
LiteralVector::F64(ref v) => v.iter().map(|e| (Literal::F64(*e))).collect(),
427+
LiteralVector::F32(ref v) => v.iter().map(|e| (Literal::F32(*e))).collect(),
428+
LiteralVector::F16(ref v) => v.iter().map(|e| (Literal::F16(*e))).collect(),
429+
LiteralVector::U32(ref v) => v.iter().map(|e| (Literal::U32(*e))).collect(),
430+
LiteralVector::I32(ref v) => v.iter().map(|e| (Literal::I32(*e))).collect(),
431+
LiteralVector::U64(ref v) => v.iter().map(|e| (Literal::U64(*e))).collect(),
432+
LiteralVector::I64(ref v) => v.iter().map(|e| (Literal::I64(*e))).collect(),
433+
LiteralVector::Bool(ref v) => v.iter().map(|e| (Literal::Bool(*e))).collect(),
434+
LiteralVector::AbstractInt(ref v) => {
435+
v.iter().map(|e| (Literal::AbstractInt(*e))).collect()
436+
}
437+
LiteralVector::AbstractFloat(ref v) => {
438+
v.iter().map(|e| (Literal::AbstractFloat(*e))).collect()
439+
}
440+
}
441+
}
442+
443+
#[allow(dead_code)]
444+
/// Puts self into eval's expressions arena and returns handle to it
445+
fn register_as_evaluated_expr(
446+
&self,
447+
eval: &mut ConstantEvaluator<'_>,
448+
span: Span,
449+
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
450+
let lit_vec = self.to_literal_vec();
451+
assert!(!lit_vec.is_empty());
452+
let expr = if lit_vec.len() == 1 {
453+
Expression::Literal(lit_vec[0])
454+
} else {
455+
Expression::Compose {
456+
ty: eval.types.insert(
457+
Type {
458+
name: None,
459+
inner: TypeInner::Vector {
460+
size: match lit_vec.len() {
461+
2 => crate::VectorSize::Bi,
462+
3 => crate::VectorSize::Tri,
463+
4 => crate::VectorSize::Quad,
464+
_ => unreachable!(),
465+
},
466+
scalar: lit_vec[0].scalar(),
467+
},
468+
},
469+
Span::UNDEFINED,
470+
),
471+
components: lit_vec
472+
.iter()
473+
.map(|&l| eval.register_evaluated_expr(Expression::Literal(l), span))
474+
.collect::<Result<_, _>>()?,
475+
}
476+
};
477+
eval.register_evaluated_expr(expr, span)
478+
}
479+
}
480+
481+
/// A macro for matching on [`LiteralVector`] variants.
482+
///
483+
/// `Float` variant expands to `F16`, `F32`, `F64` and `AbstractFloat`.
484+
/// `Integer` variant expands to `I32`, `I64`, `U32`, `U64` and `AbstractInt`.
485+
///
486+
/// For output both [`Literal`] (fold) and [`LiteralVector`] (map) are supported.
487+
///
488+
/// Example usage:
489+
///
490+
/// ```rust,ignore
491+
/// match_literal_vector!(match v => Literal {
492+
/// F16 => |v| {v.sum()},
493+
/// Integer => |v| {v.sum()},
494+
/// U32 => |v| -> I32 {v.sum()}, // optionally override return type
495+
/// })
496+
/// ```
497+
///
498+
/// ```rust,ignore
499+
/// match_literal_vector!(match (e1, e2) => LiteralVector {
500+
/// F16 => |e1, e2| {e1+e2},
501+
/// Integer => |e1, e2| {e1+e2},
502+
/// U32 => |e1, e2| -> I32 {e1+e2}, // optionally override return type
503+
/// })
504+
/// ```
505+
macro_rules! match_literal_vector {
506+
(match $lit_vec:expr => $out:ident {
507+
$(
508+
$ty:ident => |$($var:ident),+| $(-> $ret:ident)? { $body:expr }
509+
),+
510+
$(,)?
511+
}) => {
512+
match_literal_vector!(@inner_start $lit_vec; $out; [$($ty),+]; [$({ $($var),+ ; $($ret)? ; $body }),+])
513+
};
514+
515+
(@inner_start
516+
$lit_vec:expr;
517+
$out:ident;
518+
[$($ty:ident),+];
519+
[$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
520+
) => {
521+
match_literal_vector!(@inner
522+
$lit_vec;
523+
$out;
524+
[$($ty),+];
525+
[] <> [$({ $($var),+ ; $($ret)? ; $body }),+]
526+
)
527+
};
528+
529+
(@inner
530+
$lit_vec:expr;
531+
$out:ident;
532+
[$ty:ident $(, $ty1:ident)*];
533+
[$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
534+
[$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
535+
) => {
536+
match_literal_vector!(@inner
537+
$ty;
538+
$lit_vec;
539+
$out;
540+
[$($ty1),*];
541+
[$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
542+
[$({ $($var),+ ; $($ret)? ; $body }),+]
543+
)
544+
};
545+
(@inner
546+
Integer;
547+
$lit_vec:expr;
548+
$out:ident;
549+
[$($ty:ident),*];
550+
[$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
551+
[
552+
{ $($var:ident),+ ; $($ret:ident)? ; $body:expr }
553+
$(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
554+
]
555+
) => {
556+
match_literal_vector!(@inner
557+
$lit_vec;
558+
$out;
559+
[U32, I32, U64, I64, AbstractInt $(, $ty)*];
560+
[$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
561+
[
562+
{ $($var),+ ; $($ret)? ; $body }, // U32
563+
{ $($var),+ ; $($ret)? ; $body }, // I32
564+
{ $($var),+ ; $($ret)? ; $body }, // U64
565+
{ $($var),+ ; $($ret)? ; $body }, // I64
566+
{ $($var),+ ; $($ret)? ; $body } // AbstractInt
567+
$(,{ $($var1),+ ; $($ret1)? ; $body1 })*
568+
]
569+
)
570+
};
571+
(@inner
572+
Float;
573+
$lit_vec:expr;
574+
$out:ident;
575+
[$($ty:ident),*];
576+
[$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
577+
[
578+
{ $($var:ident),+ ; $($ret:ident)? ; $body:expr }
579+
$(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
580+
]
581+
) => {
582+
match_literal_vector!(@inner
583+
$lit_vec;
584+
$out;
585+
[F16, F32, F64, AbstractFloat $(, $ty)*];
586+
[$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
587+
[
588+
{ $($var),+ ; $($ret)? ; $body }, // F16
589+
{ $($var),+ ; $($ret)? ; $body }, // F32
590+
{ $($var),+ ; $($ret)? ; $body }, // F64
591+
{ $($var),+ ; $($ret)? ; $body } // AbstractFloat
592+
$(,{ $($var1),+ ; $($ret1)? ; $body1 })*
593+
]
594+
)
595+
};
596+
(@inner
597+
$ty:ident;
598+
$lit_vec:expr;
599+
$out:ident;
600+
[$ty1:ident $(,$ty2:ident)*];
601+
[$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <> [
602+
{ $($var:ident),+ ; $($ret:ident)? ; $body:expr }
603+
$(, { $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
604+
]
605+
) => {
606+
match_literal_vector!(@inner
607+
$ty1;
608+
$lit_vec;
609+
$out;
610+
[$($ty2),*];
611+
[
612+
$({$_ty ; $($_var),+ ; $($_ret)? ; $_body},)*
613+
{ $ty; $($var),+ ; $($ret)? ; $body }
614+
] <>
615+
[$({ $($var1),+ ; $($ret1)? ; $body1 }),*]
616+
617+
)
618+
};
619+
(@inner
620+
$ty:ident;
621+
$lit_vec:expr;
622+
$out:ident;
623+
[];
624+
[$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
625+
[{ $($var:ident),+ ; $($ret:ident)? ; $body:expr }]
626+
) => {
627+
match_literal_vector!(@inner_finish
628+
$lit_vec;
629+
$out;
630+
[
631+
$({ $_ty ; $($_var),+ ; $($_ret)? ; $_body },)*
632+
{ $ty; $($var),+ ; $($ret)? ; $body }
633+
]
634+
)
635+
};
636+
(@inner_finish
637+
$lit_vec:expr;
638+
$out:ident;
639+
[$({$ty:ident ; $($var:ident),+ ; $($ret:ident)? ; $body:expr}),+]
640+
) => {
641+
match $lit_vec {
642+
$(
643+
#[allow(unused_parens)]
644+
($(LiteralVector::$ty(ref $var)),+) => { Ok(match_literal_vector!(@expand_ret $out; $ty $(; $ret)? ; $body)) }
645+
)+
646+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
647+
}
648+
};
649+
(@expand_ret $out:ident; $ty:ident; $body:expr) => {
650+
$out::$ty($body)
651+
};
652+
(@expand_ret $out:ident; $_ty:ident; $ret:ident; $body:expr) => {
653+
$out::$ret($body)
654+
};
655+
}
656+
271657
#[derive(Debug)]
272658
enum Behavior<'a> {
273659
Wgsl(WgslRestrictions<'a>),

0 commit comments

Comments
 (0)