@@ -193,7 +193,7 @@ class MLPPWriterSuite extends SharedContext {
193193 assert(result === expected)
194194 }
195195
196- " withEndBucket" should " add a column with the minimum among deathBucket, diseaseBucket and the max number of buckets" in {
196+ " withEndBucket" should " add a column with the minimum among deathBucket and the max number of buckets" in {
197197 val sqlCtx = sqlContext
198198 import sqlCtx .implicits ._
199199
@@ -244,6 +244,53 @@ class MLPPWriterSuite extends SharedContext {
244244 assert(result === expected)
245245 }
246246
247+ it should " add a column with the minimum among deathBucket + 1, and the max number of buckets if " +
248+ " includeDeathBucket is true" in {
249+ val sqlCtx = sqlContext
250+ import sqlCtx .implicits ._
251+
252+ // Given
253+ val params = MLPPWriter .Params (
254+ minTimestamp = makeTS(2006 , 1 , 1 ),
255+ maxTimestamp = makeTS(2006 , 2 , 2 ),
256+ bucketSize = 2 ,
257+ includeDeathBucket = true
258+ )
259+
260+ val input = Seq (
261+ (" PA" , Some (16 )),
262+ (" PA" , Some (16 )),
263+ (" PB" , Some ( 0 )),
264+ (" PB" , Some ( 0 )),
265+ (" PC" , Some ( 5 )),
266+ (" PC" , Some ( 5 )),
267+ (" PD" , None ),
268+ (" PD" , None )
269+ ).toDF(" patientID" , " deathBucket" )
270+
271+ val expected = Seq (
272+ (" PA" , Some (16 )),
273+ (" PA" , Some (16 )),
274+ (" PB" , Some ( 1 )),
275+ (" PB" , Some ( 1 )),
276+ (" PC" , Some ( 6 )),
277+ (" PC" , Some ( 6 )),
278+ (" PD" , Some (16 )),
279+ (" PD" , Some (16 ))
280+ ).toDF(" patientID" , " endBucket" )
281+
282+ // When
283+ val writer = MLPPWriter (params)
284+ import writer .MLPPDataFrame
285+ val result = input.withEndBucket.select(" patientID" , " endBucket" )
286+
287+ // Then
288+ import RichDataFrames ._
289+ result.show
290+ expected.show
291+ assert(result === expected)
292+ }
293+
247294 " makeDiscreteExposures" should " return a Dataset containing the 0-lag exposures in the sparse format" in {
248295 val sqlCtx = sqlContext
249296 import sqlCtx .implicits ._
@@ -592,9 +639,9 @@ class MLPPWriterSuite extends SharedContext {
592639 )
593640 val input : Dataset [FlatEvent ] = Seq (
594641 FlatEvent (" PC" , 2 , makeTS(1970 , 1 , 1 ), None , " exposure" , " Mol1" , 1.0 , makeTS(2006 , 5 , 15 ), None ),
595- FlatEvent (" PB" , 1 , makeTS(1950 , 1 , 1 ), Some (makeTS(2006 , 6 , 15 )), " exposure" , " Mol1" , 1.0 , makeTS(2006 , 1 , 15 ), None ),
596- FlatEvent (" PB" , 1 , makeTS(1950 , 1 , 1 ), Some (makeTS(2006 , 6 , 15 )), " exposure" , " Mol2 " , 1.0 , makeTS(2006 , 3 , 15 ), None ),
597- FlatEvent (" PB" , 1 , makeTS(1950 , 1 , 1 ), Some (makeTS(2006 , 6 , 15 )), " exposure " , " Mol2 " , 1.0 , makeTS(2006 , 5 , 15 ), None ),
642+ FlatEvent (" PB" , 1 , makeTS(1950 , 1 , 1 ), Some (makeTS(2006 , 4 , 15 )), " exposure" , " Mol1" , 1.0 , makeTS(2006 , 1 , 15 ), None ),
643+ FlatEvent (" PB" , 1 , makeTS(1950 , 1 , 1 ), Some (makeTS(2006 , 4 , 15 )), " exposure" , " Mol1 " , 1.0 , makeTS(2006 , 3 , 15 ), None ),
644+ FlatEvent (" PB" , 1 , makeTS(1950 , 1 , 1 ), Some (makeTS(2006 , 4 , 15 )), " disease " , " targetDisease " , 1.0 , makeTS(2006 , 3 , 15 ), None ),
598645 FlatEvent (" PA" , 1 , makeTS(1960 , 1 , 1 ), None , " exposure" , " Mol1" , 1.0 , makeTS(2006 , 1 , 15 ), None ),
599646 FlatEvent (" PA" , 1 , makeTS(1960 , 1 , 1 ), None , " exposure" , " Mol1" , 1.0 , makeTS(2006 , 3 , 15 ), None ),
600647 FlatEvent (" PA" , 1 , makeTS(1960 , 1 , 1 ), None , " exposure" , " Mol1" , 1.0 , makeTS(2006 , 4 , 15 ), None ),
@@ -624,12 +671,17 @@ class MLPPWriterSuite extends SharedContext {
624671 MLPPFeature (" PA" , 0 , " Mol3" , 2 , 3 , 0 , 3 , 8 , 1.0 ),
625672 MLPPFeature (" PA" , 0 , " Mol3" , 2 , 4 , 1 , 4 , 9 , 1.0 ),
626673 MLPPFeature (" PA" , 0 , " Mol3" , 2 , 5 , 2 , 5 , 10 , 1.0 ),
627- MLPPFeature (" PA" , 0 , " Mol3" , 2 , 6 , 3 , 6 , 11 , 1.0 )
674+ MLPPFeature (" PA" , 0 , " Mol3" , 2 , 6 , 3 , 6 , 11 , 1.0 ),
675+ // Patient B
676+ MLPPFeature (" PB" , 1 , " Mol1" , 0 , 0 , 0 , 7 , 0 , 1.0 ),
677+ MLPPFeature (" PB" , 1 , " Mol1" , 0 , 1 , 1 , 8 , 1 , 1.0 ),
678+ MLPPFeature (" PB" , 1 , " Mol1" , 0 , 2 , 2 , 9 , 2 , 1.0 ),
679+ MLPPFeature (" PB" , 1 , " Mol1" , 0 , 2 , 0 , 9 , 0 , 1.0 )
628680 ).toDF
629681
630682 val expectedZMatrix = Seq (
631683 (3D , 1D , 1D , 46 , 1 , " PA" , 0 ),
632- (1D , 2D , 0D , 56 , 1 , " PB" , 1 ),
684+ (2D , 0D , 0D , 56 , 1 , " PB" , 1 ),
633685 (1D , 0D , 0D , 36 , 2 , " PC" , 2 )
634686 ).toDF(" MOL0000_Mol1" , " MOL0001_Mol2" , " MOL0002_Mol3" , " age" , " gender" , " patientID" , " patientIDIndex" )
635687
@@ -640,8 +692,88 @@ class MLPPWriterSuite extends SharedContext {
640692
641693 // Then
642694 import RichDataFrames ._
643- result.show
644- expectedFeatures.show
695+ result.show(100 )
696+ expectedFeatures.show(100 )
697+ StaticExposures .show
698+ expectedZMatrix.show
699+ assert(result === expectedFeatures)
700+ assert(writtenResult === expectedFeatures)
701+ assert(StaticExposures === expectedZMatrix)
702+ }
703+
704+
705+ it should " create the final matrices and write them as parquet files (removing death bucket)" in {
706+ val sqlCtx = sqlContext
707+ import sqlCtx .implicits ._
708+
709+ // Given
710+ val rootDir = " target/test/output"
711+ val params = MLPPWriter .Params (
712+ minTimestamp = makeTS(2006 , 1 , 1 ),
713+ maxTimestamp = makeTS(2006 , 8 , 1 ), // 7 total buckets
714+ bucketSize = 30 ,
715+ lagCount = 4 ,
716+ includeDeathBucket = true
717+ )
718+ val input : Dataset [FlatEvent ] = Seq (
719+ FlatEvent (" PC" , 2 , makeTS(1970 , 1 , 1 ), None , " exposure" , " Mol1" , 1.0 , makeTS(2006 , 5 , 15 ), None ),
720+ FlatEvent (" PB" , 1 , makeTS(1950 , 1 , 1 ), Some (makeTS(2006 , 4 , 15 )), " exposure" , " Mol1" , 1.0 , makeTS(2006 , 1 , 15 ), None ),
721+ FlatEvent (" PB" , 1 , makeTS(1950 , 1 , 1 ), Some (makeTS(2006 , 4 , 15 )), " exposure" , " Mol1" , 1.0 , makeTS(2006 , 3 , 15 ), None ),
722+ FlatEvent (" PB" , 1 , makeTS(1950 , 1 , 1 ), Some (makeTS(2006 , 4 , 15 )), " disease" , " targetDisease" , 1.0 , makeTS(2006 , 3 , 15 ), None ),
723+ FlatEvent (" PA" , 1 , makeTS(1960 , 1 , 1 ), None , " exposure" , " Mol1" , 1.0 , makeTS(2006 , 1 , 15 ), None ),
724+ FlatEvent (" PA" , 1 , makeTS(1960 , 1 , 1 ), None , " exposure" , " Mol1" , 1.0 , makeTS(2006 , 3 , 15 ), None ),
725+ FlatEvent (" PA" , 1 , makeTS(1960 , 1 , 1 ), None , " exposure" , " Mol1" , 1.0 , makeTS(2006 , 4 , 15 ), None ),
726+ FlatEvent (" PA" , 1 , makeTS(1960 , 1 , 1 ), None , " exposure" , " Mol2" , 1.0 , makeTS(2006 , 3 , 15 ), None ),
727+ FlatEvent (" PA" , 1 , makeTS(1960 , 1 , 1 ), None , " exposure" , " Mol3" , 1.0 , makeTS(2006 , 4 , 15 ), None ),
728+ FlatEvent (" PA" , 1 , makeTS(1960 , 1 , 1 ), None , " disease" , " targetDisease" , 1.0 , makeTS(2006 , 5 , 15 ), None )
729+ ).toDS
730+
731+ val expectedFeatures = Seq (
732+ // Patient A
733+ MLPPFeature (" PA" , 0 , " Mol1" , 0 , 0 , 0 , 0 , 0 , 1.0 ),
734+ MLPPFeature (" PA" , 0 , " Mol1" , 0 , 1 , 1 , 1 , 1 , 1.0 ),
735+ MLPPFeature (" PA" , 0 , " Mol1" , 0 , 2 , 2 , 2 , 2 , 1.0 ),
736+ MLPPFeature (" PA" , 0 , " Mol1" , 0 , 3 , 3 , 3 , 3 , 1.0 ),
737+ MLPPFeature (" PA" , 0 , " Mol1" , 0 , 2 , 0 , 2 , 0 , 1.0 ),
738+ MLPPFeature (" PA" , 0 , " Mol1" , 0 , 3 , 1 , 3 , 1 , 1.0 ),
739+ MLPPFeature (" PA" , 0 , " Mol1" , 0 , 4 , 2 , 4 , 2 , 1.0 ),
740+ MLPPFeature (" PA" , 0 , " Mol1" , 0 , 5 , 3 , 5 , 3 , 1.0 ),
741+ MLPPFeature (" PA" , 0 , " Mol1" , 0 , 3 , 0 , 3 , 0 , 1.0 ),
742+ MLPPFeature (" PA" , 0 , " Mol1" , 0 , 4 , 1 , 4 , 1 , 1.0 ),
743+ MLPPFeature (" PA" , 0 , " Mol1" , 0 , 5 , 2 , 5 , 2 , 1.0 ),
744+ MLPPFeature (" PA" , 0 , " Mol1" , 0 , 6 , 3 , 6 , 3 , 1.0 ),
745+ MLPPFeature (" PA" , 0 , " Mol2" , 1 , 2 , 0 , 2 , 4 , 1.0 ),
746+ MLPPFeature (" PA" , 0 , " Mol2" , 1 , 3 , 1 , 3 , 5 , 1.0 ),
747+ MLPPFeature (" PA" , 0 , " Mol2" , 1 , 4 , 2 , 4 , 6 , 1.0 ),
748+ MLPPFeature (" PA" , 0 , " Mol2" , 1 , 5 , 3 , 5 , 7 , 1.0 ),
749+ MLPPFeature (" PA" , 0 , " Mol3" , 2 , 3 , 0 , 3 , 8 , 1.0 ),
750+ MLPPFeature (" PA" , 0 , " Mol3" , 2 , 4 , 1 , 4 , 9 , 1.0 ),
751+ MLPPFeature (" PA" , 0 , " Mol3" , 2 , 5 , 2 , 5 , 10 , 1.0 ),
752+ MLPPFeature (" PA" , 0 , " Mol3" , 2 , 6 , 3 , 6 , 11 , 1.0 ),
753+ // Patient A,
754+ MLPPFeature (" PB" , 1 , " Mol1" , 0 , 0 , 0 , 7 , 0 , 1.0 ),
755+ MLPPFeature (" PB" , 1 , " Mol1" , 0 , 1 , 1 , 8 , 1 , 1.0 ),
756+ MLPPFeature (" PB" , 1 , " Mol1" , 0 , 2 , 2 , 9 , 2 , 1.0 ),
757+ MLPPFeature (" PB" , 1 , " Mol1" , 0 , 3 , 3 , 10 , 3 , 1.0 ),
758+ MLPPFeature (" PB" , 1 , " Mol1" , 0 , 2 , 0 , 9 , 0 , 1.0 ),
759+ MLPPFeature (" PB" , 1 , " Mol1" , 0 , 3 , 1 , 10 , 1 , 1.0 )
760+ ).toDF
761+
762+ val expectedZMatrix = Seq (
763+ (3D , 1D , 1D , 46 , 1 , " PA" , 0 ),
764+ (2D , 0D , 0D , 56 , 1 , " PB" , 1 ),
765+ (1D , 0D , 0D , 36 , 2 , " PC" , 2 )
766+ ).toDF(" MOL0000_Mol1" , " MOL0001_Mol2" , " MOL0002_Mol3" , " age" , " gender" , " patientID" , " patientIDIndex" )
767+
768+ // When
769+ val result = MLPPWriter (params).write(input, rootDir).toDF
770+ val writtenResult = sqlContext.read.parquet(s " $rootDir/parquet/SparseFeatures " )
771+ val StaticExposures = sqlContext.read.parquet(s " $rootDir/parquet/StaticExposures " )
772+
773+ // Then
774+ import RichDataFrames ._
775+ result.show(100 )
776+ expectedFeatures.show(100 )
645777 StaticExposures .show
646778 expectedZMatrix.show
647779 assert(result === expectedFeatures)
0 commit comments