Skip to content

Commit 4d2e334

Browse files
zdevitometa-codesync[bot]
authored andcommitted
Removing remaining pipe functionality (#1966)
Summary: Pull Request resolved: #1966 ghstack-source-id: 325007412 Reviewed By: mariusae Differential Revision: D87599488 fbshipit-source-id: 796b71b9c844e280b50171c973d063df35198b01
1 parent 426225e commit 4d2e334

File tree

5 files changed

+25
-907
lines changed

5 files changed

+25
-907
lines changed

monarch_tensor_worker/src/lib.rs

Lines changed: 19 additions & 195 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@
3030
mod borrow;
3131
mod comm;
3232
pub mod device_mesh;
33-
pub mod pipe;
34-
pub mod py_pipe;
3533
pub mod stream;
3634
pub mod test_util;
3735

@@ -71,7 +69,6 @@ use monarch_messages::controller::Seq;
7169
use monarch_messages::wire_value::WireValue;
7270
use monarch_messages::worker::ActorCallParams;
7371
use monarch_messages::worker::ActorMethodParams;
74-
use monarch_messages::worker::CallFunctionError;
7572
use monarch_messages::worker::CallFunctionParams;
7673
use monarch_messages::worker::Factory;
7774
use monarch_messages::worker::Reduction;
@@ -84,8 +81,6 @@ use monarch_messages::worker::WorkerMessageHandler;
8481
use monarch_messages::worker::WorkerParams;
8582
use monarch_types::PyTree;
8683
use ndslice::Slice;
87-
use pipe::PipeActor;
88-
use pipe::PipeParams;
8984
use pyo3::Python;
9085
use pyo3::types::PyAnyMethods;
9186
use serde::Deserialize;
@@ -173,8 +168,6 @@ pub struct WorkerActor {
173168
borrows: HashMap<u64, Borrow>,
174169
comm: Option<ActorHandle<NcclCommActor>>,
175170
controller_actor: ActorRef<ControllerActor>,
176-
/// Pipes created for the worker.
177-
pipes: HashMap<Ref, ActorHandle<PipeActor>>,
178171
/// Remember the process groups "created" via `CreateRemoteProcessGroup` for
179172
/// subsequent `CallFunction` calls, as this is where the actual allocation
180173
/// will happen.
@@ -244,7 +237,6 @@ impl Actor for WorkerActor {
244237
borrows: HashMap::new(),
245238
comm: None,
246239
controller_actor,
247-
pipes: HashMap::new(),
248240
remote_process_groups: HashMap::new(),
249241
send_recv_comms: HashMap::new(),
250242
recordings: HashMap::new(),
@@ -648,47 +640,18 @@ impl WorkerMessageHandler for WorkerActor {
648640

649641
async fn create_pipe(
650642
&mut self,
651-
cx: &hyperactor::Context<Self>,
652-
result: Ref,
643+
_cx: &hyperactor::Context<Self>,
644+
_result: Ref,
653645
// TODO(agallagher): This is used in the python impl to name the socket
654646
// path to use for comms, but we don't currently use a named socket.
655647
_key: String,
656-
function: ResolvableFunction,
657-
max_messages: i64,
658-
device_mesh: Ref,
659-
args: Vec<WireValue>,
660-
kwargs: HashMap<String, WireValue>,
648+
_function: ResolvableFunction,
649+
_max_messages: i64,
650+
_device_mesh: Ref,
651+
_args: Vec<WireValue>,
652+
_kwargs: HashMap<String, WireValue>,
661653
) -> Result<()> {
662-
println!("CREATE PIPE1 {}", result);
663-
let args: Vec<PyTree<RValue>> = args
664-
.into_iter()
665-
.map(|object| RValue::PyObject(object.into_py_object().unwrap()).into())
666-
.collect();
667-
let kwargs: HashMap<_, PyTree<RValue>> = kwargs
668-
.into_iter()
669-
.map(|(k, object)| (k, RValue::PyObject(object.into_py_object().unwrap()).into()))
670-
.collect();
671-
let device_mesh = self.device_meshes.get(&device_mesh).ok_or_else(|| {
672-
CallFunctionError::Error(anyhow::anyhow!("ref not found: {}", device_mesh))
673-
})?;
674-
println!("CREATE PIPE2 {}", result);
675-
// TODO(agallagher): Fix error prop. (When pipe is read from the pipes dict if it had an error it should cause a dependent error in send_value not an actor error as it does now)
676-
let pipe = PipeActor::spawn(
677-
cx,
678-
PipeParams {
679-
function,
680-
max_messages,
681-
ranks: device_mesh.0.ranks(),
682-
sizes: device_mesh.0.sizes(),
683-
args,
684-
kwargs,
685-
},
686-
)
687-
.await?;
688-
println!("AFTER CREATE PIPE {}", result);
689-
690-
self.pipes.insert(result, pipe);
691-
Ok(())
654+
panic!("create_pipe is no longer implemented")
692655
}
693656

694657
async fn send_tensor(
@@ -818,18 +781,11 @@ impl WorkerMessageHandler for WorkerActor {
818781
.collect()
819782
};
820783

821-
let pipe = if let Some(destination) = destination {
822-
let pipe = self
823-
.pipes
824-
.get(&destination)
825-
.ok_or_else(|| anyhow::anyhow!("invalid pipe id: {:#?}", destination))?
826-
.port();
827-
Some(pipe)
828-
} else {
829-
None
830-
};
831-
// Resolve the value on the stream, then send the value to the pipe if provided,
832-
// or back to the controller if not.
784+
if destination.is_some() {
785+
panic!("send_value with pipe destination is no longer implemented")
786+
}
787+
788+
// Resolve the value on the stream, then send the value back to the controller.
833789
stream
834790
.send_value(
835791
cx,
@@ -840,7 +796,6 @@ impl WorkerMessageHandler for WorkerActor {
840796
args,
841797
kwargs,
842798
device_meshes,
843-
pipe,
844799
)
845800
.await
846801
}
@@ -971,24 +926,13 @@ impl WorkerMessageHandler for WorkerActor {
971926

972927
async fn pipe_recv(
973928
&mut self,
974-
cx: &hyperactor::Context<Self>,
975-
seq: Seq,
976-
results: Vec<Option<Ref>>,
977-
pipe: Ref,
978-
stream: StreamRef,
929+
_cx: &hyperactor::Context<Self>,
930+
_seq: Seq,
931+
_results: Vec<Option<Ref>>,
932+
_pipe: Ref,
933+
_stream: StreamRef,
979934
) -> Result<()> {
980-
self.maybe_add_stream_to_recording(cx, stream).await?;
981-
982-
// Get a port for the pipe
983-
let pipe = self
984-
.pipes
985-
.get(&pipe)
986-
.ok_or_else(|| anyhow::anyhow!("ref not found: {}", pipe))?;
987-
let pipe = pipe.port();
988-
// Resolve the stream.
989-
let stream = self.try_get_stream(stream)?;
990-
// Push result into the stream.
991-
stream.set_value(cx, seq, results, pipe).await
935+
panic!("pipe_recv is no longer implemented")
992936
}
993937

994938
async fn set_ref_unit_tests_only(
@@ -2186,126 +2130,6 @@ mod tests {
21862130
Ok(())
21872131
}
21882132

2189-
#[async_timed_test(timeout_secs = 60)]
2190-
async fn pipe_send_recv() -> Result<()> {
2191-
test_setup()?;
2192-
2193-
let proc = Proc::local();
2194-
let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
2195-
2196-
let handle = proc
2197-
.spawn::<WorkerActor>(
2198-
"worker",
2199-
WorkerParams {
2200-
world_size: 1,
2201-
rank: 0,
2202-
device_index: None,
2203-
controller_actor: controller_ref,
2204-
},
2205-
)
2206-
.await
2207-
.unwrap();
2208-
let (resolve_value_arg, torch_eq_arg1, torch_eq_arg2): (
2209-
PickledPyObject,
2210-
PickledPyObject,
2211-
PickledPyObject,
2212-
) = Python::with_gil(|py| {
2213-
PyResult::Ok((
2214-
PyList::new(py, [2, 3])?.into_any().try_into()?,
2215-
Ref { id: 2 }.into_bound_py_any(py)?.try_into()?,
2216-
Ref { id: 4 }.into_bound_py_any(py)?.try_into()?,
2217-
))
2218-
})?;
2219-
2220-
handle
2221-
.command_group(
2222-
&client,
2223-
vec![
2224-
WorkerMessage::CreateStream {
2225-
id: 0.into(),
2226-
stream_creation: StreamCreationMode::UseDefaultStream,
2227-
},
2228-
WorkerMessage::CreateDeviceMesh {
2229-
result: 1.into(),
2230-
names: vec!["x".into()],
2231-
ranks: Slice::new(0, vec![2], vec![1]).unwrap(),
2232-
},
2233-
// Create a tensor value which we'll send through the pipe.
2234-
WorkerMessage::CallFunction(CallFunctionParams {
2235-
seq: 0.into(),
2236-
results: vec![Some(2.into())],
2237-
mutates: vec![],
2238-
function: "torch.ops.aten.ones.default".into(),
2239-
args: vec![WireValue::IntList(vec![2, 3])],
2240-
kwargs: HashMap::new(),
2241-
stream: 0.into(),
2242-
remote_process_groups: vec![],
2243-
}),
2244-
WorkerMessage::CreatePipe {
2245-
result: 3.into(),
2246-
key: "unused".into(),
2247-
function: "monarch.monarch_tensor_worker.test_utils.handler".into(),
2248-
max_messages: 1,
2249-
mesh: 1.into(),
2250-
args: vec![],
2251-
kwargs: HashMap::new(),
2252-
},
2253-
WorkerMessage::SendValue {
2254-
seq: 1.into(),
2255-
destination: Some(3.into()),
2256-
mutates: vec![],
2257-
function: Some(
2258-
"monarch.monarch_tensor_worker.test_utils.resolve_value".into(),
2259-
),
2260-
args: vec![resolve_value_arg.into()],
2261-
kwargs: HashMap::new(),
2262-
stream: 0.into(),
2263-
},
2264-
WorkerMessage::PipeRecv {
2265-
seq: 2.into(),
2266-
results: vec![Some(4.into())],
2267-
pipe: 3.into(),
2268-
stream: 0.into(),
2269-
},
2270-
WorkerMessage::CallFunction(CallFunctionParams {
2271-
seq: 0.into(),
2272-
results: vec![Some(5.into())],
2273-
mutates: vec![],
2274-
function: "torch.equal".into(),
2275-
args: vec![torch_eq_arg1.into(), torch_eq_arg2.into()],
2276-
kwargs: HashMap::new(),
2277-
stream: 0.into(),
2278-
remote_process_groups: vec![],
2279-
}),
2280-
],
2281-
)
2282-
.await
2283-
.unwrap();
2284-
2285-
let matches: bool = handle
2286-
.get_ref_unit_tests_only(&client, 5.into(), 0.into())
2287-
.await
2288-
.unwrap()
2289-
.unwrap()
2290-
.unwrap()
2291-
.try_into()
2292-
.unwrap();
2293-
assert!(matches);
2294-
2295-
handle.drain_and_stop()?;
2296-
assert_matches!(handle.await, ActorStatus::Stopped);
2297-
2298-
let responses = controller_rx.drain();
2299-
assert_eq!(
2300-
responses.len(),
2301-
0,
2302-
"Expected one response, got: {:#?}",
2303-
responses
2304-
);
2305-
2306-
Ok(())
2307-
}
2308-
23092133
fn get_random_channel_addr() -> ChannelAddr {
23102134
let random_string = rand::thread_rng()
23112135
.sample_iter(&Alphanumeric)

0 commit comments

Comments
 (0)