Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ public class BoqaBenchmarkCommand implements Callable<Integer> {
split = ",")
private Set<String> diseaseDatabases;

@CommandLine.Option(
names={"-T","--temperature"},
description = "Float value such that temperature>1 (this is not an inverse temperature) to stabilize distribution.")
private Double temperature;

@Override
public Integer call() throws Exception {
LOGGER.info("Starting up BOQA analysis, loading ontology file {} ...", ontologyFile);
Expand Down Expand Up @@ -129,8 +134,8 @@ public Integer call() throws Exception {

LOGGER.debug("Disease data parsed from {}", phenotypeAnnotationFile);

AlgorithmParameters params = AlgorithmParameters.create(alpha, beta);
LOGGER.info("Using alpha={}, beta={}", params.getAlpha(), params.getBeta());
AlgorithmParameters params = AlgorithmParameters.create(alpha, beta, temperature);
LOGGER.info("Using alpha={}, beta={}, temperature={}", params.getAlpha(), params.getBeta(), params.getTemperature());

// Initialize Counter
Counter counter = new BoqaSetCounter(diseaseData, hpo);
Expand Down Expand Up @@ -172,7 +177,7 @@ public Integer call() throws Exception {
Paths.get(ontologyFile),
phenotypeAnnotationFile,
cliArgs,
Map.of("alpha", params.getAlpha(), "beta", params.getBeta()),
Map.of("alpha", params.getAlpha(), "beta", params.getBeta(), "temperature", params.getTemperature()),
outPath
);
LOGGER.info("BOQA analysis completed successfully.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,37 @@
* Alpha and beta are fixed (across all diseases and patients) for a single run.
*
* {@code ALPHA} represents the probability of a false positive, {@code BETA} that of a false negative.
* {@code TEMPERATURE} (or T) can be used to temper the posterior distribution. The posterior is effectively raised
* to the 1/T-th power, or, equivalently, this can be seen as taking the T-th root.
*/
public final class AlgorithmParameters {
private static final double DEFAULT_ALPHA = 1.0 / 19077;
private static final double DEFAULT_BETA = 0.9;
private static final double DEFAULT_TEMPERATURE = 1.0;

private final double alpha;
private final double beta;

private final double temperature;
private final double logAlpha;
private final double logBeta;
private final double logOneMinusAlpha;
private final double logOneMinusBeta;

private AlgorithmParameters(double alpha, double beta) {
private AlgorithmParameters(double alpha, double beta, double temperature) {
this.alpha = alpha;
this.beta = beta;

this.temperature = temperature;
this.logAlpha = Math.log(alpha);
this.logBeta = Math.log(beta);
this.logOneMinusAlpha = Math.log(1-alpha);
this.logOneMinusBeta = Math.log(1-beta);
}

/**
* Create parameters using the default values.
*/
public static AlgorithmParameters create() {
return create(DEFAULT_ALPHA, DEFAULT_BETA, DEFAULT_TEMPERATURE);

}

Expand All @@ -35,37 +45,44 @@ private AlgorithmParameters(double alpha, double beta) {
*
* @param alpha the alpha parameter (false positive probability), or null for default
* @param beta the beta parameter (false negative probability), or null for default
* @param temperature the inverse of the exponent for the posterior, or null for default
* @return AlgorithmParameters instance
* @throws IllegalArgumentException if alpha or beta is not in the range (0, 1)
* @throws IllegalArgumentException if alpha or beta is not in the range (0, 1) and temperature is less than 1
*/
public static AlgorithmParameters create(Double alpha, Double beta) {
public static AlgorithmParameters create(Double alpha, Double beta, Double temperature) {
double a = (alpha != null) ? alpha : DEFAULT_ALPHA;
double b = (beta != null) ? beta : DEFAULT_BETA;
double t = (temperature != null) ? temperature : DEFAULT_TEMPERATURE;
// Validate alpha
if (a <= 0.0 || a >= 1.0) {
throw new IllegalArgumentException(
String.format("Alpha must be in the range (0, 1), exclusive. Got: %f", alpha)
String.format("Alpha must be in the range (0, 1), exclusive. Got: %f", a)
);
}

// Validate beta
if (b <= 0.0 || b >= 1.0) {
throw new IllegalArgumentException(
String.format("Beta must be in the range (0, 1), exclusive. Got: %f", beta)
String.format("Beta must be in the range (0, 1), exclusive. Got: %f", b)
);
}

return new AlgorithmParameters(a, b);
// Validate temperature
if (t < 1.0) {
throw new IllegalArgumentException(
String.format("Temperature must be in the range [1, infinity). Got: %f", t)
);
}
return new AlgorithmParameters(a, b, t);
}

public double getAlpha() {
return alpha;
}

public double getBeta() {
return beta;
}

public double getTemperature() {
return temperature;
}
public double getLogAlpha() {
return logAlpha;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import org.slf4j.LoggerFactory;
import java.util.*;

import static org.junit.Assert.assertThat;

/**
* Performs BOQA analysis for a given query set of HPO terms (patient's data).
* <p>
Expand Down Expand Up @@ -74,6 +76,7 @@ public static BoqaAnalysisResult computeBoqaResultsRawLog(
public static BoqaAnalysisResult computeBoqaResults(
PatientData patientData, Counter counter, int resultsLimit, AlgorithmParameters params) {


// Get BoqaResults with raw log scores
List<BoqaResult> rawLogBoqaResults = new ArrayList<>(computeBoqaResultsRawLog(patientData, counter, params).boqaResults());

Expand All @@ -86,15 +89,25 @@ public static BoqaAnalysisResult computeBoqaResults(
.max()
.orElse(Double.NEGATIVE_INFINITY);

// Compute sum of exp(logP - maxLogP)
double sum = rawLogBoqaResults.stream()
.mapToDouble(r -> Math.exp(r.boqaScore() - maxLogP))
.sum();
double normalizationFactor;
double epsilon = 0.000001d;

// Standard BOQA
if((Math.abs(1.0 - params.getTemperature()) < epsilon)){
// Compute sum of exp(logP - maxLogP)
normalizationFactor = rawLogBoqaResults.stream()
.mapToDouble(r -> Math.exp(r.boqaScore() - maxLogP))
.sum();
} else {
// Alternative to Standard BOQA
// Non-unit temperature normalizes trivially by largest score, which is already present below
normalizationFactor = 1.0;
}

// Normalize
List<BoqaResult> allResults = new ArrayList<>();
rawLogBoqaResults.forEach(r -> {
double normProb = Math.exp(r.boqaScore() - maxLogP) / sum;
double normProb = Math.exp(r.boqaScore() - maxLogP) / normalizationFactor;
allResults.add(new BoqaResult(r.counts(), normProb));
});

Expand All @@ -104,17 +117,18 @@ public static BoqaAnalysisResult computeBoqaResults(
/**
* Computes the un-normalized BOQA log probability for given BoqaCounts and parameters:
* <p>
* log(P) = fp × log(α) + fn × log(β) + tn × log(1-α) + tp × log(1-β)
* log(P) = [fp × log(α) + fn × log(β) + tn × log(1-α) + tp × log(1-β)] / T
* </p>
* @param params alpha, beta, log(alpha), log(beta) etc.
* @param params alpha, beta, log(alpha), log(beta), termperature T etc.
* @param counts The {@link BoqaCounts} for a query and a disease.
* @return The un-normalized BOQA log probability score.
*/
static double computeUnnormalizedLogProbability(AlgorithmParameters params, BoqaCounts counts){
return counts.fpBoqaCount() * params.getLogAlpha() +
return (counts.fpBoqaCount() * params.getLogAlpha() +
counts.fnBoqaCount() * params.getLogBeta() +
counts.tnBoqaCount() * params.getLogOneMinusAlpha() +
counts.tpBoqaCount() * params.getLogOneMinusBeta();
counts.tpBoqaCount() * params.getLogOneMinusBeta()
)/params.getTemperature();
}

/**
Expand All @@ -125,13 +139,18 @@ static double computeUnnormalizedLogProbability(AlgorithmParameters params, Boqa
* </pre>
* @param alpha False positive rate parameter.
* @param beta False negative rate parameter.
* @param temperature Use to make distributions more robust.
* @param counts The {@link BoqaCounts} for a disease.
* @return The un-normalized probability score.
*/
static double computeUnnormalizedProbability(double alpha, double beta, BoqaCounts counts){
return Math.pow(alpha, counts.fpBoqaCount())*
Math.pow(beta, counts.fnBoqaCount())*
Math.pow(1-alpha, counts.tnBoqaCount())*
Math.pow(1-beta, counts.tpBoqaCount());
static double computeUnnormalizedProbability(double alpha, double beta, double temperature, BoqaCounts counts){
return Math.exp(
(
counts.fpBoqaCount()*Math.log(alpha) +
counts.fnBoqaCount()*Math.log(beta) +
counts.tnBoqaCount()*Math.log(1-alpha) +
counts.tpBoqaCount()*Math.log(1-beta)
) / temperature
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ void testComputeBoqaCountsAgainstPyboqa(
}
Path ppkt = Path.of(resourceUrl.toURI());
int limit = Integer.MAX_VALUE;
AlgorithmParameters params = AlgorithmParameters.create(0.2,0.3); // numbers don't matter
AlgorithmParameters params = AlgorithmParameters.create(0.2,0.3, 1.0); // numbers don't matter
BoqaAnalysisResult boqaAnalysisResult = computeBoqaResults(
new PhenopacketData(ppkt), counter, limit, params);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ void testComputeUnnormalizedProbability(
){
// Initialize BoqaCounts
BoqaCounts counts = new BoqaCounts("idIsUnimportant", "labelIsUnimportant", count1mb, countA, count1ma, countB);
double actualScore = computeUnnormalizedProbability(alpha, beta, counts);
double actualScore = computeUnnormalizedProbability(alpha, beta, 1.0, counts);

// Assert with small delta for floating-point comparison
assertEquals(expectedScore, actualScore, 1e-9);
Expand Down Expand Up @@ -97,23 +97,23 @@ void testComputeBoqaResults(){
int limit = counter.getDiseaseIds().size();
double alpha = 0.01;
double beta = 0.9;
AlgorithmParameters params = AlgorithmParameters.create(alpha, beta);
AlgorithmParameters params = AlgorithmParameters.create(alpha, beta, 1.0);

// Run 'computeBoqaResults'
BoqaAnalysisResult boqaAnalysisResult = BoqaPatientAnalyzer.computeBoqaResults(
patientData, counter, limit, params);

// Recompute un-normalized probabilities in the conventional way
List<Double> rawProbs = boqaAnalysisResult.boqaResults().stream()
.map(result -> computeUnnormalizedProbability(params.getAlpha(), params.getBeta(), result.counts()))
.map(result -> computeUnnormalizedProbability(params.getAlpha(), params.getBeta(), params.getTemperature(), result.counts()))
.toList();

// Get the sum for normalization
double rawProbsSum = rawProbs.stream().mapToDouble(Double::doubleValue).sum();

// Compare normalized probabilities from BoqaResults with those recalculated from counts
boqaAnalysisResult.boqaResults().forEach(br-> {
double expectedNormProb = computeUnnormalizedProbability(params.getAlpha(), params.getBeta(), br.counts()) / rawProbsSum;
double expectedNormProb = computeUnnormalizedProbability(params.getAlpha(), params.getBeta(), params.getTemperature(), br.counts()) / rawProbsSum;
double actualNormProb = br.boqaScore();
assertEquals(expectedNormProb, actualNormProb, 1e-9);
});
Expand Down