Skip to content

Commit 657c035

Browse files
committed
CNAM-164 Major architectural refactoring
CNAM-164 Beginning of refactoring into Cake Pattern CNAM-164 Refactored existing implementations into new format CNAM-164 Continued refactoring CNAM-164 Fixed some test cases and bugs CNAM-164 Improved test coverage in new ExposuresTransformer architecture CNAM-164 Further improved test coverage CNAM-164 Finished tests, including TestBasedWeightAgg CNAM-164 Changed "withWeight" method name to "aggregateWeight" CNAM-164 Updated CoxMain and configuration files CNAM-164 Fixed tests CNAM-164 Fixed another test
1 parent 08a8b15 commit 657c035

34 files changed

+1944
-17
lines changed

src/main/resources/config/filtering-default.conf

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,30 @@ default = {
4242
threshold = 4
4343
delay = 2
4444
}
45+
filters = {
46+
delayed_entries = true #Patients who are exposed after certain months of the study start.
47+
delayed_entries_threshold = 12 #Months that signifies the delayed entries.
48+
}
49+
exposures = {
50+
min_purchases = 2 #Minimum number of purchases that have to be made in order to be considered exposed.
51+
purchases_window = 6 #Purchase window, within which the min number of purchases have to be made.
52+
start_delay = 3 #Number of months after which a patient will be considered exposed after the min purchases, window.
53+
period_strategy = "unlimited" # Period stratgy. Possible values: "unlimited" | "limited" (multiple exposures with start and end)
54+
weight_strategy = "non-cumulative" # Weight Aggregation strategy. Possible values: "non-cumulative" | "purchase-based" | "dosage-based" | "time-based"
55+
cumulative = {
56+
window = 1 #Number of months to quantile.
57+
start_threshold = 6 #Number of months within which more than one purchases have to made
58+
end_threshold = 4 #Number of months during which no purchases of the particular molecule have to be made
59+
}
60+
}
4561

