Skip to content

Commit 0f7cd8f

Browse files
authored
Merge pull request #44 from imbue-ai/factor-out-sandbox-creation
Factor out sandbox creation from orchestrator
2 parents 5f91183 + 7533574 commit 0f7cd8f

File tree

4 files changed

+131
-85
lines changed

4 files changed

+131
-85
lines changed

src/lib.rs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@
5353
//! ## Quick Start
5454
//!
5555
//! ```no_run
56-
//! use tokio::sync::Mutex;
57-
//! use offload::config::load_config;
56+
//! use offload::config::{load_config, SandboxConfig};
5857
//! use offload::orchestrator::{Orchestrator, SandboxPool};
5958
//! use offload::provider::local::LocalProvider;
6059
//! use offload::framework::{TestFramework, pytest::PytestFramework};
@@ -73,10 +72,19 @@
7372
//! // Discover tests using the framework
7473
//! let tests = framework.discover(&[]).await?;
7574
//!
75+
//! // Pre-populate sandbox pool
76+
//! let sandbox_config = SandboxConfig {
77+
//! id: "sandbox".to_string(),
78+
//! working_dir: None,
79+
//! env: vec![],
80+
//! copy_dirs: vec![],
81+
//! };
82+
//! let mut sandbox_pool = SandboxPool::new();
83+
//! sandbox_pool.populate(config.offload.max_parallel, &provider, &sandbox_config).await?;
84+
//!
7685
//! // Run tests using the orchestrator
77-
//! let orchestrator = Orchestrator::new(config, provider, framework, &[], false);
78-
//! let sandbox_pool = Mutex::new(SandboxPool::new());
79-
//! let result = orchestrator.run_with_tests(&tests, &sandbox_pool).await?;
86+
//! let orchestrator = Orchestrator::new(config, framework, false);
87+
//! let result = orchestrator.run_with_tests(&tests, sandbox_pool).await?;
8088
//!
8189
//! std::process::exit(result.exit_code());
8290
//! }

src/main.rs

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@ use clap::{Parser, Subcommand};
77
use tracing::{Level, info};
88
use tracing_subscriber::FmtSubscriber;
99

