Skip to content

Commit 5b3ebba

Browse files
[python/knowpro] Various improvements to tools (#1426)
* 595de6d vizcmp.py: Get scores from Score: headers instead of file trailer * fc5dcb6 utools.py: Always show output diff even when success differs * 8147b7f utool.py: assign score 0.001 if we have an unexpected answer * fa4053a utool.py: Add support for Pydantic's Logfire * fca0f57 vizcmp.py: Simplify footer * e29246e vizcmp.py: Sort by average score, not first file's * f0e7ebf vizcmp.py: Only sort filenames when using default glob * d2a6be4 vizcmp.py: Print only basename * 6b564a2 vizcmp.py: show file names in footer * 4e72780 vizcmp.py: Display N/A results in bright yellow
1 parent a765744 commit 5b3ebba

File tree

2 files changed

+62
-32
lines changed

2 files changed

+62
-32
lines changed

python/ta/tools/utool.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,33 @@
4848
from typeagent.podcasts import podcast
4949

5050

51+
### Logfire setup ###
52+
53+
54+
def setup_logfire():
55+
import logfire
56+
57+
def scrubbing_callback(m: logfire.ScrubMatch):
58+
# if m.path == ('attributes', 'http.request.header.authorization'):
59+
# return m.value
60+
61+
# if m.path == ('attributes', 'http.request.header.api-key'):
62+
# return m.value
63+
64+
if (
65+
m.path == ("attributes", "http.request.body.text", "messages", 0, "content")
66+
and m.pattern_match.group(0) == "secret"
67+
):
68+
return m.value
69+
70+
# if m.path == ('attributes', 'http.response.header.azureml-model-session'):
71+
# return m.value
72+
73+
logfire.configure(scrubbing=logfire.ScrubbingOptions(callback=scrubbing_callback))
74+
logfire.instrument_pydantic_ai()
75+
logfire.instrument_httpx(capture_all=True)
76+
77+
5178
### Classes ###
5279

5380

@@ -116,6 +143,8 @@ def main():
116143
parser = make_arg_parser("TypeAgent Query Tool")
117144
args = parser.parse_args()
118145
fill_in_debug_defaults(parser, args)
146+
if args.logfire:
147+
setup_logfire()
119148
settings = importing.ConversationSettings()
120149
query_context = load_podcast_index(args.podcast, settings)
121150
ar_list, ar_index = load_index_file(args.qafile, "question", QuestionAnswerData)
@@ -488,6 +517,11 @@ def make_arg_parser(description: str) -> argparse.ArgumentParser:
488517
action="store_true",
489518
help="Show the TypeScript schema computed by typechat.",
490519
)
520+
debug.add_argument(
521+
"--logfire",
522+
action="store_true",
523+
help="Upload log events to Pydantic's Logfire server",
524+
)
491525

492526
return parser
493527

@@ -744,14 +778,16 @@ async def compare_answers(
744778
actual_text, actual_success = actual
745779

746780
if expected_success != actual_success:
747-
print(f"Expected success: {expected_success}; actual: {actual_success}")
748-
return 0.000
781+
print(
782+
f"Expected success: {Fore.RED}{expected_success}{Fore.RESET}; "
783+
f"actual: {Fore.GREEN}{actual_success}{Fore.RESET}"
784+
)
749785

750-
if not actual_success:
786+
elif not actual_success:
751787
print(Fore.GREEN + f"Both failed" + Fore.RESET)
752788
return 1.001
753789

754-
if expected_text == actual_text:
790+
elif expected_text == actual_text:
755791
print(Fore.GREEN + f"Both equal" + Fore.RESET)
756792
return 1.000
757793

@@ -760,7 +796,11 @@ async def compare_answers(
760796
else:
761797
n = 2
762798
print_diff(expected_text, actual_text, n=n)
763-
return await equality_score(context, expected_text, actual_text)
799+
800+
if expected_success != actual_success:
801+
return 0.000 if expected_success else 0.001 # 0.001 == Answer not expected
802+
else:
803+
return await equality_score(context, expected_text, actual_text)
764804

765805

766806
def print_diff(a: str, b: str, n: int) -> None:

python/ta/tools/vizcmp.py

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
def main():
14-
files = sys.argv[1:] or glob.glob("evals/eval-*.txt")
14+
files = sys.argv[1:] or sorted(glob.glob("evals/eval-*.txt"))
1515
table = {} # {file: {counter: score, ...}, ...}
1616
questions = {} # {counter: question, ...}
1717

@@ -20,56 +20,41 @@ def main():
2020
with open(file, "r") as f:
2121
lines = f.readlines()
2222

23+
scores = {}
2324
counter = None
2425
for i, line in enumerate(lines):
2526
if m := re.match(r"^(?:-+|\*+)\s+(\d+)\s+", line):
2627
counter = int(m.group(1))
27-
elif m := re.match(r"^.*; Question:\s+(.*)$", line):
28-
question = m.group(1)
28+
elif m := re.match(r"^Score:\s+([\d.]+); Question:\s+(.*)$", line):
29+
score = float(m.group(1))
30+
scores[counter] = score
31+
question = m.group(2)
2932
if counter not in questions:
3033
questions[counter] = question
3134
elif questions[counter] != question:
3235
print(f"File {file} has a different question for {counter}:")
3336
print(f"< {questions[counter]}")
3437
print(f"> {question}")
3538

36-
i = lines.index("==================================================\n")
37-
if i < 0:
38-
print(f"File {file} does not contain a separator line")
39-
continue
40-
lines = lines[i + 1 :]
41-
text = "".join(lines)
42-
matches = re.findall(r"\d\.\d\d\d\(\d+\)", text)
43-
if not matches:
44-
print(f"File {file} does not contain any scores")
45-
continue
46-
# print(len(matches), matches)
47-
data = {}
48-
for match in matches:
49-
m = re.match(r"(\d\.\d\d\d)\((\d+)\)", match)
50-
assert m
51-
score = float(m.group(1))
52-
counter = int(m.group(2))
53-
data[counter] = score
54-
assert len(data) == len(matches)
55-
table[file] = data
39+
table[file] = scores
5640

5741
# Print header
58-
all_files = sorted(table.keys())
42+
all_files = list(table.keys())
5943
print_header(all_files)
6044

6145
# Print data
6246
all_counters = sorted(
6347
{counter for data in table.values() for counter in data.keys()},
64-
key=lambda x: table[all_files[0]].get(x, 0.0),
48+
key=lambda x: statistics.mean(table[file].get(x, 0.0) for file in all_files),
6549
reverse=True,
6650
)
6751
for counter in all_counters:
6852
print(f"{counter:>3}:", end="")
6953
for file in all_files:
7054
score = table[file].get(counter, None)
7155
if score is None:
72-
output = " N/A "
56+
output = Fore.YELLOW + " N/A " + Fore.RESET
57+
output = Style.BRIGHT + output + Style.RESET_ALL
7358
else:
7459
output = f"{score:.3f}"
7560
output = f"{output:>6}"
@@ -87,7 +72,7 @@ def main():
8772
print(f" {questions.get(counter)}")
8873

8974
# Print header again
90-
print_header(all_files)
75+
print_footer(all_files)
9176

9277

9378
def print_header(all_files):
@@ -103,5 +88,10 @@ def print_header(all_files):
10388
print()
10489

10590

91+
def print_footer(all_files):
92+
for i, file in reversed(list(enumerate(all_files))):
93+
print(" |" * i + " " + os.path.basename(file))
94+
95+
10696
if __name__ == "__main__":
10797
main()

0 commit comments

Comments
 (0)