Skip to content

Commit e3355ba

Browse files
authored
Merge pull request #105 from imbue-ai/danver/cost-estimate
Implement Cost Estimate function
2 parents 4f8477a + f1a60d4 commit e3355ba

File tree

13 files changed

+524
-24
lines changed

13 files changed

+524
-24
lines changed

.beads/issues.jsonl

Lines changed: 8 additions & 1 deletion
Large diffs are not rendered by default.

.beads/last-touched

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
code-102
1+
code-109

scripts/modal_sandbox.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,8 +436,17 @@ def run(command: str):
436436
multiple=True,
437437
help="Environment variable (format: KEY=VALUE)",
438438
)
439+
@click.option(
440+
"--cpu",
441+
type=float,
442+
default=None,
443+
help="CPU cores per sandbox",
444+
)
439445
def create_from_image(
440-
image_id: str, copy_dirs: tuple[str, ...] = (), env_vars: tuple[str, ...] = ()
446+
image_id: str,
447+
copy_dirs: tuple[str, ...] = (),
448+
env_vars: tuple[str, ...] = (),
449+
cpu: float | None = None,
441450
):
442451
"""Create sandbox using existing image_id.
443452
@@ -483,13 +492,16 @@ def create_from_image(
483492

484493
logger.debug("[%.2fs] Creating sandbox...", time.time() - t0)
485494
try:
486-
sandbox = modal.Sandbox.create(
495+
create_kwargs = dict(
487496
app=app,
488497
image=image,
489498
workdir="/app",
490499
timeout=3600,
491500
secrets=secrets,
492501
)
502+
if cpu is not None:
503+
create_kwargs["cpu"] = cpu
504+
sandbox = modal.Sandbox.create(**create_kwargs)
493505
except Exception as e:
494506
logger.error("Failed to create sandbox with image %s: %s", image_id, e)
495507
logger.error(

src/config/schema.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,10 @@ pub struct ModalProviderConfig {
178178
/// making sandbox creation faster.
179179
#[serde(default)]
180180
pub copy_dirs: Vec<String>,
181+
182+
/// CPU cores per sandbox (default: 0.125).
183+
#[serde(default = "default_modal_cpu_cores")]
184+
pub cpu_cores: f64,
181185
}
182186

183187
/// Configuration for custom remote execution provider.
@@ -272,12 +276,24 @@ pub struct DefaultProviderConfig {
272276
/// These are merged with (and override) the current environment.
273277
#[serde(default)]
274278
pub env: HashMap<String, String>,
279+
280+
/// CPU cores per sandbox (default: 1.0).
281+
#[serde(default = "default_cpu_cores")]
282+
pub cpu_cores: f64,
275283
}
276284

277285
fn default_remote_timeout() -> u64 {
278286
3600 // 1 hour
279287
}
280288

289+
fn default_cpu_cores() -> f64 {
290+
1.0
291+
}
292+
293+
fn default_modal_cpu_cores() -> f64 {
294+
0.125
295+
}
296+
281297
/// Configuration for a test group.
282298
///
283299
/// Groups allow segmenting tests for different retry behaviors or filtering.

src/main.rs

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@ enum Commands {
7272
/// Emit a Perfetto trace to {output_dir}/trace.json
7373
#[arg(long)]
7474
trace: bool,
75+
76+
/// Show estimated sandbox cost after run.
77+
///
78+
/// Note: This is calculated client-side using simple formulas and
79+
/// may not reflect actual billing, discounts, or pricing adjustments.
80+
#[arg(long)]
81+
show_estimated_cost: bool,
7582
},
7683

7784
/// Discover tests without running them
@@ -139,6 +146,7 @@ async fn main() -> Result<()> {
139146
env_vars,
140147
no_cache,
141148
trace,
149+
show_estimated_cost,
142150
} => {
143151
run_tests(
144152
&cli.config,
@@ -149,6 +157,7 @@ async fn main() -> Result<()> {
149157
no_cache,
150158
cli.verbose,
151159
trace,
160+
show_estimated_cost,
152161
)
153162
.await
154163
}
@@ -247,6 +256,7 @@ async fn dispatch_framework<P: offload::provider::SandboxProvider>(
247256
copy_dirs: &[CopyDir],
248257
verbose: bool,
249258
tracer: &offload::trace::Tracer,
259+
show_estimated_cost: bool,
250260
) -> Result<i32> {
251261
match &config.framework {
252262
FrameworkConfig::Pytest(f_cfg) => {
@@ -258,6 +268,7 @@ async fn dispatch_framework<P: offload::provider::SandboxProvider>(
258268
copy_dirs,
259269
verbose,
260270
tracer,
271+
show_estimated_cost,
261272
)
262273
.await
263274
}
@@ -270,6 +281,7 @@ async fn dispatch_framework<P: offload::provider::SandboxProvider>(
270281
copy_dirs,
271282
verbose,
272283
tracer,
284+
show_estimated_cost,
273285
)
274286
.await
275287
}
@@ -282,6 +294,7 @@ async fn dispatch_framework<P: offload::provider::SandboxProvider>(
282294
copy_dirs,
283295
verbose,
284296
tracer,
297+
show_estimated_cost,
285298
)
286299
.await
287300
}
@@ -294,6 +307,7 @@ async fn dispatch_framework<P: offload::provider::SandboxProvider>(
294307
copy_dirs,
295308
verbose,
296309
tracer,
310+
show_estimated_cost,
297311
)
298312
.await
299313
}
@@ -310,6 +324,7 @@ async fn run_tests(
310324
no_cache: bool,
311325
verbose: bool,
312326
trace: bool,
327+
show_estimated_cost: bool,
313328
) -> Result<()> {
314329
let tracer = if trace {
315330
offload::trace::Tracer::new()
@@ -435,6 +450,7 @@ async fn run_tests(
435450
&copy_dirs,
436451
verbose,
437452
&tracer,
453+
show_estimated_cost,
438454
)
439455
.await?
440456
}
@@ -465,7 +481,16 @@ async fn run_tests(
465481
info!("No tests to run");
466482
return Ok(());
467483
}
468-
dispatch_framework(&config, &all_tests, provider, &copy_dirs, verbose, &tracer).await?
484+
dispatch_framework(
485+
&config,
486+
&all_tests,
487+
provider,
488+
&copy_dirs,
489+
verbose,
490+
&tracer,
491+
show_estimated_cost,
492+
)
493+
.await?
469494
}
470495
ProviderConfig::Modal(p_cfg) => {
471496
// Run discovery and image preparation concurrently
@@ -494,7 +519,16 @@ async fn run_tests(
494519
info!("No tests to run");
495520
return Ok(());
496521
}
497-
dispatch_framework(&config, &all_tests, provider, &copy_dirs, verbose, &tracer).await?
522+
dispatch_framework(
523+
&config,
524+
&all_tests,
525+
provider,
526+
&copy_dirs,
527+
verbose,
528+
&tracer,
529+
show_estimated_cost,
530+
)
531+
.await?
498532
}
499533
};
500534

@@ -515,6 +549,7 @@ async fn run_tests(
515549

516550
/// Run all tests with a single orchestrator call.
517551
/// Returns the exit code (0 = success, 1 = failures/not run, 2 = flaky only).
552+
#[allow(clippy::too_many_arguments)]
518553
async fn run_all_tests<P, D>(
519554
config: &config::Config,
520555
tests: &[TestRecord],
@@ -523,6 +558,7 @@ async fn run_all_tests<P, D>(
523558
copy_dirs: &[CopyDir],
524559
verbose: bool,
525560
tracer: &offload::trace::Tracer,
561+
show_estimated_cost: bool,
526562
) -> Result<i32>
527563
where
528564
P: offload::provider::SandboxProvider,
@@ -565,7 +601,13 @@ where
565601
.context("Failed to create sandboxes")?;
566602
drop(_pool_span);
567603

568-
let orchestrator = Orchestrator::new(config.clone(), framework, verbose, tracer.clone());
604+
let orchestrator = Orchestrator::new(
605+
config.clone(),
606+
framework,
607+
verbose,
608+
tracer.clone(),
609+
show_estimated_cost,
610+
);
569611

570612
let result = orchestrator.run_with_tests(tests, sandbox_pool).await?;
571613

@@ -663,6 +705,7 @@ fn init_config(provider: &str, framework: &str) -> Result<()> {
663705
timeout_secs: 3600,
664706
copy_dirs: vec![],
665707
env: HashMap::new(),
708+
cpu_cores: 1.0,
666709
}),
667710
_ => {
668711
eprintln!("Unknown provider: {}. Use: local, default", provider);

src/orchestrator.rs

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use tracing::{debug, error, info, warn};
1414

1515
use crate::config::Config;
1616
use crate::framework::{TestFramework, TestInstance, TestRecord};
17-
use crate::provider::Sandbox;
17+
use crate::provider::{CostEstimate, Sandbox};
1818
use crate::report::{MasterJunitReport, load_test_durations, print_summary};
1919

2020
pub use pool::SandboxPool;
@@ -66,6 +66,9 @@ pub struct RunResult {
6666

6767
/// Wall-clock duration of the entire test run.
6868
pub duration: Duration,
69+
70+
/// Estimated cost of the test run (aggregated from all sandboxes).
71+
pub estimated_cost: CostEstimate,
6972
}
7073

7174
impl RunResult {
@@ -109,6 +112,7 @@ pub struct Orchestrator<S, D> {
109112
framework: D,
110113
verbose: bool,
111114
tracer: crate::trace::Tracer,
115+
show_cost: bool,
112116
_sandbox: std::marker::PhantomData<S>,
113117
}
114118

@@ -125,12 +129,20 @@ where
125129
/// * `framework` - Test framework for running tests
126130
/// * `verbose` - Whether to show verbose output (streaming test output)
127131
/// * `tracer` - Performance tracer for emitting trace events
128-
pub fn new(config: Config, framework: D, verbose: bool, tracer: crate::trace::Tracer) -> Self {
132+
/// * `show_cost` - Whether to display cost estimate in summary
133+
pub fn new(
134+
config: Config,
135+
framework: D,
136+
verbose: bool,
137+
tracer: crate::trace::Tracer,
138+
show_cost: bool,
139+
) -> Self {
129140
Self {
130141
config,
131142
framework,
132143
verbose,
133144
tracer,
145+
show_cost,
134146
_sandbox: std::marker::PhantomData,
135147
}
136148
}
@@ -199,6 +211,7 @@ where
199211
flaky: 0,
200212
not_run: 0,
201213
duration: start.elapsed(),
214+
estimated_cost: CostEstimate::default(),
202215
});
203216
}
204217

@@ -413,20 +426,22 @@ where
413426

414427
// Use the JUnit total as the authoritative count (passed + failed + flaky = total)
415428
// This ensures passed can never exceed total
429+
// Note: estimated_cost is set to default here and updated after sandbox cleanup
416430
let run_result = RunResult {
417431
total_tests: total_in_junit,
418432
passed: passed + flaky_count, // Flaky tests count as passed
419433
failed,
420434
flaky: flaky_count,
421435
not_run,
422436
duration: start.elapsed(),
437+
estimated_cost: CostEstimate::default(),
423438
};
424439
drop(_agg_span);
425440

426441
progress.finish_and_clear();
427-
print_summary(&run_result);
428442

429443
// Terminate all sandboxes in parallel (after printing results)
444+
// Aggregate cost estimates BEFORE terminating (cost_estimate uses elapsed time)
430445
let _cleanup_span = self.tracer.span(
431446
"sandbox_cleanup",
432447
"orchestrator",
@@ -440,6 +455,17 @@ where
440455
Vec::new()
441456
}
442457
};
458+
459+
// Aggregate cost estimates before terminating sandboxes
460+
let estimated_cost = sandboxes
461+
.iter()
462+
.fold(CostEstimate::default(), |mut acc, sb| {
463+
let cost = sb.cost_estimate();
464+
acc.cpu_seconds += cost.cpu_seconds;
465+
acc.estimated_cost_usd += cost.estimated_cost_usd;
466+
acc
467+
});
468+
443469
let terminate_futures = sandboxes.into_iter().map(|sandbox| async move {
444470
if let Err(e) = sandbox.terminate().await {
445471
warn!("Failed to terminate sandbox {}: {}", sandbox.id(), e);
@@ -448,6 +474,14 @@ where
448474
futures::future::join_all(terminate_futures).await;
449475
drop(_cleanup_span);
450476

477+
// Update run_result with estimated_cost
478+
let run_result = RunResult {
479+
estimated_cost,
480+
..run_result
481+
};
482+
483+
print_summary(&run_result, self.show_cost);
484+
451485
Ok(run_result)
452486
}
453487
}

src/orchestrator/pool.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ impl<S: Sandbox> Default for SandboxPool<S> {
7070
#[cfg(test)]
7171
mod tests {
7272
use super::*;
73-
use crate::provider::{OutputStream, ProviderResult};
73+
use crate::provider::{CostEstimate, OutputStream, ProviderResult};
7474
use async_trait::async_trait;
7575
use std::path::Path;
7676

@@ -95,6 +95,9 @@ mod tests {
9595
async fn terminate(&self) -> ProviderResult<()> {
9696
Ok(())
9797
}
98+
fn cost_estimate(&self) -> CostEstimate {
99+
CostEstimate::default()
100+
}
98101
}
99102

100103
struct FakeProvider;

0 commit comments

Comments
 (0)