-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathget_dataset.py
More file actions
75 lines (62 loc) · 2.32 KB
/
get_dataset.py
File metadata and controls
75 lines (62 loc) · 2.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#!/usr/bin/env python3
#
# The script is used for downloading the three datasets used by the examples.
#
# All three datasets come from LIBSVM website, and are stored in LIBSVM format.
# For more details, please refer to the following link:
# https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html
#
# Besides, all three datasets are for binary classification, since most of the
# state-of-the-arts active learning algorithms (query strategies) are only
# suitable for binary classification.
#
# The following table describes some informations about
# the three datasets: australian, diabetes, and heart.
#
# +------------+-----+----+
# | dataset | N | D |
# +============+=====+====+
# | australian | 690 | 14 |
# +------------+-----+----+
# | diabetes | 768 | 8 |
# +------------+-----+----+
# | heart | 270 | 13 |
# +------------+-----+----+
#
# N is the number of samples, and D is the dimension of the input feature.
# labels y \in {-1, +1}
import os
import urllib
import random
AUS_URL = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/australian_scale'
AUS_SIZE = 690
DB_URL = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/diabetes_scale'
DB_SIZE = 768
HT_URL = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/heart_scale'
HT_SIZE = 270
TARGET_PATH = os.path.dirname(os.path.realpath(__file__))
def main():
print('downloading australian ...')
# rows = list(urllib.request.urlopen(AUS_URL))
rows = list(urllib.urlopen(AUS_URL))
selected = random.sample(rows, AUS_SIZE)
with open(os.path.join(TARGET_PATH, 'australian.txt'), 'wb') as f:
for row in selected:
f.write(row)
print('australian downloaded successfully !\n')
print('downloading diabetes ...')
rows = list(urllib.urlopen(DB_URL))
selected = random.sample(rows, DB_SIZE)
with open(os.path.join(TARGET_PATH, 'diabetes.txt'), 'wb') as f:
for row in selected:
f.write(row)
print('diabetes downloaded successfully !\n')
print('downloading heart ...')
rows = list(urllib.urlopen(HT_URL))
selected = random.sample(rows, HT_SIZE)
with open(os.path.join(TARGET_PATH, 'heart.txt'), 'wb') as f:
for row in selected:
f.write(row)
print('heart downloaded successfully !')
if __name__ == '__main__':
main()