@@ -38,7 +38,6 @@ class MLPPWriter(params: MLPPWriter.Params = MLPPWriter.Params()) {
3838 data.withColumn(" age" , age)
3939 }
4040
41- // TODO: merge this function with withDeathBucket into a generic one (withBucketizedColumn(colName))
4241 def withStartBucket : DataFrame = {
4342 data.withColumn(" startBucket" ,
4443 col(" start" ).bucketize(params.minTimestamp, params.maxTimestamp, params.bucketSize)
@@ -54,9 +53,11 @@ class MLPPWriter(params: MLPPWriter.Params = MLPPWriter.Params()) {
5453 def withDiseaseBucket : DataFrame = {
5554 val window = Window .partitionBy(" patientId" )
5655
57- val diseaseBucket : Column = min(
58- when(col(" category" ) === " disease" && col(" eventId" ) === " targetDisease" , col(" startBucket" ))
59- ).over(window)
56+ val hadDisease : Column = (col(" category" ) === " disease" ) &&
57+ (col(" eventId" ) === " targetDisease" ) &&
58+ (col(" startBucket" ) < minColumn(col(" deathBucket" ), lit(bucketCount)))
59+
60+ val diseaseBucket : Column = min(when(hadDisease, col(" startBucket" ))).over(window)
6061
6162 data.withColumn(" diseaseBucket" , diseaseBucket)
6263 }
@@ -101,23 +102,54 @@ class MLPPWriter(params: MLPPWriter.Params = MLPPWriter.Params()) {
101102 def makeDiscreteExposures : Dataset [LaggedExposure ] = {
102103 data
103104 .groupBy(
104- " patientID" , " patientIDIndex" , " gender" , " age" , " molecule " , " moleculeIndex " , " startBucket " ,
105- " endBucket"
105+ " patientID" , " patientIDIndex" , " gender" , " age" , " diseaseBucket " , " molecule " ,
106+ " moleculeIndex " , " startBucket " , " endBucket"
106107 ).agg(
107108 lit(0 ).as(" lag" ),
108109 lit(1.0 ).as(" weight" ) // In the future, we might change it to sum("weight").as("weight")
109110 )
110111 .as[LaggedExposure ]
111112 }
113+
114+ def writeCSV (path : String ): Unit = {
115+ data.coalesce(1 ).write
116+ .format(" com.databricks.spark.csv" )
117+ .option(" delimiter" , " ," )
118+ .option(" header" , " true" )
119+ .save(path)
120+ }
112121 }
113122
114123 implicit class DiscreteExposures (exposures : Dataset [LaggedExposure ]) {
115124
116125 import exposures .sqlContext .implicits ._
117126
127+ def makeMetadata : Dataset [Metadata ] = {
128+ // We already have this information inside the "withIndices" function, in the labels object,
129+ // however, I couldn't think of a good readable solution to extract this information to
130+ // the outer scope, so I just compute it again here.
131+ val max = math.max _
132+ val patientCount = exposures.map(_.patientIDIndex).reduce(max) + 1
133+ val moleculeCount = exposures.map(_.moleculeIndex).reduce(max) + 1
134+ val lags = params.lagCount
135+ val buckets = bucketCount
136+ val bucketSize = params.bucketSize
137+
138+ Seq (
139+ Metadata (
140+ rows = patientCount * buckets,
141+ columns = moleculeCount * lags,
142+ patients = patientCount,
143+ buckets = buckets,
144+ bucketSize = bucketSize,
145+ molecules = moleculeCount,
146+ lags = lags
147+ )
148+ ).toDS
149+ }
150+
118151 def lagExposures : Dataset [LaggedExposure ] = {
119- val lagCount = params.lagCount
120- val maxLag = lagCount - 1
152+ val lagCount = params.lagCount // to avoid full class serialization
121153
122154 // The following function transforms a single initial exposure like (Pat1, MolA, 4, 0, 1)
123155 // into a sequence of lagged exposures like:
@@ -131,9 +163,9 @@ class MLPPWriter(params: MLPPWriter.Params = MLPPWriter.Params()) {
131163 // when we reach either the last lag or the defined end bucket (min among bucket count,
132164 // death date and target disease date)
133165 val createLags : (LaggedExposure ) => Seq [LaggedExposure ] = {
134- exposure => (0 to maxLag ).collect {
135- case newLag if exposure .startBucket + newLag < exposure .endBucket =>
136- exposure .copy(startBucket = exposure .startBucket + newLag, lag = newLag)
166+ e : LaggedExposure => (0 until lagCount ).collect {
167+ case newLag if e .startBucket + newLag < e .endBucket =>
168+ e .copy(startBucket = e .startBucket + newLag, lag = newLag)
137169 }
138170 }
139171
@@ -179,39 +211,90 @@ class MLPPWriter(params: MLPPWriter.Params = MLPPWriter.Params()) {
179211 // two columns.
180212 pivoted.select(moleculeColumns ++ referenceColumns : _* )
181213 }
214+
215+ // The following function assumes the data has been filtered and contains only patients with the
216+ // disease
217+ def makeOutcomes : DataFrame = {
218+ val b = bucketCount
219+ exposures.map(
220+ e => e.patientIDIndex * b + e.diseaseBucket.get
221+ ).distinct.toDF
222+ }
223+
224+ // The following function assumes the data has been filtered and contains only patients with the
225+ // disease
226+ def makeStaticOutcomes : DataFrame = {
227+ exposures.map(_.patientIDIndex).distinct.toDF
228+ }
229+
230+ def writeLookupFiles (rootDir : String ): Unit = {
231+ val inputDF = exposures.toDF
232+
233+ inputDF
234+ .select(" patientID" , " patientIDIndex" , " gender" , " age" )
235+ .dropDuplicates(Seq (" patientID" ))
236+ .writeCSV(s " $rootDir/csv/PatientsLookup.csv " )
237+
238+ inputDF
239+ .select(" molecule" , " moleculeIndex" )
240+ .dropDuplicates(Seq (" molecule" ))
241+ .writeCSV(s " $rootDir/csv/MoleculeLookup.csv " )
242+ }
182243 }
183244
184245 // Maybe put this in an implicit class of Dataset[FlatEvent]? This would cause the need of the following:
185246 // val writer = MLPPWriter(params)
186- // import writer.FlatEventDataset
247+ // import writer.FlatEventDataset (or import writer._)
187248 // data.write(path)
188249 //
189- // Returns the final dataset for convenience
250+ // Returns the features dataset for convenience
190251 def write (data : Dataset [FlatEvent ], path : String ): Dataset [MLPPFeature ] = {
191252
192253 val rootDir = if (path.last == '/' ) path.dropRight(1 ) else path
193254 val input = data.toDF
194255
195- val initialExposures = input
256+ val initialExposures : Dataset [ LaggedExposure ] = input
196257 .withAge(AgeReferenceDate )
197258 .withStartBucket
198- .withDiseaseBucket
199259 .withDeathBucket
260+ .withDiseaseBucket
200261 .withEndBucket
201262 .where(col(" category" ) === " exposure" )
202263 .withColumnRenamed(" eventId" , " molecule" )
203264 .where(col(" startBucket" ) < col(" endBucket" ))
204265 .withIndices(Seq (" patientID" , " molecule" ))
205266 .makeDiscreteExposures
206-
207- val StaticExposures = initialExposures.makeStaticExposures
208- StaticExposures .write.parquet(s " $rootDir/StaticExposures " )
209-
210- val result = initialExposures
211- .lagExposures
212- .toMLPPFeatures
213-
214- result.toDF.write.parquet(s " $rootDir/SparseFeatures " )
215- result
267+ .persist()
268+
269+ val filteredExposures = initialExposures.filter(_.diseaseBucket.isDefined).persist()
270+
271+ val metadata : DataFrame = initialExposures.makeMetadata.toDF
272+ val staticExposures : DataFrame = initialExposures.makeStaticExposures
273+ val outcomes = filteredExposures.makeOutcomes
274+ val staticOutcomes = filteredExposures.makeStaticOutcomes
275+ val features : Dataset [MLPPFeature ] = filteredExposures.lagExposures.toMLPPFeatures
276+ val featuresDF = features.toDF.persist()
277+
278+ // write static exposures ("Z" matrix)
279+ staticExposures.write.parquet(s " $rootDir/parquet/StaticExposures " )
280+ staticExposures.writeCSV(s " $rootDir/csv/StaticExposures.csv " )
281+ // write outcomes ("Y" matrices)
282+ outcomes.writeCSV(s " $rootDir/csv/Outcomes.csv " )
283+ staticOutcomes.writeCSV(s " $rootDir/csv/StaticOutcomes.csv " )
284+ // write lookup tables
285+ initialExposures.writeLookupFiles(rootDir)
286+ // write sparse features ("X" matrix)
287+ featuresDF.write.parquet(s " $rootDir/parquet/SparseFeatures " )
288+ featuresDF.select(" rowIndex" , " colIndex" , " value" ).writeCSV(s " $rootDir/csv/SparseFeatures.csv " )
289+
290+ // write metadata
291+ metadata.writeCSV(s " $rootDir/csv/metadata.csv " )
292+
293+ featuresDF.unpersist()
294+ filteredExposures.unpersist()
295+ initialExposures.unpersist()
296+
297+ // Return the features for convenience
298+ features
216299 }
217300}
0 commit comments