Skip to content

Commit 24a25c6

Browse files
shuhuayumeta-codesync[bot]
authored andcommitted
Add Extent::concat (meta-pytorch#1419)
Summary: Pull Request resolved: meta-pytorch#1419 1. Addressing Github issue: https://fburl.com/ns9ekdb7. Add Extent concatenation method `Extent::concat` with duplicate label validation. Returns `ExtentError::OverlappingLabel` when duplicate labels are detected. 2. Add tests for several cases of concatenation, including cases with overlapping labels. 3. Simplify an manual concatenation in `host_mesh.rs` by calling 'Extent::concat`. ghstack-source-id: 314332973 exported-using-ghexport Reviewed By: shayne-fletcher Differential Revision: D83869447 fbshipit-source-id: 1dd83b5f8705287e2e3b659eee815c967524b9c8
1 parent bf58a9b commit 24a25c6

File tree

2 files changed

+130
-15
lines changed

2 files changed

+130
-15
lines changed

hyperactor_mesh/src/v1/host_mesh.rs

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -511,23 +511,11 @@ impl HostMeshRef {
511511
)));
512512
}
513513

514-
let labels = self
515-
.region
516-
.labels()
517-
.to_vec()
518-
.into_iter()
519-
.chain(per_host.labels().to_vec().into_iter())
520-
.collect();
521-
let sizes = self
514+
let extent = self
522515
.region
523516
.extent()
524-
.sizes()
525-
.to_vec()
526-
.into_iter()
527-
.chain(per_host.sizes().to_vec().into_iter())
528-
.collect();
529-
let extent =
530-
Extent::new(labels, sizes).map_err(|err| v1::Error::ConfigurationError(err.into()))?;
517+
.concat(&per_host)
518+
.map_err(|err| v1::Error::ConfigurationError(err.into()))?;
531519

532520
let mesh_name = Name::new(name);
533521
let mut procs = Vec::new();

ndslice/src/view.rs

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,16 @@ pub enum ExtentError {
5959
/// Number of dimension sizes provided.
6060
num_sizes: usize,
6161
},
62+
63+
/// An overlapping label was found.
64+
///
65+
/// This occurs when attempting to combine extents that
66+
/// share one or more dimension labels, which is not allowed.
67+
#[error("overlapping label found: {label}")]
68+
OverlappingLabel {
69+
/// The label that appears in both extents.
70+
label: String,
71+
},
6272
}
6373

6474
/// `Extent` defines the logical shape of a multidimensional space by
@@ -307,6 +317,30 @@ impl Extent {
307317
pub fn points(&self) -> ExtentPointsIterator<'_> {
308318
ExtentPointsIterator::new(self)
309319
}
320+
321+
/// Append the dimensions of `other` to this extent, preserving order.
322+
///
323+
/// Duplicate labels are not allowed: if any label in `other` already appears
324+
/// in `self`, this returns `ExtentError::OverlappingLabel`.
325+
///
326+
/// This operation is not commutative: `a.concat(&b)` may differ from
327+
/// `b.concat(&a)`.
328+
pub fn concat(&self, other: &Extent) -> Result<Self, ExtentError> {
329+
use std::collections::HashSet;
330+
// Check for any overlapping labels in linear time using hash set
331+
let lhs: HashSet<&str> = self.labels().iter().map(|s| s.as_str()).collect();
332+
if let Some(dup) = other.labels().iter().find(|l| lhs.contains(l.as_str())) {
333+
return Err(ExtentError::OverlappingLabel { label: dup.clone() });
334+
}
335+
// Combine labels and sizes from both extents with pre-allocated memory
336+
let mut labels = self.labels().to_vec();
337+
let mut sizes = self.sizes().to_vec();
338+
labels.reserve(other.labels().len());
339+
sizes.reserve(other.sizes().len());
340+
labels.extend(other.labels().iter().cloned());
341+
sizes.extend(other.sizes().iter().copied());
342+
Extent::new(labels, sizes)
343+
}
310344
}
311345

312346
/// Label formatting utilities shared across `Extent`, `Region`, and
@@ -2111,6 +2145,99 @@ mod test {
21112145
assert!(it.next().is_none()); // fused
21122146
}
21132147

