Skip to content

Commit 014b831

Browse files
Add fisher_exact_test pvalue_only
1 parent 7118484 commit 014b831

File tree

4 files changed

+153
-56
lines changed

4 files changed

+153
-56
lines changed

hail/hail/src/is/hail/expr/ir/functions/MathFunctions.scala

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -425,16 +425,15 @@ object MathFunctions extends RegistryFunctions {
425425
fetStruct.virtualType,
426426
(_, _, _, _, _) => fetStruct.sType,
427427
) { case (r, cb, _, a: SInt32Value, b: SInt32Value, c: SInt32Value, d: SInt32Value, _) =>
428-
val res = cb.newLocal[Array[Double]](
429-
"fisher_exact_test_res",
428+
val res = cb.memoize[Array[Double]](
430429
Code.invokeScalaObject4[Int, Int, Int, Int, Array[Double]](
431430
statsPackageClass,
432431
"fisherExactTest",
433432
a.value,
434433
b.value,
435434
c.value,
436435
d.value,
437-
),
436+
)
438437
)
439438

440439
fetStruct.constructFromFields(
@@ -450,6 +449,29 @@ object MathFunctions extends RegistryFunctions {
450449
)
451450
}
452451

452+
// FIXME: delete when PruneDeadField can optimize fisher_exact_test when only
453+
// the pvalue is used from the result struct
454+
registerSCode4(
455+
"fisher_exact_test_pvalue_only",
456+
TInt32,
457+
TInt32,
458+
TInt32,
459+
TInt32,
460+
TFloat64,
461+
(_, _, _, _, _) => SFloat64,
462+
) { case (_, cb, _, a: SInt32Value, b: SInt32Value, c: SInt32Value, d: SInt32Value, _) =>
463+
primitive(cb.memoize[Double](
464+
Code.invokeScalaObject4[Int, Int, Int, Int, Double](
465+
statsPackageClass,
466+
"fisherExactTestPValueOnly",
467+
a.value,
468+
b.value,
469+
c.value,
470+
d.value,
471+
)
472+
))
473+
}
474+
453475
registerSCode4(
454476
"chi_squared_test",
455477
TInt32,

hail/hail/src/is/hail/stats/package.scala

Lines changed: 68 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package is.hail
33
import is.hail.types.physical.{PCanonicalStruct, PFloat64}
44
import is.hail.utils._
55

6+
import scala.annotation.tailrec
7+
68
import net.sourceforge.jdistlib.{Beta, ChiSquare, Gamma, NonCentralChiSquare, Normal, Poisson}
79
import net.sourceforge.jdistlib.disttest.{DistributionTest, TestKind}
810
import org.apache.commons.math3.distribution.HypergeometricDistribution
@@ -162,45 +164,28 @@ package object stats {
162164
)
163165

164166
def fisherExactTest(a: Int, b: Int, c: Int, d: Int): Array[Double] =
165-
fisherExactTest(a, b, c, d, 1.0, 0.95, "two.sided")
166-
167-
def fisherExactTest(
168-
a: Int,
169-
b: Int,
170-
c: Int,
171-
d: Int,
172-
oddsRatio: Double = 1d,
173-
confidenceLevel: Double = 0.95,
174-
alternative: String = "two.sided",
175-
): Array[Double] = {
167+
fisherExactTest(a, b, c, d, 0.95)
176168

169+
def fisherExactTest(a: Int, b: Int, c: Int, d: Int, confidenceLevel: Double): Array[Double] = {
177170
if (!(a >= 0 && b >= 0 && c >= 0 && d >= 0))
178171
fatal(s"fisher_exact_test: all arguments must be non-negative, got $a, $b, $c, $d")
179172

180173
if (confidenceLevel < 0d || confidenceLevel > 1d)
181174
fatal("Confidence level must be between 0 and 1")
182175

183-
if (oddsRatio < 0d)
184-
fatal("Odds ratio must be non-negative")
185-
186-
if (alternative != "greater" && alternative != "less" && alternative != "two.sided")
187-
fatal("Did not recognize test type string. Use one of greater, less, two.sided")
188-
189176
val popSize = a + b + c + d
190-
val numSuccessPopulation = a + c
191-
val sampleSize = a + b
177+
val nGood = a + c
178+
val nSample = a + b
192179
val numSuccessSample = a
193180

194-
if (
195-
!(popSize > 0 && sampleSize > 0 && sampleSize < popSize && numSuccessPopulation > 0 && numSuccessPopulation < popSize)
196-
)
181+
if (!(popSize > 0 && nSample > 0 && nSample < popSize && nGood > 0 && nGood < popSize))
197182
return Array(Double.NaN, Double.NaN, Double.NaN, Double.NaN)
198183

199184
val low = math.max(0, (a + b) - (b + d))
200185
val high = math.min(a + b, a + c)
201186
val support = (low to high).toArray
202187

203-
val hgd = new HypergeometricDistribution(null, popSize, numSuccessPopulation, sampleSize)
188+
val hgd = new HypergeometricDistribution(null, popSize, nGood, nSample)
204189
val epsilon = 2.220446e-16
205190

206191
def dhyper(k: Int, logProb: Boolean): Double =
@@ -320,36 +305,74 @@ package object stats {
320305
}
321306
}
322307

323-
val pvalue: Double = (alternative: @unchecked) match {
324-
case "less" => pnhyper(numSuccessSample, oddsRatio)
325-
case "greater" => pnhyper(numSuccessSample, oddsRatio, upper_tail = true)
326-
case "two.sided" =>
327-
if (oddsRatio == 0)
328-
if (low == numSuccessSample) 1d else 0d
329-
else if (oddsRatio == Double.PositiveInfinity)
330-
if (high == numSuccessSample) 1d else 0d
331-
else {
332-
val relErr = 1d + 1e-7
333-
val d = dnhyper(oddsRatio)
334-
d.filter(_ <= d(numSuccessSample - low) * relErr).sum
335-
}
336-
}
337-
338-
assert(pvalue >= 0d && pvalue <= 1.000000000002)
308+
val pvalue = fisherExactTestPValueOnly(a, b, c, d)
339309

340310
val oddsRatioEstimate = mle(numSuccessSample.toDouble)
341311

342-
val confInterval = alternative match {
343-
case "less" => (0d, ncpUpper(numSuccessSample, 1 - confidenceLevel))
344-
case "greater" => (ncpLower(numSuccessSample, 1 - confidenceLevel), Double.PositiveInfinity)
345-
case "two.sided" =>
346-
val alpha = (1 - confidenceLevel) / 2d
347-
(ncpLower(numSuccessSample, alpha), ncpUpper(numSuccessSample, alpha))
312+
val confInterval = {
313+
val alpha = (1 - confidenceLevel) / 2d
314+
(ncpLower(numSuccessSample, alpha), ncpUpper(numSuccessSample, alpha))
348315
}
349316

350317
Array(pvalue, oddsRatioEstimate, confInterval._1, confInterval._2)
351318
}
352319

320+
def fisherExactTestPValueOnly(a: Int, b: Int, c: Int, d: Int): Double = {
321+
val popSize = a + b + c + d
322+
val nGood = a + c
323+
val nSample = a + b
324+
val numSuccessSample = a
325+
326+
if (!(a >= 0 && b >= 0 && c >= 0 && d >= 0))
327+
fatal(s"fisher_exact_test: all arguments must be non-negative, got $a, $b, $c, $d")
328+
329+
if (!(popSize > 0 && nSample > 0 && nSample < popSize && nGood > 0 && nGood < popSize))
330+
return Double.NaN
331+
332+
val hgd = new HypergeometricDistribution(null, popSize, nGood, nSample)
333+
334+
// Returns i in [start, end] such that a([start, i)) is <= d, and a([i, end)) is > d
335+
@tailrec def upperBoundIncreasing(a: Int => Double, d: Double, start: Int, end: Int): Int = {
336+
if (start >= end) return start
337+
val mid = (start + end) >>> 1
338+
val elt = a(mid)
339+
if (elt <= d) upperBoundIncreasing(a, d, mid + 1, end)
340+
else upperBoundIncreasing(a, d, start, mid)
341+
}
342+
343+
// Returns i in [start, end] such that a([start, i)) is > d, and a([i, end)) is <= d
344+
@tailrec def lowerBoundDecreasing(a: Int => Double, d: Double, start: Int, end: Int): Int = {
345+
if (start >= end) return start
346+
val mid = (start + end) >>> 1
347+
val elt = a(mid)
348+
if (elt > d) lowerBoundDecreasing(a, d, mid + 1, end)
349+
else lowerBoundDecreasing(a, d, start, mid)
350+
}
351+
352+
val epsilon = 1e-14
353+
val gamma = 1 + epsilon
354+
355+
val mode = ((nSample + 1.0) * (nGood + 1.0) / (popSize + 2.0)).toInt
356+
val pexact = hgd.probability(numSuccessSample)
357+
val pmode = hgd.probability(mode)
358+
359+
val pvalue = if (math.abs(pexact - pmode) / math.max(pexact, pmode) <= epsilon) {
360+
1.0
361+
} else if (numSuccessSample < mode) {
362+
val plower = hgd.cumulativeProbability(numSuccessSample)
363+
val bound = lowerBoundDecreasing(hgd.probability, pexact * gamma, mode + 1, nSample + 1)
364+
plower + hgd.upperCumulativeProbability(bound)
365+
} else {
366+
val pupper = hgd.upperCumulativeProbability(numSuccessSample)
367+
val bound = upperBoundIncreasing(hgd.probability, pexact * gamma, 0, mode)
368+
pupper + hgd.cumulativeProbability(bound - 1)
369+
}
370+
371+
assert(pvalue >= 0d && pvalue <= 1.000000000002)
372+
373+
pvalue
374+
}
375+
353376
def dnorm(x: Double, mu: Double, sigma: Double, logP: Boolean): Double =
354377
Normal.density(x, mu, sigma, logP)
355378

hail/hail/test/src/is/hail/stats/FisherExactTestSuite.scala

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package is.hail.stats
22

33
import is.hail.HailSuite
4+
import is.hail.utils.D_==
45

56
import org.testng.annotations.Test
67

@@ -14,9 +15,57 @@ class FisherExactTestSuite extends HailSuite {
1415

1516
val result = fisherExactTest(a, b, c, d)
1617

17-
assert(math.abs(result(0) - 0.2828) < 1e-4)
18-
assert(math.abs(result(1) - 0.4754059) < 1e-4)
19-
assert(math.abs(result(2) - 0.122593) < 1e-4)
20-
assert(math.abs(result(3) - 1.597972) < 1e-4)
18+
assert(D_==(result(0), 0.2828, 1e-3))
19+
assert(D_==(result(1), 0.4754059, 1e-4))
20+
assert(D_==(result(2), 0.122593, 1e-4))
21+
assert(D_==(result(3), 1.597972, 1e-4))
22+
}
23+
24+
@Test def testPvalue2(): Unit = {
25+
val a = 10
26+
val b = 5
27+
val c = 90
28+
val d = 95
29+
30+
val result = fisherExactTest(a, b, c, d)
31+
32+
assert(D_==(result(0), 0.2828, 1e-3))
33+
}
34+
35+
@Test def test_basic(): Unit = {
36+
// test cases taken from scipy/stats/tests/test_stats.py
37+
var res = fisherExactTestPValueOnly(14500, 20000, 30000, 40000)
38+
assert(D_==(res, 0.01106, 1e-3))
39+
res = fisherExactTestPValueOnly(100, 2, 1000, 5)
40+
assert(D_==(res, 0.1301, 1e-3))
41+
res = fisherExactTestPValueOnly(2, 7, 8, 2)
42+
assert(D_==(res, 0.0230141, 1e-5))
43+
res = fisherExactTestPValueOnly(5, 1, 10, 10)
44+
assert(D_==(res, 0.1973244, 1e-6))
45+
res = fisherExactTestPValueOnly(5, 15, 20, 20)
46+
assert(D_==(res, 0.0958044, 1e-6))
47+
res = fisherExactTestPValueOnly(5, 16, 20, 25)
48+
assert(D_==(res, 0.1725862, 1e-5))
49+
res = fisherExactTestPValueOnly(10, 5, 10, 1)
50+
assert(D_==(res, 0.1973244, 1e-6))
51+
res = fisherExactTestPValueOnly(5, 0, 1, 4)
52+
assert(D_==(res, 0.04761904, 1e-6))
53+
res = fisherExactTestPValueOnly(0, 1, 3, 2)
54+
assert(res == 1.0)
55+
res = fisherExactTestPValueOnly(0, 2, 6, 4)
56+
assert(D_==(res, 0.4545454545))
57+
res = fisherExactTestPValueOnly(2, 7, 8, 2)
58+
assert(D_==(res, 0.0230141, 1e-5))
59+
60+
res = fisherExactTestPValueOnly(6, 37, 108, 200)
61+
assert(D_==(res, 0.005092697748126))
62+
res = fisherExactTestPValueOnly(22, 0, 0, 102)
63+
assert(D_==(res, 7.175066786244549e-25))
64+
res = fisherExactTestPValueOnly(94, 48, 3577, 16988)
65+
assert(D_==(res, 2.069356340993818e-37))
66+
res = fisherExactTestPValueOnly(5829225, 5692693, 5760959, 5760959)
67+
assert(res <= 1e-170)
68+
for ((a, b, c, d) <- Array((0, 0, 5, 10), (5, 10, 0, 0), (0, 5, 0, 10), (5, 0, 10, 0)))
69+
assert(fisherExactTestPValueOnly(a, b, c, d).isNaN)
2170
}
2271
}

hail/python/hail/expr/functions.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,8 +1133,8 @@ def exp(x) -> Float64Expression:
11331133
return _func("exp", tfloat64, x)
11341134

11351135

1136-
@typecheck(c1=expr_int32, c2=expr_int32, c3=expr_int32, c4=expr_int32)
1137-
def fisher_exact_test(c1, c2, c3, c4) -> StructExpression:
1136+
@typecheck(c1=expr_int32, c2=expr_int32, c3=expr_int32, c4=expr_int32, _pvalue_only=bool)
1137+
def fisher_exact_test(c1, c2, c3, c4, *, _pvalue_only=False) -> StructExpression:
11381138
"""Calculates the p-value, odds ratio, and 95% confidence interval using
11391139
Fisher's exact test for a 2x2 table.
11401140
@@ -1176,8 +1176,11 @@ def fisher_exact_test(c1, c2, c3, c4) -> StructExpression:
11761176
`ci_95_lower (:py:data:`.tfloat64`), and `ci_95_upper`
11771177
(:py:data:`.tfloat64`).
11781178
"""
1179-
ret_type = tstruct(p_value=tfloat64, odds_ratio=tfloat64, ci_95_lower=tfloat64, ci_95_upper=tfloat64)
1180-
return _func("fisher_exact_test", ret_type, c1, c2, c3, c4)
1179+
if _pvalue_only:
1180+
return struct(p_value=_func("fisher_exact_test_pvalue_only", tfloat64, c1, c2, c3, c4))
1181+
else:
1182+
ret_type = tstruct(p_value=tfloat64, odds_ratio=tfloat64, ci_95_lower=tfloat64, ci_95_upper=tfloat64)
1183+
return _func("fisher_exact_test", ret_type, c1, c2, c3, c4)
11811184

11821185

11831186
@typecheck(x=expr_oneof(expr_float32, expr_float64, expr_ndarray(expr_float64)))

0 commit comments

Comments
 (0)