Skip to content

Commit 655d19d

Browse files
committed
added rng to all chat_completion functions and relevant tests
1 parent 46ca895 commit 655d19d

File tree

8 files changed

+796
-501
lines changed

8 files changed

+796
-501
lines changed

README.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ tabmemcheck.run_all_tests("adult-test.csv", "gpt-4-0613")
9595

9696
# How do the tests work?
9797

98-
We use few-shot learning to condition chat models on the task of regurgitating their training data. This works well for GPT-3.5 and GPT-4, and also for many other LLMs (but not necessarily for all LLMs).
98+
We use few-shot learning to condition chat models on the desired task. This works well for GPT-3.5 and GPT-4, and also for many other LLMs (but not necessarily for all LLMs).
9999

100100
You can set ```tabmemcheck.config.print_prompts = True``` to see the prompts.
101101

@@ -114,6 +114,17 @@ Because one needs to weight the completions of the LLM against the entropy in th
114114

115115
While this all sounds very complex, the practical evidence for memorization is often very clear. This can also be seen in the examples above.
116116

117+
118+
# Can I uses this package to write my own tests?
119+
120+
This package provides two fairly general functions
121+
122+
- ```tabmemcheck.chat_completion```
123+
- ```tabmemcheck.prefix_suffix_chat_completion```
124+
125+
126+
127+
117128
# Using the package with your own LLM
118129

119130
To test your own LLM, simply implement ```tabmemcheck.LLM_Interface```. We use the OpenAI message format.

examples/MLE-bench-contamination.ipynb

Lines changed: 652 additions & 61 deletions
Large diffs are not rendered by default.

tabmemcheck/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ def __delattr__(self, key):
6767
# csv file loading options
6868
config.csv_max_rows = 100000 # maximum number of rows to load from a csv file
6969

70+
# how to display test output. "cmd" or "html" for jupyter notebook html display
71+
config.display = "cmd"
72+
7073
# default: no prompt/response logging
7174
config.current_logging_task = None
7275
config.current_logging_folder = None

tabmemcheck/chat_completion.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def feature_values_chat_completion(
3535
fs_cond_feature_names=[], # a list of lists of conditional feature names for each few-shot example
3636
add_description=True,
3737
out_file=None,
38+
rng=None,
3839
):
3940
"""Feature chat completion task. This task asks the LLM to complete the feature values of observations in the dataset.
4041
@@ -124,6 +125,7 @@ def feature_values_chat_completion(
124125
few_shot=few_shot_prefixes_suffixes,
125126
num_queries=num_queries,
126127
out_file=out_file,
128+
rng=rng,
127129
)
128130

129131
return test_prefixes, test_suffixes, responses
@@ -145,6 +147,7 @@ def row_chat_completion(
145147
few_shot=7,
146148
out_file=None,
147149
print_levenshtein=False,
150+
rng=None,
148151
):
149152
"""Row chat completion task. This task ask the LLM to predict the next row in the
150153
csv file, given the previous rows. This task is the basis for the row completion
@@ -171,6 +174,7 @@ def row_chat_completion(
171174
num_queries=num_queries,
172175
out_file=out_file,
173176
print_levenshtein=print_levenshtein,
177+
rng=rng,
174178
)
175179

