@@ -268,6 +268,387 @@ gen_component_wise_extractor! {
268
268
] ,
269
269
}
270
270
271
+ /// Vectors with a concrete element type.
272
+ #[ derive( Debug ) ]
273
+ enum LiteralVector {
274
+ F64 ( ArrayVec < f64 , { crate :: VectorSize :: MAX } > ) ,
275
+ F32 ( ArrayVec < f32 , { crate :: VectorSize :: MAX } > ) ,
276
+ F16 ( ArrayVec < f16 , { crate :: VectorSize :: MAX } > ) ,
277
+ U32 ( ArrayVec < u32 , { crate :: VectorSize :: MAX } > ) ,
278
+ I32 ( ArrayVec < i32 , { crate :: VectorSize :: MAX } > ) ,
279
+ U64 ( ArrayVec < u64 , { crate :: VectorSize :: MAX } > ) ,
280
+ I64 ( ArrayVec < i64 , { crate :: VectorSize :: MAX } > ) ,
281
+ Bool ( ArrayVec < bool , { crate :: VectorSize :: MAX } > ) ,
282
+ AbstractInt ( ArrayVec < i64 , { crate :: VectorSize :: MAX } > ) ,
283
+ AbstractFloat ( ArrayVec < f64 , { crate :: VectorSize :: MAX } > ) ,
284
+ }
285
+
286
+ impl LiteralVector {
287
+ #[ allow( clippy:: missing_const_for_fn, reason = "MSRV" ) ]
288
+ fn len ( & self ) -> usize {
289
+ match * self {
290
+ LiteralVector :: F64 ( ref v) => v. len ( ) ,
291
+ LiteralVector :: F32 ( ref v) => v. len ( ) ,
292
+ LiteralVector :: F16 ( ref v) => v. len ( ) ,
293
+ LiteralVector :: U32 ( ref v) => v. len ( ) ,
294
+ LiteralVector :: I32 ( ref v) => v. len ( ) ,
295
+ LiteralVector :: U64 ( ref v) => v. len ( ) ,
296
+ LiteralVector :: I64 ( ref v) => v. len ( ) ,
297
+ LiteralVector :: Bool ( ref v) => v. len ( ) ,
298
+ LiteralVector :: AbstractInt ( ref v) => v. len ( ) ,
299
+ LiteralVector :: AbstractFloat ( ref v) => v. len ( ) ,
300
+ }
301
+ }
302
+
303
+ /// Creates [`LiteralVector`] of size 1 from single [`Literal`]
304
+ fn from_literal ( literal : Literal ) -> Self {
305
+ match literal {
306
+ Literal :: F64 ( e) => Self :: F64 ( ArrayVec :: from_iter ( iter:: once ( e) ) ) ,
307
+ Literal :: F32 ( e) => Self :: F32 ( ArrayVec :: from_iter ( iter:: once ( e) ) ) ,
308
+ Literal :: U32 ( e) => Self :: U32 ( ArrayVec :: from_iter ( iter:: once ( e) ) ) ,
309
+ Literal :: I32 ( e) => Self :: I32 ( ArrayVec :: from_iter ( iter:: once ( e) ) ) ,
310
+ Literal :: U64 ( e) => Self :: U64 ( ArrayVec :: from_iter ( iter:: once ( e) ) ) ,
311
+ Literal :: I64 ( e) => Self :: I64 ( ArrayVec :: from_iter ( iter:: once ( e) ) ) ,
312
+ Literal :: Bool ( e) => Self :: Bool ( ArrayVec :: from_iter ( iter:: once ( e) ) ) ,
313
+ Literal :: AbstractInt ( e) => Self :: AbstractInt ( ArrayVec :: from_iter ( iter:: once ( e) ) ) ,
314
+ Literal :: AbstractFloat ( e) => Self :: AbstractFloat ( ArrayVec :: from_iter ( iter:: once ( e) ) ) ,
315
+ Literal :: F16 ( e) => Self :: F16 ( ArrayVec :: from_iter ( iter:: once ( e) ) ) ,
316
+ }
317
+ }
318
+
319
+ /// Creates [`LiteralVector`] from [`ArrayVec`] of [`Literal`]s.
320
+ /// Returns error if components types do not match.
321
+ /// # Panics
322
+ /// Panics if vector is empty
323
+ fn from_literal_vec (
324
+ components : ArrayVec < Literal , { crate :: VectorSize :: MAX } > ,
325
+ ) -> Result < Self , ConstantEvaluatorError > {
326
+ assert ! ( !components. is_empty( ) ) ;
327
+ Ok ( match components[ 0 ] {
328
+ Literal :: I32 ( _) => Self :: I32 (
329
+ components
330
+ . iter ( )
331
+ . map ( |l| match l {
332
+ & Literal :: I32 ( v) => Ok ( v) ,
333
+ // TODO: should we handle abstract int here?
334
+ _ => Err ( ConstantEvaluatorError :: InvalidMathArg ) ,
335
+ } )
336
+ . collect :: < Result < _ , _ > > ( ) ?,
337
+ ) ,
338
+ Literal :: U32 ( _) => Self :: U32 (
339
+ components
340
+ . iter ( )
341
+ . map ( |l| match l {
342
+ & Literal :: U32 ( v) => Ok ( v) ,
343
+ _ => Err ( ConstantEvaluatorError :: InvalidMathArg ) ,
344
+ } )
345
+ . collect :: < Result < _ , _ > > ( ) ?,
346
+ ) ,
347
+ Literal :: I64 ( _) => Self :: I64 (
348
+ components
349
+ . iter ( )
350
+ . map ( |l| match l {
351
+ & Literal :: I64 ( v) => Ok ( v) ,
352
+ _ => Err ( ConstantEvaluatorError :: InvalidMathArg ) ,
353
+ } )
354
+ . collect :: < Result < _ , _ > > ( ) ?,
355
+ ) ,
356
+ Literal :: U64 ( _) => Self :: U64 (
357
+ components
358
+ . iter ( )
359
+ . map ( |l| match l {
360
+ & Literal :: U64 ( v) => Ok ( v) ,
361
+ _ => Err ( ConstantEvaluatorError :: InvalidMathArg ) ,
362
+ } )
363
+ . collect :: < Result < _ , _ > > ( ) ?,
364
+ ) ,
365
+ Literal :: F32 ( _) => Self :: F32 (
366
+ components
367
+ . iter ( )
368
+ . map ( |l| match l {
369
+ & Literal :: F32 ( v) => Ok ( v) ,
370
+ _ => Err ( ConstantEvaluatorError :: InvalidMathArg ) ,
371
+ } )
372
+ . collect :: < Result < _ , _ > > ( ) ?,
373
+ ) ,
374
+ Literal :: F64 ( _) => Self :: F64 (
375
+ components
376
+ . iter ( )
377
+ . map ( |l| match l {
378
+ & Literal :: F64 ( v) => Ok ( v) ,
379
+ _ => Err ( ConstantEvaluatorError :: InvalidMathArg ) ,
380
+ } )
381
+ . collect :: < Result < _ , _ > > ( ) ?,
382
+ ) ,
383
+ Literal :: Bool ( _) => Self :: Bool (
384
+ components
385
+ . iter ( )
386
+ . map ( |l| match l {
387
+ & Literal :: Bool ( v) => Ok ( v) ,
388
+ _ => Err ( ConstantEvaluatorError :: InvalidMathArg ) ,
389
+ } )
390
+ . collect :: < Result < _ , _ > > ( ) ?,
391
+ ) ,
392
+ Literal :: AbstractInt ( _) => Self :: AbstractInt (
393
+ components
394
+ . iter ( )
395
+ . map ( |l| match l {
396
+ & Literal :: AbstractInt ( v) => Ok ( v) ,
397
+ _ => Err ( ConstantEvaluatorError :: InvalidMathArg ) ,
398
+ } )
399
+ . collect :: < Result < _ , _ > > ( ) ?,
400
+ ) ,
401
+ Literal :: AbstractFloat ( _) => Self :: AbstractFloat (
402
+ components
403
+ . iter ( )
404
+ . map ( |l| match l {
405
+ & Literal :: AbstractFloat ( v) => Ok ( v) ,
406
+ _ => Err ( ConstantEvaluatorError :: InvalidMathArg ) ,
407
+ } )
408
+ . collect :: < Result < _ , _ > > ( ) ?,
409
+ ) ,
410
+ Literal :: F16 ( _) => Self :: F16 (
411
+ components
412
+ . iter ( )
413
+ . map ( |l| match l {
414
+ & Literal :: F16 ( v) => Ok ( v) ,
415
+ _ => Err ( ConstantEvaluatorError :: InvalidMathArg ) ,
416
+ } )
417
+ . collect :: < Result < _ , _ > > ( ) ?,
418
+ ) ,
419
+ } )
420
+ }
421
+
422
+ #[ allow( dead_code) ]
423
+ /// Returns [`ArrayVec`] of [`Literal`]s
424
+ fn to_literal_vec ( & self ) -> ArrayVec < Literal , { crate :: VectorSize :: MAX } > {
425
+ match * self {
426
+ LiteralVector :: F64 ( ref v) => v. iter ( ) . map ( |e| ( Literal :: F64 ( * e) ) ) . collect ( ) ,
427
+ LiteralVector :: F32 ( ref v) => v. iter ( ) . map ( |e| ( Literal :: F32 ( * e) ) ) . collect ( ) ,
428
+ LiteralVector :: F16 ( ref v) => v. iter ( ) . map ( |e| ( Literal :: F16 ( * e) ) ) . collect ( ) ,
429
+ LiteralVector :: U32 ( ref v) => v. iter ( ) . map ( |e| ( Literal :: U32 ( * e) ) ) . collect ( ) ,
430
+ LiteralVector :: I32 ( ref v) => v. iter ( ) . map ( |e| ( Literal :: I32 ( * e) ) ) . collect ( ) ,
431
+ LiteralVector :: U64 ( ref v) => v. iter ( ) . map ( |e| ( Literal :: U64 ( * e) ) ) . collect ( ) ,
432
+ LiteralVector :: I64 ( ref v) => v. iter ( ) . map ( |e| ( Literal :: I64 ( * e) ) ) . collect ( ) ,
433
+ LiteralVector :: Bool ( ref v) => v. iter ( ) . map ( |e| ( Literal :: Bool ( * e) ) ) . collect ( ) ,
434
+ LiteralVector :: AbstractInt ( ref v) => {
435
+ v. iter ( ) . map ( |e| ( Literal :: AbstractInt ( * e) ) ) . collect ( )
436
+ }
437
+ LiteralVector :: AbstractFloat ( ref v) => {
438
+ v. iter ( ) . map ( |e| ( Literal :: AbstractFloat ( * e) ) ) . collect ( )
439
+ }
440
+ }
441
+ }
442
+
443
+ #[ allow( dead_code) ]
444
+ /// Puts self into eval's expressions arena and returns handle to it
445
+ fn register_as_evaluated_expr (
446
+ & self ,
447
+ eval : & mut ConstantEvaluator < ' _ > ,
448
+ span : Span ,
449
+ ) -> Result < Handle < Expression > , ConstantEvaluatorError > {
450
+ let lit_vec = self . to_literal_vec ( ) ;
451
+ assert ! ( !lit_vec. is_empty( ) ) ;
452
+ let expr = if lit_vec. len ( ) == 1 {
453
+ Expression :: Literal ( lit_vec[ 0 ] )
454
+ } else {
455
+ Expression :: Compose {
456
+ ty : eval. types . insert (
457
+ Type {
458
+ name : None ,
459
+ inner : TypeInner :: Vector {
460
+ size : match lit_vec. len ( ) {
461
+ 2 => crate :: VectorSize :: Bi ,
462
+ 3 => crate :: VectorSize :: Tri ,
463
+ 4 => crate :: VectorSize :: Quad ,
464
+ _ => unreachable ! ( ) ,
465
+ } ,
466
+ scalar : lit_vec[ 0 ] . scalar ( ) ,
467
+ } ,
468
+ } ,
469
+ Span :: UNDEFINED ,
470
+ ) ,
471
+ components : lit_vec
472
+ . iter ( )
473
+ . map ( |& l| eval. register_evaluated_expr ( Expression :: Literal ( l) , span) )
474
+ . collect :: < Result < _ , _ > > ( ) ?,
475
+ }
476
+ } ;
477
+ eval. register_evaluated_expr ( expr, span)
478
+ }
479
+ }
480
+
481
+ /// ```rust
482
+ /// match_literal_vector!(match v => Literal {
483
+ /// F16 => |v| {v.sum()},
484
+ /// Integer => |v| {v.sum()},
485
+ /// U32 => |v| -> I32 {v.sum()}, // optionally override return type
486
+ /// })
487
+ /// ```
488
+ ///
489
+ /// ```rust
490
+ /// match_literal_vector!(match (e1, e2) => LiteralVector {
491
+ /// F16 => |e1, e2| {e1+e2},
492
+ /// Integer => |e1, e2| {e1+e2},
493
+ /// U32 => |e1, e2| -> I32 {e1+e2}, // optionally override return type
494
+ /// })
495
+ /// ```
496
+ ///
497
+ /// `Float` expands to `F16`, `F32`, `F64` and `AbstractFloat`.
498
+ /// `Integer` expands to `I32`, `I64`, `U32`, `U64` and `AbstractInt`.
499
+ ///
500
+ macro_rules! match_literal_vector {
501
+ ( match $lit_vec: expr => $out: ident {
502
+ $(
503
+ $ty: ident => |$( $var: ident) ,+| $( -> $ret: ident) ? { $body: expr }
504
+ ) ,+
505
+ $( , ) ?
506
+ } ) => {
507
+ match_literal_vector!( @inner_start $lit_vec; $out; [ $( $ty) ,+] ; [ $( { $( $var) ,+ ; $( $ret) ? ; $body } ) ,+] )
508
+ } ;
509
+
510
+ ( @inner_start
511
+ $lit_vec: expr;
512
+ $out: ident;
513
+ [ $( $ty: ident) ,+] ;
514
+ [ $( { $( $var: ident) ,+ ; $( $ret: ident) ? ; $body: expr } ) ,+]
515
+ ) => {
516
+ match_literal_vector!( @inner
517
+ $lit_vec;
518
+ $out;
519
+ [ $( $ty) ,+] ;
520
+ [ ] <> [ $( { $( $var) ,+ ; $( $ret) ? ; $body } ) ,+]
521
+ )
522
+ } ;
523
+
524
+ ( @inner
525
+ $lit_vec: expr;
526
+ $out: ident;
527
+ [ $ty: ident $( , $ty1: ident) * ] ;
528
+ [ $( { $_ty: ident ; $( $_var: ident) ,+ ; $( $_ret: ident) ? ; $_body: expr} ) ,* ] <>
529
+ [ $( { $( $var: ident) ,+ ; $( $ret: ident) ? ; $body: expr } ) ,+]
530
+ ) => {
531
+ match_literal_vector!( @inner
532
+ $ty;
533
+ $lit_vec;
534
+ $out;
535
+ [ $( $ty1) ,* ] ;
536
+ [ $( { $_ty ; $( $_var) ,+ ; $( $_ret) ? ; $_body} ) ,* ] <>
537
+ [ $( { $( $var) ,+ ; $( $ret) ? ; $body } ) ,+]
538
+ )
539
+ } ;
540
+ ( @inner
541
+ Integer ;
542
+ $lit_vec: expr;
543
+ $out: ident;
544
+ [ $( $ty: ident) ,* ] ;
545
+ [ $( { $_ty: ident ; $( $_var: ident) ,+ ; $( $_ret: ident) ? ; $_body: expr} ) ,* ] <>
546
+ [
547
+ { $( $var: ident) ,+ ; $( $ret: ident) ? ; $body: expr }
548
+ $( , { $( $var1: ident) ,+ ; $( $ret1: ident) ? ; $body1: expr } ) *
549
+ ]
550
+ ) => {
551
+ match_literal_vector!( @inner
552
+ $lit_vec;
553
+ $out;
554
+ [ U32 , I32 , U64 , I64 , AbstractInt $( , $ty) * ] ;
555
+ [ $( { $_ty ; $( $_var) ,+ ; $( $_ret) ? ; $_body} ) ,* ] <>
556
+ [
557
+ { $( $var) ,+ ; $( $ret) ? ; $body } , // U32
558
+ { $( $var) ,+ ; $( $ret) ? ; $body } , // I32
559
+ { $( $var) ,+ ; $( $ret) ? ; $body } , // U64
560
+ { $( $var) ,+ ; $( $ret) ? ; $body } , // I64
561
+ { $( $var) ,+ ; $( $ret) ? ; $body } // AbstractInt
562
+ $( , { $( $var1) ,+ ; $( $ret1) ? ; $body1 } ) *
563
+ ]
564
+ )
565
+ } ;
566
+ ( @inner
567
+ Float ;
568
+ $lit_vec: expr;
569
+ $out: ident;
570
+ [ $( $ty: ident) ,* ] ;
571
+ [ $( { $_ty: ident ; $( $_var: ident) ,+ ; $( $_ret: ident) ? ; $_body: expr} ) ,* ] <>
572
+ [
573
+ { $( $var: ident) ,+ ; $( $ret: ident) ? ; $body: expr }
574
+ $( , { $( $var1: ident) ,+ ; $( $ret1: ident) ? ; $body1: expr } ) *
575
+ ]
576
+ ) => {
577
+ match_literal_vector!( @inner
578
+ $lit_vec;
579
+ $out;
580
+ [ F16 , F32 , F64 , AbstractFloat $( , $ty) * ] ;
581
+ [ $( { $_ty ; $( $_var) ,+ ; $( $_ret) ? ; $_body} ) ,* ] <>
582
+ [
583
+ { $( $var) ,+ ; $( $ret) ? ; $body } , // F16
584
+ { $( $var) ,+ ; $( $ret) ? ; $body } , // F32
585
+ { $( $var) ,+ ; $( $ret) ? ; $body } , // F64
586
+ { $( $var) ,+ ; $( $ret) ? ; $body } // AbstractFloat
587
+ $( , { $( $var1) ,+ ; $( $ret1) ? ; $body1 } ) *
588
+ ]
589
+ )
590
+ } ;
591
+ ( @inner
592
+ $ty: ident;
593
+ $lit_vec: expr;
594
+ $out: ident;
595
+ [ $ty1: ident $( , $ty2: ident) * ] ;
596
+ [ $( { $_ty: ident ; $( $_var: ident) ,+ ; $( $_ret: ident) ? ; $_body: expr} ) ,* ] <> [
597
+ { $( $var: ident) ,+ ; $( $ret: ident) ? ; $body: expr }
598
+ $( , { $( $var1: ident) ,+ ; $( $ret1: ident) ? ; $body1: expr } ) *
599
+ ]
600
+ ) => {
601
+ match_literal_vector!( @inner
602
+ $ty1;
603
+ $lit_vec;
604
+ $out;
605
+ [ $( $ty2) ,* ] ;
606
+ [
607
+ $( { $_ty ; $( $_var) ,+ ; $( $_ret) ? ; $_body} , ) *
608
+ { $ty; $( $var) ,+ ; $( $ret) ? ; $body }
609
+ ] <>
610
+ [ $( { $( $var1) ,+ ; $( $ret1) ? ; $body1 } ) ,* ]
611
+
612
+ )
613
+ } ;
614
+ ( @inner
615
+ $ty: ident;
616
+ $lit_vec: expr;
617
+ $out: ident;
618
+ [ ] ;
619
+ [ $( { $_ty: ident ; $( $_var: ident) ,+ ; $( $_ret: ident) ? ; $_body: expr} ) ,* ] <>
620
+ [ { $( $var: ident) ,+ ; $( $ret: ident) ? ; $body: expr } ]
621
+ ) => {
622
+ match_literal_vector!( @inner_finish
623
+ $lit_vec;
624
+ $out;
625
+ [
626
+ $( { $_ty ; $( $_var) ,+ ; $( $_ret) ? ; $_body } , ) *
627
+ { $ty; $( $var) ,+ ; $( $ret) ? ; $body }
628
+ ]
629
+ )
630
+ } ;
631
+ ( @inner_finish
632
+ $lit_vec: expr;
633
+ $out: ident;
634
+ [ $( { $ty: ident ; $( $var: ident) ,+ ; $( $ret: ident) ? ; $body: expr} ) ,+]
635
+ ) => {
636
+ match $lit_vec {
637
+ $(
638
+ #[ allow( unused_parens) ]
639
+ ( $( LiteralVector :: $ty( ref $var) ) ,+) => { Ok ( match_literal_vector!( @expand_ret $out; $ty $( ; $ret) ? ; $body) ) }
640
+ ) +
641
+ _ => Err ( ConstantEvaluatorError :: InvalidMathArg ) ,
642
+ }
643
+ } ;
644
+ ( @expand_ret $out: ident; $ty: ident; $body: expr) => {
645
+ $out:: $ty( $body)
646
+ } ;
647
+ ( @expand_ret $out: ident; $_ty: ident; $ret: ident; $body: expr) => {
648
+ $out:: $ret( $body)
649
+ } ;
650
+ }
651
+
271
652
#[ derive( Debug ) ]
272
653
enum Behavior < ' a > {
273
654
Wgsl ( WgslRestrictions < ' a > ) ,
0 commit comments