forked from pietrofranceschi/Physalia_ML
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path5.regression_classification_tidymodels.Rmd
More file actions
254 lines (190 loc) · 7.01 KB
/
5.regression_classification_tidymodels.Rmd
File metadata and controls
254 lines (190 loc) · 7.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
---
title: "Regression and Classification with Tidymodels"
author: "Filippo"
date: "`r Sys.Date()`"
output: html_document
---
```{r, warning=FALSE, message=FALSE}
library("knitr")
library("glmnet")
library("ggplot2")
library("flextable")
library("tidyverse")
library("tidymodels")
library("data.table")
```
## Regression
Here we use the `bats` dataset to build a regression model using the programming framework provided by the R package `tidymodels`.
```{r label="bat_data", echo=TRUE}
ch4 <- readxl::read_excel("../data/DNA methylation data.xlsm", sheet = 1)
ch4 <- na.omit(ch4[,-c(1,3)])
head(ch4)
```
#### 1. partitioning the data
We now use `k-fold` cross-validation to estimate the performance of our linear regresdsion model for the prediction of the age of bats based on epigenetics data. The dataset is small, and therefore we apply 5-fold cross-validation.
The objective is to get a **robust estimate** of the model **predictive ability**.
```{r pressure, echo=FALSE}
folds <- vfold_cv(ch4, v = 5)
folds
```
#### 2. specifying the model and workflow
We use `tidymodels` syntax to specify our linear regression model, and put everything into a *workflow*:
```{r}
lm_mod = linear_reg() %>%
set_engine("lm")
lm_wf <-
workflow() %>%
add_model(lm_mod) %>%
add_formula(Age ~ .)
```
#### 3. fitting the model on the resampled data
Now we apply the workflow to the partitions of the data that we prepared before (our 5-fold partitions: 4 folds for training, 1 fold for validation, in turns $\rightarrow$ 5 different partitions of the data)
```{r}
lm_fit <-
lm_wf %>%
fit_resamples(
folds,
control = control_resamples(save_pred = TRUE))
```
#### 4. evaluating the performance of the model
To evaluate the performance of the model we first collect the error metrics: RMSE and $R^2$. These are averages over the 5 validation folds, together with the standard error:
```{r}
collect_metrics(lm_fit) |> regulartable()
```
```{r}
metrs <- collect_metrics(lm_fit)
error_pct = metrs$mean[1]/mean(ch4$Age)
print(error_pct)
```
We see the error is **`r round(error_pct,2)*100`%** of the average age of bats in the dataset. We then look at predictions in the validation folds, and measure the Pearson linear correlation between predicted and observed age values:
```{r}
collect_predictions(lm_fit) %>% group_by(id) %>%
summarise(pearson = cor(.pred,Age)) |>
regulartable()
```
### Repeated cross-validation
This was one single 5-fold cross-validation partition. However, to really take advantage of the full power of resampling methods and get robust estimates of the model performance, it is better (and usually done) to replicate the k-fold cross-validation **n times** (thereby resampling at each of the *n* replicates a different 5-fold partition):
```{r 'repeated_cv'}
gc()
folds = vfold_cv(ch4, v = 5, repeats = 10)
```
```{r}
lm_fit <-
lm_wf %>%
fit_resamples(
folds,
control = control_resamples(save_pred = TRUE))
```
```{r}
collect_metrics(lm_fit) |> regulartable()
```
We see that the standard errors of the estimated RMSE and $R^2$ are now smaller (averages of 50 values instead of 5 values!). We finally get the average Pearson correlation between predicted and observed ages, and its standard deviation:
```{r}
collect_predictions(lm_fit) %>%
group_by(id) %>%
summarise(pearson = cor(.pred,Age)) %>%
summarise(avg = mean(pearson), std = sd(pearson)) |>
regulartable()
```
## Classification
We now see how we can apply `tidymodels` to a classification problem. We use our `dogs` dataset (reduced: only 45 SNPs):
```{r label="reading_data"}
dogs <- fread("../data/dogs_imputed_reduced.raw")
dogs <- dogs %>%
select(-c(IID,FID,PAT,MAT,SEX))
## values for logistic regression in glm() must be 0 - 1
dogs$PHENOTYPE = dogs$PHENOTYPE-1
head(dogs)
print(nrow(dogs))
print(ncol(dogs)-1)
```
This dataset is unbalanced: we have far fewer cases than controls:
```{r}
dogs %>%
group_by(PHENOTYPE) %>%
summarise(N=n())
```
When data are unbalanced, one possible approach to improve the performance of the predictive model is to **subsample** the data:
- **downsampling** or **undersampling**: *sample down* the *majority* class until it has the same frequency of the *minority* class
- **oversample** the data: generate new samples of the *minority* class based on its *neighbourhood* (determined by a similarity matrix)
Subsampling usually improves model performance (see for instance [Menardi and Torelli, 2012](https://link.springer.com/article/10.1007/s10618-012-0295-5)).
It is important that subsampling is **applied only to the training data** (e.g. when doing k-fold cross-validation, under- or over-sampling must happen only in the k-1 folds used for training).
#### 1. subsampling the data
In this case, data will be oversampled (the dataset is small):
```{r}
## ROSE: Randomly Over Sampling Examples
library("themis")
imbal_rec <-
recipe(PHENOTYPE ~ ., data = dogs) %>%
step_bin2factor(PHENOTYPE) %>%
step_rose(PHENOTYPE, skip = TRUE) ## skip == TRUE is important to make sure that subsampling happens only in the training data
```
Let's look at the upsampled data:
```{r}
temp <- prep(imbal_rec) |>
juice()
head(temp)
```
[**Question for you: what do you notice in the upsampled data?**]{style="color:red"}
```{r}
prep(imbal_rec) %>%
juice() %>%
group_by(PHENOTYPE) %>%
summarise(N=n()) |>
kable(format = "html", table.attr = "style = \"color: white;\"")
```
#### 2. partitioning the data
```{r}
folds <- vfold_cv(dogs, v = 5, strata = "PHENOTYPE", repeats = 5)
```
#### 3. specifying the logistic regression model and workflow
```{r}
log_mod = logistic_reg(
mode = "classification",
engine = "glm"
)
log_wf <-
workflow() %>%
add_model(log_mod) %>%
add_recipe(imbal_rec)
```
#### 4. defining performance metrics
```{r}
cls_metrics <- metric_set(roc_auc,sens,spec)
```
#### 5. fitting the model on the resampled data
```{r, warning=FALSE, message=FALSE}
log_res <- fit_resamples(
log_wf,
resamples = folds,
metrics = cls_metrics,
control = control_resamples(save_pred = TRUE)
)
```
#### 6. evaluating the performance of the model
We look at the average:
- `AUC`: area under the ROC curve
- `sens`: TPR (true positive rate)
- `spec`: TNR (true negative rate)
```{r}
collect_metrics(log_res)
```
Then we can look at predictions and calculate the overall error rate:
```{r}
collect_predictions(log_res) %>%
group_by(id,id2) %>%
summarise(errors = sum(.pred_class != PHENOTYPE), tot = length(PHENOTYPE), error_rate = errors/tot) |>
regulartable()
```
```{r}
collect_predictions(log_res) %>%
group_by(id,id2) %>%
summarise(errors = sum(.pred_class != PHENOTYPE), tot = length(PHENOTYPE), error_rate = errors/tot) %>%
ungroup() %>%
summarise(avg_er = mean(error_rate), std = sd(error_rate)) |>
regulartable() |>
autofit()
```
### Exercise 5.1
- Can you calculate the decompostion of the above error rate into the false positive and false negative rates?
- Can you add the MCC to the used metrics for model evaluation? [look here](https://yardstick.tidymodels.org/reference/mcc.html)