@@ -55,12 +55,15 @@ use tokio_stream::wrappers::WatchStream;
55
55
use tokio_util:: sync:: CancellationToken ;
56
56
57
57
use crate :: alloc:: Alloc ;
58
+ use crate :: alloc:: AllocConstraints ;
58
59
use crate :: alloc:: AllocSpec ;
59
60
use crate :: alloc:: Allocator ;
60
61
use crate :: alloc:: AllocatorError ;
61
62
use crate :: alloc:: ProcState ;
62
63
use crate :: alloc:: ProcStopReason ;
63
64
use crate :: alloc:: ProcessAllocator ;
65
+ use crate :: alloc:: process:: CLIENT_TRACE_ID_LABEL ;
66
+ use crate :: alloc:: process:: ClientContext ;
64
67
65
68
/// Control messages sent from remote process allocator to local allocator.
66
69
#[ derive( Debug , Clone , Serialize , Deserialize , Named ) ]
@@ -74,6 +77,10 @@ pub enum RemoteProcessAllocatorMessage {
74
77
/// Ordered list of hosts in this allocation. Can be used to
75
78
/// pre-populate the any local configurations such as torch.dist.
76
79
hosts : Vec < String > ,
80
+ /// Client context which is passed to the ProcessAlloc
81
+ /// Todo: Once RemoteProcessAllocator moves to mailbox,
82
+ /// the client_context will go to the message header instead
83
+ client_context : Option < ClientContext > ,
77
84
} ,
78
85
/// Stop allocation.
79
86
Stop ,
@@ -196,15 +203,25 @@ impl RemoteProcessAllocator {
196
203
view,
197
204
bootstrap_addr,
198
205
hosts,
206
+ client_context,
199
207
} ) => {
200
208
tracing:: info!( "received allocation request for view: {}" , view) ;
201
209
ensure_previous_alloc_stopped( & mut active_allocation) . await ;
202
210
tracing:: info!( "allocating..." ) ;
203
211
204
212
// Create the corresponding local allocation spec.
213
+ let mut constraints: AllocConstraints = Default :: default ( ) ;
214
+ if let Some ( context) = & client_context {
215
+ constraints = AllocConstraints {
216
+ match_labels: HashMap :: from( [ (
217
+ CLIENT_TRACE_ID_LABEL . to_string( ) ,
218
+ context. trace_id. to_string( ) ,
219
+ ) ]
220
+ ) } ;
221
+ }
205
222
let spec = AllocSpec {
206
223
extent: view. extent( ) ,
207
- constraints: Default :: default ( ) ,
224
+ constraints,
208
225
} ;
209
226
210
227
match process_allocator. allocate( spec. clone( ) ) . await {
@@ -749,10 +766,15 @@ impl RemoteProcessAlloc {
749
766
"failed to dial remote {} for host {}" ,
750
767
remote_addr, host. id
751
768
) ) ?;
769
+
770
+ let trace_id = hyperactor_telemetry:: trace:: get_or_create_trace_id ( ) ;
771
+ let client_context = Some ( ClientContext { trace_id } ) ;
772
+
752
773
tx. post ( RemoteProcessAllocatorMessage :: Allocate {
753
774
view,
754
775
bootstrap_addr : self . bootstrap_addr . clone ( ) ,
755
776
hosts : hostnames. clone ( ) ,
777
+ client_context,
756
778
} ) ;
757
779
758
780
self . hosts_by_offset . insert ( offset, host. id . clone ( ) ) ;
@@ -1280,6 +1302,7 @@ mod test {
1280
1302
view : extent. clone ( ) . into ( ) ,
1281
1303
bootstrap_addr,
1282
1304
hosts : vec ! [ ] ,
1305
+ client_context : None ,
1283
1306
} )
1284
1307
. await
1285
1308
. unwrap ( ) ;
@@ -1419,6 +1442,7 @@ mod test {
1419
1442
view : extent. clone ( ) . into ( ) ,
1420
1443
bootstrap_addr,
1421
1444
hosts : vec ! [ ] ,
1445
+ client_context : None ,
1422
1446
} )
1423
1447
. await
1424
1448
. unwrap ( ) ;
@@ -1519,6 +1543,7 @@ mod test {
1519
1543
view : extent. clone ( ) . into ( ) ,
1520
1544
bootstrap_addr : bootstrap_addr. clone ( ) ,
1521
1545
hosts : vec ! [ ] ,
1546
+ client_context : None ,
1522
1547
} )
1523
1548
. await
1524
1549
. unwrap ( ) ;
@@ -1539,6 +1564,7 @@ mod test {
1539
1564
view : extent. clone ( ) . into ( ) ,
1540
1565
bootstrap_addr,
1541
1566
hosts : vec ! [ ] ,
1567
+ client_context : None ,
1542
1568
} )
1543
1569
. await
1544
1570
. unwrap ( ) ;
@@ -1632,6 +1658,7 @@ mod test {
1632
1658
view : extent. clone ( ) . into ( ) ,
1633
1659
bootstrap_addr,
1634
1660
hosts : vec ! [ ] ,
1661
+ client_context : None ,
1635
1662
} )
1636
1663
. await
1637
1664
. unwrap ( ) ;
@@ -1721,6 +1748,7 @@ mod test {
1721
1748
view : extent. clone ( ) . into ( ) ,
1722
1749
bootstrap_addr,
1723
1750
hosts : vec ! [ ] ,
1751
+ client_context : None ,
1724
1752
} )
1725
1753
. await
1726
1754
. unwrap ( ) ;
@@ -1747,6 +1775,150 @@ mod test {
1747
1775
remote_allocator. terminate ( ) ;
1748
1776
handle. await . unwrap ( ) . unwrap ( ) ;
1749
1777
}
1778
+
1779
+ #[ timed_test:: async_timed_test( timeout_secs = 15 ) ]
1780
+ async fn test_trace_id_propagation ( ) {
1781
+ let config = hyperactor:: config:: global:: lock ( ) ;
1782
+ let _guard = config. override_key (
1783
+ hyperactor:: config:: REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL ,
1784
+ Duration :: from_secs ( 60 ) ,
1785
+ ) ;
1786
+ hyperactor_telemetry:: initialize_logging ( ClockKind :: default ( ) ) ;
1787
+ let serve_addr = ChannelAddr :: any ( ChannelTransport :: Unix ) ;
1788
+ let bootstrap_addr = ChannelAddr :: any ( ChannelTransport :: Unix ) ;
1789
+ let ( _, mut rx) = channel:: serve ( bootstrap_addr. clone ( ) ) . await . unwrap ( ) ;
1790
+
1791
+ let extent = extent ! ( host = 1 , gpu = 1 ) ;
1792
+ let tx = channel:: dial ( serve_addr. clone ( ) ) . unwrap ( ) ;
1793
+ let test_world_id: WorldId = id ! ( test_world_id) ;
1794
+ let test_trace_id = "test_trace_id_12345" ;
1795
+
1796
+ // Create a mock alloc that we can verify receives the correct trace id
1797
+ let mut alloc = MockAlloc :: new ( ) ;
1798
+ alloc. expect_world_id ( ) . return_const ( test_world_id. clone ( ) ) ;
1799
+ alloc. expect_extent ( ) . return_const ( extent. clone ( ) ) ;
1800
+ alloc. expect_next ( ) . return_const ( None ) ;
1801
+
1802
+ // Create a mock allocator that captures the AllocSpec passed to it
1803
+ let mut allocator = MockAllocator :: new ( ) ;
1804
+ allocator
1805
+ . expect_allocate ( )
1806
+ . times ( 1 )
1807
+ . withf ( move |spec : & AllocSpec | {
1808
+ // Verify that the trace id is correctly set in the constraints
1809
+ spec. constraints
1810
+ . match_labels
1811
+ . get ( CLIENT_TRACE_ID_LABEL )
1812
+ . is_some_and ( |trace_id| trace_id == test_trace_id)
1813
+ } )
1814
+ . return_once ( |_| Ok ( MockAllocWrapper :: new ( alloc) ) ) ;
1815
+
1816
+ let remote_allocator = RemoteProcessAllocator :: new ( ) ;
1817
+ let handle = tokio:: spawn ( {
1818
+ let remote_allocator = remote_allocator. clone ( ) ;
1819
+ async move {
1820
+ remote_allocator
1821
+ . start_with_allocator ( serve_addr, allocator, None )
1822
+ . await
1823
+ }
1824
+ } ) ;
1825
+
1826
+ // Send allocate message with client context containing trace id
1827
+ tx. send ( RemoteProcessAllocatorMessage :: Allocate {
1828
+ view : extent. clone ( ) . into ( ) ,
1829
+ bootstrap_addr,
1830
+ hosts : vec ! [ ] ,
1831
+ client_context : Some ( ClientContext {
1832
+ trace_id : test_trace_id. to_string ( ) ,
1833
+ } ) ,
1834
+ } )
1835
+ . await
1836
+ . unwrap ( ) ;
1837
+
1838
+ // Verify we get the allocated message
1839
+ let m = rx. recv ( ) . await . unwrap ( ) ;
1840
+ assert_matches ! (
1841
+ m,
1842
+ RemoteProcessProcStateMessage :: Allocated { world_id, view }
1843
+ if world_id == test_world_id && view. extent( ) == extent
1844
+ ) ;
1845
+
1846
+ // Verify we get the done message since the mock alloc returns None immediately
1847
+ let m = rx. recv ( ) . await . unwrap ( ) ;
1848
+ assert_matches ! ( m, RemoteProcessProcStateMessage :: Done ( world_id) if world_id == test_world_id) ;
1849
+
1850
+ remote_allocator. terminate ( ) ;
1851
+ handle. await . unwrap ( ) . unwrap ( ) ;
1852
+ }
1853
+
1854
+ #[ timed_test:: async_timed_test( timeout_secs = 15 ) ]
1855
+ async fn test_trace_id_propagation_no_client_context ( ) {
1856
+ let config = hyperactor:: config:: global:: lock ( ) ;
1857
+ let _guard = config. override_key (
1858
+ hyperactor:: config:: REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL ,
1859
+ Duration :: from_secs ( 60 ) ,
1860
+ ) ;
1861
+ hyperactor_telemetry:: initialize_logging ( ClockKind :: default ( ) ) ;
1862
+ let serve_addr = ChannelAddr :: any ( ChannelTransport :: Unix ) ;
1863
+ let bootstrap_addr = ChannelAddr :: any ( ChannelTransport :: Unix ) ;
1864
+ let ( _, mut rx) = channel:: serve ( bootstrap_addr. clone ( ) ) . await . unwrap ( ) ;
1865
+
1866
+ let extent = extent ! ( host = 1 , gpu = 1 ) ;
1867
+ let tx = channel:: dial ( serve_addr. clone ( ) ) . unwrap ( ) ;
1868
+ let test_world_id: WorldId = id ! ( test_world_id) ;
1869
+
1870
+ // Create a mock alloc
1871
+ let mut alloc = MockAlloc :: new ( ) ;
1872
+ alloc. expect_world_id ( ) . return_const ( test_world_id. clone ( ) ) ;
1873
+ alloc. expect_extent ( ) . return_const ( extent. clone ( ) ) ;
1874
+ alloc. expect_next ( ) . return_const ( None ) ;
1875
+
1876
+ // Create a mock allocator that verifies no trace id is set when client_context is None
1877
+ let mut allocator = MockAllocator :: new ( ) ;
1878
+ allocator
1879
+ . expect_allocate ( )
1880
+ . times ( 1 )
1881
+ . withf ( move |spec : & AllocSpec | {
1882
+ // Verify that no trace id is set in the constraints when client_context is None
1883
+ spec. constraints . match_labels . is_empty ( )
1884
+ } )
1885
+ . return_once ( |_| Ok ( MockAllocWrapper :: new ( alloc) ) ) ;
1886
+
1887
+ let remote_allocator = RemoteProcessAllocator :: new ( ) ;
1888
+ let handle = tokio:: spawn ( {
1889
+ let remote_allocator = remote_allocator. clone ( ) ;
1890
+ async move {
1891
+ remote_allocator
1892
+ . start_with_allocator ( serve_addr, allocator, None )
1893
+ . await
1894
+ }
1895
+ } ) ;
1896
+
1897
+ // Send allocate message without client context
1898
+ tx. send ( RemoteProcessAllocatorMessage :: Allocate {
1899
+ view : extent. clone ( ) . into ( ) ,
1900
+ bootstrap_addr,
1901
+ hosts : vec ! [ ] ,
1902
+ client_context : None ,
1903
+ } )
1904
+ . await
1905
+ . unwrap ( ) ;
1906
+
1907
+ // Verify we get the allocated message
1908
+ let m = rx. recv ( ) . await . unwrap ( ) ;
1909
+ assert_matches ! (
1910
+ m,
1911
+ RemoteProcessProcStateMessage :: Allocated { world_id, view }
1912
+ if world_id == test_world_id && view. extent( ) == extent
1913
+ ) ;
1914
+
1915
+ // Verify we get the done message since the mock alloc returns None immediately
1916
+ let m = rx. recv ( ) . await . unwrap ( ) ;
1917
+ assert_matches ! ( m, RemoteProcessProcStateMessage :: Done ( world_id) if world_id == test_world_id) ;
1918
+
1919
+ remote_allocator. terminate ( ) ;
1920
+ handle. await . unwrap ( ) . unwrap ( ) ;
1921
+ }
1750
1922
}
1751
1923
1752
1924
#[ cfg( test) ]
0 commit comments