@@ -490,6 +490,43 @@ pub trait ViewExt: Viewable {
490
490
/// );
491
491
/// ```
492
492
fn range < R : Into < Range > > ( & self , dim : & str , range : R ) -> Result < View , ViewError > ;
493
+
494
+ /// Partition the view on `dim`. The returned iterator enumerates all partitions
495
+ /// as views in the extent of `dim` to the last dimension of the view.
496
+ ///
497
+ /// ## Examples
498
+ ///
499
+ /// ```
500
+ /// use ndslice::ViewExt;
501
+ /// use ndslice::extent;
502
+ ///
503
+ /// let ext = extent!(zone = 4, host = 2, gpu = 8);
504
+ ///
505
+ /// // We generate one view for each zone.
506
+ /// assert_eq!(ext.partition("host").unwrap().count(), 4);
507
+ ///
508
+ /// let mut parts = ext.partition("host").unwrap();
509
+ ///
510
+ /// let zone0 = parts.next().unwrap();
511
+ /// let mut zone0_points = zone0.iter();
512
+ /// assert_eq!(zone0.extent(), extent!(host = 2, gpu = 8));
513
+ /// assert_eq!(
514
+ /// zone0_points.next().unwrap(),
515
+ /// (extent!(host = 2, gpu = 8).point(vec![0, 0]).unwrap(), 0)
516
+ /// );
517
+ /// assert_eq!(
518
+ /// zone0_points.next().unwrap(),
519
+ /// (extent!(host = 2, gpu = 8).point(vec![0, 1]).unwrap(), 1)
520
+ /// );
521
+ ///
522
+ /// let zone1 = parts.next().unwrap();
523
+ /// assert_eq!(zone1.extent(), extent!(host = 2, gpu = 8));
524
+ /// assert_eq!(
525
+ /// zone1.iter().next().unwrap(),
526
+ /// (extent!(host = 2, gpu = 8).point(vec![0, 0]).unwrap(), 16)
527
+ /// );
528
+ /// ```
529
+ fn partition ( & self , dim : & str ) -> Result < impl Iterator < Item = View > , ViewError > ;
493
530
}
494
531
495
532
impl < T : Viewable > ViewExt for T {
@@ -520,6 +557,32 @@ impl<T: Viewable> ViewExt for T {
520
557
slice,
521
558
} )
522
559
}
560
+
561
+ fn partition ( & self , dim : & str ) -> Result < impl Iterator < Item = View > , ViewError > {
562
+ let dim = self
563
+ . labels ( )
564
+ . iter ( )
565
+ . position ( |l| dim == l)
566
+ . ok_or_else ( || ViewError :: InvalidDim ( dim. to_string ( ) ) ) ?;
567
+
568
+ let ( offset, sizes, strides) = self . slice ( ) . into_inner ( ) ;
569
+ let mut ranks = Slice :: new ( offset, sizes[ ..dim] . to_vec ( ) , strides[ ..dim] . to_vec ( ) )
570
+ . unwrap ( )
571
+ . iter ( ) ;
572
+
573
+ let labels = self . labels ( ) [ dim..] . to_vec ( ) ;
574
+ let sizes = sizes[ dim..] . to_vec ( ) ;
575
+ let strides = strides[ dim..] . to_vec ( ) ;
576
+
577
+ Ok ( std:: iter:: from_fn ( move || {
578
+ let rank = ranks. next ( ) ?;
579
+ let slice = Slice :: new ( rank, sizes. clone ( ) , strides. clone ( ) ) . unwrap ( ) ;
580
+ Some ( View {
581
+ labels : labels. clone ( ) ,
582
+ slice,
583
+ } )
584
+ } ) )
585
+ }
523
586
}
524
587
525
588
/// Construct a new extent with the given set of dimension-size pairs.
@@ -590,9 +653,10 @@ mod test {
590
653
591
654
macro_rules! assert_view {
592
655
( $view: expr, $extent: expr, $( $( $coord: expr) ,+ => $rank: expr ) ;* $( ; ) ?) => {
593
- assert_eq!( $view. extent( ) , $extent) ;
656
+ let view = $view;
657
+ assert_eq!( view. extent( ) , $extent) ;
594
658
let expected: Vec <_> = vec![ $( ( $extent. point( vec![ $( $coord) ,+] ) . unwrap( ) , $rank) ) ,* ] ;
595
- let actual: Vec <_> = $ view. iter( ) . collect( ) ;
659
+ let actual: Vec <_> = view. iter( ) . collect( ) ;
596
660
assert_eq!( actual, expected) ;
597
661
} ;
598
662
}
@@ -784,4 +848,38 @@ mod test {
784
848
vec![ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ]
785
849
) ;
786
850
}
851
+
852
+ #[ test]
853
+ fn test_iter_subviews ( ) {
854
+ let extent = extent ! ( zone = 4 , host = 4 , gpu = 8 ) ;
855
+
856
+ assert_eq ! ( extent. partition( "gpu" ) . unwrap( ) . count( ) , 16 ) ;
857
+ assert_eq ! ( extent. partition( "zone" ) . unwrap( ) . count( ) , 1 ) ;
858
+
859
+ let mut parts = extent. partition ( "gpu" ) . unwrap ( ) ;
860
+ assert_view ! (
861
+ parts. next( ) . unwrap( ) ,
862
+ extent!( gpu = 8 ) ,
863
+ 0 => 0 ;
864
+ 1 => 1 ;
865
+ 2 => 2 ;
866
+ 3 => 3 ;
867
+ 4 => 4 ;
868
+ 5 => 5 ;
869
+ 6 => 6 ;
870
+ 7 => 7 ;
871
+ ) ;
872
+ assert_view ! (
873
+ parts. next( ) . unwrap( ) ,
874
+ extent!( gpu = 8 ) ,
875
+ 0 => 8 ;
876
+ 1 => 9 ;
877
+ 2 => 10 ;
878
+ 3 => 11 ;
879
+ 4 => 12 ;
880
+ 5 => 13 ;
881
+ 6 => 14 ;
882
+ 7 => 15 ;
883
+ ) ;
884
+ }
787
885
}
0 commit comments