Skip to content

Commit 8a5d125

Browse files
authored
Dynamic dataset names (#5)
* remove data/ * list datasets() * Increment version number to 0.3.0
1 parent 244f2e0 commit 8a5d125

29 files changed

+368
-202
lines changed

DESCRIPTION

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Description: Check available classification and regression data sets from the PM
55
These data sets cover a range of applications, and include binary/multi-class classification problems and
66
regression problems, as well as combinations of categorical, ordinal, and continuous features.
77
There are currently over 150 datasets included in the PMLB repository.
8-
Version: 0.2.3
8+
Version: 0.3.0
99
Authors@R: c(
1010
person("Trang", "Le", email = "[email protected]", role = c("aut", "cre"), comment = "https://trang.page/"),
1111
person("makeyourownmaker", email = "[email protected]", role = "aut", comment = "https://github.com/makeyourownmaker"),
@@ -22,3 +22,6 @@ URL: https://github.com/EpistasisLab/pmlbr
2222
Encoding: UTF-8
2323
LazyData: true
2424
RoxygenNote: 7.3.2
25+
Suggests:
26+
testthat (>= 3.0.0)
27+
Config/testthat/edition: 3

NAMESPACE

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,10 @@
33
S3method(nearest_datasets,character)
44
S3method(nearest_datasets,data.frame)
55
S3method(nearest_datasets,default)
6+
export(classification_datasets)
7+
export(dataset_names)
68
export(fetch_data)
79
export(nearest_datasets)
10+
export(pmlb_metadata)
11+
export(regression_datasets)
12+
export(summary_stats)

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# pmlbr 0.3.0
2+
13
# pmlbr 0.2.3
24

35
* Use interactive()

R/data.R

Lines changed: 0 additions & 39 deletions
This file was deleted.

R/globals.R

Lines changed: 0 additions & 1 deletion
This file was deleted.

R/list_datasets.R

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#' Get metadata for all datasets in PMLB.
2+
#'
3+
#' Metadata like summary statistics and names of available datasets
4+
#' on the PMLB repository.
5+
#'
6+
#' @return A list containing summary_stats, dataset_names, classification_datasets, and regression_datasets
7+
#' @export
8+
#' @examples
9+
#' if (interactive()) {
10+
#' sample(pmlb_metadata()$dataset_names, 10)
11+
#' }
12+
pmlb_metadata <- function() {
13+
if (!exists("summary_stats", envir = .pmlbr_env)) {
14+
links_to_stats <- 'https://github.com/EpistasisLab/pmlb/raw/master/pmlb/all_summary_stats.tsv'
15+
summary_stats <- utils::read.csv(links_to_stats, sep = '\t')
16+
colnames(summary_stats) <- tolower(gsub(
17+
'X.',
18+
'n_',
19+
colnames(summary_stats)
20+
))
21+
assign(
22+
"summary_stats",
23+
summary_stats,
24+
envir = .pmlbr_env
25+
)
26+
assign(
27+
"dataset_names",
28+
summary_stats$dataset,
29+
envir = .pmlbr_env
30+
)
31+
assign(
32+
"regression_datasets",
33+
sort(summary_stats[summary_stats$task == "regression", "dataset"]),
34+
envir = .pmlbr_env
35+
)
36+
assign(
37+
"classification_datasets",
38+
sort(summary_stats[summary_stats$task == "classification", "dataset"]),
39+
envir = .pmlbr_env
40+
)
41+
}
42+
43+
list(
44+
summary_stats = .pmlbr_env$summary_stats,
45+
dataset_names = .pmlbr_env$dataset_names,
46+
classification_datasets = .pmlbr_env$classification_datasets,
47+
regression_datasets = .pmlbr_env$regression_datasets
48+
)
49+
}
50+
51+
52+
#' All available datasets
53+
#'
54+
#' @return A character vector of all dataset names.
55+
#' @export
56+
#' @examples
57+
#' if (interactive()) {
58+
#' sample(dataset_names(), 10)
59+
#' }
60+
dataset_names <- function() {
61+
pmlb_metadata()$dataset_names
62+
}
63+
64+
#' Classification datasets
65+
#'
66+
#' @return A character vector of classification dataset names.
67+
#' @export
68+
#' @examples
69+
#' if (interactive()) {
70+
#' sample(classification_datasets(), 10)
71+
#' }
72+
classification_datasets <- function() {
73+
pmlb_metadata()$classification_datasets
74+
}
75+
76+
#' Regression datasets
77+
#'
78+
#' @return A character vector of regression dataset names.
79+
#' @export
80+
#' @examples
81+
#' if (interactive()) {
82+
#' sample(regression_datasets(), 10)
83+
#' }
84+
regression_datasets <- function() {
85+
pmlb_metadata()$regression_datasets
86+
}
87+
88+
#' Summary statistics
89+
#'
90+
#' @return A dataframe of summary statistics of all available datasets,
91+
#' including number of instances/rows, number of columns/features, task, etc.
92+
#'
93+
#' @export
94+
#' @examples
95+
#' if (interactive()) {
96+
#' head(summary_stats())
97+
#' }
98+
summary_stats <- function() {
99+
pmlb_metadata()$summary_stats
100+
}

R/nearest.R

Lines changed: 64 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#' @param n_neighbors Integer. The number of dataset names to return as neighbors.
1111
#' @param dimensions Character vector specifying dataset characteristics to include in similarity calculation.
1212
#' Dimensions must correspond to numeric columns of
13-
#' [all_summary_stats.tsv](https://github.com/EpistasisLab/pmlb/blob/master/pmlb/all_summary_stats.tsv).
13+
#' [all_summary_stats.tsv](https://github.com/EpistasisLab/pmlb/blob/master/pmlb/all_summarystats.tsv).
1414
#' If 'all' (default), uses all numeric columns.
1515
#' @param task Character string specifying classification or regression for summary stat generation.
1616
#' @param target_name Character string specifying column of target/dependent variable.
@@ -26,34 +26,45 @@
2626
#' nearest_datasets('penguins')
2727
#' nearest_datasets(fetch_data('penguins'))
2828
#' }
29-
nearest_datasets <- function(x, ...){
29+
nearest_datasets <- function(x, ...) {
3030
UseMethod('nearest_datasets', x)
3131
}
3232

3333

3434
#' @rdname nearest_datasets-methods
3535
#' @export
36-
nearest_datasets.default <- function(x, ...){
36+
nearest_datasets.default <- function(x, ...) {
3737
stop('`x` must be of class `data.frame` or `character`.')
3838
}
3939

4040

4141
#' @rdname nearest_datasets-methods
4242
#' @export
4343
nearest_datasets.character <- function(
44-
x, n_neighbors = 5,
44+
x,
45+
n_neighbors = 5,
4546
dimensions = c('n_instances', 'n_features'),
46-
target_name = 'target', ...) {
47-
48-
if (!(x %in% dataset_names))
49-
stop("'dataset_name' ", x, " not found in PMLB.\n * Check spelling, capitalisation etc.", call.=FALSE)
50-
dataset_stats <- summary_stats[summary_stats$dataset == x, ]
51-
52-
num_cols <- unlist(lapply(summary_stats, function(x) is.numeric(x)||is.integer(x)))
53-
summary_task <- summary_stats[summary_stats$task == dataset_stats$task, ] # restrict to same task
47+
target_name = 'target',
48+
...
49+
) {
50+
if (!(x %in% dataset_names()))
51+
stop(
52+
"'dataset_name' ",
53+
x,
54+
" not found in PMLB.\n * Check spelling, capitalisation etc.",
55+
call. = FALSE
56+
)
57+
sum_stats <- summary_stats()
58+
dataset_stats <- sum_stats[sum_stats$dataset == x, ]
59+
60+
num_cols <- unlist(lapply(
61+
sum_stats,
62+
function(x) is.numeric(x) || is.integer(x)
63+
))
64+
summary_task <- sum_stats[sum_stats$task == dataset_stats$task, ] # restrict to same task
5465
summary_i <- summary_task[, num_cols]
5566

56-
if (length(dimensions) == 1 && dimensions == 'all'){
67+
if (length(dimensions) == 1 && dimensions == 'all') {
5768
dimensions <- colnames(summary_i)
5869
} else {
5970
stopifnot(dimensions %in% colnames(summary_i))
@@ -70,28 +81,36 @@ nearest_datasets.character <- function(
7081
#' @rdname nearest_datasets-methods
7182
#' @export
7283
nearest_datasets.data.frame <- function(
73-
x, y = NULL, n_neighbors = 5,
84+
x,
85+
y = NULL,
86+
n_neighbors = 5,
7487
dimensions = c('n_instances', 'n_features'),
7588
task = c('classification', 'regression'),
76-
target_name = 'target', ...) {
77-
89+
target_name = 'target',
90+
...
91+
) {
7892
df <- if (is.null(y)) x else data.frame(x, target = y)
7993

8094
# get summary stats for dataset
81-
if (is.null(task)){
82-
task <- if (length(unique(df$target)) < 5) 'classification' else 'regression'
95+
if (is.null(task)) {
96+
task <- if (length(unique(df$target)) < 5) 'classification' else
97+
'regression'
8398
} else {
8499
task <- match.arg(task)
85100
}
86101

87102
if (!(target_name %in% colnames(df)))
88103
stop(paste('Either x or y must contain', target_name))
89104

90-
num_cols <- unlist(lapply(summary_stats, function(x) is.numeric(x)||is.integer(x)))
91-
summary_task <- summary_stats[summary_stats$task == task, ] # restrict to same task
105+
sum_stats <- summary_stats()
106+
num_cols <- unlist(lapply(
107+
sum_stats,
108+
function(x) is.numeric(x) || is.integer(x)
109+
))
110+
summary_task <- sum_stats[sum_stats$task == task, ] # restrict to same task
92111
summary_i <- summary_task[, num_cols]
93112

94-
if (length(dimensions) == 1 && dimensions == 'all'){
113+
if (length(dimensions) == 1 && dimensions == 'all') {
95114
dimensions <- colnames(summary_i)
96115
} else {
97116
stopifnot(dimensions %in% colnames(summary_i))
@@ -100,22 +119,22 @@ nearest_datasets.data.frame <- function(
100119

101120
feat_names <- setdiff(colnames(df), target_name)
102121
types <- vector('character')
103-
for (i in feat_names){
104-
types[i] <- get_type(df[,i], include_binary = TRUE)
122+
for (i in feat_names) {
123+
types[i] <- get_type(df[, i], include_binary = TRUE)
105124
}
106125

107126
feat <- table(types)
108-
for (type in c('binary', 'categorical', 'continuous')){
127+
for (type in c('binary', 'categorical', 'continuous')) {
109128
if (!type %in% names(feat)) feat[type] <- 0
110129
}
111130
imb <- compute_imbalance(df[, target_name])
112131

113132
dataset_stats <- data.frame(
114133
n_instances = nrow(df),
115134
n_features = length(feat_names),
116-
n_binary_features = feat['binary'],
117-
n_categorical_features = feat['categorical'],
118-
n_continuous_features = feat['continuous'],
135+
n_binary_features = feat[['binary']],
136+
n_categorical_features = feat[['categorical']],
137+
n_continuous_features = feat[['continuous']],
119138
endpoint_type = get_type(df[, target_name]),
120139
n_classes = imb[['num_classes']],
121140
imbalance = imb[['imbalance']],
@@ -136,23 +155,25 @@ nearest_datasets.data.frame <- function(
136155
#' where zero means that the dataset is perfectly balanced
137156
#' and the higher the value, the more imbalanced the dataset.
138157
#'
139-
compute_imbalance <- function(target_col){
158+
compute_imbalance <- function(target_col) {
140159
imb <- 0
141160
classes_count <- table(target_col)
142161
num_classes <- length(classes_count)
143-
for (x in classes_count){
144-
p_x = x/length(target_col)
162+
for (x in classes_count) {
163+
p_x = x / length(target_col)
145164
}
146165

147-
if (p_x > 0){
148-
imb = imb + (p_x - 1/num_classes)*(p_x - 1/num_classes)
166+
if (p_x > 0) {
167+
imb = imb + (p_x - 1 / num_classes) * (p_x - 1 / num_classes)
149168
}
150169

151170
# worst case scenario: all but 1 examplars in 1st class
152171
# the remaining one in 2nd class
153-
worst_case <- (num_classes-1)*(1/num_classes)^2 + (1-1/num_classes)^2
172+
worst_case <- (num_classes - 1) *
173+
(1 / num_classes)^2 +
174+
(1 - 1 / num_classes)^2
154175

155-
list(num_classes = num_classes, imbalance = imb/worst_case)
176+
list(num_classes = num_classes, imbalance = imb / worst_case)
156177
}
157178

158179
#' Get type/class of given vector.
@@ -163,14 +184,17 @@ compute_imbalance <- function(target_col){
163184
#'
164185
#' @return Type/class of `x`.
165186
#'
166-
get_type <- function(x, include_binary = FALSE){
187+
get_type <- function(x, include_binary = FALSE) {
167188
x <- stats::na.omit(x)
168189

169-
if (inherits(x, 'numeric')){
190+
if (inherits(x, 'numeric')) {
170191
return('continuous')
171-
} else if (inherits(x, 'integer') || inherits(x, 'factor')){
172-
if (include_binary){
173-
if (length(unique(x)) == 2) return('binary')}
192+
} else if (inherits(x, 'integer') || inherits(x, 'factor')) {
193+
if (include_binary) {
194+
if (length(unique(x)) == 2) return('binary')
195+
}
174196
return('categorical')
175-
} else {stop("Cannot get types for dataset columns")}
197+
} else {
198+
stop("Cannot get types for dataset columns")
199+
}
176200
}

0 commit comments

Comments
 (0)