1+ import numpy as np
2+ import urllib .request
3+ from io import StringIO
4+
5+ def load_dataset (dataset_name ):
6+ '''
7+ Load real-world datasets. Processed datasets are from https://github.com/cmu-phil/example-causal-datasets/tree/main
8+
9+ Parameters
10+ ----------
11+ dataset_name : str, ['sachs', 'boston_housing', 'airfoil']
12+
13+ Returns
14+ -------
15+ data = np.array
16+ labels = list
17+ '''
18+
19+ url_mapping = {
20+ 'sachs' : 'https://raw.githubusercontent.com/cmu-phil/example-causal-datasets/main/real/sachs/data/sachs.2005.continuous.txt' ,
21+ 'boston_housing' : 'https://raw.githubusercontent.com/cmu-phil/example-causal-datasets/main/real/boston-housing/data/boston-housing.continuous.txt' ,
22+ 'airfoil' : 'https://raw.githubusercontent.com/cmu-phil/example-causal-datasets/main/real/airfoil-self-noise/data/airfoil-self-noise.continuous.txt'
23+ }
24+
25+ if dataset_name not in url_mapping :
26+ raise ValueError ("Invalid dataset name" )
27+
28+ url = url_mapping [dataset_name ]
29+ with urllib .request .urlopen (url ) as response :
30+ content = response .read ().decode ('utf-8' ) # Read content and decode to string
31+
32+ # Use StringIO to turn the string content into a file-like object so numpy can read from it
33+ labels_array = np .genfromtxt (StringIO (content ), delimiter = "\t " , dtype = str , max_rows = 1 )
34+ data = np .loadtxt (StringIO (content ), skiprows = 1 )
35+
36+ # Convert labels_array to a list of strings
37+ labels_list = labels_array .tolist ()
38+ if isinstance (labels_list , str ): # handle the case where there's only one label
39+ labels_list = [labels_list ]
40+
41+ return data , labels_list
0 commit comments