Skip to content

Commit b7469a1

Browse files
committed
CNAM-143 Fixes for running at CNAM
1 parent a2b15be commit b7469a1

File tree

5 files changed

+121
-26
lines changed

5 files changed

+121
-26
lines changed
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package fr.polytechnique.cmap.cnam.filtering.mlpp
2+
3+
import org.apache.spark.sql.SQLContext
4+
import org.apache.spark.sql.functions._
5+
import com.typesafe.config.{Config, ConfigFactory}
6+
import fr.polytechnique.cmap.cnam.Main
7+
import fr.polytechnique.cmap.cnam.filtering.FlatEvent
8+
9+
// Used to run the code @CNAM on 09/11/2016 (Donald Trump's election day)
10+
object MLPPProvisoryMain extends Main {
11+
12+
override def appName: String = "MLPPMain"
13+
14+
def runMLPPFeaturing(sqlContext: SQLContext, config: Config) = {
15+
import sqlContext.implicits._
16+
17+
Seq("broad", "narrow").foreach { i =>
18+
val coxPatients = sqlContext.read.parquet(s"/shared/burq/filtered_data/$i/cox").select("patientID").distinct
19+
20+
val flatEventsDF = sqlContext.read.parquet(s"/shared/burq/filtered_data/$i/events")
21+
.where(col("category").isin("trackloss", "disease", "molecule"))
22+
.join(coxPatients, "patientID")
23+
.withColumn("category", when(col("category") === "molecule", lit("exposure")).otherwise(col("category")))
24+
25+
val flatEvents = flatEventsDF.as[FlatEvent].persist
26+
27+
MLPPWriter().write(flatEvents, s"/shared/mlpp_features/$i/")
28+
}
29+
}
30+
31+
override def main(args: Array[String]): Unit = {
32+
startContext()
33+
val environment = if (args.nonEmpty) args(0) else "test"
34+
val config: Config = ConfigFactory.parseResources("filtering.conf").getConfig(environment)
35+
runMLPPFeaturing(sqlContext, config)
36+
stopContext()
37+
}
38+
}

src/main/scala/fr/polytechnique/cmap/cnam/filtering/mlpp/MLPPWriter.scala

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,23 @@ class MLPPWriter(params: MLPPWriter.Params = MLPPWriter.Params()) {
5050
)
5151
}
5252