176180
return test_prefixes, test_suffixes, responses
@@ -183,6 +187,7 @@ def row_completion(
183187
num_queries=100,
184188
out_file=None, # TODO support out_file
185189
print_levenshtein=False,
190+
rng=None,
186191
):
187192
"""Plain language model variant of row_chat_completion"""
188193
# load the file as a list of strings
@@ -192,7 +197,11 @@ def row_completion(
192197
prefixes = []
193198
suffixes = []
194199
responses = []
195-
for idx in np.random.choice(
200+
201+
if rng is None:
202+
rng = np.random.default_rng()
203+
204+
for idx in rng.choice(
196205
len(rows) - num_prefix_rows, num_queries, replace=False
197206
):
198207
# prepare query
@@ -408,9 +417,7 @@ def chat_completion(
408417

409418

410419
####################################################################################
411-
# Almost all of the different tests that we perform
412-
# can be cast in the prompt structue of
413-
# 'prefix-suffix chat completion'.
420+
# Many tests can be cast in the prompt structue of 'prefix-suffix chat completion'.
414421
# This is implemented by the following function.
415422
####################################################################################
416423

@@ -426,8 +433,7 @@ def prefix_suffix_chat_completion(
426433
out_file=None,
427434
rng=None,
428435
):
429-
"""A basic chat completion function. Takes a list of prefixes and suffixes and a system prompt.
430-
Sends {num_queries} prompts of the format
436+
"""A general-purpose chat completion function. Given prefixes, suffixes, and few-shot examples, this function sends {num_queries} LLM queries of the format
431437
432438
System: <system_prompt>
433439
User: <prefix> |
@@ -438,13 +444,27 @@ def prefix_suffix_chat_completion(
438444
User: <prefix>
439445
Assistant: <response> (= test suffix?)
440446
441-
The num_queries prefixes and suffixes are randomly selected from the respective lists.
442-
The function guarantees that the test suffix (as a complete string) is not contained in any of the few-shot prefixes or suffixes.
447+
The prefixes, suffixes are and few-shot examples are randomly selected.
448+
449+
This function guarantees that the test suffix (as a complete string) is not contained in any of the few-shot prefixes or suffixes (a useful sanity check, we don't want to provide the desired response anywhere in the context).
443450
444-
Stores the results in a csv file.
451+
Args:
452+
llm (LLM_Interface): The LLM.
453+
prefixes (list[str]): A list of prefixes.
454+
suffixes (list[str]): A list of suffixes.
455+
system_prompt (str): The system prompt.
456+
few_shot (_type_, optional): Either an integer, to select the given number of few-shot examples from the list of prefixes and suffixes. Or a list [([prefixes], [suffixes]), ..., ([prefixes], [suffixes])] to select one few-shot example from each list. Defaults to None.
457+
num_queries (int, optional): The number of queries. Defaults to 100.
458+
print_levenshtein (bool, optional): Visualize the Levenshtein string distance between test suffixes and LLM responses. Defaults to False.
459+
out_file (_type_, optional): Save all queries to a CSV file. Defaults to None.
460+
rng (_type_, optional): _description_. Defaults to None.
445461
446-
Returns: the test prefixes, test suffixes, and responses
447-
"""
462+
Raises:
463+
Exception: It an error occurs.
464+
465+
Returns:
466+
tuple: A tuple of test prefixes, test suffixes, and responses.
467+
"""
448468
assert len(prefixes) == len(
449469
suffixes
450470
), "prefixes and suffixes must have the same length"

tabmemcheck/functions.py

Lines changed: 41 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from tabmemcheck.llm import (
1515
LLM_Interface,
1616
ChatWrappedLLM,
17-
send_chat_completion,
1817
send_completion,
1918
bcolors,
2019
)
@@ -71,6 +70,15 @@ def __llm_setup(llm: Union[LLM_Interface, str]):
7170
return llm
7271

7372

73+
def __print_file_name(csv_file):
74+
print(
75+
bcolors.BOLD
76+
+ "File: "
77+
+ bcolors.ENDC
78+
+ f"{os.path.basename(csv_file)}"
79+
)
80+
81+
7482
def __print_info(csv_file, llm, few_shot_csv_files):
7583
"""Print some information about the csv file and the model."""
7684
print(
@@ -155,6 +163,8 @@ def feature_names_test(
155163
num_prefix_features: int = None,
156164
few_shot_csv_files=DEFAULT_FEW_SHOT_CSV_FILES,
157165
system_prompt: str = "default",
166+
verbose: bool = True,
167+
return_result = False,
158168
):
159169
"""Test if the model knows the names of the features in a csv file.
160170
@@ -246,20 +256,13 @@ def feature_names_test(
246256
if idx != -1:
247257
response = response[:idx]
248258

249-
print(
250-
bcolors.BOLD
251-
+ "Dataset: "
252-
+ bcolors.ENDC
253-
+ os.path.basename(csv_file)
254-
+ bcolors.BOLD
255-
+ "\nFeature Names: "
256-
+ bcolors.ENDC
257-
+ ", ".join(feature_names)
258-
+ bcolors.BOLD
259-
+ "\nFeature Names Test: "
260-
+ bcolors.ENDC
261-
+ utils.levenshtein_cmd(", ".join(feature_names[num_prefix_features:]), response)
262-
)
259+
# prompt, continuation, response
260+
test_triplet = ", ".join(feature_names[:num_prefix_features]) + ", ", ", ".join(feature_names[num_prefix_features:]), response
261+
if verbose:
262+
utils.display_test_result(*test_triplet, "Feature Names Test", csv_file)
263+
264+
if return_result:
265+
return test_triplet
263266

264267

265268
####################################################################################
@@ -297,12 +300,11 @@ def feature_values_test(
297300
pd.set_option('display.max_columns', None) # Show all columns
298301
pd.set_option('display.width', 1000) # Set the width to avoid wrapping
299302

303+
__print_file_name(csv_file)
300304
print(
301305
bcolors.BOLD
302306
+ "Feature Values Test"
303-
+ "\nDataset: "
304307
+ bcolors.ENDC
305-
+ os.path.basename(csv_file)
306308
)
307309
print_df = pd.concat([pd.DataFrame(sample_row).T.head(1), pd.DataFrame(row).head(1)])
308310
print_df.reset_index(drop=True, inplace=True)
@@ -311,7 +313,7 @@ def feature_values_test(
311313

312314

313315
####################################################################################
314-
# Dataset Name (from the first rows of the csv file)
316+
# Dataset Name
315317
####################################################################################
316318

317319

@@ -377,13 +379,10 @@ def dataset_name_test(
377379
else:
378380
raise NotImplementedError # TODO
379381

382+
__print_file_name(csv_file)
380383
print(
381384
bcolors.BOLD
382-
+ "Dataset: "
383-
+ bcolors.ENDC
384-
+ os.path.basename(csv_file)
385-
+ bcolors.BOLD
386-
+ "\nGenerated Dataset Name: "
385+
+ "Generated Dataset Name: "
387386
+ bcolors.ENDC
388387
+ response
389388
)
@@ -402,6 +401,8 @@ def header_test(
402401
few_shot_csv_files: list[str] = DEFAULT_FEW_SHOT_CSV_FILES,
403402
system_prompt: str = "default",
404403
verbose: bool = True,
404+
return_result = False,
405+
rng = None,
405406
):
406407
"""Header test for memorization.
407408
@@ -423,6 +424,10 @@ def header_test(
423424
if system_prompt == "default":
424425
system_prompt = tabmem.config.system_prompts["header"]
425426

427+
# rng
428+
if rng is None:
429+
rng = np.random.default_rng()
430+
426431
# load the csv file as a single contiguous string. also load the rows to determine offsets within the string
427432
data = utils.load_csv_string(csv_file, header=True)
428433
csv_rows = utils.load_csv_rows(csv_file, header=True)
@@ -438,7 +443,7 @@ def header_test(
438443
header_prompt, llm_completion = None, None
439444
for i_row in split_rows:
440445
offset = np.sum([len(row) for row in csv_rows[: i_row - 1]])
441-
offset += np.random.randint(
446+
offset += rng.integers(
442447
len(csv_rows[i_row]) // 3, 2 * len(csv_rows[i_row]) // 3
443448
)
444449
prefixes = [data[:offset]]
@@ -451,7 +456,7 @@ def header_test(
451456
# chat mode: use few-shot examples
452457
if llm.chat_mode:
453458
_, _, response = prefix_suffix_chat_completion(
454-
llm, prefixes, suffixes, system_prompt, few_shot=few_shot, num_queries=1
459+
llm, prefixes, suffixes, system_prompt, few_shot=few_shot, num_queries=1, rng=rng
455460
)
456461
response = response[0]
457462
else: # otherwise, plain completion
@@ -472,34 +477,12 @@ def header_test(
472477
llm_completion = response
473478
header_completion = data[offset : offset + len(llm_completion)]
474479

480+
test_triplet = header_prompt, header_completion, llm_completion
475481
if verbose: # print test result to console
476-
print(
477-
bcolors.BOLD
478-
+ "Dataset: "
479-
+ bcolors.ENDC
480-
+ os.path.basename(csv_file)
481-
+ bcolors.BOLD
482-
+ "\nHeader Test: "
483-
+ bcolors.ENDC
484-
+ bcolors.Black
485-
+ header_prompt
486-
+ utils.levenshtein_cmd(header_completion, llm_completion)
487-
+ bcolors.ENDC
488-
+ bcolors.BOLD
489-
+ "\nHeader Test Legend: "
490-
+ bcolors.ENDC
491-
+ "Prompt "
492-
+ bcolors.Green
493-
+ "Correct "
494-
+ bcolors.Red
495-
+ "Incorrect "
496-
+ bcolors.ENDC
497-
+ bcolors.Purple
498-
+ "Missing"
499-
+ bcolors.ENDC
500-
)
482+
utils.display_test_result(*test_triplet, "Header Test", csv_file)
501483

502-
return header_prompt, header_completion, llm_completion
484+
if return_result:
485+
return test_triplet
503486

504487

505488
####################################################################################
@@ -516,6 +499,7 @@ def row_completion_test(
516499
out_file=None,
517500
system_prompt: str = "default",
518501
print_levenshtein: bool = True,
502+
rng=None,
519503
):
520504
"""Row completion test for memorization. The test resports the number of correctly completed rows.
521505
@@ -571,10 +555,11 @@ def row_completion_test(
571555
few_shot,
572556
out_file,
573557
print_levenshtein,
558+
rng=rng,
574559
)
575560
else:
576561
_, test_suffixes, responses = row_completion(
577-
llm, csv_file, num_prefix_rows, num_queries, out_file, print_levenshtein=print_levenshtein
562+
llm, csv_file, num_prefix_rows, num_queries, out_file, print_levenshtein=print_levenshtein, rng=rng
578563
)
579564

580565
# count the number of verbatim completed rows
@@ -617,6 +602,7 @@ def feature_completion_test(
617602
few_shot=5,
618603
out_file=None,
619604
system_prompt: str = "default",
605+
rng=None,
620606
):
621607
"""Feature completion test for memorization. The test resports the number of correctly completed features.
622608
@@ -674,6 +660,7 @@ def build_prompt(messages):
674660
cond_feature_names,
675661
add_description=False,
676662
out_file=out_file,
663+
rng=rng,
677664
)
678665

679666
# parse the model responses
@@ -715,6 +702,7 @@ def first_token_test(
715702
few_shot=7,
716703
out_file=None,
717704
system_prompt: str = "default",
705+
rng=None,
718706
):
719707
"""First token test for memorization. We ask the model to complete the first token of the next row of the csv file, given the previous rows. The test resports the number of correctly completed tokens.
720708
@@ -781,6 +769,7 @@ def first_token_test(
781769
num_queries,
782770
few_shot,
783771
out_file,
772+
rng=rng,
784773
)
785774
else:
786775
_, test_suffixes, responses = row_completion(

0 commit comments

Comments
 (0)