9
9
use hyperactor:: channel:: ChannelTransport ;
10
10
pub mod mesh_agent;
11
11
12
+ use std:: collections:: HashSet ;
12
13
use std:: ops:: Deref ;
13
14
use std:: str:: FromStr ;
14
15
use std:: sync:: Arc ;
@@ -460,31 +461,69 @@ impl HostMeshRef {
460
461
}
461
462
}
462
463
463
- /// Spawn a ProcMesh onto this host mesh.
464
- // TODO: add an "additional dims" API
465
- pub async fn spawn ( & self , cx : & impl context:: Actor , name : & str ) -> v1:: Result < ProcMesh > {
466
- let name = Name :: new ( name) ;
464
+ /// Spawn a ProcMesh onto this host mesh. The per_host extent specifies the shape
465
+ /// of the procs to spawn on each host.
466
+ pub async fn spawn (
467
+ & self ,
468
+ cx : & impl context:: Actor ,
469
+ name : & str ,
470
+ per_host : Extent ,
471
+ ) -> v1:: Result < ProcMesh > {
472
+ let per_host_labels = per_host. labels ( ) . iter ( ) . collect :: < HashSet < _ > > ( ) ;
473
+ let host_labels = self . region . labels ( ) . iter ( ) . collect :: < HashSet < _ > > ( ) ;
474
+ if !per_host_labels
475
+ . intersection ( & host_labels)
476
+ . collect :: < Vec < _ > > ( )
477
+ . is_empty ( )
478
+ {
479
+ return Err ( v1:: Error :: ConfigurationError ( anyhow:: anyhow!(
480
+ "per_host dims overlap with existing dims when spawning proc mesh"
481
+ ) ) ) ;
482
+ }
483
+
484
+ let labels = self
485
+ . region
486
+ . labels ( )
487
+ . to_vec ( )
488
+ . into_iter ( )
489
+ . chain ( per_host. labels ( ) . to_vec ( ) . into_iter ( ) )
490
+ . collect ( ) ;
491
+ let sizes = self
492
+ . region
493
+ . extent ( )
494
+ . sizes ( )
495
+ . to_vec ( )
496
+ . into_iter ( )
497
+ . chain ( per_host. sizes ( ) . to_vec ( ) . into_iter ( ) )
498
+ . collect ( ) ;
499
+ let extent =
500
+ Extent :: new ( labels, sizes) . map_err ( |err| v1:: Error :: ConfigurationError ( err. into ( ) ) ) ?;
501
+
502
+ let mesh_name = Name :: new ( name) ;
467
503
let mut procs = Vec :: new ( ) ;
468
- for ( rank, host) in self . ranks . iter ( ) . enumerate ( ) {
469
- let _ok = host
470
- . mesh_agent ( )
471
- . create_or_update ( cx, name. clone ( ) , ( ) )
472
- . await
473
- . map_err ( |e| {
474
- v1:: Error :: HostMeshAgentConfigurationError (
475
- host. mesh_agent ( ) . actor_id ( ) . clone ( ) ,
476
- format ! ( "failed while creating proc: {}" , e) ,
477
- )
478
- } ) ?;
479
- procs. push ( ProcRef :: new (
480
- host. named_proc ( & name) ,
481
- rank,
482
- // TODO: specify or retrieve from state instead, to avoid attestation.
483
- ActorRef :: attest ( host. named_proc ( & name) . actor_id ( "agent" , 0 ) ) ,
484
- ) ) ;
504
+ for ( host_rank, host) in self . ranks . iter ( ) . enumerate ( ) {
505
+ for per_host_rank in 0 ..per_host. num_ranks ( ) {
506
+ let proc_name = Name :: new ( format ! ( "{}-{}" , name, per_host_rank) ) ;
507
+ let _ok = host
508
+ . mesh_agent ( )
509
+ . create_or_update ( cx, proc_name. clone ( ) , ( ) )
510
+ . await
511
+ . map_err ( |e| {
512
+ v1:: Error :: HostMeshAgentConfigurationError (
513
+ host. mesh_agent ( ) . actor_id ( ) . clone ( ) ,
514
+ format ! ( "failed while creating proc: {}" , e) ,
515
+ )
516
+ } ) ?;
517
+ procs. push ( ProcRef :: new (
518
+ host. named_proc ( & proc_name) ,
519
+ per_host. num_ranks ( ) * host_rank + per_host_rank,
520
+ // TODO: specify or retrieve from state instead, to avoid attestation.
521
+ ActorRef :: attest ( host. named_proc ( & proc_name) . actor_id ( "agent" , 0 ) ) ,
522
+ ) ) ;
523
+ }
485
524
}
486
525
487
- ProcMesh :: create_owned_unchecked ( cx, name , self . clone ( ) , procs) . await
526
+ ProcMesh :: create_owned_unchecked ( cx, mesh_name , extent , self . clone ( ) , procs) . await
488
527
}
489
528
}
490
529
@@ -621,12 +660,28 @@ mod tests {
621
660
. await
622
661
. unwrap ( ) ;
623
662
624
- let proc_mesh1 = host_mesh. spawn ( instance, "test_1" ) . await . unwrap ( ) ;
663
+ let proc_mesh1 = host_mesh
664
+ . spawn ( instance, "test_1" , Extent :: unity ( ) )
665
+ . await
666
+ . unwrap ( ) ;
625
667
let actor_mesh1: ActorMesh < testactor:: TestActor > =
626
668
proc_mesh1. spawn ( instance, "test" , & ( ) ) . await . unwrap ( ) ;
627
- let proc_mesh2 = host_mesh. spawn ( instance, "test_2" ) . await . unwrap ( ) ;
669
+ let proc_mesh2 = host_mesh
670
+ . spawn ( instance, "test_2" , extent ! ( gpus = 3 , extra = 2 ) )
671
+ . await
672
+ . unwrap ( ) ;
673
+ assert_eq ! (
674
+ proc_mesh2. extent( ) ,
675
+ extent!( replicas = 4 , gpus = 3 , extra = 2 )
676
+ ) ;
677
+ assert_eq ! ( proc_mesh2. values( ) . count( ) , 24 ) ;
628
678
let actor_mesh2: ActorMesh < testactor:: TestActor > =
629
679
proc_mesh2. spawn ( instance, "test" , & ( ) ) . await . unwrap ( ) ;
680
+ assert_eq ! (
681
+ actor_mesh2. extent( ) ,
682
+ extent!( replicas = 4 , gpus = 3 , extra = 2 )
683
+ ) ;
684
+ assert_eq ! ( actor_mesh2. values( ) . count( ) , 24 ) ;
630
685
631
686
// Host meshes can be dereferenced to produce a concrete ref.
632
687
let host_mesh_ref: HostMeshRef = host_mesh. clone ( ) ;
@@ -637,23 +692,24 @@ mod tests {
637
692
) ;
638
693
639
694
// Validate we can cast:
640
-
641
- let ( port, mut rx) = instance. mailbox ( ) . open_port ( ) ;
642
- actor_mesh1
643
- . cast ( instance, testactor:: GetActorId ( port. bind ( ) ) )
644
- . unwrap ( ) ;
645
-
646
- let mut expected_actor_ids: HashSet < _ > = actor_mesh1
647
- . values ( )
648
- . map ( |actor_ref| actor_ref. actor_id ( ) . clone ( ) )
649
- . collect ( ) ;
650
-
651
- while !expected_actor_ids. is_empty ( ) {
652
- let actor_id = rx. recv ( ) . await . unwrap ( ) ;
653
- assert ! (
654
- expected_actor_ids. remove( & actor_id) ,
655
- "got {actor_id}, expect {expected_actor_ids:?}"
656
- ) ;
695
+ for actor_mesh in [ & actor_mesh1, & actor_mesh2] {
696
+ let ( port, mut rx) = instance. mailbox ( ) . open_port ( ) ;
697
+ actor_mesh
698
+ . cast ( instance, testactor:: GetActorId ( port. bind ( ) ) )
699
+ . unwrap ( ) ;
700
+
701
+ let mut expected_actor_ids: HashSet < _ > = actor_mesh
702
+ . values ( )
703
+ . map ( |actor_ref| actor_ref. actor_id ( ) . clone ( ) )
704
+ . collect ( ) ;
705
+
706
+ while !expected_actor_ids. is_empty ( ) {
707
+ let actor_id = rx. recv ( ) . await . unwrap ( ) ;
708
+ assert ! (
709
+ expected_actor_ids. remove( & actor_id) ,
710
+ "got {actor_id}, expect {expected_actor_ids:?}"
711
+ ) ;
712
+ }
657
713
}
658
714
659
715
// Now forward a message through all directed edges across the two meshes.
@@ -719,7 +775,10 @@ mod tests {
719
775
720
776
let instance = testing:: instance ( ) . await ;
721
777
let host_mesh = HostMeshRef :: from_hosts ( hosts) ;
722
- let proc_mesh = host_mesh. spawn ( & instance, "test" ) . await . unwrap ( ) ;
778
+ let proc_mesh = host_mesh
779
+ . spawn ( & testing:: instance ( ) . await , "test" , Extent :: unity ( ) )
780
+ . await
781
+ . unwrap ( ) ;
723
782
let actor_mesh: ActorMesh < testactor:: TestActor > = proc_mesh
724
783
. spawn ( & testing:: instance ( ) . await , "test" , & ( ) )
725
784
. await
0 commit comments