diff --git a/monarch_extension/src/convert.rs b/monarch_extension/src/convert.rs index a0dd02b80..68b42f5ad 100644 --- a/monarch_extension/src/convert.rs +++ b/monarch_extension/src/convert.rs @@ -330,21 +330,6 @@ fn create_map(py: Python) -> HashMap { to_stream: p.parseStreamRef("to_stream")?, }) }); - m.insert(key("CreatePipe"), |p| { - let function = p.parseFunction("function")?; - let args = p.parse("args")?; - let kwargs = p.parse("kwargs")?; - let (args, kwargs) = func_call_args_to_wire_values(Some(&function), &args, &kwargs)?; - Ok(WorkerMessage::CreatePipe { - result: p.parseRef("result")?, - key: p.parse("key")?, - function, - max_messages: p.parse("max_messages")?, - mesh: p.parseRef("device_mesh")?, - args, - kwargs, - }) - }); m.insert(key("SendValue"), |p| { let function = p.parseOptionalFunction("function")?; let args: Bound<'_, PyTuple> = p.parse("args")?; diff --git a/monarch_messages/src/wire_value.rs b/monarch_messages/src/wire_value.rs index 3e0b1fc1b..ddbab72b2 100644 --- a/monarch_messages/src/wire_value.rs +++ b/monarch_messages/src/wire_value.rs @@ -171,110 +171,12 @@ impl From for WireValue { } } -impl WireValue { - fn from_pyobject_with_torch_op_arg_type( - obj: Bound<'_, PyAny>, - type_: &torch_sys::call_op::TypePtr, - num_elements: i32, - allow_nums_as_tensors: bool, - ) -> PyResult { - if type_.is_tensor() || type_.is_optional_tensor() { - if type_.is_optional_tensor() && obj.is_none() { - return Ok(WireValue::None(())); - } else if let Ok(ref_) = Ref::from_py_object(&obj) { - return Ok(WireValue::Ref(ref_)); - } - } - if type_.is_tensor_list() || type_.is_optional_tensor_list() { - if type_.is_optional_tensor_list() && obj.is_none() { - return Ok(WireValue::None(())); - } - let list = obj.downcast::()?; - let len = list.len(); - if len == 0 { - return Ok(WireValue::RefList(vec![])); - } - // SAFETY: We know it is within bounds - let item = unsafe { list.get_item_unchecked(0) }; - if let Ok(ref_) = Ref::from_py_object(&item) { - let mut ref_list = Vec::with_capacity(len); - ref_list.push(ref_); - for item in list.iter().skip(1) { - ref_list.push(Ref::from_py_object(&item).map_err(|_| { - PyValueError::new_err(format!( - "Expected homogeneous list of refs got: {:?}", - list - )) - })?); - } - return Ok(WireValue::RefList(ref_list)); - } - } - OpaqueIValue::from_py_object_with_type(obj, type_, num_elements, allow_nums_as_tensors) - .map(WireValue::IValue) - } -} - pub fn func_call_args_to_wire_values( - func: Option<&ResolvableFunction>, - args: &Bound<'_, PyTuple>, - kwargs: &Bound<'_, PyDict>, -) -> PyResult<(Vec, HashMap)> { - if let Some((op, overload)) = func.and_then(|func| func.as_torch_op()) { - torch_op_args_to_wire_values(&op, &overload, args, kwargs) - } else { - python_func_args_to_wire_value(args, kwargs) - } -} - -fn torch_op_args_to_wire_values( - op: &str, - overload: &str, + _func: Option<&ResolvableFunction>, args: &Bound<'_, PyTuple>, kwargs: &Bound<'_, PyDict>, ) -> PyResult<(Vec, HashMap)> { - let args_info = torch_sys::call_op::get_schema_args_info(op, overload).map_err(|err| { - PyValueError::new_err(format!( - "Failed to get the operator schema for {}::{}: {}", - op, overload, err - )) - })?; - - let args = args - .iter() - .zip(&args_info) - .map(|(arg, arg_info)| { - WireValue::from_pyobject_with_torch_op_arg_type( - arg, - arg_info.type_, - arg_info.num_elements, - arg_info.allows_number_as_tensor, - ) - }) - .collect::, _>>()?; - let kwargs = kwargs - .iter() - .map(|(k, v)| { - let key = k.extract::()?; - let arg_info = args_info - .iter() - .find(|arg_info| arg_info.name == key) - .ok_or_else(|| { - PyValueError::new_err(format!( - "Torch op {}::{} does not support kwarg {}", - op, overload, key - )) - })?; - let val = WireValue::from_pyobject_with_torch_op_arg_type( - v, - arg_info.type_, - arg_info.num_elements, - arg_info.allows_number_as_tensor, - )?; - Ok((key, val)) - }) - .collect::, PyErr>>()?; - Ok((args, kwargs)) + python_func_args_to_wire_value(args, kwargs) } fn python_func_args_to_wire_value( diff --git a/monarch_messages/src/worker.rs b/monarch_messages/src/worker.rs index 255571b90..43c1a822c 100644 --- a/monarch_messages/src/worker.rs +++ b/monarch_messages/src/worker.rs @@ -340,21 +340,6 @@ impl ResolvableFunction { } } - pub fn as_torch_op<'a>(&'a self) -> Option<(String, String)> { - match self { - Self::FunctionPath(func) => match func.path.split(".").collect::>().as_slice() { - ["torch", "ops", namespace, op_name, "default"] => { - Some((format!("{}::{}", namespace, op_name), String::new())) - } - ["torch", "ops", namespace, op_name, overload] => { - Some((format!("{}::{}", namespace, op_name), overload.to_string())) - } - _ => None, - }, - _ => None, - } - } - /// For testing: this is a special remote function path that induces a panic /// when called. pub fn panic_if_requested(&self) { @@ -367,13 +352,6 @@ impl ResolvableFunction { _ => (), } } - - pub fn supports_pytree_args(&self) -> bool { - match self { - Self::Cloudpickle(_) => true, - Self::FunctionPath(_) => self.as_torch_op().is_none(), - } - } } impl> From for ResolvableFunction { @@ -800,16 +778,6 @@ pub enum WorkerMessage { to_stream: StreamRef, }, - CreatePipe { - result: Ref, - key: String, - function: ResolvableFunction, - max_messages: i64, - mesh: Ref, - args: Vec, - kwargs: HashMap, - }, - SendValue { seq: Seq, /// Pipe to send value to. If `None`, value is sent to controller. diff --git a/monarch_tensor_worker/src/lib.rs b/monarch_tensor_worker/src/lib.rs index bc8a5bbbb..892685e94 100644 --- a/monarch_tensor_worker/src/lib.rs +++ b/monarch_tensor_worker/src/lib.rs @@ -79,7 +79,6 @@ use monarch_messages::worker::StreamRef; use monarch_messages::worker::WorkerMessage; use monarch_messages::worker::WorkerMessageHandler; use monarch_messages::worker::WorkerParams; -use monarch_types::PyTree; use ndslice::Slice; use pyo3::Python; use pyo3::types::PyAnyMethods; @@ -92,7 +91,6 @@ use stream::StreamParams; use torch_sys::CudaDevice; use torch_sys::DeviceIndex; use torch_sys::Layout; -use torch_sys::RValue; use torch_sys::ScalarType; use torch_sys::TensorCell; use torch_sys::factory_zeros; @@ -383,14 +381,10 @@ impl WorkerMessageHandler for WorkerActor { self.maybe_add_stream_to_recording(cx, params.stream) .await?; - let device_meshes = if params.function.as_torch_op().is_some() { - HashMap::new() - } else { - self.device_meshes - .iter() - .map(|(k, v)| (k.clone(), v.0.clone())) - .collect() - }; + let device_meshes = self.device_meshes + .iter() + .map(|(k, v)| (k.clone(), v.0.clone())) + .collect(); let mut remote_process_groups = HashMap::new(); for remote_process_group_ref in ¶ms.remote_process_groups { @@ -638,22 +632,6 @@ impl WorkerMessageHandler for WorkerActor { Ok(()) } - async fn create_pipe( - &mut self, - _cx: &hyperactor::Context, - _result: Ref, - // TODO(agallagher): This is used in the python impl to name the socket - // path to use for comms, but we don't currently use a named socket. - _key: String, - _function: ResolvableFunction, - _max_messages: i64, - _device_mesh: Ref, - _args: Vec, - _kwargs: HashMap, - ) -> Result<()> { - panic!("create_pipe is no longer implemented") - } - async fn send_tensor( &mut self, cx: &hyperactor::Context, @@ -772,7 +750,7 @@ impl WorkerMessageHandler for WorkerActor { // Resolve the stream. let stream = self.try_get_stream(stream)?; - let device_meshes = if function.as_ref().is_none_or(|f| f.as_torch_op().is_some()) { + let device_meshes = if function.is_none() { HashMap::new() } else { self.device_meshes diff --git a/monarch_tensor_worker/src/stream.rs b/monarch_tensor_worker/src/stream.rs index 605cda893..30129da78 100644 --- a/monarch_tensor_worker/src/stream.rs +++ b/monarch_tensor_worker/src/stream.rs @@ -54,7 +54,6 @@ use monarch_types::PyTree; use monarch_types::SerializablePyErr; use monarch_types::TryIntoPyObjectUnsafe; use pyo3::prelude::*; -use pyo3::types::PyTuple; use tokio::runtime::Handle; use tokio::sync::Mutex; use tokio::task::JoinHandle; @@ -740,34 +739,6 @@ impl StreamActor { Ok(()) } - fn call_torch_op( - &self, - op: String, - overload: String, - args: Vec, - kwargs: HashMap, - ) -> Result, CallFunctionError> { - let args = args - .into_iter() - .map(|arg| self.wire_to_rvalue(arg)) - .collect::, _>>()?; - let kwargs = kwargs - .into_iter() - .map(|(k, v)| self.wire_to_rvalue(v).map(|rvalue| (k, rvalue))) - .collect::, CallFunctionError>>()?; - - let results = torch_sys::call_op::call_op(op, overload, &args, &kwargs, true)?; - - // Handle the case where the op returns nothing and convert it to a list of None. - // This is to ensure handle results does not error out as the client will call - // such a function with expected results of size 1. - Ok(if results.is_empty() { - vec![RValue::None] - } else { - results - }) - } - fn call_python_fn<'py>( &mut self, py: Python<'py>, @@ -1118,21 +1089,17 @@ impl StreamMessageHandler for StreamActor { params.results, ¶ms.mutates, async |self| { - tokio::task::block_in_place(|| match params.function.as_torch_op() { - Some((op, overload)) => { - self.call_torch_op(op, overload, params.args, params.kwargs) - } - _ => self - .call_python_fn_pytree( - cx, - params.function, - params.args, - params.kwargs, - ¶ms.mutates, - device_meshes, - remote_process_groups, - ) - .map(|results| results.into_leaves()), + tokio::task::block_in_place(|| { + self.call_python_fn_pytree( + cx, + params.function, + params.args, + params.kwargs, + ¶ms.mutates, + device_meshes, + remote_process_groups, + ) + .map(|results| results.into_leaves()) }) }, ) @@ -1562,44 +1529,17 @@ impl StreamMessageHandler for StreamActor { } let result = if let Some(function) = function { // If a function was provided, use that to resolve the value. - match function.as_torch_op() { - Some((op, overload)) => { - self.call_torch_op(op, overload, args, kwargs) - .map(|rvalues| { - if rvalues.len() == 1 { - Ok(rvalues[0].clone().into()) - } else { - // TODO: Replace with native pytrees when possible - Python::with_gil(|py| { - Ok((|| { - let py_rvalues = rvalues - .into_iter() - // SAFETY: This inherits the unsafety of `try_to_object_unsafe`. - .map(|rvalue| unsafe { - rvalue.try_to_object_unsafe(py) - }) - .collect::, _>>()?; - PyTuple::new(py, &py_rvalues)?.extract::>() - })() - .map_err(SerializablePyErr::from_fn(py))?) - }) - } - })? - } - // Use block-in-place to allow nested callbacks to re-enter the - // runtime to run async code. - _ => tokio::task::block_in_place(|| { - self.call_python_fn_pytree( - cx, - function, - args, - kwargs, - &mutates, - device_meshes, - HashMap::new(), - ) - }), - } + tokio::task::block_in_place(|| { + self.call_python_fn_pytree( + cx, + function, + args, + kwargs, + &mutates, + device_meshes, + HashMap::new(), + ) + }) } else { // If there's no function provided, there should be exactly one arg // and no kwargs.