Skip to content

Commit a2b15be

Browse files
committed
CNAM-143 New version with outcomes, metadata and CSV output
CNAM-143 Added writes to single-partition CSV files CNAM-143 New version with outcome matrices CNAM-143 Small changes CNAM-143 Added metadata creation
1 parent d367c4e commit a2b15be

File tree

6 files changed

+322
-124
lines changed

6 files changed

+322
-124
lines changed

src/main/scala/fr/polytechnique/cmap/cnam/filtering/mlpp/LaggedExposure.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ case class LaggedExposure(
55
patientIDIndex: Int,
66
gender: Int,
77
age: Int,
8+
diseaseBucket: Option[Int],
89
molecule: String,
910
moleculeIndex: Int,
1011
startBucket: Int,

src/main/scala/fr/polytechnique/cmap/cnam/filtering/mlpp/MLPPFeature.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,21 @@ package fr.polytechnique.cmap.cnam.filtering.mlpp
22

33
case class MLPPFeature(
44
patientID: String,
5-
patientIndex: Long,
5+
patientIndex: Int,
66
moleculeName: String,
7-
moleculeIndex: Long,
7+
moleculeIndex: Int,
88
bucketIndex: Int,
99
lagIndex: Int,
10-
rowIndex: Long,
11-
colIndex: Long,
10+
rowIndex: Int,
11+
colIndex: Int,
1212
value: Double)
1313

1414
object MLPPFeature {
1515

1616
def fromLaggedExposure(e: LaggedExposure, bucketCount: Int, lagCount: Int): MLPPFeature = {
1717

18-
val r = e.patientIDIndex * bucketCount + e.startBucket
19-
val c = e.moleculeIndex * lagCount + e.lag
18+
val r: Int = e.patientIDIndex * bucketCount + e.startBucket
19+
val c: Short = (e.moleculeIndex * lagCount + e.lag).toShort
2020

2121
MLPPFeature(
2222
patientID = e.patientID,

src/main/scala/fr/polytechnique/cmap/cnam/filtering/mlpp/MLPPWriter.scala

Lines changed: 108 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ class MLPPWriter(params: MLPPWriter.Params = MLPPWriter.Params()) {
3838
data.withColumn("age", age)
3939
}
4040

41-
// TODO: merge this function with withDeathBucket into a generic one (withBucketizedColumn(colName))
4241
def withStartBucket: DataFrame = {
4342
data.withColumn("startBucket",
4443
col("start").bucketize(params.minTimestamp, params.maxTimestamp, params.bucketSize)
@@ -54,9 +53,11 @@ class MLPPWriter(params: MLPPWriter.Params = MLPPWriter.Params()) {
5453
def withDiseaseBucket: DataFrame = {
5554
val window = Window.partitionBy("patientId")
5655

57-
val diseaseBucket: Column = min(
58-
when(col("category") === "disease" && col("eventId") === "targetDisease", col("startBucket"))
59-
).over(window)
56+
val hadDisease: Column = (col("category") === "disease") &&
57+
(col("eventId") === "targetDisease") &&
58+
(col("startBucket") < minColumn(col("deathBucket"), lit(bucketCount)))
59+
60+
val diseaseBucket: Column = min(when(hadDisease, col("startBucket"))).over(window)
6061

6162
data.withColumn("diseaseBucket", diseaseBucket)
6263
}
@@ -101,23 +102,54 @@ class MLPPWriter(params: MLPPWriter.Params = MLPPWriter.Params()) {
101102
def makeDiscreteExposures: Dataset[LaggedExposure] = {
102103
data
103104
.groupBy(
104-
"patientID", "patientIDIndex", "gender", "age", "molecule", "moleculeIndex", "startBucket",
105-
"endBucket"
105+
"patientID", "patientIDIndex", "gender", "age", "diseaseBucket", "molecule",
106+
"moleculeIndex", "startBucket", "endBucket"
106107
).agg(
107108
lit(0).as("lag"),
108109
lit(1.0).as("weight") // In the future, we might change it to sum("weight").as("weight")
109110
)
110111
.as[LaggedExposure]
111112
}
113+
114+
def writeCSV(path: String): Unit = {
115+
data.coalesce(1).write
116+
.format("com.databricks.spark.csv")
117+
.option("delimiter", ",")
118+
.option("header", "true")
119+
.save(path)
120+
}
112121
}
113122

114123
implicit class DiscreteExposures(exposures: Dataset[LaggedExposure]) {
115124

116125
import exposures.sqlContext.implicits._
117126

127+
def makeMetadata: Dataset[Metadata] = {
128+
// We already have this information inside the "withIndices" function, in the labels object,
129+
// however, I couldn't think of a good readable solution to extract this information to
130+
// the outer scope, so I just compute it again here.
131+
val max = math.max _
132+
val patientCount = exposures.map(_.patientIDIndex).reduce(max) + 1
133+
val moleculeCount = exposures.map(_.moleculeIndex).reduce(max) + 1
134+
val lags = params.lagCount
135+
val buckets = bucketCount
136+
val bucketSize = params.bucketSize
137+
138+
Seq(
139+
Metadata(
140+
rows = patientCount * buckets,
141+
columns = moleculeCount * lags,
142+
patients = patientCount,
143+
buckets = buckets,
144+
bucketSize = bucketSize,
145+
molecules = moleculeCount,
146+
lags = lags
147+
)
148+
).toDS
149+
}
150+
118151
def lagExposures: Dataset[LaggedExposure] = {
119-
val lagCount = params.lagCount
120-
val maxLag = lagCount - 1
152+
val lagCount = params.lagCount // to avoid full class serialization
121153

122154
// The following function transforms a single initial exposure like (Pat1, MolA, 4, 0, 1)
123155
// into a sequence of lagged exposures like:
@@ -131,9 +163,9 @@ class MLPPWriter(params: MLPPWriter.Params = MLPPWriter.Params()) {
131163
// when we reach either the last lag or the defined end bucket (min among bucket count,
132164
// death date and target disease date)
133165
val createLags: (LaggedExposure) => Seq[LaggedExposure] = {
134-
exposure => (0 to maxLag).collect {
135-
case newLag if exposure.startBucket + newLag < exposure.endBucket =>
136-
exposure.copy(startBucket = exposure.startBucket + newLag, lag = newLag)
166+
e: LaggedExposure => (0 until lagCount).collect {
167+
case newLag if e.startBucket + newLag < e.endBucket =>
168+
e.copy(startBucket = e.startBucket + newLag, lag = newLag)
137169
}
138170
}
139171

@@ -179,39 +211,90 @@ class MLPPWriter(params: MLPPWriter.Params = MLPPWriter.Params()) {
179211
// two columns.
180212
pivoted.select(moleculeColumns ++ referenceColumns: _*)
181213
}
214+
215+
// The following function assumes the data has been filtered and contains only patients with the
216+
// disease
217+
def makeOutcomes: DataFrame = {
218+
val b = bucketCount
219+
exposures.map(
220+
e => e.patientIDIndex * b + e.diseaseBucket.get
221+
).distinct.toDF
222+
}
223+
224+
// The following function assumes the data has been filtered and contains only patients with the
225+
// disease
226+
def makeStaticOutcomes: DataFrame = {
227+
exposures.map(_.patientIDIndex).distinct.toDF
228+
}
229+
230+
def writeLookupFiles(rootDir: String): Unit = {
231+
val inputDF = exposures.toDF
232+
233+
inputDF
234+
.select("patientID", "patientIDIndex", "gender", "age")
235+
.dropDuplicates(Seq("patientID"))
236+
.writeCSV(s"$rootDir/csv/PatientsLookup.csv")
237+
238+
inputDF
239+
.select("molecule", "moleculeIndex")
240+
.dropDuplicates(Seq("molecule"))
241+
.writeCSV(s"$rootDir/csv/MoleculeLookup.csv")
242+
}
182243
}
183244

184245
// Maybe put this in an implicit class of Dataset[FlatEvent]? This would cause the need of the following:
185246
// val writer = MLPPWriter(params)
186-
// import writer.FlatEventDataset
247+
// import writer.FlatEventDataset (or import writer._)
187248
// data.write(path)
188249
//
189-
// Returns the final dataset for convenience
250+
// Returns the features dataset for convenience
190251
def write(data: Dataset[FlatEvent], path: String): Dataset[MLPPFeature] = {
191252

192253
val rootDir = if(path.last == '/') path.dropRight(1) else path
193254
val input = data.toDF
194255

195-
val initialExposures = input
256+
val initialExposures: Dataset[LaggedExposure] = input
196257
.withAge(AgeReferenceDate)
197258
.withStartBucket
198-
.withDiseaseBucket
199259
.withDeathBucket
260+
.withDiseaseBucket
200261
.withEndBucket
201262
.where(col("category") === "exposure")
202263
.withColumnRenamed("eventId", "molecule")
203264
.where(col("startBucket") < col("endBucket"))
204265
.withIndices(Seq("patientID", "molecule"))
205266
.makeDiscreteExposures
206-
207-
val StaticExposures = initialExposures.makeStaticExposures
208-
StaticExposures.write.parquet(s"$rootDir/StaticExposures")
209-
210-
val result = initialExposures
211-
.lagExposures
212-
.toMLPPFeatures
213-
214-
result.toDF.write.parquet(s"$rootDir/SparseFeatures")
215-
result
267+
.persist()
268+
269+
val filteredExposures = initialExposures.filter(_.diseaseBucket.isDefined).persist()
270+
271+
val metadata: DataFrame = initialExposures.makeMetadata.toDF
272+
val staticExposures: DataFrame = initialExposures.makeStaticExposures
273+
val outcomes = filteredExposures.makeOutcomes
274+
val staticOutcomes = filteredExposures.makeStaticOutcomes
275+
val features: Dataset[MLPPFeature] = filteredExposures.lagExposures.toMLPPFeatures
276+
val featuresDF = features.toDF.persist()
277+
278+
// write static exposures ("Z" matrix)
279+
staticExposures.write.parquet(s"$rootDir/parquet/StaticExposures")
280+
staticExposures.writeCSV(s"$rootDir/csv/StaticExposures.csv")
281+
// write outcomes ("Y" matrices)
282+
outcomes.writeCSV(s"$rootDir/csv/Outcomes.csv")
283+
staticOutcomes.writeCSV(s"$rootDir/csv/StaticOutcomes.csv")
284+
// write lookup tables
285+
initialExposures.writeLookupFiles(rootDir)
286+
// write sparse features ("X" matrix)
287+
featuresDF.write.parquet(s"$rootDir/parquet/SparseFeatures")
288+
featuresDF.select("rowIndex", "colIndex", "value").writeCSV(s"$rootDir/csv/SparseFeatures.csv")
289+
290+
// write metadata
291+
metadata.writeCSV(s"$rootDir/csv/metadata.csv")
292+
293+
featuresDF.unpersist()
294+
filteredExposures.unpersist()
295+
initialExposures.unpersist()
296+
297+
// Return the features for convenience
298+
features
216299
}
217300
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package fr.polytechnique.cmap.cnam.filtering.mlpp
2+
3+
case class Metadata(
4+
rows: Int,
5+
columns: Int,
6+
patients: Int,
7+
buckets: Int,
8+
bucketSize: Int,
9+
molecules: Int,
10+
lags: Int
11+
)

src/test/scala/fr/polytechnique/cmap/cnam/filtering/mlpp/MLPPFeatureSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ class MLPPFeatureSuite extends FlatSpec {
77
// Given
88
val bucketCount = 10
99
val lagCount = 10
10-
val exposure = LaggedExposure("PE", 4, 1, 40, "Mol3", 2, 3, 9, 5, 1.0)
10+
val exposure = LaggedExposure("PE", 4, 1, 40, None, "Mol3", 2, 3, 9, 5, 1.0)
1111
val expected = MLPPFeature("PE", 4, "Mol3", 2, 3, 5, 43, 25, 1.0)
1212
// When
1313
val result = MLPPFeature.fromLaggedExposure(exposure, bucketCount, lagCount)

0 commit comments

Comments
 (0)