Skip to content

Commit 46c7973

Browse files
committed
hotfix PLSDA predicted groups
- predicted group now correctly assigned based on ingroup probability or yhat value
1 parent 710dd2a commit 46c7973

File tree

6 files changed

+117
-91
lines changed

6 files changed

+117
-91
lines changed

R/PLSDA_class.R

Lines changed: 105 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
#' @include PLSR_class.R
44
#' @examples
55
#' M = PLSDA('number_components'=2,factor_name='Species')
6-
PLSDA = function(number_components=2,factor_name,...) {
6+
PLSDA = function(number_components=2,factor_name,pred_method='max_prob',...) {
77
out=struct::new_struct('PLSDA',
8-
number_components=number_components,
9-
factor_name=factor_name,
10-
...)
8+
number_components=number_components,
9+
factor_name=factor_name,
10+
pred_method=pred_method,
11+
...)
1112
return(out)
1213
}
1314

@@ -29,19 +30,24 @@ PLSDA = function(number_components=2,factor_name,...) {
2930
pred='data.frame',
3031
threshold='numeric',
3132
sr = 'entity',
32-
sr_pvalue='entity'
33+
sr_pvalue='entity',
34+
pred_method='entity'
3335

3436
),
35-
prototype = list(name='Partial least squares discriminant analysis',
37+
prototype = list(
38+
name='Partial least squares discriminant analysis',
3639
type="classification",
3740
predicted='pred',
3841
libraries='pls',
39-
description=paste0('PLS is a multivariate regression technique that ',
42+
description=paste0(
43+
'PLS is a multivariate regression technique that ',
4044
'extracts latent variables maximising covariance between the input ',
4145
'data and the response. The Discriminant Analysis variant uses group ',
42-
'labels in the response variable and applies a threshold to the ',
43-
'predicted values in order to predict group membership for new samples.'),
44-
.params=c('number_components','factor_name'),
46+
'labels in the response variable. For >2 groups a 1-vs-all ',
47+
'approach is used. Group membership can be predicted for test ',
48+
'samples based on a probability estimate of group membership, ',
49+
'or the estimated y-value.'),
50+
.params=c('number_components','factor_name','pred_method'),
4551
.outputs=c(
4652
'scores',
4753
'loadings',
@@ -57,12 +63,28 @@ PLSDA = function(number_components=2,factor_name,...) {
5763
'sr',
5864
'sr_pvalue'),
5965

60-
number_components=entity(value = 2,
66+
number_components=entity(
67+
value = 2,
6168
name = 'Number of components',
6269
description = 'The number of PLS components',
6370
type = c('numeric','integer')
6471
),
6572
factor_name=ents$factor_name,
73+
pred_method=enum(
74+
name='Prediction method',
75+
description=c(
76+
'max_yhat'=
77+
paste0('The predicted group is selected based on the ',
78+
'largest value of y_hat.'),
79+
'max_prob'=
80+
paste0('The predicted group is selected based on the ',
81+
'largest probability of group membership.')
82+
),
83+
value='max_prob',
84+
allowed=c('max_yhat','max_prob'),
85+
type='character',
86+
max_length=1
87+
),
6688
sr = entity(
6789
name = 'Selectivity ratio',
6890
description = paste0(
@@ -92,8 +114,8 @@ PLSDA = function(number_components=2,factor_name,...) {
92114
pages = '122-128',
93115
author = as.person("Nestor F. Perez and Joan Ferre and Ricard Boque"),
94116
title = paste0('Calculation of the reliability of ',
95-
'classification in discriminant partial least-squares ',
96-
'binary classification'),
117+
'classification in discriminant partial least-squares ',
118+
'binary classification'),
97119
journal = "Chemometrics and Intelligent Laboratory Systems"
98120
),
99121
bibentry(
@@ -113,80 +135,83 @@ PLSDA = function(number_components=2,factor_name,...) {
113135
#' @export
114136
#' @template model_train
115137
setMethod(f="model_train",
116-
signature=c("PLSDA",'DatasetExperiment'),
117-
definition=function(M,D)
118-
{
119-
SM=D$sample_meta
120-
y=SM[[M$factor_name]]
121-
# convert the factor to a design matrix
122-
z=model.matrix(~y+0)
123-
z[z==0]=-1 # +/-1 for PLS
124-
125-
X=as.matrix(D$data) # convert X to matrix
126-
127-
Z=as.data.frame(z)
128-
colnames(Z)=as.character(interaction('PLSDA',1:ncol(Z),sep='_'))
129-
130-
D$sample_meta=cbind(D$sample_meta,Z)
131-
132-
# PLSR model
133-
N = PLSR(number_components=M$number_components,factor_name=colnames(Z))
134-
N = model_apply(N,D)
135-
136-
# copy outputs across
137-
output_list(M) = output_list(N)
138-
139-
# some specific outputs for PLSDA
140-
output_value(M,'design_matrix')=Z
141-
output_value(M,'y')=D$sample_meta[,M$factor_name,drop=FALSE]
142-
143-
# for PLSDA compute probabilities
144-
probs=prob(as.matrix(M$yhat),as.matrix(M$yhat),D$sample_meta[[M$factor_name]])
145-
output_value(M,'probability')=as.data.frame(probs$ingroup)
146-
output_value(M,'threshold')=probs$threshold
147-
148-
# update column names for outputs
149-
colnames(M$reg_coeff)=levels(y)
150-
colnames(M$sr)=levels(y)
151-
colnames(M$vip)=levels(y)
152-
colnames(M$yhat)=levels(y)
153-
colnames(M$design_matrix)=levels(y)
154-
colnames(M$probability)=levels(y)
155-
names(M$threshold)=levels(y)
156-
colnames(M$sr_pvalue)=levels(y)
157-
158-
return(M)
159-
}
138+
signature=c("PLSDA",'DatasetExperiment'),
139+
definition=function(M,D)
140+
{
141+
SM=D$sample_meta
142+
y=SM[[M$factor_name]]
143+
# convert the factor to a design matrix
144+
z=model.matrix(~y+0)
145+
z[z==0]=-1 # +/-1 for PLS
146+
147+
X=as.matrix(D$data) # convert X to matrix
148+
149+
Z=as.data.frame(z)
150+
colnames(Z)=as.character(interaction('PLSDA',1:ncol(Z),sep='_'))
151+
152+
D$sample_meta=cbind(D$sample_meta,Z)
153+
154+
# PLSR model
155+
N = PLSR(number_components=M$number_components,factor_name=colnames(Z))
156+
N = model_apply(N,D)
157+
158+
# copy outputs across
159+
output_list(M) = output_list(N)
160+
161+
# some specific outputs for PLSDA
162+
output_value(M,'design_matrix')=Z
163+
output_value(M,'y')=D$sample_meta[,M$factor_name,drop=FALSE]
164+
165+
# for PLSDA compute probabilities
166+
probs=prob(as.matrix(M$yhat),as.matrix(M$yhat),D$sample_meta[[M$factor_name]])
167+
output_value(M,'probability')=as.data.frame(probs$ingroup)
168+
output_value(M,'threshold')=probs$threshold
169+
170+
# update column names for outputs
171+
colnames(M$reg_coeff)=levels(y)
172+
colnames(M$sr)=levels(y)
173+
colnames(M$vip)=levels(y)
174+
colnames(M$yhat)=levels(y)
175+
colnames(M$design_matrix)=levels(y)
176+
colnames(M$probability)=levels(y)
177+
names(M$threshold)=levels(y)
178+
colnames(M$sr_pvalue)=levels(y)
179+
180+
return(M)
181+
}
160182
)
161183

162184
#' @export
163185
#' @template model_predict
164186
setMethod(f="model_predict",
165-
signature=c("PLSDA",'DatasetExperiment'),
166-
definition=function(M,D)
167-
{
168-
# call PLSR predict
169-
N=callNextMethod(M,D)
170-
SM=N$y
171-
172-
## probability estimate
173-
# http://www.eigenvector.com/faq/index.php?id=38%7C
174-
p=as.matrix(N$pred)
175-
d=prob(x=p,yhat=as.matrix(N$yhat),ytrue=SM[[M$factor_name]])
176-
pred=(p>d$threshold)*1
177-
pred=apply(pred,MARGIN=1,FUN=which.max)
178-
hi=apply(d$ingroup,MARGIN=1,FUN=which.max) # max probability
179-
if (sum(is.na(pred)>0)) {
180-
pred[is.na(pred)]=hi[is.na(pred)] # if none above threshold, use group with highest probability
181-
}
182-
pred=factor(pred,levels=1:nlevels(SM[[M$factor_name]]),labels=levels(SM[[M$factor_name]])) # make sure pred has all the levels of y
183-
q=data.frame("pred"=pred)
184-
output_value(M,'pred')=q
185-
return(M)
186-
}
187+
signature=c("PLSDA",'DatasetExperiment'),
188+
definition=function(M,D)
189+
{
190+
# call PLSR predict
191+
N=callNextMethod(M,D)
192+
SM=N$y
193+
194+
## probability estimate
195+
# http://www.eigenvector.com/faq/index.php?id=38%7C
196+
p=as.matrix(N$pred)
197+
d=prob(x=p,yhat=as.matrix(N$yhat),ytrue=M$y[[M$factor_name]])
198+
199+
# predictions
200+
if (M$pred_method=='max_yhat') {
201+
pred=apply(p,MARGIN=1,FUN=which.max)
202+
} else if (M$pred_method=='max_prob') {
203+
pred=apply(d$ingroup,MARGIN=1,FUN=which.max)
204+
}
205+
pred=factor(pred,levels=1:nlevels(SM[[M$factor_name]]),labels=levels(SM[[M$factor_name]])) # make sure pred has all the levels of y
206+
q=data.frame("pred"=pred)
207+
output_value(M,'pred')=q
208+
return(M)
209+
}
187210
)
188211

189212

213+
214+
190215
prob=function(x,yhat,ytrue)
191216
{
192217
# x is predicted values
@@ -250,8 +275,7 @@ prob=function(x,yhat,ytrue)
250275
}
251276

252277

253-
gauss_intersect=function(m1,m2,s1,s2)
254-
{
278+
gauss_intersect=function(m1,m2,s1,s2) {
255279
#https://stackoverflow.com/questions/22579434/python-finding-the-intersection-point-of-two-gaussian-curves
256280
a=(1/(2*s1*s1))-(1/(2*s2*s2))
257281
b=(m2/(s2*s2)) - (m1/(s1*s1))

man/PLSDA.Rd

Lines changed: 4 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-gridsearch1d.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ test_that('grid_search iterator',{
1616
# run
1717
I=run(I,D,B)
1818
# calculate metric
19-
expect_equal(I$metric$value,0.3,tolerance=0.05)
19+
expect_equal(I$metric$value,0.045,tolerance=0.0005)
2020
})
2121

2222
# test grid search
@@ -36,7 +36,7 @@ test_that('grid_search wf',{
3636
# run
3737
I=run(I,D,B)
3838
# calculate metric
39-
expect_equal(I$metric$value[1],0.3,tolerance=0.05)
39+
expect_equal(I$metric$value[1],0.04,tolerance=0.005)
4040
})
4141

4242
# test grid search

tests/testthat/test-kfold-xval.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ test_that('kfold xval venetian',{
1010
# run
1111
I=run(I,D,B)
1212
# calculate metric
13-
expect_equal(I$metric$mean,0.23,tolerance=0.05)
13+
expect_equal(I$metric$mean,0.11,tolerance=0.005)
1414
})
1515

1616
test_that('kfold xval blocks',{
@@ -26,7 +26,7 @@ test_that('kfold xval blocks',{
2626
# run
2727
I=run(I,D,B)
2828
# calculate metric
29-
expect_equal(I$metric$mean,0.23,tolerance=0.05)
29+
expect_equal(I$metric$mean,0.115,tolerance=0.005)
3030
})
3131

3232
test_that('kfold xval random',{
@@ -40,7 +40,7 @@ test_that('kfold xval random',{
4040
# run
4141
I=run(I,D,B)
4242
# calculate metric
43-
expect_equal(I$metric$mean,0.23,tolerance=0.05)
43+
expect_equal(I$metric$mean,0.105,tolerance=0.0005)
4444
})
4545

4646
test_that('kfold xval metric plot',{

tests/testthat/test-permutation_test.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ test_that('permutation test',{
1313
# calculate metric
1414
B=calculate(B,Yhat=output_value(I,'results.unpermuted')$predicted,
1515
Y=output_value(I,'results.unpermuted')$actual)
16-
expect_equal(value(B),expected=0.211,tolerance=0.004)
16+
expect_equal(value(B),expected=0.105,tolerance=0.0005)
1717
})
1818

1919
# permutation test box plot

tests/testthat/test-permute-sample-order.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ test_that('permute sample order model_seq',{
99
B=balanced_accuracy()
1010
# run
1111
I=run(I,D,B)
12-
expect_equal(I$metric$mean,expected=0.335,tolerance=0.05)
12+
expect_equal(I$metric$mean,expected=0.04,tolerance=0.005)
1313
})
1414

1515
# permute sample order
@@ -23,5 +23,5 @@ test_that('permute sample order iterator',{
2323
B=balanced_accuracy()
2424
# run
2525
I=run(I,D,B)
26-
expect_equal(I$metric$mean,expected=0.339,tolerance=0.05)
26+
expect_equal(I$metric$mean,expected=0.048,tolerance=0.0005)
2727
})

0 commit comments

Comments
 (0)