|
23 | 23 |
|
24 | 24 | import rai_bench.manipulation_o3de as manipulation_o3de
|
25 | 25 | import rai_bench.tool_calling_agent as tool_calling_agent
|
| 26 | +import rai_bench.vlm_benchmark as vlm_benchmark |
26 | 27 | from rai_bench.utils import (
|
27 | 28 | define_benchmark_logger,
|
28 | 29 | get_llm_for_benchmark,
|
@@ -77,6 +78,25 @@ def name(self) -> str:
|
77 | 78 | return "tool_calling_agent"
|
78 | 79 |
|
79 | 80 |
|
| 81 | +class VLMBenchmarkConfig(BenchmarkConfig): |
| 82 | + complexities: List[Literal["easy", "medium", "hard"]] = ["easy", "medium", "hard"] |
| 83 | + task_types: List[ |
| 84 | + Literal[ |
| 85 | + "bool_response_image_task", |
| 86 | + "quantity_response_image_task", |
| 87 | + "multiple_choice_image_task", |
| 88 | + ] |
| 89 | + ] = [ |
| 90 | + "bool_response_image_task", |
| 91 | + "quantity_response_image_task", |
| 92 | + "multiple_choice_image_task", |
| 93 | + ] |
| 94 | + |
| 95 | + @property |
| 96 | + def name(self) -> str: |
| 97 | + return "vlm" |
| 98 | + |
| 99 | + |
80 | 100 | def test_dual_agents(
|
81 | 101 | multimodal_llms: List[BaseChatModel],
|
82 | 102 | tool_calling_models: List[BaseChatModel],
|
@@ -211,6 +231,15 @@ def test_models(
|
211 | 231 | experiment_id=experiment_id,
|
212 | 232 | bench_logger=bench_logger,
|
213 | 233 | )
|
| 234 | + |
| 235 | + elif isinstance(bench_conf, VLMBenchmarkConfig): |
| 236 | + vlm_tasks = vlm_benchmark.get_spatial_tasks() |
| 237 | + vlm_benchmark.run_benchmark( |
| 238 | + llm=llm, |
| 239 | + out_dir=Path(curr_out_dir), |
| 240 | + tasks=vlm_tasks, |
| 241 | + bench_logger=bench_logger, |
| 242 | + ) |
214 | 243 | except Exception as e:
|
215 | 244 | bench_logger.critical(f"BENCHMARK RUN FAILED: {e}")
|
216 | 245 | bench_logger.critical(
|
|
0 commit comments