Skip to content

Commit 2368eef

Browse files
sagudevSelkoSays
andcommitted
[naga] Add LiteralVector and match_literal_vector!
Co-authored-by: SelkoSays <[email protected]> Signed-off-by: sagudev <[email protected]>
1 parent 04a3401 commit 2368eef

File tree

1 file changed

+381
-0
lines changed

1 file changed

+381
-0
lines changed

naga/src/proc/constant_evaluator.rs

Lines changed: 381 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,387 @@ gen_component_wise_extractor! {
268268
],
269269
}
270270

271+
/// Vectors with a concrete element type.
272+
#[derive(Debug)]
273+
enum LiteralVector {
274+
F64(ArrayVec<f64, { crate::VectorSize::MAX }>),
275+
F32(ArrayVec<f32, { crate::VectorSize::MAX }>),
276+
F16(ArrayVec<f16, { crate::VectorSize::MAX }>),
277+
U32(ArrayVec<u32, { crate::VectorSize::MAX }>),
278+
I32(ArrayVec<i32, { crate::VectorSize::MAX }>),
279+
U64(ArrayVec<u64, { crate::VectorSize::MAX }>),
280+
I64(ArrayVec<i64, { crate::VectorSize::MAX }>),
281+
Bool(ArrayVec<bool, { crate::VectorSize::MAX }>),
282+
AbstractInt(ArrayVec<i64, { crate::VectorSize::MAX }>),
283+
AbstractFloat(ArrayVec<f64, { crate::VectorSize::MAX }>),
284+
}
285+
286+
impl LiteralVector {
287+
#[allow(clippy::missing_const_for_fn, reason = "MSRV")]
288+
fn len(&self) -> usize {
289+
match *self {
290+
LiteralVector::F64(ref v) => v.len(),
291+
LiteralVector::F32(ref v) => v.len(),
292+
LiteralVector::F16(ref v) => v.len(),
293+
LiteralVector::U32(ref v) => v.len(),
294+
LiteralVector::I32(ref v) => v.len(),
295+
LiteralVector::U64(ref v) => v.len(),
296+
LiteralVector::I64(ref v) => v.len(),
297+
LiteralVector::Bool(ref v) => v.len(),
298+
LiteralVector::AbstractInt(ref v) => v.len(),
299+
LiteralVector::AbstractFloat(ref v) => v.len(),
300+
}
301+
}
302+
303+
/// Creates [`LiteralVector`] of size 1 from single [`Literal`]
304+
fn from_literal(literal: Literal) -> Self {
305+
match literal {
306+
Literal::F64(e) => Self::F64(ArrayVec::from_iter(iter::once(e))),
307+
Literal::F32(e) => Self::F32(ArrayVec::from_iter(iter::once(e))),
308+
Literal::U32(e) => Self::U32(ArrayVec::from_iter(iter::once(e))),
309+
Literal::I32(e) => Self::I32(ArrayVec::from_iter(iter::once(e))),
310+
Literal::U64(e) => Self::U64(ArrayVec::from_iter(iter::once(e))),
311+
Literal::I64(e) => Self::I64(ArrayVec::from_iter(iter::once(e))),
312+
Literal::Bool(e) => Self::Bool(ArrayVec::from_iter(iter::once(e))),
313+
Literal::AbstractInt(e) => Self::AbstractInt(ArrayVec::from_iter(iter::once(e))),
314+
Literal::AbstractFloat(e) => Self::AbstractFloat(ArrayVec::from_iter(iter::once(e))),
315+
Literal::F16(e) => Self::F16(ArrayVec::from_iter(iter::once(e))),
316+
}
317+
}
318+
319+
/// Creates [`LiteralVector`] from [`ArrayVec`] of [`Literal`]s.
320+
/// Returns error if components types do not match.
321+
/// # Panics
322+
/// Panics if vector is empty
323+
fn from_literal_vec(
324+
components: ArrayVec<Literal, { crate::VectorSize::MAX }>,
325+
) -> Result<Self, ConstantEvaluatorError> {
326+
assert!(!components.is_empty());
327+
Ok(match components[0] {
328+
Literal::I32(_) => Self::I32(
329+
components
330+
.iter()
331+
.map(|l| match l {
332+
&Literal::I32(v) => Ok(v),
333+
// TODO: should we handle abstract int here?
334+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
335+
})
336+
.collect::<Result<_, _>>()?,
337+
),
338+
Literal::U32(_) => Self::U32(
339+
components
340+
.iter()
341+
.map(|l| match l {
342+
&Literal::U32(v) => Ok(v),
343+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
344+
})
345+
.collect::<Result<_, _>>()?,
346+
),
347+
Literal::I64(_) => Self::I64(
348+
components
349+
.iter()
350+
.map(|l| match l {
351+
&Literal::I64(v) => Ok(v),
352+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
353+
})
354+
.collect::<Result<_, _>>()?,
355+
),
356+
Literal::U64(_) => Self::U64(
357+
components
358+
.iter()
359+
.map(|l| match l {
360+
&Literal::U64(v) => Ok(v),
361+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
362+
})
363+
.collect::<Result<_, _>>()?,
364+
),
365+
Literal::F32(_) => Self::F32(
366+
components
367+
.iter()
368+
.map(|l| match l {
369+
&Literal::F32(v) => Ok(v),
370+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
371+
})
372+
.collect::<Result<_, _>>()?,
373+
),
374+
Literal::F64(_) => Self::F64(
375+
components
376+
.iter()
377+
.map(|l| match l {
378+
&Literal::F64(v) => Ok(v),
379+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
380+
})
381+
.collect::<Result<_, _>>()?,
382+
),
383+
Literal::Bool(_) => Self::Bool(
384+
components
385+
.iter()
386+
.map(|l| match l {
387+
&Literal::Bool(v) => Ok(v),
388+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
389+
})
390+
.collect::<Result<_, _>>()?,
391+
),
392+
Literal::AbstractInt(_) => Self::AbstractInt(
393+
components
394+
.iter()
395+
.map(|l| match l {
396+
&Literal::AbstractInt(v) => Ok(v),
397+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
398+
})
399+
.collect::<Result<_, _>>()?,
400+
),
401+
Literal::AbstractFloat(_) => Self::AbstractFloat(
402+
components
403+
.iter()
404+
.map(|l| match l {
405+
&Literal::AbstractFloat(v) => Ok(v),
406+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
407+
})
408+
.collect::<Result<_, _>>()?,
409+
),
410+
Literal::F16(_) => Self::F16(
411+
components
412+
.iter()
413+
.map(|l| match l {
414+
&Literal::F16(v) => Ok(v),
415+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
416+
})
417+
.collect::<Result<_, _>>()?,
418+
),
419+
})
420+
}
421+
422+
#[allow(dead_code)]
423+
/// Returns [`ArrayVec`] of [`Literal`]s
424+
fn to_literal_vec(&self) -> ArrayVec<Literal, { crate::VectorSize::MAX }> {
425+
match *self {
426+
LiteralVector::F64(ref v) => v.iter().map(|e| (Literal::F64(*e))).collect(),
427+
LiteralVector::F32(ref v) => v.iter().map(|e| (Literal::F32(*e))).collect(),
428+
LiteralVector::F16(ref v) => v.iter().map(|e| (Literal::F16(*e))).collect(),
429+
LiteralVector::U32(ref v) => v.iter().map(|e| (Literal::U32(*e))).collect(),
430+
LiteralVector::I32(ref v) => v.iter().map(|e| (Literal::I32(*e))).collect(),
431+
LiteralVector::U64(ref v) => v.iter().map(|e| (Literal::U64(*e))).collect(),
432+
LiteralVector::I64(ref v) => v.iter().map(|e| (Literal::I64(*e))).collect(),
433+
LiteralVector::Bool(ref v) => v.iter().map(|e| (Literal::Bool(*e))).collect(),
434+
LiteralVector::AbstractInt(ref v) => {
435+
v.iter().map(|e| (Literal::AbstractInt(*e))).collect()
436+
}
437+
LiteralVector::AbstractFloat(ref v) => {
438+
v.iter().map(|e| (Literal::AbstractFloat(*e))).collect()
439+
}
440+
}
441+
}
442+
443+
#[allow(dead_code)]
444+
/// Puts self into eval's expressions arena and returns handle to it
445+
fn register_as_evaluated_expr(
446+
&self,
447+
eval: &mut ConstantEvaluator<'_>,
448+
span: Span,
449+
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
450+
let lit_vec = self.to_literal_vec();
451+
assert!(!lit_vec.is_empty());
452+
let expr = if lit_vec.len() == 1 {
453+
Expression::Literal(lit_vec[0])
454+
} else {
455+
Expression::Compose {
456+
ty: eval.types.insert(
457+
Type {
458+
name: None,
459+
inner: TypeInner::Vector {
460+
size: match lit_vec.len() {
461+
2 => crate::VectorSize::Bi,
462+
3 => crate::VectorSize::Tri,
463+
4 => crate::VectorSize::Quad,
464+
_ => unreachable!(),
465+
},
466+
scalar: lit_vec[0].scalar(),
467+
},
468+
},
469+
Span::UNDEFINED,
470+
),
471+
components: lit_vec
472+
.iter()
473+
.map(|&l| eval.register_evaluated_expr(Expression::Literal(l), span))
474+
.collect::<Result<_, _>>()?,
475+
}
476+
};
477+
eval.register_evaluated_expr(expr, span)
478+
}
479+
}
480+
481+
/// ```rust
482+
/// match_literal_vector!(match v => Literal {
483+
/// F16 => |v| {v.sum()},
484+
/// Integer => |v| {v.sum()},
485+
/// U32 => |v| -> I32 {v.sum()}, // optionally override return type
486+
/// })
487+
/// ```
488+
///
489+
/// ```rust
490+
/// match_literal_vector!(match (e1, e2) => LiteralVector {
491+
/// F16 => |e1, e2| {e1+e2},
492+
/// Integer => |e1, e2| {e1+e2},
493+
/// U32 => |e1, e2| -> I32 {e1+e2}, // optionally override return type
494+
/// })
495+
/// ```
496+
///
497+
/// `Float` expands to `F16`, `F32`, `F64` and `AbstractFloat`.
498+
/// `Integer` expands to `I32`, `I64`, `U32`, `U64` and `AbstractInt`.
499+
///
500+
macro_rules! match_literal_vector {
501+
(match $lit_vec:expr => $out:ident {
502+
$(
503+
$ty:ident => |$($var:ident),+| $(-> $ret:ident)? { $body:expr }
504+
),+
505+
$(,)?
506+
}) => {
507+
match_literal_vector!(@inner_start $lit_vec; $out; [$($ty),+]; [$({ $($var),+ ; $($ret)? ; $body }),+])
508+
};
509+
510+
(@inner_start
511+
$lit_vec:expr;
512+
$out:ident;
513+
[$($ty:ident),+];
514+
[$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
515+
) => {
516+
match_literal_vector!(@inner
517+
$lit_vec;
518+
$out;
519+
[$($ty),+];
520+
[] <> [$({ $($var),+ ; $($ret)? ; $body }),+]
521+
)
522+
};
523+
524+
(@inner
525+
$lit_vec:expr;
526+
$out:ident;
527+
[$ty:ident $(, $ty1:ident)*];
528+
[$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
529+
[$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
530+
) => {
531+
match_literal_vector!(@inner
532+
$ty;
533+
$lit_vec;
534+
$out;
535+
[$($ty1),*];
536+
[$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
537+
[$({ $($var),+ ; $($ret)? ; $body }),+]
538+
)
539+
};
540+
(@inner
541+
Integer;
542+
$lit_vec:expr;
543+
$out:ident;
544+
[$($ty:ident),*];
545+
[$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
546+
[
547+
{ $($var:ident),+ ; $($ret:ident)? ; $body:expr }
548+
$(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
549+
]
550+
) => {
551+
match_literal_vector!(@inner
552+
$lit_vec;
553+
$out;
554+
[U32, I32, U64, I64, AbstractInt $(, $ty)*];
555+
[$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
556+
[
557+
{ $($var),+ ; $($ret)? ; $body }, // U32
558+
{ $($var),+ ; $($ret)? ; $body }, // I32
559+
{ $($var),+ ; $($ret)? ; $body }, // U64
560+
{ $($var),+ ; $($ret)? ; $body }, // I64
561+
{ $($var),+ ; $($ret)? ; $body } // AbstractInt
562+
$(,{ $($var1),+ ; $($ret1)? ; $body1 })*
563+
]
564+
)
565+
};
566+
(@inner
567+
Float;
568+
$lit_vec:expr;
569+
$out:ident;
570+
[$($ty:ident),*];
571+
[$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
572+
[
573+
{ $($var:ident),+ ; $($ret:ident)? ; $body:expr }
574+
$(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
575+
]
576+
) => {
577+
match_literal_vector!(@inner
578+
$lit_vec;
579+
$out;
580+
[F16, F32, F64, AbstractFloat $(, $ty)*];
581+
[$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
582+
[
583+
{ $($var),+ ; $($ret)? ; $body }, // F16
584+
{ $($var),+ ; $($ret)? ; $body }, // F32
585+
{ $($var),+ ; $($ret)? ; $body }, // F64
586+
{ $($var),+ ; $($ret)? ; $body } // AbstractFloat
587+
$(,{ $($var1),+ ; $($ret1)? ; $body1 })*
588+
]
589+
)
590+
};
591+
(@inner
592+
$ty:ident;
593+
$lit_vec:expr;
594+
$out:ident;
595+
[$ty1:ident $(,$ty2:ident)*];
596+
[$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <> [
597+
{ $($var:ident),+ ; $($ret:ident)? ; $body:expr }
598+
$(, { $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
599+
]
600+
) => {
601+
match_literal_vector!(@inner
602+
$ty1;
603+
$lit_vec;
604+
$out;
605+
[$($ty2),*];
606+
[
607+
$({$_ty ; $($_var),+ ; $($_ret)? ; $_body},)*
608+
{ $ty; $($var),+ ; $($ret)? ; $body }
609+
] <>
610+
[$({ $($var1),+ ; $($ret1)? ; $body1 }),*]
611+
612+
)
613+
};
614+
(@inner
615+
$ty:ident;
616+
$lit_vec:expr;
617+
$out:ident;
618+
[];
619+
[$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
620+
[{ $($var:ident),+ ; $($ret:ident)? ; $body:expr }]
621+
) => {
622+
match_literal_vector!(@inner_finish
623+
$lit_vec;
624+
$out;
625+
[
626+
$({ $_ty ; $($_var),+ ; $($_ret)? ; $_body },)*
627+
{ $ty; $($var),+ ; $($ret)? ; $body }
628+
]
629+
)
630+
};
631+
(@inner_finish
632+
$lit_vec:expr;
633+
$out:ident;
634+
[$({$ty:ident ; $($var:ident),+ ; $($ret:ident)? ; $body:expr}),+]
635+
) => {
636+
match $lit_vec {
637+
$(
638+
#[allow(unused_parens)]
639+
($(LiteralVector::$ty(ref $var)),+) => { Ok(match_literal_vector!(@expand_ret $out; $ty $(; $ret)? ; $body)) }
640+
)+
641+
_ => Err(ConstantEvaluatorError::InvalidMathArg),
642+
}
643+
};
644+
(@expand_ret $out:ident; $ty:ident; $body:expr) => {
645+
$out::$ty($body)
646+
};
647+
(@expand_ret $out:ident; $_ty:ident; $ret:ident; $body:expr) => {
648+
$out::$ret($body)
649+
};
650+
}
651+
271652
#[derive(Debug)]
272653
enum Behavior<'a> {
273654
Wgsl(WgslRestrictions<'a>),

0 commit comments

Comments
 (0)