Skip to content

Commit c750074

Browse files
mariusaefacebook-github-bot
authored andcommitted
ndslice: ViewExt::partition (#718)
Summary: Pull Request resolved: #718 This provides a way to partition a view on a dimension. The operation returns all resulting partitions in the reduced (suffix) extent. This will be used in the host allocator, where we delegate subspaces (e.g., gpus) to each host. Reviewed By: shayne-fletcher Differential Revision: D79404499 fbshipit-source-id: 451dc6e6b9b51a176cb573493bfe3fbd40d69bae
1 parent 982b413 commit c750074

File tree

1 file changed

+100
-2
lines changed

1 file changed

+100
-2
lines changed

ndslice/src/view.rs

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,43 @@ pub trait ViewExt: Viewable {
490490
/// );
491491
/// ```
492492
fn range<R: Into<Range>>(&self, dim: &str, range: R) -> Result<View, ViewError>;
493+
494+
/// Partition the view on `dim`. The returned iterator enumerates all partitions
495+
/// as views in the extent of `dim` to the last dimension of the view.
496+
///
497+
/// ## Examples
498+
///
499+
/// ```
500+
/// use ndslice::ViewExt;
501+
/// use ndslice::extent;
502+
///
503+
/// let ext = extent!(zone = 4, host = 2, gpu = 8);
504+
///
505+
/// // We generate one view for each zone.
506+
/// assert_eq!(ext.partition("host").unwrap().count(), 4);
507+
///
508+
/// let mut parts = ext.partition("host").unwrap();
509+
///
510+
/// let zone0 = parts.next().unwrap();
511+
/// let mut zone0_points = zone0.iter();
512+
/// assert_eq!(zone0.extent(), extent!(host = 2, gpu = 8));
513+
/// assert_eq!(
514+
/// zone0_points.next().unwrap(),
515+
/// (extent!(host = 2, gpu = 8).point(vec![0, 0]).unwrap(), 0)
516+
/// );
517+
/// assert_eq!(
518+
/// zone0_points.next().unwrap(),
519+
/// (extent!(host = 2, gpu = 8).point(vec![0, 1]).unwrap(), 1)
520+
/// );
521+
///
522+
/// let zone1 = parts.next().unwrap();
523+
/// assert_eq!(zone1.extent(), extent!(host = 2, gpu = 8));
524+
/// assert_eq!(
525+
/// zone1.iter().next().unwrap(),
526+
/// (extent!(host = 2, gpu = 8).point(vec![0, 0]).unwrap(), 16)
527+
/// );
528+
/// ```
529+
fn partition(&self, dim: &str) -> Result<impl Iterator<Item = View>, ViewError>;
493530
}
494531

495532
impl<T: Viewable> ViewExt for T {
@@ -520,6 +557,32 @@ impl<T: Viewable> ViewExt for T {
520557
slice,
521558
})
522559
}
560+
561+
fn partition(&self, dim: &str) -> Result<impl Iterator<Item = View>, ViewError> {
562+
let dim = self
563+
.labels()
564+
.iter()
565+
.position(|l| dim == l)
566+
.ok_or_else(|| ViewError::InvalidDim(dim.to_string()))?;
567+
568+
let (offset, sizes, strides) = self.slice().into_inner();
569+
let mut ranks = Slice::new(offset, sizes[..dim].to_vec(), strides[..dim].to_vec())
570+
.unwrap()
571+
.iter();
572+
573+
let labels = self.labels()[dim..].to_vec();
574+
let sizes = sizes[dim..].to_vec();
575+
let strides = strides[dim..].to_vec();
576+
577+
Ok(std::iter::from_fn(move || {
578+
let rank = ranks.next()?;
579+
let slice = Slice::new(rank, sizes.clone(), strides.clone()).unwrap();
580+
Some(View {
581+
labels: labels.clone(),
582+
slice,
583+
})
584+
}))
585+
}
523586
}
524587

525588
/// Construct a new extent with the given set of dimension-size pairs.
@@ -590,9 +653,10 @@ mod test {
590653

591654
macro_rules! assert_view {
592655
($view:expr, $extent:expr, $( $($coord:expr),+ => $rank:expr );* $(;)?) => {
593-
assert_eq!($view.extent(), $extent);
656+
let view = $view;
657+
assert_eq!(view.extent(), $extent);
594658
let expected: Vec<_> = vec![$(($extent.point(vec![$($coord),+]).unwrap(), $rank)),*];
595-
let actual: Vec<_> = $view.iter().collect();
659+
let actual: Vec<_> = view.iter().collect();
596660
assert_eq!(actual, expected);
597661
};
598662
}
@@ -784,4 +848,38 @@ mod test {
784848
vec![0, 1, 2, 3, 4, 5, 6, 7]
785849
);
786850
}
851+
852+
#[test]
853+
fn test_iter_subviews() {
854+
let extent = extent!(zone = 4, host = 4, gpu = 8);
855+
856+
assert_eq!(extent.partition("gpu").unwrap().count(), 16);
857+
assert_eq!(extent.partition("zone").unwrap().count(), 1);
858+
859+
let mut parts = extent.partition("gpu").unwrap();
860+
assert_view!(
861+
parts.next().unwrap(),
862+
extent!(gpu = 8),
863+
0 => 0;
864+
1 => 1;
865+
2 => 2;
866+
3 => 3;
867+
4 => 4;
868+
5 => 5;
869+
6 => 6;
870+
7 => 7;
871+
);
872+
assert_view!(
873+
parts.next().unwrap(),
874+
extent!(gpu = 8),
875+
0 => 8;
876+
1 => 9;
877+
2 => 10;
878+
3 => 11;
879+
4 => 12;
880+
5 => 13;
881+
6 => 14;
882+
7 => 15;
883+
);
884+
}
787885
}

0 commit comments

Comments
 (0)