Skip to content

Commit b2cff39

Browse files
authored
Merge pull request #35 from P2GX/revert-30-lc/tempered_posterior_experiment
Revert "tempered posterior experiment (Temperature) with log scores"
2 parents 122bb59 + a1b0fd5 commit b2cff39

File tree

5 files changed

+34
-75
lines changed

5 files changed

+34
-75
lines changed

boqa-cli/src/main/java/org/p2gx/boqa/cli/cmd/BoqaBenchmarkCommand.java

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,6 @@ public class BoqaBenchmarkCommand implements Callable<Integer> {
9898
split = ",")
9999
private Set<String> diseaseDatabases;
100100

101-
@CommandLine.Option(
102-
names={"-T","--temperature"},
103-
description = "Float value such that temperature>1 (this is not an inverse temperature) to stabilize distribution.")
104-
private Double temperature;
105-
106101
@Override
107102
public Integer call() throws Exception {
108103
LOGGER.info("Starting up BOQA analysis, loading ontology file {} ...", ontologyFile);
@@ -134,8 +129,8 @@ public Integer call() throws Exception {
134129

135130
LOGGER.debug("Disease data parsed from {}", phenotypeAnnotationFile);
136131

137-
AlgorithmParameters params = AlgorithmParameters.create(alpha, beta, temperature);
138-
LOGGER.info("Using alpha={}, beta={}, temperature={}", params.getAlpha(), params.getBeta(), params.getTemperature());
132+
AlgorithmParameters params = AlgorithmParameters.create(alpha, beta);
133+
LOGGER.info("Using alpha={}, beta={}", params.getAlpha(), params.getBeta());
139134

140135
// Initialize Counter
141136
Counter counter = new BoqaSetCounter(diseaseData, hpo);
@@ -177,7 +172,7 @@ public Integer call() throws Exception {
177172
Paths.get(ontologyFile),
178173
phenotypeAnnotationFile,
179174
cliArgs,
180-
Map.of("alpha", params.getAlpha(), "beta", params.getBeta(), "temperature", params.getTemperature()),
175+
Map.of("alpha", params.getAlpha(), "beta", params.getBeta()),
181176
outPath
182177
);
183178
LOGGER.info("BOQA analysis completed successfully.");

boqa-core/src/main/java/org/p2gx/boqa/core/algorithm/AlgorithmParameters.java

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,37 +5,27 @@
55
* Alpha and beta are fixed (across all diseases and patients) for a single run.
66
*
77
* {@code ALPHA} represents the probability of a false positive, {@code BETA} that of a false negative.
8-
* {@code TEMPERATURE} (or T) can be used to temper the posterior distribution. The posterior is effectively raised
9-
* to the 1/T-th power, or, equivalently, this can be seen as taking the T-th root.
108
*/
119
public final class AlgorithmParameters {
1210
private static final double DEFAULT_ALPHA = 1.0 / 19077;
1311
private static final double DEFAULT_BETA = 0.9;
14-
private static final double DEFAULT_TEMPERATURE = 1.0;
1512

1613
private final double alpha;
1714
private final double beta;
18-
private final double temperature;
15+
1916
private final double logAlpha;
2017
private final double logBeta;
2118
private final double logOneMinusAlpha;
2219
private final double logOneMinusBeta;
2320

24-
private AlgorithmParameters(double alpha, double beta, double temperature) {
21+
private AlgorithmParameters(double alpha, double beta) {
2522
this.alpha = alpha;
2623
this.beta = beta;
27-
this.temperature = temperature;
24+
2825
this.logAlpha = Math.log(alpha);
2926
this.logBeta = Math.log(beta);
3027
this.logOneMinusAlpha = Math.log(1-alpha);
3128
this.logOneMinusBeta = Math.log(1-beta);
32-
}
33-
34-
/**
35-
* Create parameters using the default values.
36-
*/
37-
public static AlgorithmParameters create() {
38-
return create(DEFAULT_ALPHA, DEFAULT_BETA, DEFAULT_TEMPERATURE);
3929

4030
}
4131

@@ -45,44 +35,37 @@ public static AlgorithmParameters create() {
4535
*
4636
* @param alpha the alpha parameter (false positive probability), or null for default
4737
* @param beta the beta parameter (false negative probability), or null for default
48-
* @param temperature the inverse of the exponent for the posterior, or null for default
4938
* @return AlgorithmParameters instance
50-
* @throws IllegalArgumentException if alpha or beta is not in the range (0, 1) and temperature is less than 1
39+
* @throws IllegalArgumentException if alpha or beta is not in the range (0, 1)
5140
*/
52-
public static AlgorithmParameters create(Double alpha, Double beta, Double temperature) {
41+
public static AlgorithmParameters create(Double alpha, Double beta) {
5342
double a = (alpha != null) ? alpha : DEFAULT_ALPHA;
5443
double b = (beta != null) ? beta : DEFAULT_BETA;
55-
double t = (temperature != null) ? temperature : DEFAULT_TEMPERATURE;
5644
// Validate alpha
5745
if (a <= 0.0 || a >= 1.0) {
5846
throw new IllegalArgumentException(
59-
String.format("Alpha must be in the range (0, 1), exclusive. Got: %f", a)
47+
String.format("Alpha must be in the range (0, 1), exclusive. Got: %f", alpha)
6048
);
6149
}
50+
6251
// Validate beta
6352
if (b <= 0.0 || b >= 1.0) {
6453
throw new IllegalArgumentException(
65-
String.format("Beta must be in the range (0, 1), exclusive. Got: %f", b)
54+
String.format("Beta must be in the range (0, 1), exclusive. Got: %f", beta)
6655
);
6756
}
68-
// Validate temperature
69-
if (t < 1.0) {
70-
throw new IllegalArgumentException(
71-
String.format("Temperature must be in the range [1, infinity). Got: %f", t)
72-
);
73-
}
74-
return new AlgorithmParameters(a, b, t);
57+
58+
return new AlgorithmParameters(a, b);
7559
}
7660

7761
public double getAlpha() {
7862
return alpha;
7963
}
64+
8065
public double getBeta() {
8166
return beta;
8267
}
83-
public double getTemperature() {
84-
return temperature;
85-
}
68+
8669
public double getLogAlpha() {
8770
return logAlpha;
8871
}

boqa-core/src/main/java/org/p2gx/boqa/core/analysis/BoqaPatientAnalyzer.java

Lines changed: 14 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
import org.slf4j.LoggerFactory;
99
import java.util.*;
1010

11-
import static org.junit.Assert.assertThat;
12-
1311
/**
1412
* Performs BOQA analysis for a given query set of HPO terms (patient's data).
1513
* <p>
@@ -76,7 +74,6 @@ public static BoqaAnalysisResult computeBoqaResultsRawLog(
7674
public static BoqaAnalysisResult computeBoqaResults(
7775
PatientData patientData, Counter counter, int resultsLimit, AlgorithmParameters params) {
7876

79-
8077
// Get BoqaResults with raw log scores
8178
List<BoqaResult> rawLogBoqaResults = new ArrayList<>(computeBoqaResultsRawLog(patientData, counter, params).boqaResults());
8279

@@ -89,25 +86,15 @@ public static BoqaAnalysisResult computeBoqaResults(
8986
.max()
9087
.orElse(Double.NEGATIVE_INFINITY);
9188

92-
double normalizationFactor;
93-
double epsilon = 0.000001d;
94-
95-
// Standard BOQA
96-
if((Math.abs(1.0 - params.getTemperature()) < epsilon)){
97-
// Compute sum of exp(logP - maxLogP)
98-
normalizationFactor = rawLogBoqaResults.stream()
99-
.mapToDouble(r -> Math.exp(r.boqaScore() - maxLogP))
100-
.sum();
101-
} else {
102-
// Alternative to Standard BOQA
103-
// Non-unit temperature normalizes trivially by largest score, which is already present below
104-
normalizationFactor = 1.0;
105-
}
89+
// Compute sum of exp(logP - maxLogP)
90+
double sum = rawLogBoqaResults.stream()
91+
.mapToDouble(r -> Math.exp(r.boqaScore() - maxLogP))
92+
.sum();
10693

10794
// Normalize
10895
List<BoqaResult> allResults = new ArrayList<>();
10996
rawLogBoqaResults.forEach(r -> {
110-
double normProb = Math.exp(r.boqaScore() - maxLogP) / normalizationFactor;
97+
double normProb = Math.exp(r.boqaScore() - maxLogP) / sum;
11198
allResults.add(new BoqaResult(r.counts(), normProb));
11299
});
113100

@@ -117,18 +104,17 @@ public static BoqaAnalysisResult computeBoqaResults(
117104
/**
118105
* Computes the un-normalized BOQA log probability for given BoqaCounts and parameters:
119106
* <p>
120-
* log(P) = [fp × log(α) + fn × log(β) + tn × log(1-α) + tp × log(1-β)] / T
107+
* log(P) = fp × log(α) + fn × log(β) + tn × log(1-α) + tp × log(1-β)
121108
* </p>
122-
* @param params alpha, beta, log(alpha), log(beta), termperature T etc.
109+
* @param params alpha, beta, log(alpha), log(beta) etc.
123110
* @param counts The {@link BoqaCounts} for a query and a disease.
124111
* @return The un-normalized BOQA log probability score.
125112
*/
126113
static double computeUnnormalizedLogProbability(AlgorithmParameters params, BoqaCounts counts){
127-
return (counts.fpBoqaCount() * params.getLogAlpha() +
114+
return counts.fpBoqaCount() * params.getLogAlpha() +
128115
counts.fnBoqaCount() * params.getLogBeta() +
129116
counts.tnBoqaCount() * params.getLogOneMinusAlpha() +
130-
counts.tpBoqaCount() * params.getLogOneMinusBeta()
131-
)/params.getTemperature();
117+
counts.tpBoqaCount() * params.getLogOneMinusBeta();
132118
}
133119

134120
/**
@@ -139,18 +125,13 @@ static double computeUnnormalizedLogProbability(AlgorithmParameters params, Boqa
139125
* </pre>
140126
* @param alpha False positive rate parameter.
141127
* @param beta False negative rate parameter.
142-
* @param temperature Use to make distributions more robust.
143128
* @param counts The {@link BoqaCounts} for a disease.
144129
* @return The un-normalized probability score.
145130
*/
146-
static double computeUnnormalizedProbability(double alpha, double beta, double temperature, BoqaCounts counts){
147-
return Math.exp(
148-
(
149-
counts.fpBoqaCount()*Math.log(alpha) +
150-
counts.fnBoqaCount()*Math.log(beta) +
151-
counts.tnBoqaCount()*Math.log(1-alpha) +
152-
counts.tpBoqaCount()*Math.log(1-beta)
153-
) / temperature
154-
);
131+
static double computeUnnormalizedProbability(double alpha, double beta, BoqaCounts counts){
132+
return Math.pow(alpha, counts.fpBoqaCount())*
133+
Math.pow(beta, counts.fnBoqaCount())*
134+
Math.pow(1-alpha, counts.tnBoqaCount())*
135+
Math.pow(1-beta, counts.tpBoqaCount());
155136
}
156137
}

boqa-core/src/test/java/org/p2gx/boqa/core/algorithm/BoqaSetCounterTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ void testComputeBoqaCountsAgainstPyboqa(
103103
}
104104
Path ppkt = Path.of(resourceUrl.toURI());
105105
int limit = Integer.MAX_VALUE;
106-
AlgorithmParameters params = AlgorithmParameters.create(0.2,0.3, 1.0); // numbers don't matter
106+
AlgorithmParameters params = AlgorithmParameters.create(0.2,0.3); // numbers don't matter
107107
BoqaAnalysisResult boqaAnalysisResult = computeBoqaResults(
108108
new PhenopacketData(ppkt), counter, limit, params);
109109

boqa-core/src/test/java/org/p2gx/boqa/core/analysis/BoqaPatientAnalyzerTest.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ void testComputeUnnormalizedProbability(
6565
){
6666
// Initialize BoqaCounts
6767
BoqaCounts counts = new BoqaCounts("idIsUnimportant", "labelIsUnimportant", count1mb, countA, count1ma, countB);
68-
double actualScore = computeUnnormalizedProbability(alpha, beta, 1.0, counts);
68+
double actualScore = computeUnnormalizedProbability(alpha, beta, counts);
6969

7070
// Assert with small delta for floating-point comparison
7171
assertEquals(expectedScore, actualScore, 1e-9);
@@ -97,23 +97,23 @@ void testComputeBoqaResults(){
9797
int limit = counter.getDiseaseIds().size();
9898
double alpha = 0.01;
9999
double beta = 0.9;
100-
AlgorithmParameters params = AlgorithmParameters.create(alpha, beta, 1.0);
100+
AlgorithmParameters params = AlgorithmParameters.create(alpha, beta);
101101

102102
// Run 'computeBoqaResults'
103103
BoqaAnalysisResult boqaAnalysisResult = BoqaPatientAnalyzer.computeBoqaResults(
104104
patientData, counter, limit, params);
105105

106106
// Recompute un-normalized probabilities in the conventional way
107107
List<Double> rawProbs = boqaAnalysisResult.boqaResults().stream()
108-
.map(result -> computeUnnormalizedProbability(params.getAlpha(), params.getBeta(), params.getTemperature(), result.counts()))
108+
.map(result -> computeUnnormalizedProbability(params.getAlpha(), params.getBeta(), result.counts()))
109109
.toList();
110110

111111
// Get the sum for normalization
112112
double rawProbsSum = rawProbs.stream().mapToDouble(Double::doubleValue).sum();
113113

114114
// Compare normalized probabilities from BoqaResults with those recalculated from counts
115115
boqaAnalysisResult.boqaResults().forEach(br-> {
116-
double expectedNormProb = computeUnnormalizedProbability(params.getAlpha(), params.getBeta(), params.getTemperature(), br.counts()) / rawProbsSum;
116+
double expectedNormProb = computeUnnormalizedProbability(params.getAlpha(), params.getBeta(), br.counts()) / rawProbsSum;
117117
double actualNormProb = br.boqaScore();
118118
assertEquals(expectedNormProb, actualNormProb, 1e-9);
119119
});

0 commit comments

Comments
 (0)