10-
use tokio::sync::Mutex;
11-
12-
use offload::config::{self, FrameworkConfig, ProviderConfig};
10+
use offload::config::{self, FrameworkConfig, ProviderConfig, SandboxConfig};
1311
use offload::framework::{
1412
TestFramework, TestRecord, cargo::CargoFramework, default::DefaultFramework,
1513
pytest::PytestFramework,
@@ -399,24 +397,33 @@ where
399397
P: offload::provider::SandboxProvider,
400398
D: TestFramework,
401399
{
402-
let sandbox_pool = Mutex::new(SandboxPool::new());
403-
404400
// Convert CopyDir to tuples
405401
let copy_dir_tuples: Vec<(PathBuf, PathBuf)> = copy_dirs
406402
.iter()
407403
.map(|cd| (cd.local.clone(), cd.remote.clone()))
408404
.collect();
409405

410-
let orchestrator = Orchestrator::new(
411-
config.clone(),
412-
provider,
413-
framework,
414-
&copy_dir_tuples,
415-
verbose,
416-
);
406+
// Pre-populate sandbox pool
407+
let sandbox_config = SandboxConfig {
408+
id: format!("offload-{}", uuid::Uuid::new_v4()),
409+
working_dir: config
410+
.offload
411+
.working_dir
412+
.as_ref()
413+
.map(|p| p.to_string_lossy().to_string()),
414+
env: Vec::new(),
415+
copy_dirs: copy_dir_tuples.clone(),
416+
};
417+
418+
let mut sandbox_pool = SandboxPool::new();
419+
sandbox_pool
420+
.populate(config.offload.max_parallel, &provider, &sandbox_config)
421+
.await
422+
.context("Failed to create sandboxes")?;
423+
424+
let orchestrator = Orchestrator::new(config.clone(), framework, verbose);
417425

418-
let result = orchestrator.run_with_tests(tests, &sandbox_pool).await?;
419-
sandbox_pool.lock().await.terminate_all().await;
426+
let result = orchestrator.run_with_tests(tests, sandbox_pool).await?;
420427

421428
Ok(result.exit_code())
422429
}

src/orchestrator.rs

Lines changed: 70 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,8 @@
5555
//! # Example
5656
//!
5757
//! ```no_run
58-
//! use tokio::sync::Mutex;
5958
//! use offload::orchestrator::{Orchestrator, SandboxPool};
60-
//! use offload::config::load_config;
59+
//! use offload::config::{load_config, SandboxConfig};
6160
//! use offload::provider::local::LocalProvider;
6261
//! use offload::framework::{TestFramework, pytest::PytestFramework};
6362
//!
@@ -71,10 +70,19 @@
7170
//! // Discover tests using the framework
7271
//! let tests = framework.discover(&[]).await?;
7372
//!
73+
//! // Pre-populate sandbox pool
74+
//! let sandbox_config = SandboxConfig {
75+
//! id: "sandbox".to_string(),
76+
//! working_dir: None,
77+
//! env: vec![],
78+
//! copy_dirs: vec![],
79+
//! };
80+
//! let mut sandbox_pool = SandboxPool::new();
81+
//! sandbox_pool.populate(config.offload.max_parallel, &provider, &sandbox_config).await?;
82+
//!
7483
//! // Run tests using the orchestrator
75-
//! let orchestrator = Orchestrator::new(config, provider, framework, &[], false);
76-
//! let sandbox_pool = Mutex::new(SandboxPool::new());
77-
//! let result = orchestrator.run_with_tests(&tests, &sandbox_pool).await?;
84+
//! let orchestrator = Orchestrator::new(config, framework, false);
85+
//! let result = orchestrator.run_with_tests(&tests, sandbox_pool).await?;
7886
//!
7987
//! if result.success() {
8088
//! println!("All tests passed!");
@@ -98,9 +106,9 @@ use tokio::sync::Mutex;
98106
use tokio_util::sync::CancellationToken;
99107
use tracing::{debug, error, warn};
100108

101-
use crate::config::{Config, SandboxConfig};
109+
use crate::config::Config;
102110
use crate::framework::{TestFramework, TestInstance, TestRecord, TestResult};
103-
use crate::provider::{OutputLine, SandboxProvider};
111+
use crate::provider::{OutputLine, Sandbox};
104112
use crate::report::{MasterJunitReport, print_summary};
105113

106114
pub use pool::SandboxPool;
@@ -198,23 +206,22 @@ impl RunResult {
198206
/// The main orchestrator that coordinates test execution.
199207
///
200208
/// The orchestrator is the top-level component that ties together:
201-
/// - A [`SandboxProvider`] for execution environments
202-
/// - A [`TestFramework`] for finding tests
209+
/// - A pre-populated [`SandboxPool`] of execution environments
210+
/// - A [`TestFramework`] for running tests
203211
///
204-
/// It manages the full lifecycle of a test run: discovery, scheduling,
212+
/// It manages the full lifecycle of a test run: scheduling,
205213
/// parallel execution, retries, and result aggregation.
206214
///
207215
/// # Type Parameters
208216
///
209-
/// - `P`: The sandbox provider type
217+
/// - `S`: The sandbox type (implements [`Sandbox`](crate::provider::Sandbox))
210218
/// - `D`: The test framework type
211219
///
212220
/// # Example
213221
///
214222
/// ```no_run
215-
/// use tokio::sync::Mutex;
216223
/// use offload::orchestrator::{Orchestrator, SandboxPool};
217-
/// use offload::config::load_config;
224+
/// use offload::config::{load_config, SandboxConfig};
218225
/// use offload::provider::local::LocalProvider;
219226
/// use offload::framework::{TestFramework, pytest::PytestFramework};
220227
///
@@ -229,49 +236,48 @@ impl RunResult {
229236
/// // Discover tests using the framework
230237
/// let tests = framework.discover(&[]).await?;
231238
///
239+
/// // Pre-populate sandbox pool
240+
/// let sandbox_config = SandboxConfig {
241+
/// id: "sandbox".to_string(),
242+
/// working_dir: None,
243+
/// env: vec![],
244+
/// copy_dirs: vec![],
245+
/// };
246+
/// let mut sandbox_pool = SandboxPool::new();
247+
/// sandbox_pool.populate(config.offload.max_parallel, &provider, &sandbox_config).await?;
248+
///
232249
/// // Create orchestrator and run tests
233-
/// let orchestrator = Orchestrator::new(config, provider, framework, &[], false);
234-
/// let sandbox_pool = Mutex::new(SandboxPool::new());
235-
/// let result = orchestrator.run_with_tests(&tests, &sandbox_pool).await?;
250+
/// let orchestrator = Orchestrator::new(config, framework, false);
251+
/// let result = orchestrator.run_with_tests(&tests, sandbox_pool).await?;
236252
///
237253
/// std::process::exit(result.exit_code());
238254
/// }
239255
/// ```
240-
pub struct Orchestrator<P, D> {
256+
pub struct Orchestrator<S, D> {
241257
config: Config,
242-
provider: P,
243258
framework: D,
244-
copy_dirs: Vec<(std::path::PathBuf, std::path::PathBuf)>,
245259
verbose: bool,
260+
_sandbox: std::marker::PhantomData<S>,
246261
}
247262

248-
impl<P, D> Orchestrator<P, D>
263+
impl<S, D> Orchestrator<S, D>
249264
where
250-
P: SandboxProvider,
265+
S: Sandbox,
251266
D: TestFramework,
252267
{
253268
/// Creates a new orchestrator with the given components.
254269
///
255270
/// # Arguments
256271
///
257272
/// * `config` - Configuration loaded from TOML
258-
/// * `provider` - Sandbox provider for creating execution environments
259273
/// * `framework` - Test framework for running tests
260-
/// * `copy_dirs` - Directories to copy to sandboxes (local_path, remote_path)
261274
/// * `verbose` - Whether to show verbose output (streaming test output)
262-
pub fn new(
263-
config: Config,
264-
provider: P,
265-
framework: D,
266-
copy_dirs: &[(std::path::PathBuf, std::path::PathBuf)],
267-
verbose: bool,
268-
) -> Self {
275+
pub fn new(config: Config, framework: D, verbose: bool) -> Self {
269276
Self {
270277
config,
271-
provider,
272278
framework,
273-
copy_dirs: copy_dirs.to_vec(),
274279
verbose,
280+
_sandbox: std::marker::PhantomData,
275281
}
276282
}
277283

@@ -296,7 +302,7 @@ where
296302
pub async fn run_with_tests(
297303
&self,
298304
tests: &[TestRecord],
299-
sandbox_pool: &Mutex<SandboxPool<P::Sandbox>>,
305+
mut sandbox_pool: SandboxPool<S>,
300306
) -> anyhow::Result<RunResult> {
301307
let start = std::time::Instant::now();
302308

@@ -352,10 +358,21 @@ where
352358
let scheduler = Scheduler::new(self.config.offload.max_parallel);
353359
let batches = scheduler.schedule(&tests_to_run);
354360

361+
// Take sandboxes from pool - must match batch count
362+
let sandboxes = sandbox_pool.take_all();
363+
assert_eq!(
364+
sandboxes.len(),
365+
batches.len(),
366+
"sandbox count ({}) must match batch count ({})",
367+
sandboxes.len(),
368+
batches.len()
369+
);
370+
355371
debug!(
356-
"Scheduled {} tests into {} batches",
372+
"Scheduled {} tests into {} batches with {} sandboxes",
357373
tests_to_run.len(),
358-
batches.len()
374+
batches.len(),
375+
sandboxes.len()
359376
);
360377

361378
// Shared JUnit report for accumulating results and early stopping
@@ -366,50 +383,29 @@ where
366383
let all_passed = Arc::new(AtomicBool::new(false));
367384
let cancellation_token = CancellationToken::new();
368385

386+
// Collect sandboxes back after use for termination
387+
let sandboxes_for_cleanup = Arc::new(Mutex::new(Vec::new()));
388+
369389
// Run tests in parallel
370390
// Execute batches concurrently using scoped spawns (no 'static required)
371391
tokio_scoped::scope(|scope| {
372-
for (batch_idx, batch) in batches.into_iter().enumerate() {
373-
let provider = &self.provider;
392+
for (batch_idx, (sandbox, batch)) in sandboxes.into_iter().zip(batches).enumerate() {
374393
let framework = &self.framework;
375394
let config = &self.config;
376395
let progress = &progress;
377396
let verbose = self.verbose;
378397
let junit_report = Arc::clone(&junit_report);
379398
let all_passed = Arc::clone(&all_passed);
380399
let cancellation_token = cancellation_token.clone();
400+
let sandboxes_for_cleanup = Arc::clone(&sandboxes_for_cleanup);
381401

382402
scope.spawn(async move {
383403
// Early exit if all tests have already passed
384404
if all_passed.load(Ordering::SeqCst) {
385405
debug!("Batch {} skipped - all tests have passed", batch_idx);
406+
sandboxes_for_cleanup.lock().await.push(sandbox);
386407
return;
387408
}
388-
// Take sandbox from pool or create new one
389-
let sandbox = {
390-
let existing = sandbox_pool.lock().await.take_one();
391-
if let Some(s) = existing {
392-
s
393-
} else {
394-
let sandbox_config = SandboxConfig {
395-
id: format!("offload-{}-{}", uuid::Uuid::new_v4(), batch_idx),
396-
working_dir: config
397-
.offload
398-
.working_dir
399-
.as_ref()
400-
.map(|p| p.to_string_lossy().to_string()),
401-
env: Vec::new(),
402-
copy_dirs: self.copy_dirs.clone(),
403-
};
404-
match provider.create_sandbox(&sandbox_config).await {
405-
Ok(s) => s,
406-
Err(e) => {
407-
error!("Failed to create sandbox: {}", e);
408-
return;
409-
}
410-
}
411-
}
412-
};
413409

414410
let mut runner = TestRunner::new(
415411
sandbox,
@@ -420,7 +416,7 @@ where
420416
.with_junit_report(Arc::clone(&junit_report));
421417

422418
// Enable output callback only in verbose mode
423-
if config.offload.stream_output && self.verbose {
419+
if config.offload.stream_output && verbose {
424420
let callback: OutputCallback = Arc::new(|test_id, line| match line {
425421
OutputLine::Stdout(s) => println!("[{}] {}", test_id, s),
426422
OutputLine::Stderr(s) => eprintln!("[{}] {}", test_id, s),
@@ -464,9 +460,9 @@ where
464460
// Update progress for completed batch
465461
progress.inc(batch.len() as u64);
466462

467-
// Return sandbox to pool for reuse (don't terminate)
463+
// Collect sandbox for cleanup
468464
let sandbox = runner.into_sandbox();
469-
sandbox_pool.lock().await.add(sandbox);
465+
sandboxes_for_cleanup.lock().await.push(sandbox);
470466
});
471467
}
472468
});
@@ -516,6 +512,15 @@ where
516512
progress.finish_and_clear();
517513
print_summary(&run_result);
518514

515+
// Terminate all sandboxes in parallel (after printing results)
516+
let sandboxes: Vec<_> = sandboxes_for_cleanup.lock().await.drain(..).collect();
517+
let terminate_futures = sandboxes.into_iter().map(|sandbox| async move {
518+
if let Err(e) = sandbox.terminate().await {
519+
warn!("Failed to terminate sandbox {}: {}", sandbox.id(), e);
520+
}
521+
});
522+
futures::future::join_all(terminate_futures).await;
523+
519524
Ok(run_result)
520525
}
521526
}

0 commit comments

Comments
 (0)