@@ -111,34 +111,75 @@ class MLPPWriterSuite extends SharedContext {
111111 assert(result === expected)
112112 }
113113
114- " withDiseaseBucket " should " add a column with the timeBucket of the first targetDisease of each patient" in {
114+ " withTracklossBucket " should " add a column with the timeBucket of the first trackloss of each patient" in {
115115 val sqlCtx = sqlContext
116116 import sqlCtx .implicits ._
117117
118118 // Given
119119 val input = Seq (
120120 (" PA" , " molecule" , " PIOGLITAZONE" , 0 , Some (4 )),
121121 (" PA" , " molecule" , " PIOGLITAZONE" , 5 , Some (4 )),
122- (" PA" , " disease " , " targetDisease " , 3 , Some (4 )),
122+ (" PA" , " trackloss " , " trackloss " , 3 , Some (4 )),
123123 (" PB" , " molecule" , " PIOGLITAZONE" , 2 , None ),
124- (" PB" , " disease " , " targetDisease " , 4 , None ),
124+ (" PB" , " trackloss " , " trackloss " , 4 , None ),
125125 (" PC" , " molecule" , " PIOGLITAZONE" , 0 , Some (6 )),
126126 (" PD" , " molecule" , " PIOGLITAZONE" , 2 , Some (3 )),
127127 (" PD" , " molecule" , " PIOGLITAZONE" , 3 , Some (3 )),
128- (" PD" , " disease " , " targetDisease " , 4 , Some (3 ))
128+ (" PD" , " trackloss " , " trackloss " , 4 , Some (3 ))
129129 ).toDF(" patientID" , " category" , " eventId" , " startBucket" , " deathBucket" )
130130
131131 val expected = Seq (
132132 (" PA" , " molecule" , " PIOGLITAZONE" , 0 , Some (4 ), Some (3 )),
133133 (" PA" , " molecule" , " PIOGLITAZONE" , 5 , Some (4 ), Some (3 )),
134- (" PA" , " disease " , " targetDisease " , 3 , Some (4 ), Some (3 )),
134+ (" PA" , " trackloss " , " trackloss " , 3 , Some (4 ), Some (3 )),
135135 (" PB" , " molecule" , " PIOGLITAZONE" , 2 , None , Some (4 )),
136- (" PB" , " disease" , " targetDisease" , 4 , None , Some (4 )),
136+ (" PB" , " trackloss" , " trackloss" , 4 , None , Some (4 )),
137+ (" PC" , " molecule" , " PIOGLITAZONE" , 0 , Some (6 ), None ),
138+ (" PD" , " molecule" , " PIOGLITAZONE" , 2 , Some (3 ), None ),
139+ (" PD" , " molecule" , " PIOGLITAZONE" , 3 , Some (3 ), None ),
140+ (" PD" , " trackloss" , " trackloss" , 4 , Some (3 ), None )
141+ ).toDF(" patientID" , " category" , " eventId" , " startBucket" , " deathBucket" , " tracklossBucket" )
142+
143+ // When
144+ val writer = MLPPWriter ()
145+ import writer .MLPPDataFrame
146+ val result = input.withTracklossBucket
147+
148+ // Then
149+ import RichDataFrames ._
150+ result.show
151+ expected.show
152+ assert(result === expected)
153+ }
154+
155+ " withDiseaseBucket" should " add a column with the timeBucket of the first targetDisease of each patient" in {
156+ val sqlCtx = sqlContext
157+ import sqlCtx .implicits ._
158+
159+ // Given
160+ val input = Seq (
161+ (" PA" , " molecule" , " PIOGLITAZONE" , 0 , Some (4 ), None ),
162+ (" PA" , " molecule" , " PIOGLITAZONE" , 5 , Some (4 ), None ),
163+ (" PA" , " disease" , " targetDisease" , 3 , Some (4 ), None ),
164+ (" PB" , " molecule" , " PIOGLITAZONE" , 2 , None , Some (5 )),
165+ (" PB" , " disease" , " targetDisease" , 4 , None , Some (5 )),
137166 (" PC" , " molecule" , " PIOGLITAZONE" , 0 , Some (6 ), None ),
138167 (" PD" , " molecule" , " PIOGLITAZONE" , 2 , Some (3 ), None ),
139168 (" PD" , " molecule" , " PIOGLITAZONE" , 3 , Some (3 ), None ),
140169 (" PD" , " disease" , " targetDisease" , 4 , Some (3 ), None )
141- ).toDF(" patientID" , " category" , " eventId" , " startBucket" , " deathBucket" , " diseaseBucket" )
170+ ).toDF(" patientID" , " category" , " eventId" , " startBucket" , " deathBucket" , " tracklossBucket" )
171+
172+ val expected = Seq (
173+ (" PA" , " molecule" , " PIOGLITAZONE" , 0 , Some (4 ), None , Some (3 )),
174+ (" PA" , " molecule" , " PIOGLITAZONE" , 5 , Some (4 ), None , Some (3 )),
175+ (" PA" , " disease" , " targetDisease" , 3 , Some (4 ), None , Some (3 )),
176+ (" PB" , " molecule" , " PIOGLITAZONE" , 2 , None , Some (5 ), Some (4 )),
177+ (" PB" , " disease" , " targetDisease" , 4 , None , Some (5 ), Some (4 )),
178+ (" PC" , " molecule" , " PIOGLITAZONE" , 0 , Some (6 ), None , None ),
179+ (" PD" , " molecule" , " PIOGLITAZONE" , 2 , Some (3 ), None , None ),
180+ (" PD" , " molecule" , " PIOGLITAZONE" , 3 , Some (3 ), None , None ),
181+ (" PD" , " disease" , " targetDisease" , 4 , Some (3 ), None , None )
182+ ).toDF(" patientID" , " category" , " eventId" , " startBucket" , " deathBucket" , " tracklossBucket" , " diseaseBucket" )
142183
143184 // When
144185 val writer = MLPPWriter ()
@@ -164,16 +205,18 @@ class MLPPWriterSuite extends SharedContext {
164205 )
165206
166207 val input = Seq (
167- (" PA" , Some (2 ), Some (3 )),
168- (" PA" , Some (2 ), Some (3 )),
169- (" PB" , Some (4 ), Some (3 )),
170- (" PB" , Some (4 ), Some (3 )),
171- (" PC" , None , Some (4 )),
172- (" PC" , None , Some (4 )),
173- (" PD" , Some (4 ), None ),
174- (" PD" , Some (4 ), None ),
175- (" PE" , None , None )
176- ).toDF(" patientID" , " deathBucket" , " diseaseBucket" )
208+ (" PA" , Some (2 ), None , Some (3 )),
209+ (" PA" , Some (2 ), None , Some (3 )),
210+ (" PB" , Some (4 ), Some (5 ), Some (3 )),
211+ (" PB" , Some (4 ), Some (5 ), Some (3 )),
212+ (" PC" , None , Some (5 ), Some (4 )),
213+ (" PC" , None , Some (5 ), Some (4 )),
214+ (" PD" , Some (5 ), None , None ),
215+ (" PD" , Some (5 ), None , None ),
216+ (" PE" , Some (7 ), Some (6 ), None ),
217+ (" PE" , Some (7 ), Some (6 ), None ),
218+ (" PF" , None , None , None )
219+ ).toDF(" patientID" , " deathBucket" , " tracklossBucket" , " diseaseBucket" )
177220
178221 val expected = Seq (
179222 (" PA" , Some (2 )),
@@ -182,9 +225,11 @@ class MLPPWriterSuite extends SharedContext {
182225 (" PB" , Some (3 )),
183226 (" PC" , Some (4 )),
184227 (" PC" , Some (4 )),
185- (" PD" , Some (4 )),
186- (" PD" , Some (4 )),
187- (" PE" , Some (16 ))
228+ (" PD" , Some (5 )),
229+ (" PD" , Some (5 )),
230+ (" PE" , Some (6 )),
231+ (" PE" , Some (6 )),
232+ (" PF" , Some (16 ))
188233 ).toDF(" patientID" , " endBucket" )
189234
190235 // When
0 commit comments