4662
# Parameters for the Cox featuring:
4763
cox_parameters = {
4864
follow_up_delay = 6 #Number of months after the observation start that is considered to be followup
4965
filter_delayed_patients = true #Patients who are exposed after certain months of the study start.
5066
delayed_entries_threshold = 12 #Months that signifies the delayed entries.
5167

68+
# @deprecated
5269
exposures = {
5370
min_purchases = 2 #Minimum number of purchases that have to be made in order to be considered exposed.
5471
purchases_window = 6 #Purchase window, within which the min number of purchases have to be made.

src/main/scala/fr/polytechnique/cmap/cnam/filtering/FilteringConfig.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import scala.collection.JavaConverters._
55
import org.apache.spark.SparkContext
66
import org.apache.spark.sql.SQLContext
77
import com.typesafe.config.{Config, ConfigFactory}
8+
import fr.polytechnique.cmap.cnam.filtering.exposures.{ExposurePeriodStrategy, ExposuresConfig, WeightAggStrategy}
89
import fr.polytechnique.cmap.cnam.utilities.functions._
910

1011
object FilteringConfig {
@@ -136,4 +137,21 @@ object FilteringConfig {
136137
)
137138

138139
def modelConfig(modelName: String): Config = conf.getConfig(modelName)
140+
lazy val exposuresConfig: ExposuresConfig = ExposuresConfig(
141+
studyStart = dates.studyStart,
142+
diseaseCode = diseaseCode,
143+
periodStrategy = ExposurePeriodStrategy.fromString(
144+
conf.getString("exposures.period_strategy")
145+
),
146+
minPurchases = conf.getInt("exposures.min_purchases"),
147+
purchasesWindow = conf.getInt("exposures.purchases_window"),
148+
startDelay = conf.getInt("exposures.start_delay"),
149+
weightAggStrategy = WeightAggStrategy.fromString(
150+
conf.getString("exposures.weight_strategy")
151+
),
152+
filterDelayedPatients = conf.getBoolean("filters.delayed_entries"),
153+
cumulativeExposureWindow = conf.getInt("exposures.cumulative.window"),
154+
cumulativeStartThreshold = conf.getInt("exposures.cumulative.start_threshold"),
155+
cumulativeEndThreshold = conf.getInt("exposures.cumulative.end_threshold")
156+
)
139157
}

src/main/scala/fr/polytechnique/cmap/cnam/filtering/cox/CoxMain.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import org.apache.spark.sql.hive.HiveContext
55
import org.apache.spark.sql.{DataFrame, Dataset}
66
import fr.polytechnique.cmap.cnam.Main
77
import fr.polytechnique.cmap.cnam.filtering._
8+
import fr.polytechnique.cmap.cnam.filtering.exposures.ExposuresTransformer
89

910
/**
1011
* Created by sathiya on 09/11/16.
@@ -75,7 +76,10 @@ object CoxMain extends Main {
7576
drugFlatEvents
7677
.union(diseaseFlatEvents)
7778
.union(followUpFlatEvents)
78-
val exposures = CoxExposuresTransformer.transform(flatEventsForExposures).cache()
79+
// val exposures = CoxExposuresTransformer.transform(flatEventsForExposures).cache()
80+
81+
val exposuresConfig = FilteringConfig.exposuresConfig
82+
val exposures = ExposuresTransformer(exposuresConfig).transform(flatEventsForExposures).cache()
7983

8084
logger.info("Caching exposures...")
8185
logger.info("Number of exposures: " + exposures.count)
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
package fr.polytechnique.cmap.cnam.filtering.cox
2+
3+
import org.apache.spark.sql.functions._
4+
import org.apache.spark.sql.hive.HiveContext
5+
import org.apache.spark.sql.{DataFrame, Dataset}
6+
import fr.polytechnique.cmap.cnam.Main
7+
import fr.polytechnique.cmap.cnam.filtering._
8+
import fr.polytechnique.cmap.cnam.filtering.exposures.ExposuresTransformer
9+
10+
object NewCoxMain extends Main {
11+
12+
def appName = "CoxFeaturing"
13+
14+
def run(sqlContext: HiveContext, argsMap: Map[String, String]): Option[Dataset[_]] = {
15+
16+
logger.info("Running FilteringMain...")
17+
val flatEvents: Dataset[FlatEvent] = FilteringMain.run(sqlContext, argsMap).get
18+
coxFeaturing(flatEvents, argsMap)
19+
}
20+
21+
def coxFeaturing(flatEvents: Dataset[FlatEvent], argsMap: Map[String, String]): Option[Dataset[_]] = {
22+
import flatEvents.sqlContext.implicits._
23+
24+
val sqlContext = flatEvents.sqlContext
25+
26+
argsMap.get("conf").foreach(sqlContext.setConf("conf", _))
27+
argsMap.get("env").foreach(sqlContext.setConf("env", _))
28+
29+
val cancerDefinition: String = FilteringConfig.cancerDefinition
30+
val outputRoot = FilteringConfig.outputPaths.coxFeatures
31+
val outputDir = s"$outputRoot/$cancerDefinition"
32+
33+
val dcirFlat: DataFrame = sqlContext.read.parquet(FilteringConfig.inputPaths.dcir)
34+
35+
val drugFlatEvents = flatEvents.filter(_.category == "molecule")
36+
val diseaseFlatEvents = flatEvents.filter(_.category == "disease")
37+
38+
val patients: Dataset[Patient] = flatEvents
39+
.map(
40+
x => Patient(
41+
x.patientID,
42+
x.gender,
43+
x.birthDate,
44+
x.deathDate)
45+
).distinct
46+
47+
logger.info("Number of drug events: " + drugFlatEvents.count)
48+
logger.info("Caching disease events...")
49+
logger.info("Number of disease events: " + diseaseFlatEvents.count)
50+
51+
logger.info("Preparing for Cox with the following parameters:")
52+
logger.info(CoxConfig.toString)
53+
54+
logger.info("(Lazy) Transforming Follow-up events...")
55+
val observationFlatEvents = CoxObservationPeriodTransformer.transform(drugFlatEvents)
56+
57+
val tracklossEvents: Dataset[Event] = TrackLossTransformer.transform(Sources(dcir=Some(dcirFlat)))
58+
val tracklossFlatEvents = tracklossEvents
59+
.as("left")
60+
.joinWith(patients.as("right"), col("left.patientID") === col("right.patientID"))
61+
.map((FlatEvent.merge _).tupled)
62+
.cache()
63+
64+
val followUpFlatEvents = CoxFollowUpEventsTransformer.transform(
65+
drugFlatEvents
66+
.union(diseaseFlatEvents)
67+
.union(observationFlatEvents)
68+
.union(tracklossFlatEvents)
69+
).cache()
70+
71+
logger.info("(Lazy) Transforming exposures...")
72+
val flatEventsForExposures =
73+
drugFlatEvents
74+
.union(diseaseFlatEvents)
75+
.union(followUpFlatEvents)
76+
val exposuresConfig = FilteringConfig.exposuresConfig
77+
val exposures = ExposuresTransformer(exposuresConfig).transform(flatEventsForExposures).cache()
78+
79+
logger.info("Caching exposures...")
80+
logger.info("Number of exposures: " + exposures.count)
81+
82+
logger.info("(Lazy) Transforming Cox features...")
83+
val coxFlatEvents = exposures.union(followUpFlatEvents)
84+
val coxFeatures = CoxTransformer.transform(coxFlatEvents)
85+
86+
val flatEventsSummary = flatEventsForExposures
87+
.union(observationFlatEvents)
88+
.union(tracklossFlatEvents)
89+
90+
logger.info("Writing summary of all cox events and config...")
91+
flatEventsSummary.toDF.write.parquet(s"$outputDir/eventsSummary")
92+
logger.info("Writing Exposures...")
93+
exposures.toDF.write.parquet(s"$outputDir/exposures")
94+
95+
logger.info("Writing Cox features...")
96+
import CoxFeaturesWriter._
97+
coxFeatures.toDF.write.parquet(s"$outputDir/cox")
98+
coxFeatures.writeCSV(s"$outputDir/cox.csv")
99+
100+
Some(coxFeatures)
101+
}
102+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package fr.polytechnique.cmap.cnam.filtering.exposures
2+
3+
import java.sql.Timestamp
4+
import org.apache.spark.sql.DataFrame
5+
import org.apache.spark.sql.functions._
6+
7+
class DosageBasedWeightAgg(data: DataFrame) extends WeightAggregatorImpl(data) {
8+
9+
def aggregateWeight(
10+
studyStart: Option[Timestamp],
11+
cumWindow: Option[Int],
12+
cumStartThreshold: Option[Int],
13+
cumEndThreshold: Option[Int]): DataFrame = {
14+
15+
data.withColumn("weight", lit("dosage-based cumulative weight"))
16+
}
17+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package fr.polytechnique.cmap.cnam.filtering.exposures
2+
3+
import org.apache.spark.sql.DataFrame
4+
5+
trait ExposurePeriodAdder {
6+
7+
val exposurePeriodStrategy: ExposurePeriodStrategy
8+
9+
implicit def exposurePeriodImplicits(data: DataFrame): ExposurePeriodAdderImpl = {
10+
11+
exposurePeriodStrategy match {
12+
case ExposurePeriodStrategy.Limited => new LimitedExposurePeriodAdder(data)
13+
case ExposurePeriodStrategy.Unlimited => new UnlimitedExposurePeriodAdder(data)
14+
}
15+
}
16+
}
17+
18+
abstract class ExposurePeriodAdderImpl(data: DataFrame) {
19+
def withStartEnd(minPurchases: Int, startDelay: Int, purchasesWindow: Int): DataFrame
20+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package fr.polytechnique.cmap.cnam.filtering.exposures
2+
3+
sealed trait ExposurePeriodStrategy
4+
5+
object ExposurePeriodStrategy {
6+
case object Limited extends ExposurePeriodStrategy
7+
case object Unlimited extends ExposurePeriodStrategy
8+
9+
def fromString(value: String): ExposurePeriodStrategy = value.toLowerCase match {
10+
case "limited" => ExposurePeriodStrategy.Limited
11+
case "unlimited" => ExposurePeriodStrategy.Unlimited
12+
}
13+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package fr.polytechnique.cmap.cnam.filtering.exposures
2+
3+
import java.sql.Timestamp
4+
import fr.polytechnique.cmap.cnam.filtering.FilteringConfig
5+
import fr.polytechnique.cmap.cnam.filtering.cox.CoxConfig
6+
7+
case class ExposuresConfig(
8+
studyStart: Timestamp,
9+
diseaseCode: String, // todo: ExposuresTransformer should not depend on diseases
10+
periodStrategy: ExposurePeriodStrategy,
11+
minPurchases: Int,
12+
purchasesWindow: Int,
13+
startDelay: Int,
14+
weightAggStrategy: WeightAggStrategy,
15+
filterDelayedPatients: Boolean,
16+
cumulativeExposureWindow: Int,
17+
cumulativeStartThreshold: Int,
18+
cumulativeEndThreshold: Int)
19+
20+
object ExposuresConfig {
21+
// todo: Remove filters from ExposuresConfig pipeline
22+
// This method is required for compatibility with current singleton-based configuration strategy
23+
// After calling this, the user can call .copy() to change some parameters
24+
// todo: change periodStrategy and weightAggStrategy to actually take from config file
25+
// def init(): ExposuresConfig = {
26+
// new ExposuresConfig(
27+
// studyStart = FilteringConfig.dates.studyStart,
28+
// diseaseCode = FilteringConfig.diseaseCode,
29+
// periodStrategy = CoxConfig.exposureDefinition.periodStrategy,
30+
// minPurchases = CoxConfig.exposureDefinition.minPurchases,
31+
// purchasesWindow = CoxConfig.exposureDefinition.purchasesWindow,
32+
// startDelay = CoxConfig.exposureDefinition.startDelay,
33+
// weightAggStrategy = CoxConfig.exposureDefinition.weightAggStrategy,
34+
// filterDelayedPatients = CoxConfig.filterDelayedPatients,
35+
// cumulativeExposureWindow = CoxConfig.exposureDefinition.cumulativeExposureWindow,
36+
// cumulativeStartThreshold = CoxConfig.exposureDefinition.cumulativeStartThreshold,
37+
// cumulativeEndThreshold = CoxConfig.exposureDefinition.cumulativeEndThreshold
38+
// )
39+
// }
40+
def init(): ExposuresConfig = FilteringConfig.exposuresConfig
41+
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package fr.polytechnique.cmap.cnam.filtering.exposures
2+
3+
import org.apache.spark.sql.Dataset
4+
import org.apache.spark.sql.functions._
5+
import fr.polytechnique.cmap.cnam.filtering.cox.CoxFollowUpEventsTransformer
6+
import fr.polytechnique.cmap.cnam.filtering.{DatasetTransformer, FlatEvent}
7+
8+
class ExposuresTransformer(config: ExposuresConfig)
9+
extends DatasetTransformer[FlatEvent, FlatEvent]
10+
// Todo: remove PatientFilters from the ExposureTransformer pipeline (it should be a separate module)
11+
with PatientFilters
12+
with ExposurePeriodAdder
13+
with WeightAggregator {
14+
15+
lazy val exposurePeriodStrategy: ExposurePeriodStrategy = config.periodStrategy
16+
lazy val weightAggStrategy: WeightAggStrategy = config.weightAggStrategy
17+
18+
def transform(input: Dataset[FlatEvent]): Dataset[FlatEvent] = {
19+
val outputColumns = List(
20+
col("patientID"),
21+
col("gender"),
22+
col("birthDate"),
23+
col("deathDate"),
24+
lit("exposure").as("category"),
25+
col("eventId"),
26+
col("weight"),
27+
col("exposureStart").as("start"),
28+
col("exposureEnd").as("end")
29+
)
30+
31+
import CoxFollowUpEventsTransformer.FollowUpFunctions
32+
import input.sqlContext.implicits._
33+
34+
input.toDF
35+
.withFollowUpPeriodFromEvents
36+
.filterPatients(config.studyStart, config.diseaseCode, config.filterDelayedPatients)
37+
.where(col("category") === "molecule")
38+
.withStartEnd(config.minPurchases,config.startDelay,config.purchasesWindow)
39+
.where(col("exposureStart") !== col("exposureEnd")) // This also removes rows where exposureStart = null
40+
.aggregateWeight(Some(config.studyStart), Some(config.cumulativeExposureWindow), Some(config.cumulativeStartThreshold), Some(config.cumulativeEndThreshold))
41+
.dropDuplicates(Seq("patientID", "eventID", "exposureStart", "exposureEnd", "weight"))
42+
.select(outputColumns: _*)
43+
.as[FlatEvent]
44+
}
45+
}
46+
47+
object ExposuresTransformer {
48+
def apply(config: ExposuresConfig): ExposuresTransformer = new ExposuresTransformer(config)
49+
}

0 commit comments

Comments
 (0)