forked from aghyad-deeb/reward_seeker
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathload_dataset.py
More file actions
71 lines (57 loc) · 2 KB
/
load_dataset.py
File metadata and controls
71 lines (57 loc) · 2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#!/usr/bin/env python3
"""
Script to load the sycophancy dataset using the datasets package.
"""
from datasets import load_dataset
import json
def load_sycophancy_dataset():
"""
Load the sycophancy dataset from the data directory.
Returns:
datasets.Dataset: The loaded dataset
"""
# Load the JSONL file using load_dataset
dataset = load_dataset('json', data_files='data/sycophancy_fact.jsonl')
return dataset
def print_dataset_info(dataset):
"""
Print information about the loaded dataset.
Args:
dataset: The loaded dataset
"""
print("Dataset Information:")
print(f"Number of examples: {len(dataset['train'])}")
print(f"Features: {list(dataset['train'].features.keys())}")
print("\nFirst example:")
first_example = dataset['train'][0]
for key, value in first_example.items():
print(f" {key}: {value}")
def main():
"""Main function to load and display dataset information."""
# Ensure we're using the correct data path
print("Loading sycophancy dataset...")
try:
dataset = load_sycophancy_dataset()
print("Dataset loaded successfully!")
# Print dataset information
print_dataset_info(dataset)
# Show a few examples
print("\n" + "="*50)
print("Sample Examples:")
print("="*50)
for i in range(min(3, len(dataset['train']))):
example = dataset['train'][i]
print(f"\nExample {i+1}:")
print(f"Prompt: {example['prompt_list'][0][:200]}...")
print(f"High reward answer: {example['high_reward_answer']}")
print(f"Other answers: {example['other_answers']}")
except Exception as e:
print(f"Error loading dataset: {e}")
return False
return True
if __name__ == "__main__":
success = main()
if success:
print("\nDataset loading completed successfully!")
else:
print("\nDataset loading failed!")