Skip to content

Commit 867f294

Browse files
author
Magdalena Kotynia
committed
feat: added vlm bench config to test many models
1 parent 2a59c06 commit 867f294

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

src/rai_bench/rai_bench/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .test_models import (
1515
ManipulationO3DEBenchmarkConfig,
1616
ToolCallingAgentBenchmarkConfig,
17+
VLMBenchmarkConfig,
1718
test_dual_agents,
1819
test_models,
1920
)
@@ -28,6 +29,7 @@
2829
__all__ = [
2930
"ManipulationO3DEBenchmarkConfig",
3031
"ToolCallingAgentBenchmarkConfig",
32+
"VLMBenchmarkConfig",
3133
"define_benchmark_logger",
3234
"get_llm_for_benchmark",
3335
"parse_manipulation_o3de_benchmark_args",

src/rai_bench/rai_bench/test_models.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import rai_bench.manipulation_o3de as manipulation_o3de
2525
import rai_bench.tool_calling_agent as tool_calling_agent
26+
import rai_bench.vlm_benchmark as vlm_benchmark
2627
from rai_bench.utils import (
2728
define_benchmark_logger,
2829
get_llm_for_benchmark,
@@ -77,6 +78,25 @@ def name(self) -> str:
7778
return "tool_calling_agent"
7879

7980

81+
class VLMBenchmarkConfig(BenchmarkConfig):
82+
complexities: List[Literal["easy", "medium", "hard"]] = ["easy", "medium", "hard"]
83+
task_types: List[
84+
Literal[
85+
"bool_response_image_task",
86+
"quantity_response_image_task",
87+
"multiple_choice_image_task",
88+
]
89+
] = [
90+
"bool_response_image_task",
91+
"quantity_response_image_task",
92+
"multiple_choice_image_task",
93+
]
94+
95+
@property
96+
def name(self) -> str:
97+
return "vlm"
98+
99+
80100
def test_dual_agents(
81101
multimodal_llms: List[BaseChatModel],
82102
tool_calling_models: List[BaseChatModel],
@@ -211,6 +231,15 @@ def test_models(
211231
experiment_id=experiment_id,
212232
bench_logger=bench_logger,
213233
)
234+
235+
elif isinstance(bench_conf, VLMBenchmarkConfig):
236+
vlm_tasks = vlm_benchmark.get_spatial_tasks()
237+
vlm_benchmark.run_benchmark(
238+
llm=llm,
239+
out_dir=Path(curr_out_dir),
240+
tasks=vlm_tasks,
241+
bench_logger=bench_logger,
242+
)
214243
except Exception as e:
215244
bench_logger.critical(f"BENCHMARK RUN FAILED: {e}")
216245
bench_logger.critical(

0 commit comments

Comments
 (0)