Skip to content

Commit 8521f6b

Browse files
committed
CNAM-154 Added support for multiple bucket size and lag counts
CNAM-154 Added parameter for removing including death bucket
1 parent 5ed7fb1 commit 8521f6b

File tree

7 files changed

+170
-28
lines changed

7 files changed

+170
-28
lines changed

src/main/resources/config/filtering-default.conf

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,11 @@ default = {
4141

4242
}
4343
mlpp_parameters = {
44-
bucket_size = 30 # in days
45-
lag_count = 10
44+
bucket_size = [30] # in days
45+
lag_count = [10]
4646
min_timestamp = ${default.dates.study_start}
4747
max_timestamp = ${default.dates.study_end}
48+
include_death_bucket = false
4849

4950
exposures = {
5051
min_purchases = 1

src/main/scala/fr/polytechnique/cmap/cnam/filtering/mlpp/MLPPConfig.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@ object MLPPConfig {
2222

2323
private lazy val conf: Config = FilteringConfig.modelConfig("mlpp_parameters")
2424

25-
lazy val bucketSize: Int = conf.getInt("bucket_size")
26-
lazy val lagCount: Int = conf.getInt("lag_count")
25+
lazy val bucketSizes: List[Int] = conf.getIntList("bucket_size").asScala.toList.map(_.toInt)
26+
lazy val lagCounts: List[Int] = conf.getIntList("lag_count").asScala.toList.map(_.toInt)
2727
lazy val minTimestamp: Timestamp = makeTS(conf.getIntList("min_timestamp").asScala.toList)
2828
lazy val maxTimestamp: Timestamp = makeTS(conf.getIntList("max_timestamp").asScala.toList)
29+
lazy val includeDeathBucket: Boolean = conf.getBoolean("include_death_bucket")
2930

3031
lazy val exposureDefinition = MLPPExposureDefinition(
3132
minPurchases = conf.getInt("exposures.min_purchases"),

src/main/scala/fr/polytechnique/cmap/cnam/filtering/mlpp/MLPPMain.scala

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ object MLPPMain extends Main {
2828
val patients: Dataset[Patient] = flatEvents.map(
2929
e => Patient(e.patientID, e.gender, e.birthDate, e.deathDate)
3030
).distinct
31+
// todo: test if filter_lost_patients is true
3132
val tracklossEvents: Dataset[Event] = TrackLossTransformer.transform(
3233
Sources(dcir=Some(dcirFlat))
3334
)
@@ -41,15 +42,21 @@ object MLPPMain extends Main {
4142

4243
val exposures: Dataset[FlatEvent] = MLPPExposuresTransformer.transform(allEvents)
4344

44-
val mlppParams = MLPPWriter.Params(
45-
bucketSize = MLPPConfig.bucketSize,
46-
lagCount = MLPPConfig.lagCount,
47-
minTimestamp = MLPPConfig.minTimestamp,
48-
maxTimestamp = MLPPConfig.maxTimestamp
49-
)
50-
val mlppWriter = MLPPWriter(mlppParams)
51-
val result = MLPPWriter(mlppParams).write(diseaseEvents.union(exposures), outputPath)
52-
53-
Some(result)
45+
val results: List[Dataset[MLPPFeature]] = for {
46+
bucketSize <- MLPPConfig.bucketSizes
47+
lagCount <- MLPPConfig.lagCounts
48+
} yield {
49+
val mlppParams = MLPPWriter.Params(
50+
bucketSize = bucketSize,
51+
lagCount = lagCount,
52+
minTimestamp = MLPPConfig.minTimestamp,
53+
maxTimestamp = MLPPConfig.maxTimestamp,
54+
includeDeathBucket = MLPPConfig.includeDeathBucket
55+
)
56+
val mlppWriter = MLPPWriter(mlppParams)
57+
val path = s"$outputPath/${bucketSize}B-${lagCount}L"
58+
MLPPWriter(mlppParams).write(diseaseEvents.union(exposures), path)
59+
}
60+
Some(results.head)
5461
}
5562
}

src/main/scala/fr/polytechnique/cmap/cnam/filtering/mlpp/MLPPWriter.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ object MLPPWriter {
1717
bucketSize: Int = 30,
1818
lagCount: Int = 10,
1919
minTimestamp: Timestamp = makeTS(2006, 1, 1),
20-
maxTimestamp: Timestamp = makeTS(2009, 12, 31, 23, 59, 59)
20+
maxTimestamp: Timestamp = makeTS(2009, 12, 31, 23, 59, 59),
21+
includeDeathBucket: Boolean = false
2122
)
2223

2324
def apply(params: Params = Params()) = new MLPPWriter(params)
@@ -76,9 +77,9 @@ class MLPPWriter(params: MLPPWriter.Params = MLPPWriter.Params()) {
7677
// We are no longer using trackloss and disease information for calculating the end bucket.
7778
def withEndBucket: DataFrame = {
7879

79-
val endBucket: Column = minColumn(
80-
col("deathBucket"), lit(bucketCount)
81-
)
80+
val deathBucketRule = if (params.includeDeathBucket) col("deathBucket") + 1 else col("deathBucket")
81+
82+
val endBucket: Column = minColumn(deathBucketRule, lit(bucketCount))
8283
data.withColumn("endBucket", endBucket)
8384
}
8485

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# This is needed because otherwise in the current dummy data all patients would be filtered
22
mlpp_parameters.exposures.filter_diagnosed_patients = false
3-
mlpp_parameters.bucket_size = 20 # days
3+
mlpp_parameters.bucket_size = [20] # days

src/test/resources/config/mlpp-new-exposure.conf

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# This is needed because otherwise in the current dummy data all patients would be filtered
22
mlpp_parameters.exposures.filter_diagnosed_patients = false
3-
mlpp_parameters.bucket_size = 20 # days
3+
mlpp_parameters.bucket_size = [20] # days
44

55
# Changing exposure definition to a "cox-like" one.
66
mlpp_parameters.exposures = {

src/test/scala/fr/polytechnique/cmap/cnam/filtering/mlpp/MLPPWriterSuite.scala

Lines changed: 140 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ class MLPPWriterSuite extends SharedContext {
193193
assert(result === expected)
194194
}
195195

196-
"withEndBucket" should "add a column with the minimum among deathBucket, diseaseBucket and the max number of buckets" in {
196+
"withEndBucket" should "add a column with the minimum among deathBucket and the max number of buckets" in {
197197
val sqlCtx = sqlContext
198198
import sqlCtx.implicits._
199199

@@ -244,6 +244,53 @@ class MLPPWriterSuite extends SharedContext {
244244
assert(result === expected)
245245
}
246246

247+
it should "add a column with the minimum among deathBucket + 1, and the max number of buckets if " +
248+
"includeDeathBucket is true" in {
249+
val sqlCtx = sqlContext
250+
import sqlCtx.implicits._
251+
252+
// Given
253+
val params = MLPPWriter.Params(
254+
minTimestamp = makeTS(2006, 1, 1),
255+
maxTimestamp = makeTS(2006, 2, 2),
256+
bucketSize = 2,
257+
includeDeathBucket = true
258+
)
259+
260+
val input = Seq(
261+
("PA", Some(16)),
262+
("PA", Some(16)),
263+
("PB", Some( 0)),
264+
("PB", Some( 0)),
265+
("PC", Some( 5)),
266+
("PC", Some( 5)),
267+
("PD", None),
268+
("PD", None)
269+
).toDF("patientID", "deathBucket")
270+
271+
val expected = Seq(
272+
("PA", Some(16)),
273+
("PA", Some(16)),
274+
("PB", Some( 1)),
275+
("PB", Some( 1)),
276+
("PC", Some( 6)),
277+
("PC", Some( 6)),
278+
("PD", Some(16)),
279+
("PD", Some(16))
280+
).toDF("patientID", "endBucket")
281+
282+
// When
283+
val writer = MLPPWriter(params)
284+
import writer.MLPPDataFrame
285+
val result = input.withEndBucket.select("patientID", "endBucket")
286+
287+
// Then
288+
import RichDataFrames._
289+
result.show
290+
expected.show
291+
assert(result === expected)
292+
}
293+
247294
"makeDiscreteExposures" should "return a Dataset containing the 0-lag exposures in the sparse format" in {
248295
val sqlCtx = sqlContext
249296
import sqlCtx.implicits._
@@ -592,9 +639,9 @@ class MLPPWriterSuite extends SharedContext {
592639
)
593640
val input: Dataset[FlatEvent] = Seq(
594641
FlatEvent("PC", 2, makeTS(1970, 1, 1), None, "exposure", "Mol1", 1.0, makeTS(2006, 5, 15), None),
595-
FlatEvent("PB", 1, makeTS(1950, 1, 1), Some(makeTS(2006, 6, 15)), "exposure", "Mol1", 1.0, makeTS(2006, 1, 15), None),
596-
FlatEvent("PB", 1, makeTS(1950, 1, 1), Some(makeTS(2006, 6, 15)), "exposure", "Mol2", 1.0, makeTS(2006, 3, 15), None),
597-
FlatEvent("PB", 1, makeTS(1950, 1, 1), Some(makeTS(2006, 6, 15)), "exposure", "Mol2", 1.0, makeTS(2006, 5, 15), None),
642+
FlatEvent("PB", 1, makeTS(1950, 1, 1), Some(makeTS(2006, 4, 15)), "exposure", "Mol1", 1.0, makeTS(2006, 1, 15), None),
643+
FlatEvent("PB", 1, makeTS(1950, 1, 1), Some(makeTS(2006, 4, 15)), "exposure", "Mol1", 1.0, makeTS(2006, 3, 15), None),
644+
FlatEvent("PB", 1, makeTS(1950, 1, 1), Some(makeTS(2006, 4, 15)), "disease", "targetDisease", 1.0, makeTS(2006, 3, 15), None),
598645
FlatEvent("PA", 1, makeTS(1960, 1, 1), None, "exposure", "Mol1", 1.0, makeTS(2006, 1, 15), None),
599646
FlatEvent("PA", 1, makeTS(1960, 1, 1), None, "exposure", "Mol1", 1.0, makeTS(2006, 3, 15), None),
600647
FlatEvent("PA", 1, makeTS(1960, 1, 1), None, "exposure", "Mol1", 1.0, makeTS(2006, 4, 15), None),
@@ -624,12 +671,17 @@ class MLPPWriterSuite extends SharedContext {
624671
MLPPFeature("PA", 0, "Mol3", 2, 3, 0, 3, 8, 1.0),
625672
MLPPFeature("PA", 0, "Mol3", 2, 4, 1, 4, 9, 1.0),
626673
MLPPFeature("PA", 0, "Mol3", 2, 5, 2, 5, 10, 1.0),
627-
MLPPFeature("PA", 0, "Mol3", 2, 6, 3, 6, 11, 1.0)
674+
MLPPFeature("PA", 0, "Mol3", 2, 6, 3, 6, 11, 1.0),
675+
// Patient B
676+
MLPPFeature("PB", 1, "Mol1", 0, 0, 0, 7, 0, 1.0),
677+
MLPPFeature("PB", 1, "Mol1", 0, 1, 1, 8, 1, 1.0),
678+
MLPPFeature("PB", 1, "Mol1", 0, 2, 2, 9, 2, 1.0),
679+
MLPPFeature("PB", 1, "Mol1", 0, 2, 0, 9, 0, 1.0)
628680
).toDF
629681

630682
val expectedZMatrix = Seq(
631683
(3D, 1D, 1D, 46, 1, "PA", 0),
632-
(1D, 2D, 0D, 56, 1, "PB", 1),
684+
(2D, 0D, 0D, 56, 1, "PB", 1),
633685
(1D, 0D, 0D, 36, 2, "PC", 2)
634686
).toDF("MOL0000_Mol1", "MOL0001_Mol2", "MOL0002_Mol3", "age", "gender", "patientID", "patientIDIndex")
635687

@@ -640,8 +692,88 @@ class MLPPWriterSuite extends SharedContext {
640692

641693
// Then
642694
import RichDataFrames._
643-
result.show
644-
expectedFeatures.show
695+
result.show(100)
696+
expectedFeatures.show(100)
697+
StaticExposures.show
698+
expectedZMatrix.show
699+
assert(result === expectedFeatures)
700+
assert(writtenResult === expectedFeatures)
701+
assert(StaticExposures === expectedZMatrix)
702+
}
703+
704+
705+
it should "create the final matrices and write them as parquet files (removing death bucket)" in {
706+
val sqlCtx = sqlContext
707+
import sqlCtx.implicits._
708+
709+
// Given
710+
val rootDir = "target/test/output"
711+
val params = MLPPWriter.Params(
712+
minTimestamp = makeTS(2006, 1, 1),
713+
maxTimestamp = makeTS(2006, 8, 1), // 7 total buckets
714+
bucketSize = 30,
715+
lagCount = 4,
716+
includeDeathBucket = true
717+
)
718+
val input: Dataset[FlatEvent] = Seq(
719+
FlatEvent("PC", 2, makeTS(1970, 1, 1), None, "exposure", "Mol1", 1.0, makeTS(2006, 5, 15), None),
720+
FlatEvent("PB", 1, makeTS(1950, 1, 1), Some(makeTS(2006, 4, 15)), "exposure", "Mol1", 1.0, makeTS(2006, 1, 15), None),
721+
FlatEvent("PB", 1, makeTS(1950, 1, 1), Some(makeTS(2006, 4, 15)), "exposure", "Mol1", 1.0, makeTS(2006, 3, 15), None),
722+
FlatEvent("PB", 1, makeTS(1950, 1, 1), Some(makeTS(2006, 4, 15)), "disease", "targetDisease", 1.0, makeTS(2006, 3, 15), None),
723+
FlatEvent("PA", 1, makeTS(1960, 1, 1), None, "exposure", "Mol1", 1.0, makeTS(2006, 1, 15), None),
724+
FlatEvent("PA", 1, makeTS(1960, 1, 1), None, "exposure", "Mol1", 1.0, makeTS(2006, 3, 15), None),
725+
FlatEvent("PA", 1, makeTS(1960, 1, 1), None, "exposure", "Mol1", 1.0, makeTS(2006, 4, 15), None),
726+
FlatEvent("PA", 1, makeTS(1960, 1, 1), None, "exposure", "Mol2", 1.0, makeTS(2006, 3, 15), None),
727+
FlatEvent("PA", 1, makeTS(1960, 1, 1), None, "exposure", "Mol3", 1.0, makeTS(2006, 4, 15), None),
728+
FlatEvent("PA", 1, makeTS(1960, 1, 1), None, "disease", "targetDisease", 1.0, makeTS(2006, 5, 15), None)
729+
).toDS
730+
731+
val expectedFeatures = Seq(
732+
// Patient A
733+
MLPPFeature("PA", 0, "Mol1", 0, 0, 0, 0, 0, 1.0),
734+
MLPPFeature("PA", 0, "Mol1", 0, 1, 1, 1, 1, 1.0),
735+
MLPPFeature("PA", 0, "Mol1", 0, 2, 2, 2, 2, 1.0),
736+
MLPPFeature("PA", 0, "Mol1", 0, 3, 3, 3, 3, 1.0),
737+
MLPPFeature("PA", 0, "Mol1", 0, 2, 0, 2, 0, 1.0),
738+
MLPPFeature("PA", 0, "Mol1", 0, 3, 1, 3, 1, 1.0),
739+
MLPPFeature("PA", 0, "Mol1", 0, 4, 2, 4, 2, 1.0),
740+
MLPPFeature("PA", 0, "Mol1", 0, 5, 3, 5, 3, 1.0),
741+
MLPPFeature("PA", 0, "Mol1", 0, 3, 0, 3, 0, 1.0),
742+
MLPPFeature("PA", 0, "Mol1", 0, 4, 1, 4, 1, 1.0),
743+
MLPPFeature("PA", 0, "Mol1", 0, 5, 2, 5, 2, 1.0),
744+
MLPPFeature("PA", 0, "Mol1", 0, 6, 3, 6, 3, 1.0),
745+
MLPPFeature("PA", 0, "Mol2", 1, 2, 0, 2, 4, 1.0),
746+
MLPPFeature("PA", 0, "Mol2", 1, 3, 1, 3, 5, 1.0),
747+
MLPPFeature("PA", 0, "Mol2", 1, 4, 2, 4, 6, 1.0),
748+
MLPPFeature("PA", 0, "Mol2", 1, 5, 3, 5, 7, 1.0),
749+
MLPPFeature("PA", 0, "Mol3", 2, 3, 0, 3, 8, 1.0),
750+
MLPPFeature("PA", 0, "Mol3", 2, 4, 1, 4, 9, 1.0),
751+
MLPPFeature("PA", 0, "Mol3", 2, 5, 2, 5, 10, 1.0),
752+
MLPPFeature("PA", 0, "Mol3", 2, 6, 3, 6, 11, 1.0),
753+
// Patient A,
754+
MLPPFeature("PB", 1, "Mol1", 0, 0, 0, 7, 0, 1.0),
755+
MLPPFeature("PB", 1, "Mol1", 0, 1, 1, 8, 1, 1.0),
756+
MLPPFeature("PB", 1, "Mol1", 0, 2, 2, 9, 2, 1.0),
757+
MLPPFeature("PB", 1, "Mol1", 0, 3, 3, 10, 3, 1.0),
758+
MLPPFeature("PB", 1, "Mol1", 0, 2, 0, 9, 0, 1.0),
759+
MLPPFeature("PB", 1, "Mol1", 0, 3, 1, 10, 1, 1.0)
760+
).toDF
761+
762+
val expectedZMatrix = Seq(
763+
(3D, 1D, 1D, 46, 1, "PA", 0),
764+
(2D, 0D, 0D, 56, 1, "PB", 1),
765+
(1D, 0D, 0D, 36, 2, "PC", 2)
766+
).toDF("MOL0000_Mol1", "MOL0001_Mol2", "MOL0002_Mol3", "age", "gender", "patientID", "patientIDIndex")
767+
768+
// When
769+
val result = MLPPWriter(params).write(input, rootDir).toDF
770+
val writtenResult = sqlContext.read.parquet(s"$rootDir/parquet/SparseFeatures")
771+
val StaticExposures = sqlContext.read.parquet(s"$rootDir/parquet/StaticExposures")
772+
773+
// Then
774+
import RichDataFrames._
775+
result.show(100)
776+
expectedFeatures.show(100)
645777
StaticExposures.show
646778
expectedZMatrix.show
647779
assert(result === expectedFeatures)

0 commit comments

Comments
 (0)