2148+
#[test]
2149+
fn test_extent_concat() {
2150+
// Test basic concatenation of two extents with preserved order of labels
2151+
let extent1 = extent!(x = 2, y = 3);
2152+
let extent2 = extent!(z = 4, w = 5);
2153+
2154+
let result = extent1.concat(&extent2).unwrap();
2155+
assert_eq!(result.labels(), &["x", "y", "z", "w"]);
2156+
assert_eq!(result.sizes(), &[2, 3, 4, 5]);
2157+
assert_eq!(result.num_ranks(), 2 * 3 * 4 * 5);
2158+
2159+
// Test concatenating with empty extent
2160+
let empty = extent!();
2161+
let result = extent1.concat(&empty).unwrap();
2162+
assert_eq!(result.labels(), &["x", "y"]);
2163+
assert_eq!(result.sizes(), &[2, 3]);
2164+
2165+
let result = empty.concat(&extent1).unwrap();
2166+
assert_eq!(result.labels(), &["x", "y"]);
2167+
assert_eq!(result.sizes(), &[2, 3]);
2168+
2169+
// Test concatenating two empty extents
2170+
let result = empty.concat(&empty).unwrap();
2171+
assert_eq!(result.labels(), &[] as &[String]);
2172+
assert_eq!(result.sizes(), &[] as &[usize]);
2173+
assert_eq!(result.num_ranks(), 1); // 0-dimensional extent has 1 rank
2174+
2175+
// Test self-concatenation (overlapping labels should cause error)
2176+
let result = extent1.concat(&extent1);
2177+
assert!(
2178+
result.is_err(),
2179+
"Self-concatenation should error due to overlapping labels"
2180+
);
2181+
match result.unwrap_err() {
2182+
ExtentError::OverlappingLabel { label } => {
2183+
assert!(label == "x"); // Overlapping label should be "x"
2184+
}
2185+
other => panic!("Expected OverlappingLabel error, got {:?}", other),
2186+
}
2187+
2188+
// Test concatenation creates valid points
2189+
let result = extent1.concat(&extent2).unwrap();
2190+
let point = result.point(vec![1, 2, 3, 4]).unwrap();
2191+
assert_eq!(point.coords(), vec![1, 2, 3, 4]);
2192+
assert_eq!(point.extent(), &result);
2193+
2194+
// Test error case: overlapping labels with same size (should error)
2195+
let extent_a = extent!(x = 2, y = 3);
2196+
let extent_b = extent!(y = 3, z = 4); // y overlaps with same size
2197+
let result = extent_a.concat(&extent_b);
2198+
assert!(
2199+
result.is_err(),
2200+
"Should error on overlapping labels even with same size"
2201+
);
2202+
match result.unwrap_err() {
2203+
ExtentError::OverlappingLabel { label } => {
2204+
assert_eq!(label, "y"); // the overlapping label
2205+
}
2206+
other => panic!("Expected OverlappingLabel error, got {:?}", other),
2207+
}
2208+
2209+
// Test that Extent::concat preserves order and is not commutative
2210+
let extent_x = extent!(x = 2, y = 3);
2211+
let extent_y = extent!(z = 4);
2212+
assert_eq!(
2213+
extent_x.concat(&extent_y).unwrap().labels(),
2214+
&["x", "y", "z"]
2215+
);
2216+
assert_eq!(
2217+
extent_y.concat(&extent_x).unwrap().labels(),
2218+
&["z", "x", "y"]
2219+
);
2220+
2221+
// Test associativity: (a ⊕ b) ⊕ c == a ⊕ (b ⊕ c) for disjoint labels
2222+
let extent_m = extent!(x = 2);
2223+
let extent_n = extent!(y = 3);
2224+
let extent_o = extent!(z = 4);
2225+
2226+
let left_assoc = extent_m
2227+
.concat(&extent_n)
2228+
.unwrap()
2229+
.concat(&extent_o)
2230+
.unwrap();
2231+
let right_assoc = extent_m
2232+
.concat(&extent_n.concat(&extent_o).unwrap())
2233+
.unwrap();
2234+
2235+
assert_eq!(left_assoc, right_assoc);
2236+
assert_eq!(left_assoc.labels(), &["x", "y", "z"]);
2237+
assert_eq!(left_assoc.sizes(), &[2, 3, 4]);
2238+
assert_eq!(left_assoc.num_ranks(), 2 * 3 * 4);
2239+
}
2240+
21142241
#[test]
21152242
fn extent_unity_equiv_to_0d() {
21162243
let e = Extent::unity();

0 commit comments

Comments
 (0)