1010# ' @param n_neighbors Integer. The number of dataset names to return as neighbors.
1111# ' @param dimensions Character vector specifying dataset characteristics to include in similarity calculation.
1212# ' Dimensions must correspond to numeric columns of
13- # ' [all_summary_stats.tsv](https://github.com/EpistasisLab/pmlb/blob/master/pmlb/all_summary_stats .tsv).
13+ # ' [all_summary_stats.tsv](https://github.com/EpistasisLab/pmlb/blob/master/pmlb/all_summarystats .tsv).
1414# ' If 'all' (default), uses all numeric columns.
1515# ' @param task Character string specifying classification or regression for summary stat generation.
1616# ' @param target_name Character string specifying column of target/dependent variable.
2626# ' nearest_datasets('penguins')
2727# ' nearest_datasets(fetch_data('penguins'))
2828# ' }
29- nearest_datasets <- function (x , ... ){
29+ nearest_datasets <- function (x , ... ) {
3030 UseMethod(' nearest_datasets' , x )
3131}
3232
3333
3434# ' @rdname nearest_datasets-methods
3535# ' @export
36- nearest_datasets.default <- function (x , ... ){
36+ nearest_datasets.default <- function (x , ... ) {
3737 stop(' `x` must be of class `data.frame` or `character`.' )
3838}
3939
4040
4141# ' @rdname nearest_datasets-methods
4242# ' @export
4343nearest_datasets.character <- function (
44- x , n_neighbors = 5 ,
44+ x ,
45+ n_neighbors = 5 ,
4546 dimensions = c(' n_instances' , ' n_features' ),
46- target_name = ' target' , ... ) {
47-
48- if (! (x %in% dataset_names ))
49- stop(" 'dataset_name' " , x , " not found in PMLB.\n * Check spelling, capitalisation etc." , call. = FALSE )
50- dataset_stats <- summary_stats [summary_stats $ dataset == x , ]
51-
52- num_cols <- unlist(lapply(summary_stats , function (x ) is.numeric(x )|| is.integer(x )))
53- summary_task <- summary_stats [summary_stats $ task == dataset_stats $ task , ] # restrict to same task
47+ target_name = ' target' ,
48+ ...
49+ ) {
50+ if (! (x %in% dataset_names()))
51+ stop(
52+ " 'dataset_name' " ,
53+ x ,
54+ " not found in PMLB.\n * Check spelling, capitalisation etc." ,
55+ call. = FALSE
56+ )
57+ sum_stats <- summary_stats()
58+ dataset_stats <- sum_stats [sum_stats $ dataset == x , ]
59+
60+ num_cols <- unlist(lapply(
61+ sum_stats ,
62+ function (x ) is.numeric(x ) || is.integer(x )
63+ ))
64+ summary_task <- sum_stats [sum_stats $ task == dataset_stats $ task , ] # restrict to same task
5465 summary_i <- summary_task [, num_cols ]
5566
56- if (length(dimensions ) == 1 && dimensions == ' all' ){
67+ if (length(dimensions ) == 1 && dimensions == ' all' ) {
5768 dimensions <- colnames(summary_i )
5869 } else {
5970 stopifnot(dimensions %in% colnames(summary_i ))
@@ -70,28 +81,36 @@ nearest_datasets.character <- function(
7081# ' @rdname nearest_datasets-methods
7182# ' @export
7283nearest_datasets.data.frame <- function (
73- x , y = NULL , n_neighbors = 5 ,
84+ x ,
85+ y = NULL ,
86+ n_neighbors = 5 ,
7487 dimensions = c(' n_instances' , ' n_features' ),
7588 task = c(' classification' , ' regression' ),
76- target_name = ' target' , ... ) {
77-
89+ target_name = ' target' ,
90+ ...
91+ ) {
7892 df <- if (is.null(y )) x else data.frame (x , target = y )
7993
8094 # get summary stats for dataset
81- if (is.null(task )){
82- task <- if (length(unique(df $ target )) < 5 ) ' classification' else ' regression'
95+ if (is.null(task )) {
96+ task <- if (length(unique(df $ target )) < 5 ) ' classification' else
97+ ' regression'
8398 } else {
8499 task <- match.arg(task )
85100 }
86101
87102 if (! (target_name %in% colnames(df )))
88103 stop(paste(' Either x or y must contain' , target_name ))
89104
90- num_cols <- unlist(lapply(summary_stats , function (x ) is.numeric(x )|| is.integer(x )))
91- summary_task <- summary_stats [summary_stats $ task == task , ] # restrict to same task
105+ sum_stats <- summary_stats()
106+ num_cols <- unlist(lapply(
107+ sum_stats ,
108+ function (x ) is.numeric(x ) || is.integer(x )
109+ ))
110+ summary_task <- sum_stats [sum_stats $ task == task , ] # restrict to same task
92111 summary_i <- summary_task [, num_cols ]
93112
94- if (length(dimensions ) == 1 && dimensions == ' all' ){
113+ if (length(dimensions ) == 1 && dimensions == ' all' ) {
95114 dimensions <- colnames(summary_i )
96115 } else {
97116 stopifnot(dimensions %in% colnames(summary_i ))
@@ -100,22 +119,22 @@ nearest_datasets.data.frame <- function(
100119
101120 feat_names <- setdiff(colnames(df ), target_name )
102121 types <- vector(' character' )
103- for (i in feat_names ){
104- types [i ] <- get_type(df [,i ], include_binary = TRUE )
122+ for (i in feat_names ) {
123+ types [i ] <- get_type(df [, i ], include_binary = TRUE )
105124 }
106125
107126 feat <- table(types )
108- for (type in c(' binary' , ' categorical' , ' continuous' )){
127+ for (type in c(' binary' , ' categorical' , ' continuous' )) {
109128 if (! type %in% names(feat )) feat [type ] <- 0
110129 }
111130 imb <- compute_imbalance(df [, target_name ])
112131
113132 dataset_stats <- data.frame (
114133 n_instances = nrow(df ),
115134 n_features = length(feat_names ),
116- n_binary_features = feat [' binary' ],
117- n_categorical_features = feat [' categorical' ],
118- n_continuous_features = feat [' continuous' ],
135+ n_binary_features = feat [[ ' binary' ] ],
136+ n_categorical_features = feat [[ ' categorical' ] ],
137+ n_continuous_features = feat [[ ' continuous' ] ],
119138 endpoint_type = get_type(df [, target_name ]),
120139 n_classes = imb [[' num_classes' ]],
121140 imbalance = imb [[' imbalance' ]],
@@ -136,23 +155,25 @@ nearest_datasets.data.frame <- function(
136155# ' where zero means that the dataset is perfectly balanced
137156# ' and the higher the value, the more imbalanced the dataset.
138157# '
139- compute_imbalance <- function (target_col ){
158+ compute_imbalance <- function (target_col ) {
140159 imb <- 0
141160 classes_count <- table(target_col )
142161 num_classes <- length(classes_count )
143- for (x in classes_count ){
144- p_x = x / length(target_col )
162+ for (x in classes_count ) {
163+ p_x = x / length(target_col )
145164 }
146165
147- if (p_x > 0 ){
148- imb = imb + (p_x - 1 / num_classes )* (p_x - 1 / num_classes )
166+ if (p_x > 0 ) {
167+ imb = imb + (p_x - 1 / num_classes ) * (p_x - 1 / num_classes )
149168 }
150169
151170 # worst case scenario: all but 1 examplars in 1st class
152171 # the remaining one in 2nd class
153- worst_case <- (num_classes - 1 )* (1 / num_classes )^ 2 + (1 - 1 / num_classes )^ 2
172+ worst_case <- (num_classes - 1 ) *
173+ (1 / num_classes )^ 2 +
174+ (1 - 1 / num_classes )^ 2
154175
155- list (num_classes = num_classes , imbalance = imb / worst_case )
176+ list (num_classes = num_classes , imbalance = imb / worst_case )
156177}
157178
158179# ' Get type/class of given vector.
@@ -163,14 +184,17 @@ compute_imbalance <- function(target_col){
163184# '
164185# ' @return Type/class of `x`.
165186# '
166- get_type <- function (x , include_binary = FALSE ){
187+ get_type <- function (x , include_binary = FALSE ) {
167188 x <- stats :: na.omit(x )
168189
169- if (inherits(x , ' numeric' )){
190+ if (inherits(x , ' numeric' )) {
170191 return (' continuous' )
171- } else if (inherits(x , ' integer' ) || inherits(x , ' factor' )){
172- if (include_binary ){
173- if (length(unique(x )) == 2 ) return (' binary' )}
192+ } else if (inherits(x , ' integer' ) || inherits(x , ' factor' )) {
193+ if (include_binary ) {
194+ if (length(unique(x )) == 2 ) return (' binary' )
195+ }
174196 return (' categorical' )
175- } else {stop(" Cannot get types for dataset columns" )}
197+ } else {
198+ stop(" Cannot get types for dataset columns" )
199+ }
176200}
0 commit comments