|
| 1 | +//! Core DAG data structures, validator, and async work-stealing scheduler. |
| 2 | +
|
| 3 | +use std::collections::{HashMap, HashSet, VecDeque}; |
| 4 | +use std::time::Instant; |
| 5 | + |
| 6 | +use pyo3::prelude::*; |
| 7 | +use pyo3::types::PyDict; |
| 8 | +use tokio::sync::mpsc; |
| 9 | + |
| 10 | +// ───────────── Node status ───────────── |
| 11 | + |
| 12 | +#[derive(Debug, Clone, PartialEq, Eq)] |
| 13 | +pub enum NodeStatus { |
| 14 | + Completed, |
| 15 | + Failed, |
| 16 | + Cancelled, |
| 17 | +} |
| 18 | + |
| 19 | +impl NodeStatus { |
| 20 | + pub fn as_str(&self) -> &'static str { |
| 21 | + match self { |
| 22 | + Self::Completed => "completed", |
| 23 | + Self::Failed => "failed", |
| 24 | + Self::Cancelled => "cancelled", |
| 25 | + } |
| 26 | + } |
| 27 | +} |
| 28 | + |
| 29 | +// ───────────── Internal types ───────────── |
| 30 | + |
| 31 | +pub struct DagNode { |
| 32 | + pub id: String, |
| 33 | + pub name: String, |
| 34 | + pub dependencies: Vec<String>, |
| 35 | + pub callable: PyObject, |
| 36 | +} |
| 37 | + |
| 38 | +enum Completion { |
| 39 | + Ok { id: String, result: PyObject, ms: f64 }, |
| 40 | + Err { id: String, error: String, ms: f64 }, |
| 41 | +} |
| 42 | + |
| 43 | +pub struct ExecResult { |
| 44 | + pub node_id: String, |
| 45 | + pub status: NodeStatus, |
| 46 | + pub result: Option<PyObject>, |
| 47 | + pub error: Option<String>, |
| 48 | + pub duration_ms: f64, |
| 49 | +} |
| 50 | + |
| 51 | +// ───────────── DAG validation (Kahn's algorithm) ───────────── |
| 52 | + |
| 53 | +pub fn validate_dag(nodes: &HashMap<String, DagNode>) -> Result<Vec<String>, String> { |
| 54 | + let mut in_deg: HashMap<&str, usize> = HashMap::with_capacity(nodes.len()); |
| 55 | + let mut adj: HashMap<&str, Vec<&str>> = HashMap::with_capacity(nodes.len()); |
| 56 | + |
| 57 | + for id in nodes.keys() { |
| 58 | + in_deg.entry(id.as_str()).or_insert(0); |
| 59 | + } |
| 60 | + for (id, node) in nodes { |
| 61 | + for dep in &node.dependencies { |
| 62 | + if !nodes.contains_key(dep) { |
| 63 | + return Err(format!("Node '{}' depends on unknown node '{}'", id, dep)); |
| 64 | + } |
| 65 | + adj.entry(dep.as_str()).or_default().push(id.as_str()); |
| 66 | + *in_deg.entry(id.as_str()).or_insert(0) += 1; |
| 67 | + } |
| 68 | + } |
| 69 | + |
| 70 | + let mut queue: VecDeque<&str> = in_deg.iter() |
| 71 | + .filter(|(_, &d)| d == 0).map(|(&id, _)| id).collect(); |
| 72 | + let mut order: Vec<String> = Vec::with_capacity(nodes.len()); |
| 73 | + |
| 74 | + while let Some(cur) = queue.pop_front() { |
| 75 | + order.push(cur.to_string()); |
| 76 | + if let Some(nxt) = adj.get(cur) { |
| 77 | + for &n in nxt { |
| 78 | + let d = in_deg.get_mut(n).unwrap(); |
| 79 | + *d -= 1; |
| 80 | + if *d == 0 { queue.push_back(n); } |
| 81 | + } |
| 82 | + } |
| 83 | + } |
| 84 | + |
| 85 | + if order.len() != nodes.len() { |
| 86 | + let stuck: Vec<&str> = in_deg.iter() |
| 87 | + .filter(|(_, &d)| d > 0).map(|(&id, _)| id).collect(); |
| 88 | + Err(format!("Cycle detected involving nodes: {:?}", stuck)) |
| 89 | + } else { |
| 90 | + Ok(order) |
| 91 | + } |
| 92 | +} |
| 93 | + |
| 94 | +// ───────────── Cascade-cancel downstream ───────────── |
| 95 | + |
| 96 | +fn cascade_cancel( |
| 97 | + failed_id: &str, |
| 98 | + dependents: &HashMap<String, Vec<String>>, |
| 99 | + failed_set: &mut HashSet<String>, |
| 100 | + remaining: &mut HashMap<String, DagNode>, |
| 101 | +) -> Vec<String> { |
| 102 | + let mut cancelled: Vec<String> = Vec::new(); |
| 103 | + let mut queue: VecDeque<String> = VecDeque::new(); |
| 104 | + if let Some(ds) = dependents.get(failed_id) { |
| 105 | + for d in ds { queue.push_back(d.clone()); } |
| 106 | + } |
| 107 | + while let Some(id) = queue.pop_front() { |
| 108 | + if failed_set.contains(&id) { continue; } |
| 109 | + failed_set.insert(id.clone()); |
| 110 | + remaining.remove(&id); |
| 111 | + cancelled.push(id.clone()); |
| 112 | + if let Some(ds) = dependents.get(&id) { |
| 113 | + for d in ds { |
| 114 | + if !failed_set.contains(d) { queue.push_back(d.clone()); } |
| 115 | + } |
| 116 | + } |
| 117 | + } |
| 118 | + cancelled |
| 119 | +} |
| 120 | + |
| 121 | +// ───────────── Execute single node (spawn_blocking + GIL) ───────────── |
| 122 | + |
| 123 | +async fn run_node( |
| 124 | + node_id: String, |
| 125 | + callable: PyObject, |
| 126 | + dep_results: Vec<(String, PyObject)>, // Vec of (key, value) — avoids Clone |
| 127 | + tx: mpsc::Sender<Completion>, |
| 128 | +) { |
| 129 | + let start = Instant::now(); |
| 130 | + let id = node_id.clone(); |
| 131 | + |
| 132 | + let outcome = tokio::task::spawn_blocking(move || { |
| 133 | + Python::with_gil(|py| -> PyResult<PyObject> { |
| 134 | + let deps = PyDict::new(py); |
| 135 | + for (key, val) in &dep_results { |
| 136 | + deps.set_item(key, val.clone_ref(py))?; |
| 137 | + } |
| 138 | + let result = callable.call1(py, (&deps,))?; |
| 139 | + |
| 140 | + // Handle async callables transparently |
| 141 | + let inspect = py.import("inspect")?; |
| 142 | + let is_coro: bool = inspect |
| 143 | + .call_method1("iscoroutine", (result.bind(py),))? |
| 144 | + .extract()?; |
| 145 | + if is_coro { |
| 146 | + let asyncio = py.import("asyncio")?; |
| 147 | + let awaited = asyncio.call_method1("run", (result.bind(py),))?; |
| 148 | + Ok(awaited.unbind()) |
| 149 | + } else { |
| 150 | + Ok(result) |
| 151 | + } |
| 152 | + }) |
| 153 | + }) |
| 154 | + .await; |
| 155 | + |
| 156 | + let ms = start.elapsed().as_secs_f64() * 1000.0; |
| 157 | + |
| 158 | + let msg = match outcome { |
| 159 | + Ok(Ok(r)) => Completion::Ok { id, result: r, ms }, |
| 160 | + Ok(Err(e)) => { |
| 161 | + let err = Python::with_gil(|py| format!("{}", e.value(py))); |
| 162 | + Completion::Err { id, error: err, ms } |
| 163 | + } |
| 164 | + Err(e) => Completion::Err { id, error: format!("Task panicked: {}", e), ms }, |
| 165 | + }; |
| 166 | + let _ = tx.send(msg).await; |
| 167 | +} |
| 168 | + |
| 169 | +// ───────────── Main entry: execute_dag ───────────── |
| 170 | + |
| 171 | +pub fn execute_dag(mut nodes: HashMap<String, DagNode>) -> Result<Vec<ExecResult>, String> { |
| 172 | + let _topo = validate_dag(&nodes)?; |
| 173 | + let total = nodes.len(); |
| 174 | + if total == 0 { return Ok(Vec::new()); } |
| 175 | + |
| 176 | + // Reverse deps + pending counts |
| 177 | + let mut dependents: HashMap<String, Vec<String>> = HashMap::new(); |
| 178 | + let mut pending: HashMap<String, usize> = HashMap::new(); |
| 179 | + for (id, node) in &nodes { |
| 180 | + pending.insert(id.clone(), node.dependencies.len()); |
| 181 | + for dep in &node.dependencies { |
| 182 | + dependents.entry(dep.clone()).or_default().push(id.clone()); |
| 183 | + } |
| 184 | + } |
| 185 | + |
| 186 | + let cpus = std::thread::available_parallelism().map(|n| n.get()).unwrap_or(4); |
| 187 | + let rt = tokio::runtime::Builder::new_multi_thread() |
| 188 | + .worker_threads(cpus.max(2).min(16)) |
| 189 | + .enable_all() |
| 190 | + .build() |
| 191 | + .map_err(|e| format!("Failed to create tokio runtime: {}", e))?; |
| 192 | + |
| 193 | + rt.block_on(async move { |
| 194 | + let (tx, mut rx) = mpsc::channel::<Completion>(total.max(1)); |
| 195 | + |
| 196 | + // Store results as Vec<(key, PyObject)> to avoid Clone requirement |
| 197 | + let mut results_store: HashMap<String, PyObject> = HashMap::new(); |
| 198 | + let mut exec_results: Vec<ExecResult> = Vec::with_capacity(total); |
| 199 | + let mut completed: usize = 0; |
| 200 | + let mut failed_set: HashSet<String> = HashSet::new(); |
| 201 | + |
| 202 | + // Spawn root nodes |
| 203 | + let roots: Vec<String> = pending.iter() |
| 204 | + .filter(|(_, &c)| c == 0).map(|(id, _)| id.clone()).collect(); |
| 205 | + for id in roots { |
| 206 | + if let Some(node) = nodes.remove(&id) { |
| 207 | + let sender = tx.clone(); |
| 208 | + tokio::task::spawn(async move { |
| 209 | + run_node(node.id, node.callable, Vec::new(), sender).await; |
| 210 | + }); |
| 211 | + } |
| 212 | + } |
| 213 | + |
| 214 | + // Scheduling loop |
| 215 | + while completed < total { |
| 216 | + match rx.recv().await { |
| 217 | + Some(Completion::Ok { id, result, ms }) => { |
| 218 | + // Store result (need GIL to clone_ref) |
| 219 | + Python::with_gil(|py| { |
| 220 | + results_store.insert(id.clone(), result.clone_ref(py)); |
| 221 | + }); |
| 222 | + exec_results.push(ExecResult { |
| 223 | + node_id: id.clone(), |
| 224 | + status: NodeStatus::Completed, |
| 225 | + result: Some(result), |
| 226 | + error: None, |
| 227 | + duration_ms: ms, |
| 228 | + }); |
| 229 | + completed += 1; |
| 230 | + |
| 231 | + // Propagate |
| 232 | + if let Some(ds) = dependents.get(&id) { |
| 233 | + for dep_id in ds { |
| 234 | + if failed_set.contains(dep_id) { continue; } |
| 235 | + if let Some(cnt) = pending.get_mut(dep_id) { |
| 236 | + *cnt -= 1; |
| 237 | + if *cnt == 0 { |
| 238 | + if let Some(node) = nodes.remove(dep_id) { |
| 239 | + // Gather deps as Vec<(String, PyObject)> |
| 240 | + let dep_data: Vec<(String, PyObject)> = |
| 241 | + Python::with_gil(|py| { |
| 242 | + node.dependencies.iter().filter_map(|d| { |
| 243 | + results_store.get(d).map(|r| { |
| 244 | + (d.clone(), r.clone_ref(py)) |
| 245 | + }) |
| 246 | + }).collect() |
| 247 | + }); |
| 248 | + let sender = tx.clone(); |
| 249 | + tokio::task::spawn(async move { |
| 250 | + run_node( |
| 251 | + node.id, node.callable, |
| 252 | + dep_data, sender, |
| 253 | + ).await; |
| 254 | + }); |
| 255 | + } |
| 256 | + } |
| 257 | + } |
| 258 | + } |
| 259 | + } |
| 260 | + } |
| 261 | + |
| 262 | + Some(Completion::Err { id, error, ms }) => { |
| 263 | + failed_set.insert(id.clone()); |
| 264 | + exec_results.push(ExecResult { |
| 265 | + node_id: id.clone(), |
| 266 | + status: NodeStatus::Failed, |
| 267 | + result: None, |
| 268 | + error: Some(error.clone()), |
| 269 | + duration_ms: ms, |
| 270 | + }); |
| 271 | + completed += 1; |
| 272 | + let cancelled = cascade_cancel(&id, &dependents, &mut failed_set, &mut nodes); |
| 273 | + for cid in cancelled { |
| 274 | + exec_results.push(ExecResult { |
| 275 | + node_id: cid, |
| 276 | + status: NodeStatus::Cancelled, |
| 277 | + result: None, |
| 278 | + error: Some(format!("Cancelled: upstream '{}' failed", id)), |
| 279 | + duration_ms: 0.0, |
| 280 | + }); |
| 281 | + completed += 1; |
| 282 | + } |
| 283 | + } |
| 284 | + |
| 285 | + None => break, |
| 286 | + } |
| 287 | + } |
| 288 | + |
| 289 | + Ok(exec_results) |
| 290 | + }) |
| 291 | +} |
0 commit comments