Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion airtbench/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@ def build_container(
force_rebuild: bool = False,
memory_limit: str = "4g",
) -> str:
docker_client = docker.DockerClient()
try:
docker_client = docker.DockerClient()
except docker.errors.DockerException as e:
raise RuntimeError(
"Docker connection failed: Docker is not running or not accessible",
) from e

docker_file = Path(docker_file)
if not docker_file.exists():
Expand Down
32 changes: 25 additions & 7 deletions airtbench/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import dreadnode as dn
import litellm
import rigging as rg
from dotenv import load_dotenv
from loguru import logger

from airtbench.container import build_container
Expand Down Expand Up @@ -44,7 +45,7 @@
class AIRTBenchArgs:
model: str
"""Model to use for inference"""
platform_api_key: str
platform_api_key: str | None = None
"""Platform API key"""
include_thoughts: bool = False
"""Include thoughts in the reasoning"""
Expand Down Expand Up @@ -691,6 +692,17 @@ async def main(
dn_args: DreadnodeArgs
| None = None, # Has to be None even though not interior fields are required
) -> None:
# Load environment variables from .env file
load_dotenv()

# Set platform_api_key from environment if not provided via command line
if not args.platform_api_key:
args.platform_api_key = os.environ.get("PLATFORM_API_KEY") or os.environ.get("DREADNODE_API_TOKEN")

if not args.platform_api_key:
logger.error("Platform API key is required. Set it via --platform-api-key or PLATFORM_API_KEY environment variable.")
return

dn_args = dn_args or DreadnodeArgs()
dn.configure(
server=dn_args.server,
Expand All @@ -709,12 +721,18 @@ async def main(
logger.info("API key validated successfully")

# Build the container
image = build_container(
"airtbench",
g_container_dir / "Dockerfile",
g_container_dir,
memory_limit=args.memory_limit,
)
try:
image = build_container(
"airtbench",
g_container_dir / "Dockerfile",
g_container_dir,
memory_limit=args.memory_limit,
)
except RuntimeError as e:
if "Docker connection failed" in str(e):
logger.error("Cannot proceed without Docker. Please start Docker and try again.")
return
raise

challenges = load_challenges()

Expand Down
Loading