|
11 | 11 | /// and defines methods on them. |
12 | 12 | use std::{collections::HashMap, path::PathBuf, sync::Arc, time::Duration}; |
13 | 13 |
|
14 | | -use evaluations::{EvaluationCoreArgs, EvaluationVariant, run_evaluation_core_streaming}; |
| 14 | +use evaluations::{ |
| 15 | + EvaluationCoreArgs, EvaluationFunctionConfig, EvaluationFunctionConfigTable, EvaluationVariant, |
| 16 | + run_evaluation_core_streaming, |
| 17 | +}; |
15 | 18 | use futures::StreamExt; |
16 | 19 | use pyo3::{ |
17 | 20 | IntoPyObjectExt, |
@@ -43,8 +46,8 @@ use tensorzero_core::{ |
43 | 46 | OptimizationJobInfoPyClass, OptimizationJobStatus, UninitializedOptimizerInfo, |
44 | 47 | dicl::UninitializedDiclOptimizationConfig, fireworks_sft::UninitializedFireworksSFTConfig, |
45 | 48 | gcp_vertex_gemini_sft::UninitializedGCPVertexGeminiSFTConfig, |
46 | | - openai_rft::UninitializedOpenAIRFTConfig, openai_sft::UninitializedOpenAISFTConfig, |
47 | | - together_sft::UninitializedTogetherSFTConfig, |
| 49 | + gepa::UninitializedGEPAConfig, openai_rft::UninitializedOpenAIRFTConfig, |
| 50 | + openai_sft::UninitializedOpenAISFTConfig, together_sft::UninitializedTogetherSFTConfig, |
48 | 51 | }, |
49 | 52 | tool::ProviderTool, |
50 | 53 | variant::{ |
@@ -107,6 +110,7 @@ fn tensorzero(m: &Bound<'_, PyModule>) -> PyResult<()> { |
107 | 110 | m.add_class::<UninitializedFireworksSFTConfig>()?; |
108 | 111 | m.add_class::<UninitializedDiclOptimizationConfig>()?; |
109 | 112 | m.add_class::<UninitializedGCPVertexGeminiSFTConfig>()?; |
| 113 | + m.add_class::<UninitializedGEPAConfig>()?; |
110 | 114 | m.add_class::<UninitializedTogetherSFTConfig>()?; |
111 | 115 | m.add_class::<Datapoint>()?; |
112 | 116 | m.add_class::<ResolvedInput>()?; |
@@ -1431,10 +1435,32 @@ impl TensorZeroGateway { |
1431 | 1435 | }) |
1432 | 1436 | .transpose()?; |
1433 | 1437 |
|
| 1438 | + // Extract evaluation config from app_state |
| 1439 | + let evaluation_config = app_state |
| 1440 | + .config |
| 1441 | + .evaluations |
| 1442 | + .get(&evaluation_name) |
| 1443 | + .ok_or_else(|| { |
| 1444 | + pyo3::exceptions::PyValueError::new_err(format!( |
| 1445 | + "evaluation '{evaluation_name}' not found" |
| 1446 | + )) |
| 1447 | + })? |
| 1448 | + .clone(); |
| 1449 | + |
| 1450 | + // Build function configs table from all functions in the config |
| 1451 | + let function_configs: EvaluationFunctionConfigTable = app_state |
| 1452 | + .config |
| 1453 | + .functions |
| 1454 | + .iter() |
| 1455 | + .map(|(name, func)| (name.clone(), EvaluationFunctionConfig::from(func.as_ref()))) |
| 1456 | + .collect(); |
| 1457 | + let function_configs = Arc::new(function_configs); |
| 1458 | + |
1434 | 1459 | let core_args = EvaluationCoreArgs { |
1435 | 1460 | tensorzero_client: client.clone(), |
1436 | 1461 | clickhouse_client: app_state.clickhouse_connection_info.clone(), |
1437 | | - config: app_state.config.clone(), |
| 1462 | + evaluation_config, |
| 1463 | + function_configs, |
1438 | 1464 | evaluation_name, |
1439 | 1465 | evaluation_run_id, |
1440 | 1466 | dataset_name, |
@@ -2646,10 +2672,32 @@ impl AsyncTensorZeroGateway { |
2646 | 2672 |
|
2647 | 2673 | let evaluation_run_id = uuid::Uuid::now_v7(); |
2648 | 2674 |
|
| 2675 | + // Extract evaluation config from app_state |
| 2676 | + let evaluation_config = app_state |
| 2677 | + .config |
| 2678 | + .evaluations |
| 2679 | + .get(&evaluation_name) |
| 2680 | + .ok_or_else(|| { |
| 2681 | + pyo3::exceptions::PyValueError::new_err(format!( |
| 2682 | + "evaluation '{evaluation_name}' not found" |
| 2683 | + )) |
| 2684 | + })? |
| 2685 | + .clone(); |
| 2686 | + |
| 2687 | + // Build function configs table from all functions in the config |
| 2688 | + let function_configs: EvaluationFunctionConfigTable = app_state |
| 2689 | + .config |
| 2690 | + .functions |
| 2691 | + .iter() |
| 2692 | + .map(|(name, func)| (name.clone(), EvaluationFunctionConfig::from(func.as_ref()))) |
| 2693 | + .collect(); |
| 2694 | + let function_configs = Arc::new(function_configs); |
| 2695 | + |
2649 | 2696 | let core_args = EvaluationCoreArgs { |
2650 | 2697 | tensorzero_client: client.clone(), |
2651 | 2698 | clickhouse_client: app_state.clickhouse_connection_info.clone(), |
2652 | | - config: app_state.config.clone(), |
| 2699 | + evaluation_config, |
| 2700 | + function_configs, |
2653 | 2701 | evaluation_name, |
2654 | 2702 | evaluation_run_id, |
2655 | 2703 | dataset_name, |
|
0 commit comments