@@ -268,6 +268,392 @@ gen_component_wise_extractor! {
268
268
] ,
269
269
}
270
270
271
+ /// Vectors with a concrete element type.
272
+ #[ derive( Debug ) ]
273
+ enum LiteralVector {
274
+ F64 ( ArrayVec < f64 , { crate :: VectorSize :: MAX } > ) ,
275
+ F32 ( ArrayVec < f32 , { crate :: VectorSize :: MAX } > ) ,
276
+ F16 ( ArrayVec < f16 , { crate :: VectorSize :: MAX } > ) ,
277
+ U32 ( ArrayVec < u32 , { crate :: VectorSize :: MAX } > ) ,
278
+ I32 ( ArrayVec < i32 , { crate :: VectorSize :: MAX } > ) ,
279
+ U64 ( ArrayVec < u64 , { crate :: VectorSize :: MAX } > ) ,
280
+ I64 ( ArrayVec < i64 , { crate :: VectorSize :: MAX } > ) ,
281
+ Bool ( ArrayVec < bool , { crate :: VectorSize :: MAX } > ) ,
282
+ AbstractInt ( ArrayVec < i64 , { crate :: VectorSize :: MAX } > ) ,
283
+ AbstractFloat ( ArrayVec < f64 , { crate :: VectorSize :: MAX } > ) ,
284
+ }
285
+
286
+ impl LiteralVector {
287
+ #[ allow( clippy:: missing_const_for_fn, reason = "MSRV" ) ]
288
+ fn len ( & self ) -> usize {
289
+ match * self {
290
+ LiteralVector :: F64 ( ref v) => v. len ( ) ,
291
+ LiteralVector :: F32 ( ref v) => v. len ( ) ,
292
+ LiteralVector :: F16 ( ref v) => v. len ( ) ,
293
+ LiteralVector :: U32 ( ref v) => v. len ( ) ,
294
+ LiteralVector :: I32 ( ref v) => v. len ( ) ,
295
+ LiteralVector :: U64 ( ref v) => v. len ( ) ,
296
+ LiteralVector :: I64 ( ref v) => v. len ( ) ,
297
+ LiteralVector :: Bool ( ref v) => v. len ( ) ,
298
+ LiteralVector :: AbstractInt ( ref v) => v. len ( ) ,
299
+ LiteralVector :: AbstractFloat ( ref v) => v. len ( ) ,
300
+ }
301
+ }
302
+
303
+ /// Creates [`LiteralVector`] of size 1 from single [`Literal`]
304
+ fn from_literal ( literal : Literal ) -> Self {
305
+ match literal {
306
+ Literal :: F64 ( e) => Self :: F64 ( ArrayVec :: from_iter ( iter:: once ( e) ) ) ,
307
+ Literal :: F32 ( e) => Self :: F32 ( ArrayVec :: from_iter ( iter:: once ( e) ) ) ,
308
+ Literal :: U32 ( e) => Self :: U32 ( ArrayVec :: from_iter ( iter:: once ( e) ) ) ,
309
+ Literal :: I32 ( e) => Self :: I32 ( ArrayVec :: from_iter ( iter:: once ( e) ) ) ,
310
+ Literal :: U64 ( e) => Self :: U64 ( ArrayVec :: from_iter ( iter:: once ( e) ) ) ,
311
+ Literal :: I64 ( e) => Self :: I64 ( ArrayVec :: from_iter ( iter:: once ( e) ) ) ,
312
+ Literal :: Bool ( e) => Self :: Bool ( ArrayVec :: from_iter ( iter:: once ( e) ) ) ,
313
+ Literal :: AbstractInt ( e) => Self :: AbstractInt ( ArrayVec :: from_iter ( iter:: once ( e) ) ) ,
314
+ Literal :: AbstractFloat ( e) => Self :: AbstractFloat ( ArrayVec :: from_iter ( iter:: once ( e) ) ) ,
315
+ Literal :: F16 ( e) => Self :: F16 ( ArrayVec :: from_iter ( iter:: once ( e) ) ) ,
316
+ }
317
+ }
318
+
319
+ /// Creates [`LiteralVector`] from [`ArrayVec`] of [`Literal`]s.
320
+ /// Returns error if components types do not match.
321
+ /// # Panics
322
+ /// Panics if vector is empty
323
+ fn from_literal_vec (
324
+ components : ArrayVec < Literal , { crate :: VectorSize :: MAX } > ,
325
+ ) -> Result < Self , ConstantEvaluatorError > {
326
+ assert ! ( !components. is_empty( ) ) ;
327
+ Ok ( match components[ 0 ] {
328
+ Literal :: I32 ( _) => Self :: I32 (
329
+ components
330
+ . iter ( )
331
+ . map ( |l| match l {
332
+ & Literal :: I32 ( v) => Ok ( v) ,
333
+ // TODO: should we handle abstract int here?
334
+ _ => Err ( ConstantEvaluatorError :: InvalidMathArg ) ,
335
+ } )
336
+ . collect :: < Result < _ , _ > > ( ) ?,
337
+ ) ,
338
+ Literal :: U32 ( _) => Self :: U32 (
339
+ components
340
+ . iter ( )
341
+ . map ( |l| match l {
342
+ & Literal :: U32 ( v) => Ok ( v) ,
343
+ _ => Err ( ConstantEvaluatorError :: InvalidMathArg ) ,
344
+ } )
345
+ . collect :: < Result < _ , _ > > ( ) ?,
346
+ ) ,
347
+ Literal :: I64 ( _) => Self :: I64 (
348
+ components
349
+ . iter ( )
350
+ . map ( |l| match l {
351
+ & Literal :: I64 ( v) => Ok ( v) ,
352
+ _ => Err ( ConstantEvaluatorError :: InvalidMathArg ) ,
353
+ } )
354
+ . collect :: < Result < _ , _ > > ( ) ?,
355
+ ) ,
356
+ Literal :: U64 ( _) => Self :: U64 (
357
+ components
358
+ . iter ( )
359
+ . map ( |l| match l {
360
+ & Literal :: U64 ( v) => Ok ( v) ,
361
+ _ => Err ( ConstantEvaluatorError :: InvalidMathArg ) ,
362
+ } )
363
+ . collect :: < Result < _ , _ > > ( ) ?,
364
+ ) ,
365
+ Literal :: F32 ( _) => Self :: F32 (
366
+ components
367
+ . iter ( )
368
+ . map ( |l| match l {
369
+ & Literal :: F32 ( v) => Ok ( v) ,
370
+ _ => Err ( ConstantEvaluatorError :: InvalidMathArg ) ,
371
+ } )
372
+ . collect :: < Result < _ , _ > > ( ) ?,
373
+ ) ,
374
+ Literal :: F64 ( _) => Self :: F64 (
375
+ components
376
+ . iter ( )
377
+ . map ( |l| match l {
378
+ & Literal :: F64 ( v) => Ok ( v) ,
379
+ _ => Err ( ConstantEvaluatorError :: InvalidMathArg ) ,
380
+ } )
381
+ . collect :: < Result < _ , _ > > ( ) ?,
382
+ ) ,
383
+ Literal :: Bool ( _) => Self :: Bool (
384
+ components
385
+ . iter ( )
386
+ . map ( |l| match l {
387
+ & Literal :: Bool ( v) => Ok ( v) ,
388
+ _ => Err ( ConstantEvaluatorError :: InvalidMathArg ) ,
389
+ } )
390
+ . collect :: < Result < _ , _ > > ( ) ?,
391
+ ) ,
392
+ Literal :: AbstractInt ( _) => Self :: AbstractInt (
393
+ components
394
+ . iter ( )
395
+ . map ( |l| match l {
396
+ & Literal :: AbstractInt ( v) => Ok ( v) ,
397
+ _ => Err ( ConstantEvaluatorError :: InvalidMathArg ) ,
398
+ } )
399
+ . collect :: < Result < _ , _ > > ( ) ?,
400
+ ) ,
401
+ Literal :: AbstractFloat ( _) => Self :: AbstractFloat (
402
+ components
403
+ . iter ( )
404
+ . map ( |l| match l {
405
+ & Literal :: AbstractFloat ( v) => Ok ( v) ,
406
+ _ => Err ( ConstantEvaluatorError :: InvalidMathArg ) ,
407
+ } )
408
+ . collect :: < Result < _ , _ > > ( ) ?,
409
+ ) ,
410
+ Literal :: F16 ( _) => Self :: F16 (
411
+ components
412
+ . iter ( )
413
+ . map ( |l| match l {
414
+ & Literal :: F16 ( v) => Ok ( v) ,
415
+ _ => Err ( ConstantEvaluatorError :: InvalidMathArg ) ,
416
+ } )
417
+ . collect :: < Result < _ , _ > > ( ) ?,
418
+ ) ,
419
+ } )
420
+ }
421
+
422
+ #[ allow( dead_code) ]
423
+ /// Returns [`ArrayVec`] of [`Literal`]s
424
+ fn to_literal_vec ( & self ) -> ArrayVec < Literal , { crate :: VectorSize :: MAX } > {
425
+ match * self {
426
+ LiteralVector :: F64 ( ref v) => v. iter ( ) . map ( |e| ( Literal :: F64 ( * e) ) ) . collect ( ) ,
427
+ LiteralVector :: F32 ( ref v) => v. iter ( ) . map ( |e| ( Literal :: F32 ( * e) ) ) . collect ( ) ,
428
+ LiteralVector :: F16 ( ref v) => v. iter ( ) . map ( |e| ( Literal :: F16 ( * e) ) ) . collect ( ) ,
429
+ LiteralVector :: U32 ( ref v) => v. iter ( ) . map ( |e| ( Literal :: U32 ( * e) ) ) . collect ( ) ,
430
+ LiteralVector :: I32 ( ref v) => v. iter ( ) . map ( |e| ( Literal :: I32 ( * e) ) ) . collect ( ) ,
431
+ LiteralVector :: U64 ( ref v) => v. iter ( ) . map ( |e| ( Literal :: U64 ( * e) ) ) . collect ( ) ,
432
+ LiteralVector :: I64 ( ref v) => v. iter ( ) . map ( |e| ( Literal :: I64 ( * e) ) ) . collect ( ) ,
433
+ LiteralVector :: Bool ( ref v) => v. iter ( ) . map ( |e| ( Literal :: Bool ( * e) ) ) . collect ( ) ,
434
+ LiteralVector :: AbstractInt ( ref v) => {
435
+ v. iter ( ) . map ( |e| ( Literal :: AbstractInt ( * e) ) ) . collect ( )
436
+ }
437
+ LiteralVector :: AbstractFloat ( ref v) => {
438
+ v. iter ( ) . map ( |e| ( Literal :: AbstractFloat ( * e) ) ) . collect ( )
439
+ }
440
+ }
441
+ }
442
+
443
+ #[ allow( dead_code) ]
444
+ /// Puts self into eval's expressions arena and returns handle to it
445
+ fn register_as_evaluated_expr (
446
+ & self ,
447
+ eval : & mut ConstantEvaluator < ' _ > ,
448
+ span : Span ,
449
+ ) -> Result < Handle < Expression > , ConstantEvaluatorError > {
450
+ let lit_vec = self . to_literal_vec ( ) ;
451
+ assert ! ( !lit_vec. is_empty( ) ) ;
452
+ let expr = if lit_vec. len ( ) == 1 {
453
+ Expression :: Literal ( lit_vec[ 0 ] )
454
+ } else {
455
+ Expression :: Compose {
456
+ ty : eval. types . insert (
457
+ Type {
458
+ name : None ,
459
+ inner : TypeInner :: Vector {
460
+ size : match lit_vec. len ( ) {
461
+ 2 => crate :: VectorSize :: Bi ,
462
+ 3 => crate :: VectorSize :: Tri ,
463
+ 4 => crate :: VectorSize :: Quad ,
464
+ _ => unreachable ! ( ) ,
465
+ } ,
466
+ scalar : lit_vec[ 0 ] . scalar ( ) ,
467
+ } ,
468
+ } ,
469
+ Span :: UNDEFINED ,
470
+ ) ,
471
+ components : lit_vec
472
+ . iter ( )
473
+ . map ( |& l| eval. register_evaluated_expr ( Expression :: Literal ( l) , span) )
474
+ . collect :: < Result < _ , _ > > ( ) ?,
475
+ }
476
+ } ;
477
+ eval. register_evaluated_expr ( expr, span)
478
+ }
479
+ }
480
+
481
+ /// A macro for matching on [`LiteralVector`] variants.
482
+ ///
483
+ /// `Float` variant expands to `F16`, `F32`, `F64` and `AbstractFloat`.
484
+ /// `Integer` variant expands to `I32`, `I64`, `U32`, `U64` and `AbstractInt`.
485
+ ///
486
+ /// For output both [`Literal`] (fold) and [`LiteralVector`] (map) are supported.
487
+ ///
488
+ /// Example usage:
489
+ ///
490
+ /// ```rust,ignore
491
+ /// match_literal_vector!(match v => Literal {
492
+ /// F16 => |v| {v.sum()},
493
+ /// Integer => |v| {v.sum()},
494
+ /// U32 => |v| -> I32 {v.sum()}, // optionally override return type
495
+ /// })
496
+ /// ```
497
+ ///
498
+ /// ```rust,ignore
499
+ /// match_literal_vector!(match (e1, e2) => LiteralVector {
500
+ /// F16 => |e1, e2| {e1+e2},
501
+ /// Integer => |e1, e2| {e1+e2},
502
+ /// U32 => |e1, e2| -> I32 {e1+e2}, // optionally override return type
503
+ /// })
504
+ /// ```
505
+ macro_rules! match_literal_vector {
506
+ ( match $lit_vec: expr => $out: ident {
507
+ $(
508
+ $ty: ident => |$( $var: ident) ,+| $( -> $ret: ident) ? { $body: expr }
509
+ ) ,+
510
+ $( , ) ?
511
+ } ) => {
512
+ match_literal_vector!( @inner_start $lit_vec; $out; [ $( $ty) ,+] ; [ $( { $( $var) ,+ ; $( $ret) ? ; $body } ) ,+] )
513
+ } ;
514
+
515
+ ( @inner_start
516
+ $lit_vec: expr;
517
+ $out: ident;
518
+ [ $( $ty: ident) ,+] ;
519
+ [ $( { $( $var: ident) ,+ ; $( $ret: ident) ? ; $body: expr } ) ,+]
520
+ ) => {
521
+ match_literal_vector!( @inner
522
+ $lit_vec;
523
+ $out;
524
+ [ $( $ty) ,+] ;
525
+ [ ] <> [ $( { $( $var) ,+ ; $( $ret) ? ; $body } ) ,+]
526
+ )
527
+ } ;
528
+
529
+ ( @inner
530
+ $lit_vec: expr;
531
+ $out: ident;
532
+ [ $ty: ident $( , $ty1: ident) * ] ;
533
+ [ $( { $_ty: ident ; $( $_var: ident) ,+ ; $( $_ret: ident) ? ; $_body: expr} ) ,* ] <>
534
+ [ $( { $( $var: ident) ,+ ; $( $ret: ident) ? ; $body: expr } ) ,+]
535
+ ) => {
536
+ match_literal_vector!( @inner
537
+ $ty;
538
+ $lit_vec;
539
+ $out;
540
+ [ $( $ty1) ,* ] ;
541
+ [ $( { $_ty ; $( $_var) ,+ ; $( $_ret) ? ; $_body} ) ,* ] <>
542
+ [ $( { $( $var) ,+ ; $( $ret) ? ; $body } ) ,+]
543
+ )
544
+ } ;
545
+ ( @inner
546
+ Integer ;
547
+ $lit_vec: expr;
548
+ $out: ident;
549
+ [ $( $ty: ident) ,* ] ;
550
+ [ $( { $_ty: ident ; $( $_var: ident) ,+ ; $( $_ret: ident) ? ; $_body: expr} ) ,* ] <>
551
+ [
552
+ { $( $var: ident) ,+ ; $( $ret: ident) ? ; $body: expr }
553
+ $( , { $( $var1: ident) ,+ ; $( $ret1: ident) ? ; $body1: expr } ) *
554
+ ]
555
+ ) => {
556
+ match_literal_vector!( @inner
557
+ $lit_vec;
558
+ $out;
559
+ [ U32 , I32 , U64 , I64 , AbstractInt $( , $ty) * ] ;
560
+ [ $( { $_ty ; $( $_var) ,+ ; $( $_ret) ? ; $_body} ) ,* ] <>
561
+ [
562
+ { $( $var) ,+ ; $( $ret) ? ; $body } , // U32
563
+ { $( $var) ,+ ; $( $ret) ? ; $body } , // I32
564
+ { $( $var) ,+ ; $( $ret) ? ; $body } , // U64
565
+ { $( $var) ,+ ; $( $ret) ? ; $body } , // I64
566
+ { $( $var) ,+ ; $( $ret) ? ; $body } // AbstractInt
567
+ $( , { $( $var1) ,+ ; $( $ret1) ? ; $body1 } ) *
568
+ ]
569
+ )
570
+ } ;
571
+ ( @inner
572
+ Float ;
573
+ $lit_vec: expr;
574
+ $out: ident;
575
+ [ $( $ty: ident) ,* ] ;
576
+ [ $( { $_ty: ident ; $( $_var: ident) ,+ ; $( $_ret: ident) ? ; $_body: expr} ) ,* ] <>
577
+ [
578
+ { $( $var: ident) ,+ ; $( $ret: ident) ? ; $body: expr }
579
+ $( , { $( $var1: ident) ,+ ; $( $ret1: ident) ? ; $body1: expr } ) *
580
+ ]
581
+ ) => {
582
+ match_literal_vector!( @inner
583
+ $lit_vec;
584
+ $out;
585
+ [ F16 , F32 , F64 , AbstractFloat $( , $ty) * ] ;
586
+ [ $( { $_ty ; $( $_var) ,+ ; $( $_ret) ? ; $_body} ) ,* ] <>
587
+ [
588
+ { $( $var) ,+ ; $( $ret) ? ; $body } , // F16
589
+ { $( $var) ,+ ; $( $ret) ? ; $body } , // F32
590
+ { $( $var) ,+ ; $( $ret) ? ; $body } , // F64
591
+ { $( $var) ,+ ; $( $ret) ? ; $body } // AbstractFloat
592
+ $( , { $( $var1) ,+ ; $( $ret1) ? ; $body1 } ) *
593
+ ]
594
+ )
595
+ } ;
596
+ ( @inner
597
+ $ty: ident;
598
+ $lit_vec: expr;
599
+ $out: ident;
600
+ [ $ty1: ident $( , $ty2: ident) * ] ;
601
+ [ $( { $_ty: ident ; $( $_var: ident) ,+ ; $( $_ret: ident) ? ; $_body: expr} ) ,* ] <> [
602
+ { $( $var: ident) ,+ ; $( $ret: ident) ? ; $body: expr }
603
+ $( , { $( $var1: ident) ,+ ; $( $ret1: ident) ? ; $body1: expr } ) *
604
+ ]
605
+ ) => {
606
+ match_literal_vector!( @inner
607
+ $ty1;
608
+ $lit_vec;
609
+ $out;
610
+ [ $( $ty2) ,* ] ;
611
+ [
612
+ $( { $_ty ; $( $_var) ,+ ; $( $_ret) ? ; $_body} , ) *
613
+ { $ty; $( $var) ,+ ; $( $ret) ? ; $body }
614
+ ] <>
615
+ [ $( { $( $var1) ,+ ; $( $ret1) ? ; $body1 } ) ,* ]
616
+
617
+ )
618
+ } ;
619
+ ( @inner
620
+ $ty: ident;
621
+ $lit_vec: expr;
622
+ $out: ident;
623
+ [ ] ;
624
+ [ $( { $_ty: ident ; $( $_var: ident) ,+ ; $( $_ret: ident) ? ; $_body: expr} ) ,* ] <>
625
+ [ { $( $var: ident) ,+ ; $( $ret: ident) ? ; $body: expr } ]
626
+ ) => {
627
+ match_literal_vector!( @inner_finish
628
+ $lit_vec;
629
+ $out;
630
+ [
631
+ $( { $_ty ; $( $_var) ,+ ; $( $_ret) ? ; $_body } , ) *
632
+ { $ty; $( $var) ,+ ; $( $ret) ? ; $body }
633
+ ]
634
+ )
635
+ } ;
636
+ ( @inner_finish
637
+ $lit_vec: expr;
638
+ $out: ident;
639
+ [ $( { $ty: ident ; $( $var: ident) ,+ ; $( $ret: ident) ? ; $body: expr} ) ,+]
640
+ ) => {
641
+ match $lit_vec {
642
+ $(
643
+ #[ allow( unused_parens) ]
644
+ ( $( LiteralVector :: $ty( ref $var) ) ,+) => { Ok ( match_literal_vector!( @expand_ret $out; $ty $( ; $ret) ? ; $body) ) }
645
+ ) +
646
+ _ => Err ( ConstantEvaluatorError :: InvalidMathArg ) ,
647
+ }
648
+ } ;
649
+ ( @expand_ret $out: ident; $ty: ident; $body: expr) => {
650
+ $out:: $ty( $body)
651
+ } ;
652
+ ( @expand_ret $out: ident; $_ty: ident; $ret: ident; $body: expr) => {
653
+ $out:: $ret( $body)
654
+ } ;
655
+ }
656
+
271
657
#[ derive( Debug ) ]
272
658
enum Behavior < ' a > {
273
659
Wgsl ( WgslRestrictions < ' a > ) ,
0 commit comments