Skip to content

Commit 84034f3

Browse files
committed
allow for custom prompts
1 parent b30591a commit 84034f3

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

src/wraval/actions/action_llm_judge.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from dynaconf import Dynaconf
88
from .data_utils import write_dataset_local, write_dataset_to_s3, load_latest_dataset
99
from .prompts_judge import generate_input_prompt, generate_system_prompt, get_rubric, rewrite_prompt
10+
1011
from .completion import batch_get_bedrock_completions
1112
import re
1213
import boto3
@@ -58,6 +59,10 @@ def process_tone_data(
5859
Returns:
5960
Processed DataFrame with scores
6061
"""
62+
63+
if settings.custom_prompts == True:
64+
from wraval.custom_prompts.prompts_judge import generate_input_prompt, generate_system_prompt
65+
6166
dmt = d[d.tone == tone].copy()
6267
rubrics = list(tone_rubrics.keys())
6368

@@ -111,6 +116,10 @@ def judge(
111116
data_dir: Directory containing input data
112117
endpoint_type: Type of endpoint to use
113118
"""
119+
120+
if settings.custom_prompts == True:
121+
from wraval.custom_prompts.prompts_judge import get_rubric
122+
114123
try:
115124
d = load_latest_dataset(settings.data_dir)
116125
print(f"Loaded dataset with {len(d)} rows")

src/wraval/main.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ def get_settings(args):
3232
## add the AWS account you are logged into, if necessary.
3333
settings.model = settings.model.format(aws_account=settings.aws_account)
3434
settings.data_dir = settings.data_dir.format(aws_account=settings.aws_account)
35+
36+
if args.custom_prompts:
37+
settings.custom_prompts = True
38+
else:
39+
settings.custom_prompts = False
3540
return settings
3641

3742

@@ -74,6 +79,11 @@ def parse_args() -> argparse.Namespace:
7479
required=False,
7580
help="Allow for a local path to a tokenizer.",
7681
)
82+
83+
parser.add_argument(
84+
"--custom-prompts", default=False, help="Load custom prompts from a prompt folder"
85+
)
86+
7787
return parser.parse_args()
7888

7989

0 commit comments

Comments
 (0)