53+
def withTracklossBucket: DataFrame = {
54+
val window = Window.partitionBy("patientId")
55+
56+
val hadTrackloss: Column = (col("category") === "trackloss") &&
57+
(col("startBucket") < minColumn(col("deathBucket"), lit(bucketCount)))
58+
59+
val tracklossBucket: Column = min(when(hadTrackloss, col("startBucket"))).over(window)
60+
61+
data.withColumn("tracklossBucket", tracklossBucket)
62+
}
63+
5364
def withDiseaseBucket: DataFrame = {
5465
val window = Window.partitionBy("patientId")
5566

5667
val hadDisease: Column = (col("category") === "disease") &&
57-
(col("eventId") === "targetDisease") &&
58-
(col("startBucket") < minColumn(col("deathBucket"), lit(bucketCount)))
68+
(col("eventId") === "targetDisease") &&
69+
(col("startBucket") < minColumn(col("tracklossBucket"), col("deathBucket"), lit(bucketCount)))
5970

6071
val diseaseBucket: Column = min(when(hadDisease, col("startBucket"))).over(window)
6172

@@ -65,7 +76,7 @@ class MLPPWriter(params: MLPPWriter.Params = MLPPWriter.Params()) {
6576
def withEndBucket: DataFrame = {
6677

6778
val endBucket: Column = minColumn(
68-
col("diseaseBucket"), col("deathBucket"), lit(bucketCount)
79+
col("tracklossBucket"), col("diseaseBucket"), col("deathBucket"), lit(bucketCount)
6980
)
7081
data.withColumn("endBucket", endBucket)
7182
}
@@ -257,6 +268,7 @@ class MLPPWriter(params: MLPPWriter.Params = MLPPWriter.Params()) {
257268
.withAge(AgeReferenceDate)
258269
.withStartBucket
259270
.withDeathBucket
271+
.withTracklossBucket
260272
.withDiseaseBucket
261273
.withEndBucket
262274
.where(col("category") === "exposure")

src/main/scala/fr/polytechnique/cmap/cnam/utilities/ColumnUtilities.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ object ColumnUtilities {
4141
val lastBucket = if (bucketCount > 0) bucketCount - 1 else 0
4242

4343
val bucketId: Column = floor(datediff(column, lit(minTimestamp)) / lengthDays).cast(IntegerType)
44-
when(bucketId <= lastBucket || bucketId.isNull, bucketId)
45-
.otherwise(lastBucket)
44+
when(bucketId.isNull || bucketId.between(0, lastBucket), bucketId)
45+
//.otherwise(lastBucket)
4646
}
4747
}
4848
}

src/test/scala/fr/polytechnique/cmap/cnam/filtering/mlpp/MLPPWriterSuite.scala

Lines changed: 65 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -111,34 +111,75 @@ class MLPPWriterSuite extends SharedContext {
111111
assert(result === expected)
112112
}
113113

114-
"withDiseaseBucket" should "add a column with the timeBucket of the first targetDisease of each patient" in {
114+
"withTracklossBucket" should "add a column with the timeBucket of the first trackloss of each patient" in {
115115
val sqlCtx = sqlContext
116116
import sqlCtx.implicits._
117117

118118
// Given
119119
val input = Seq(
120120
("PA", "molecule", "PIOGLITAZONE", 0, Some(4)),
121121
("PA", "molecule", "PIOGLITAZONE", 5, Some(4)),
122-
("PA", "disease", "targetDisease", 3, Some(4)),
122+
("PA", "trackloss", "trackloss", 3, Some(4)),
123123
("PB", "molecule", "PIOGLITAZONE", 2, None),
124-
("PB", "disease", "targetDisease", 4, None),
124+
("PB", "trackloss", "trackloss", 4, None),
125125
("PC", "molecule", "PIOGLITAZONE", 0, Some(6)),
126126
("PD", "molecule", "PIOGLITAZONE", 2, Some(3)),
127127
("PD", "molecule", "PIOGLITAZONE", 3, Some(3)),
128-
("PD", "disease", "targetDisease", 4, Some(3))
128+
("PD", "trackloss", "trackloss", 4, Some(3))
129129
).toDF("patientID", "category", "eventId", "startBucket", "deathBucket")
130130

131131
val expected = Seq(
132132
("PA", "molecule", "PIOGLITAZONE", 0, Some(4), Some(3)),
133133
("PA", "molecule", "PIOGLITAZONE", 5, Some(4), Some(3)),
134-
("PA", "disease", "targetDisease", 3, Some(4), Some(3)),
134+
("PA", "trackloss", "trackloss", 3, Some(4), Some(3)),
135135
("PB", "molecule", "PIOGLITAZONE", 2, None, Some(4)),
136-
("PB", "disease", "targetDisease", 4, None, Some(4)),
136+
("PB", "trackloss", "trackloss", 4, None, Some(4)),
137+
("PC", "molecule", "PIOGLITAZONE", 0, Some(6), None),
138+
("PD", "molecule", "PIOGLITAZONE", 2, Some(3), None),
139+
("PD", "molecule", "PIOGLITAZONE", 3, Some(3), None),
140+
("PD", "trackloss", "trackloss", 4, Some(3), None)
141+
).toDF("patientID", "category", "eventId", "startBucket", "deathBucket", "tracklossBucket")
142+
143+
// When
144+
val writer = MLPPWriter()
145+
import writer.MLPPDataFrame
146+
val result = input.withTracklossBucket
147+
148+
// Then
149+
import RichDataFrames._
150+
result.show
151+
expected.show
152+
assert(result === expected)
153+
}
154+
155+
"withDiseaseBucket" should "add a column with the timeBucket of the first targetDisease of each patient" in {
156+
val sqlCtx = sqlContext
157+
import sqlCtx.implicits._
158+
159+
// Given
160+
val input = Seq(
161+
("PA", "molecule", "PIOGLITAZONE", 0, Some(4), None),
162+
("PA", "molecule", "PIOGLITAZONE", 5, Some(4), None),
163+
("PA", "disease", "targetDisease", 3, Some(4), None),
164+
("PB", "molecule", "PIOGLITAZONE", 2, None, Some(5)),
165+
("PB", "disease", "targetDisease", 4, None, Some(5)),
137166
("PC", "molecule", "PIOGLITAZONE", 0, Some(6), None),
138167
("PD", "molecule", "PIOGLITAZONE", 2, Some(3), None),
139168
("PD", "molecule", "PIOGLITAZONE", 3, Some(3), None),
140169
("PD", "disease", "targetDisease", 4, Some(3), None)
141-
).toDF("patientID", "category", "eventId", "startBucket", "deathBucket", "diseaseBucket")
170+
).toDF("patientID", "category", "eventId", "startBucket", "deathBucket", "tracklossBucket")
171+
172+
val expected = Seq(
173+
("PA", "molecule", "PIOGLITAZONE", 0, Some(4), None, Some(3)),
174+
("PA", "molecule", "PIOGLITAZONE", 5, Some(4), None, Some(3)),
175+
("PA", "disease", "targetDisease", 3, Some(4), None, Some(3)),
176+
("PB", "molecule", "PIOGLITAZONE", 2, None, Some(5), Some(4)),
177+
("PB", "disease", "targetDisease", 4, None, Some(5), Some(4)),
178+
("PC", "molecule", "PIOGLITAZONE", 0, Some(6), None, None),
179+
("PD", "molecule", "PIOGLITAZONE", 2, Some(3), None, None),
180+
("PD", "molecule", "PIOGLITAZONE", 3, Some(3), None, None),
181+
("PD", "disease", "targetDisease", 4, Some(3), None, None)
182+
).toDF("patientID", "category", "eventId", "startBucket", "deathBucket", "tracklossBucket", "diseaseBucket")
142183

143184
// When
144185
val writer = MLPPWriter()
@@ -164,16 +205,18 @@ class MLPPWriterSuite extends SharedContext {
164205
)
165206

166207
val input = Seq(
167-
("PA", Some(2), Some(3)),
168-
("PA", Some(2), Some(3)),
169-
("PB", Some(4), Some(3)),
170-
("PB", Some(4), Some(3)),
171-
("PC", None, Some(4)),
172-
("PC", None, Some(4)),
173-
("PD", Some(4), None),
174-
("PD", Some(4), None),
175-
("PE", None, None)
176-
).toDF("patientID", "deathBucket", "diseaseBucket")
208+
("PA", Some(2), None, Some(3)),
209+
("PA", Some(2), None, Some(3)),
210+
("PB", Some(4), Some(5), Some(3)),
211+
("PB", Some(4), Some(5), Some(3)),
212+
("PC", None, Some(5), Some(4)),
213+
("PC", None, Some(5), Some(4)),
214+
("PD", Some(5), None, None),
215+
("PD", Some(5), None, None),
216+
("PE", Some(7), Some(6), None),
217+
("PE", Some(7), Some(6), None),
218+
("PF", None, None, None)
219+
).toDF("patientID", "deathBucket", "tracklossBucket", "diseaseBucket")
177220

178221
val expected = Seq(
179222
("PA", Some(2)),
@@ -182,9 +225,11 @@ class MLPPWriterSuite extends SharedContext {
182225
("PB", Some(3)),
183226
("PC", Some(4)),
184227
("PC", Some(4)),
185-
("PD", Some(4)),
186-
("PD", Some(4)),
187-
("PE", Some(16))
228+
("PD", Some(5)),
229+
("PD", Some(5)),
230+
("PE", Some(6)),
231+
("PE", Some(6)),
232+
("PF", Some(16))
188233
).toDF("patientID", "endBucket")
189234

190235
// When

src/test/scala/fr/polytechnique/cmap/cnam/utilities/ColumnUtilitiesSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ class ColumnUtilitiesSuite extends SharedContext{
185185
(Some(makeTS(2006, 1, 3)), Some(1)),
186186
(Some(makeTS(2006, 1, 10)), Some(4)),
187187
(Some(makeTS(2006, 1, 31)), Some(15)),
188-
(Some(makeTS(2006, 2, 2)), Some(15)),
188+
(Some(makeTS(2006, 2, 2)), None),
189189
(None, None)
190190
).toDF("input", "output")
191191

0 commit comments

Comments
 (0)