diff --git a/airtbench/container.py b/airtbench/container.py index a151f07..38bebe6 100644 --- a/airtbench/container.py +++ b/airtbench/container.py @@ -13,7 +13,12 @@ def build_container( force_rebuild: bool = False, memory_limit: str = "4g", ) -> str: - docker_client = docker.DockerClient() + try: + docker_client = docker.DockerClient() + except docker.errors.DockerException as e: + raise RuntimeError( + "Docker connection failed: Docker is not running or not accessible", + ) from e docker_file = Path(docker_file) if not docker_file.exists(): diff --git a/airtbench/main.py b/airtbench/main.py index d7e6f30..0f2f047 100644 --- a/airtbench/main.py +++ b/airtbench/main.py @@ -13,6 +13,7 @@ import dreadnode as dn import litellm import rigging as rg +from dotenv import load_dotenv from loguru import logger from airtbench.container import build_container @@ -44,7 +45,7 @@ class AIRTBenchArgs: model: str """Model to use for inference""" - platform_api_key: str + platform_api_key: str | None = None """Platform API key""" include_thoughts: bool = False """Include thoughts in the reasoning""" @@ -691,6 +692,17 @@ async def main( dn_args: DreadnodeArgs | None = None, # Has to be None even though not interior fields are required ) -> None: + # Load environment variables from .env file + load_dotenv() + + # Set platform_api_key from environment if not provided via command line + if not args.platform_api_key: + args.platform_api_key = os.environ.get("PLATFORM_API_KEY") or os.environ.get("DREADNODE_API_TOKEN") + + if not args.platform_api_key: + logger.error("Platform API key is required. Set it via --platform-api-key or PLATFORM_API_KEY environment variable.") + return + dn_args = dn_args or DreadnodeArgs() dn.configure( server=dn_args.server, @@ -709,12 +721,18 @@ async def main( logger.info("API key validated successfully") # Build the container - image = build_container( - "airtbench", - g_container_dir / "Dockerfile", - g_container_dir, - memory_limit=args.memory_limit, - ) + try: + image = build_container( + "airtbench", + g_container_dir / "Dockerfile", + g_container_dir, + memory_limit=args.memory_limit, + ) + except RuntimeError as e: + if "Docker connection failed" in str(e): + logger.error("Cannot proceed without Docker. Please start Docker and try again.") + return + raise challenges = load_challenges()