Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,13 @@ dev = [

[tool.setuptools.packages.find]
where = ["src"]
include = ["kernelbench*"]
include = ["kernelbench*"]

[tool.setuptools]
include-package-data = true

[tool.setuptools.package-data]
kernelbench = [
"prompts/*.toml",
"prompts/**/*.py",
]
24 changes: 14 additions & 10 deletions src/kernelbench/prompt_constructor_toml.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# src/prompt_constructor_toml.py | toml based prompt constructor
import importlib.resources
import os
from pathlib import Path
import runpy
import tomli
from dataclasses import dataclass
Expand All @@ -15,10 +17,12 @@
"""

REPO_TOP_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
PROMPTS_TOML = os.path.join(REPO_TOP_PATH, "src/kernelbench/prompts/prompts.toml")
PROMPTS_DIR = importlib.resources.files("kernelbench") / "prompts"
assert PROMPTS_DIR.is_dir(), f"Prompts directory not found at {PROMPTS_DIR}"

assert os.path.exists(PROMPTS_TOML), f"Prompts.toml not found at {PROMPTS_TOML}"
GPU_SPECS_PY = "src/kernelbench/prompts/hardware/gpu_specs.py"

GPU_SPECS_PY = PROMPTS_DIR / "hardware" / "gpu_specs.py"
assert GPU_SPECS_PY.is_file(), f"GPU specs file not found at {GPU_SPECS_PY}"
HARDWARE_COMPONENT_KEYS = [
"hardware_header",
"hardware_specs",
Expand Down Expand Up @@ -50,7 +54,7 @@ class PromptConfig:
data: Dict[str, Any] # The raw parsed TOML data as nested dictionaries

@classmethod
def from_toml(cls, path: str) -> "PromptConfig":
def from_toml(cls, path: str | Path) -> "PromptConfig":
"""
Load and parse a TOML configuration file.

Expand Down Expand Up @@ -134,11 +138,11 @@ def _gpu_context_from_gpu_specs(py_path: str, gpu_name: str) -> Dict[str, str]:

def render_prompt_by_option(
*,
prompts_toml: str,
prompts_toml: Path,
backend: str,
option: str,
context: Dict[str, str],
gpu_specs_py: Optional[str] = None,
gpu_specs_py: Optional[Path] = None,
gpu_name: Optional[str] = None,
precision: Optional[str] = None,
include_hardware: bool = False,
Expand Down Expand Up @@ -292,11 +296,11 @@ def render_example_entry(input_code: str, output_code: str, example_label: str)

# Load GPU details if requested
if option_data.get("requires_gpu") or include_hardware:
if not (gpu_specs_py and gpu_name):
if not (gpu_specs_py and gpu_name and gpu_specs_py.is_file()):
raise ValueError(
f"Hardware info requested for option '{option}'; provide gpu_specs_py and gpu_name"
)
context = {**context, **_gpu_context_from_gpu_specs(_abs_path(gpu_specs_py), gpu_name)}
context = {**context, **_gpu_context_from_gpu_specs(gpu_specs_py, gpu_name)}

# Builds the prompt from the components in the toml file.
prompt_parts = []
Expand Down Expand Up @@ -347,7 +351,7 @@ def get_prompt_for_backend(
gpu_name: GPU identifier used when include_hardware is True (e.g., "A100")
"""
return render_prompt_by_option(
prompts_toml=PROMPTS_TOML,
prompts_toml=PROMPTS_DIR / "prompts.toml",
backend=backend.lower(),
option=option.lower(),
context={"ref_arch_src": ref_arch_src},
Expand All @@ -367,7 +371,7 @@ def get_custom_prompt(
precision: Optional[str] = None,
include_hardware: bool = False,
gpu_name: Optional[str] = None,
prompts_toml: str = PROMPTS_TOML,
prompts_toml: Path = PROMPTS_DIR / "prompts.toml",
) -> str:
"""
Render a prompt defined under [custom_prompts.<custom_key>] in prompts.toml.
Expand Down