Skip to content

Commit 46ca895

Browse files
committed
dataset name test, config.csv_max_rows and other fixes for very large csv files
1 parent e7a8011 commit 46ca895

File tree

6 files changed

+1050
-39
lines changed

6 files changed

+1050
-39
lines changed

examples/MLE-bench-contamination.ipynb

Lines changed: 861 additions & 0 deletions
Large diffs are not rendered by default.

tabmemcheck/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
run_all_tests,
1616
header_test,
1717
feature_names_test,
18+
feature_values_test,
19+
dataset_name_test,
1820
row_completion_test,
1921
feature_completion_test,
2022
first_token_test,
@@ -59,9 +61,12 @@ def __delattr__(self, key):
5961

6062
# default llm options
6163
config.temperature = 0
62-
config.max_tokens = 500
64+
config.max_tokens = 1000
6365
config.sleep = 0.0 # amount of time to sleep after each query to the llm
6466

67+
# csv file loading options
68+
config.csv_max_rows = 100000 # maximum number of rows to load from a csv file
69+
6570
# default: no prompt/response logging
6671
config.current_logging_task = None
6772
config.current_logging_folder = None

tabmemcheck/analysis.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,7 @@ def find_matches(
6363
):
6464
"""Find the closest matches between a row x and all rows in the dataframe df. By default, we use the levenshtein distance as the distance metric.
6565
66-
This function can handle a variety of formatting differences between the values in the original data
67-
and LLM responses that should still be counted as equal.
68-
66+
This function can handle some formatting differences between the values in the original data and LLM responses that should still be counted as equal.
6967
7068
:param df: a pandas dataframe.
7169
:param x: a string, a pandas dataframe or a pandas Series.

tabmemcheck/functions.py

Lines changed: 159 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@
3838
]
3939

4040

41-
def __difflib_similar(csv_file_1, csv_file_2):
41+
def __difflib_similar(csv_file_1, csv_file_2, max_length=5000):
4242
sm = SequenceMatcher(
43-
None, utils.load_csv_string(csv_file_1), utils.load_csv_string(csv_file_2)
43+
None, utils.load_csv_string(csv_file_1, size=max_length)[:max_length], utils.load_csv_string(csv_file_2, size=max_length)[:max_length]
4444
)
4545
if sm.quick_ratio() > 0.9:
4646
return sm.ratio() > 0.9
@@ -49,33 +49,19 @@ def __difflib_similar(csv_file_1, csv_file_2):
4949

5050
def __validate_few_shot_files(csv_file, few_shot_csv_files):
5151
"""check if the csv_file is contained in the few_shot_csv_files."""
52-
dataset_name = utils.get_dataset_name(csv_file)
53-
few_shot_names = [utils.get_dataset_name(x) for x in few_shot_csv_files]
54-
if dataset_name in few_shot_names:
55-
# replace the dataset with iris or adult
56-
few_shot_csv_files = [
57-
x for x in few_shot_csv_files if utils.get_dataset_name(x) != dataset_name
58-
]
59-
if 'iris' in dataset_name:
60-
few_shot_csv_files.append("adult-train.csv")
61-
else:
62-
few_shot_csv_files.append("iris.csv")
63-
print(
64-
bcolors.BOLD
65-
+ "Info: "
66-
+ bcolors.ENDC
67-
+ f"Exchanged a few-shot datasets because its name is similar to the dataset being tested."
68-
)
69-
# now test with difflib if the dataset contents are very similar
52+
validated_few_shot_files = []
53+
# test with difflib if the dataset contents are very similar
7054
for fs_file in few_shot_csv_files:
7155
if __difflib_similar(csv_file, fs_file):
7256
print(
7357
bcolors.BOLD
74-
+ "Warning: "
58+
+ "Info: "
7559
+ bcolors.ENDC
76-
+ f"The dataset is very similar to the few-shot dataset {utils.get_dataset_name(fs_file)}."
60+
+ f"Removed the few-shot dataset {fs_file} because it is similar to the dataset being tested."
7761
)
78-
return few_shot_csv_files
62+
else:
63+
validated_few_shot_files.append(fs_file)
64+
return validated_few_shot_files
7965

8066

8167
def __llm_setup(llm: Union[LLM_Interface, str]):
@@ -193,9 +179,6 @@ def feature_names_test(
193179
if num_prefix_features is None:
194180
num_prefix_features = max(1, len(feature_names) // 4)
195181

196-
# remove the current csv file from the few-shot csv files should it be present there
197-
few_shot_csv_files = [x for x in few_shot_csv_files if not dataset_name in x]
198-
199182
# setup for the few-shot examples
200183
fs_dataset_names = [utils.get_dataset_name(x) for x in few_shot_csv_files]
201184
fs_feature_names = [
@@ -265,13 +248,17 @@ def feature_names_test(
265248

266249
print(
267250
bcolors.BOLD
268-
+ "Feature Names Test\nFeature Names: "
251+
+ "Dataset: "
269252
+ bcolors.ENDC
270-
+ ", ".join(feature_names[num_prefix_features:])
253+
+ os.path.basename(csv_file)
271254
+ bcolors.BOLD
272-
+ "\nModel Generation: "
255+
+ "\nFeature Names: "
273256
+ bcolors.ENDC
274-
+ response
257+
+ ", ".join(feature_names)
258+
+ bcolors.BOLD
259+
+ "\nFeature Names Test: "
260+
+ bcolors.ENDC
261+
+ utils.levenshtein_cmd(", ".join(feature_names[num_prefix_features:]), response)
275262
)
276263

277264

@@ -280,6 +267,128 @@ def feature_names_test(
280267
####################################################################################
281268

282269

270+
def feature_values_test(
271+
csv_file: str,
272+
llm: Union[LLM_Interface, str],
273+
few_shot_csv_files=DEFAULT_FEW_SHOT_CSV_FILES,
274+
system_prompt: str = "default",
275+
):
276+
"""Test if the model knows valid feature values for the features in a csv file. Asks the model to provide samples, then compares the sampled feature values to the values in the csv file.
277+
278+
:param csv_file: The path to the csv file.
279+
:param llm: The language model to be tested.
280+
:param few_shot_csv_files: A list of other csv files to be used as few-shot examples.
281+
:param system_prompt: The system prompt to be used.
282+
"""
283+
284+
# first, sample 3 observations at temperature zero
285+
samples_df = sample(csv_file, llm, num_queries=3, temperature=0.0, few_shot_csv_files=few_shot_csv_files, system_prompt=system_prompt)
286+
287+
# check that there is at least one valid sample
288+
if samples_df.empty:
289+
print("Error: The LLM was not able to provide valid samples.")
290+
return
291+
292+
# choose the first sample
293+
sample_row = samples_df.iloc[0]
294+
_, row = analysis.find_matches(utils.load_csv_df(csv_file), sample_row)
295+
296+
# Set pandas display options for better formatting
297+
pd.set_option('display.max_columns', None) # Show all columns
298+
pd.set_option('display.width', 1000) # Set the width to avoid wrapping
299+
300+
print(
301+
bcolors.BOLD
302+
+ "Feature Values Test"
303+
+ "\nDataset: "
304+
+ bcolors.ENDC
305+
+ os.path.basename(csv_file)
306+
)
307+
print_df = pd.concat([pd.DataFrame(sample_row).T.head(1), pd.DataFrame(row).head(1)])
308+
print_df.reset_index(drop=True, inplace=True)
309+
print_df.rename(index={0: bcolors.BOLD + "Model Sample" + bcolors.ENDC, 1: bcolors.BOLD + "Dataset Match" + bcolors.ENDC}, inplace=True)
310+
print(print_df)
311+
312+
313+
####################################################################################
314+
# Dataset Name (from the first rows of the csv file)
315+
####################################################################################
316+
317+
318+
def dataset_name_test(
319+
csv_file: str,
320+
llm: Union[LLM_Interface, str],
321+
few_shot_csv_files=DEFAULT_FEW_SHOT_CSV_FILES,
322+
few_shot_dataset_names=None,
323+
num_rows = 5,
324+
header=True,
325+
system_prompt: str = "default",
326+
):
327+
"""Test if the model knows the names of the features in a csv file.
328+
329+
:param csv_file: The path to the csv file.
330+
:param llm: The language model to be tested.
331+
:param num_prefix_features: The number of features given to the model as part of the prompt (defaults to 1/4 of the features).
332+
:param few_shot_csv_files: A list of other csv files to be used as few-shot examples.
333+
:param few_shot_dataset_names: A list of dataset names to be used as few-shot examples. If None, the dataset names are are the file names of the few-shot csv files.
334+
:num_rows: The number of dataset rows to be given to the model as part of the prompt.
335+
:header: If True, the first row of the csv file is included in the prompt (it usually contains the feature names).
336+
:param system_prompt: The system prompt to be used.
337+
"""
338+
339+
llm = __llm_setup(llm)
340+
few_shot_csv_files = __validate_few_shot_files(csv_file, few_shot_csv_files)
341+
342+
# default system prompt?
343+
if system_prompt == "default":
344+
system_prompt = tabmem.config.system_prompts["dataset-name"]
345+
346+
if few_shot_dataset_names is None:
347+
few_shot_dataset_names = [utils.get_dataset_name(x) for x in few_shot_csv_files]
348+
349+
if llm.chat_mode:
350+
# construt the prompt
351+
prefixes = [
352+
"\n".join(utils.load_csv_rows(csv_file, header=header)[:num_rows])
353+
]
354+
suffixes = [utils.get_dataset_name(csv_file)]
355+
356+
few_shot = []
357+
for fs_csv_file, dataset_name in zip(few_shot_csv_files, few_shot_dataset_names):
358+
few_shot.append(
359+
(
360+
[
361+
"\n".join(utils.load_csv_rows(fs_csv_file, header=header)[:num_rows])
362+
],
363+
[dataset_name],
364+
)
365+
)
366+
367+
# execute the the prompt
368+
_, _, responses = prefix_suffix_chat_completion(
369+
llm,
370+
prefixes,
371+
suffixes,
372+
system_prompt,
373+
few_shot=few_shot,
374+
num_queries=1,
375+
)
376+
response = responses[0]
377+
else:
378+
raise NotImplementedError # TODO
379+
380+
print(
381+
bcolors.BOLD
382+
+ "Dataset: "
383+
+ bcolors.ENDC
384+
+ os.path.basename(csv_file)
385+
+ bcolors.BOLD
386+
+ "\nGenerated Dataset Name: "
387+
+ bcolors.ENDC
388+
+ response
389+
)
390+
391+
283392
####################################################################################
284393
# Header Test
285394
####################################################################################
@@ -366,7 +475,11 @@ def header_test(
366475
if verbose: # print test result to console
367476
print(
368477
bcolors.BOLD
369-
+ "Header Test: "
478+
+ "Dataset: "
479+
+ bcolors.ENDC
480+
+ os.path.basename(csv_file)
481+
+ bcolors.BOLD
482+
+ "\nHeader Test: "
370483
+ bcolors.ENDC
371484
+ bcolors.Black
372485
+ header_prompt
@@ -422,6 +535,13 @@ def row_completion_test(
422535
if system_prompt == "default": # default system prompt?
423536
system_prompt = tabmem.config.system_prompts["row-completion"]
424537

538+
print(
539+
bcolors.BOLD
540+
+ "Dataset: "
541+
+ bcolors.ENDC
542+
+ os.path.basename(csv_file)
543+
)
544+
425545
# what fraction of the rows are duplicates?
426546
rows = utils.load_csv_rows(csv_file)
427547
frac_duplicates = 1 - len(set(rows)) / len(rows)
@@ -717,6 +837,7 @@ def sample(
717837
csv_file: str,
718838
llm: Union[LLM_Interface, str],
719839
num_queries: int,
840+
temperature: float = 0.7,
720841
few_shot_csv_files: list[str] = DEFAULT_FEW_SHOT_CSV_FILES,
721842
cond_feature_names: list[str] = [],
722843
drop_invalid_responses: bool = True,
@@ -742,6 +863,10 @@ def sample(
742863
if not llm.chat_mode: # wrap base model to take chat queries
743864
llm = ChatWrappedLLM(llm, build_sample_prompt, ends_with="\n\n")
744865

866+
# store the temperature
867+
temp = tabmem.config.temperature
868+
tabmem.config.temperature = temperature
869+
745870
# run the test
746871
_, _, responses = feature_values_chat_completion(
747872
llm,
@@ -754,6 +879,9 @@ def sample(
754879
out_file=None,
755880
)
756881

882+
# reset the temperature
883+
tabmem.config.temperature = temp
884+
757885
if len(cond_feature_names) > 0:
758886
raise NotImplementedError("Conditional sampling not yet supported.")
759887
# TODO handle the condtional case!

tabmemcheck/resources/config/system-prompts.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,8 @@ generic-csv-format: |
2121
feature-names: |
2222
You are an expert assistant for tabular datasets. Your task is to list the names of the features of different datasets. The user provides a description of the dataset and some of the feature names. You then provide the names of the remaining features.
2323
24+
dataset-name: |
25+
You are an expert assistant for tabular datasets. Your task is to provide the name of the dataset. The user provides the initial rows of the csv file, inlcuding the feature names. You then provide the name of the dataset.
26+
2427
predict: |
2528
You are an expert assistant for tabular datasets. You provide predictions on different datasets. The user provides the name of the dataset, the names of the features, as well the values of all the features except one. You then provide a prediction for the missing feature (the target).

tabmemcheck/utils.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import jellyfish
77
import difflib
88
import tempfile
9+
import itertools
910

1011
import csv
1112

@@ -80,14 +81,23 @@ def get_feature_names(csv_file):
8081
return df.columns.tolist()
8182

8283

84+
CSV_MAX_ROWS_WARNING_PRINTED = False
8385
def load_csv_df(csv_file, header=True, delimiter="auto", **kwargs):
86+
global CSV_MAX_ROWS_WARNING_PRINTED
8487
"""Load a csv file as a pandas data frame."""
8588
with _csv_file(csv_file) as csv_file:
8689
# auto detect the delimiter from the csv file
8790
if delimiter == "auto":
8891
delimiter = get_delimiter(csv_file)
8992
# load the csv file
90-
df = pd.read_csv(csv_file, delimiter=delimiter, **kwargs)
93+
max_rows = tabmem.config.csv_max_rows
94+
df = pd.read_csv(csv_file, delimiter=delimiter, nrows=max_rows+1, **kwargs)
95+
# Check if the file has more rows than n
96+
if len(df) > max_rows and not CSV_MAX_ROWS_WARNING_PRINTED:
97+
print(f'Info: Found a CSV file with more than {max_rows} rows. Note that tabmemcheck is configured to use only the first {max_rows} rows. Set tabmemcheck.config.csv_max_rows to change this behavior.')
98+
CSV_MAX_ROWS_WARNING_PRINTED = True
99+
# Truncate the dataframe to the first n rows
100+
df = df.head(max_rows)
91101
# optionally, remove the header
92102
if not header:
93103
df = df.iloc[1:]
@@ -96,9 +106,15 @@ def load_csv_df(csv_file, header=True, delimiter="auto", **kwargs):
96106

97107
def load_csv_rows(csv_file, header=True):
98108
"""Load a csv file as a list of strings, with one string per row."""
109+
global CSV_MAX_ROWS_WARNING_PRINTED
99110
with _csv_file(csv_file) as csv_file:
100111
with open(csv_file, "r") as f:
101-
data = f.readlines()
112+
data = list(itertools.islice(f, tabmem.config.csv_max_rows+1))
113+
# check if the file has more rows than n, if yes print warning and reduce to n
114+
if len(data) > tabmem.config.csv_max_rows and not CSV_MAX_ROWS_WARNING_PRINTED:
115+
print(f'Info: Found a CSV file with more than {tabmem.config.csv_max_rows} rows. Note that tabmemcheck is configured to use only the first {tabmem.config.csv_max_rows} rows. Set tabmemcheck.config.csv_max_rows to change this behavior.')
116+
CSV_MAX_ROWS_WARNING_PRINTED = True
117+
data = data[:tabmem.config.csv_max_rows]
102118
# remove all trailing newlines
103119
data = [line.rstrip("\n") for line in data]
104120
# remove all empty rows
@@ -109,12 +125,12 @@ def load_csv_rows(csv_file, header=True):
109125
return data
110126

111127

112-
def load_csv_string(csv_file, header=True):
128+
def load_csv_string(csv_file, header=True, size=10000000):
113129
"""Load a csv file as a single string."""
114130
with _csv_file(csv_file) as csv_file:
115131
# load the csv file into a single string
116132
with open(csv_file, "r") as f:
117-
data = f.read()
133+
data = f.read(size)
118134
# remove header TODO, this currently only works if header does not contain "\n"
119135
if not header:
120136
data = data.split("\n")[1:]

0 commit comments

Comments
 (0)