|
1 | 1 | package ldbc.snb.datagen.transformation.transform
|
2 | 2 |
|
3 |
| -import ldbc.snb.datagen.model.{Graph, Mode} |
| 3 | +import ldbc.snb.datagen.model.Cardinality.NN |
| 4 | +import ldbc.snb.datagen.model.EntityType.Edge |
| 5 | +import ldbc.snb.datagen.model.{EntityType, Graph, Mode} |
4 | 6 | import ldbc.snb.datagen.util.Logging
|
| 7 | +import ldbc.snb.datagen.util.sql._ |
| 8 | +import ldbc.snb.datagen.syntax._ |
5 | 9 | import org.apache.spark.sql.DataFrame
|
| 10 | +import org.apache.spark.sql.functions.{col, lit, to_timestamp} |
6 | 11 |
|
7 | 12 | case class RawToInteractiveTransform(mode: Mode.Interactive, simulationStart: Long, simulationEnd: Long)
|
8 | 13 | extends Transform[Mode.Raw.type, Mode.Interactive]
|
9 | 14 | with Logging {
|
10 | 15 | log.debug(s"Interactive Transformation parameters: $mode")
|
11 | 16 |
|
12 |
| - val bulkLoadThreshold = Interactive.calculateBulkLoadThreshold(mode.bulkLoadPortion, simulationStart, simulationEnd) |
| 17 | + val bulkLoadThreshold = RawToInteractiveTransform.calculateBulkLoadThreshold(mode.bulkLoadPortion, simulationStart, simulationEnd) |
13 | 18 |
|
14 | 19 | override def transform(input: In): Out = {
|
15 | 20 | val entities = input.entities
|
16 | 21 | .map { case (tpe, v) =>
|
17 | 22 | tpe -> IrToRawTransform.convertDates(tpe, v)
|
18 | 23 | }
|
19 | 24 | .map { case (tpe, v) =>
|
20 |
| - tpe -> Interactive.snapshotPart(tpe, v, bulkLoadThreshold, filterDeletion = true) |
| 25 | + tpe -> RawToInteractiveTransform.snapshotPart(tpe, v, bulkLoadThreshold, filterDeletion = true) |
21 | 26 | }
|
22 | 27 | Graph[Mode.Interactive](isAttrExploded = input.isAttrExploded, isEdgesExploded = input.isEdgesExploded, mode, entities)
|
23 | 28 | }
|
24 | 29 | }
|
| 30 | + |
| 31 | +object RawToInteractiveTransform { |
| 32 | + |
| 33 | + def columns(tpe: EntityType, cols: Seq[String]) = tpe match { |
| 34 | + case tpe if tpe.isStatic => cols |
| 35 | + case Edge("Knows", "Person", "Person", NN, false) => |
| 36 | + val rawCols = Set("deletionDate", "explicitlyDeleted", "weight") |
| 37 | + cols.filter(!rawCols.contains(_)) |
| 38 | + case _ => |
| 39 | + val rawCols = Set("deletionDate", "explicitlyDeleted") |
| 40 | + cols.filter(!rawCols.contains(_)) |
| 41 | + } |
| 42 | + |
| 43 | + def calculateBulkLoadThreshold(bulkLoadPortion: Double, simulationStart: Long, simulationEnd: Long) = { |
| 44 | + (simulationEnd - ((simulationEnd - simulationStart) * (1 - bulkLoadPortion)).toLong) |
| 45 | + } |
| 46 | + |
| 47 | + def snapshotPart(tpe: EntityType, df: DataFrame, bulkLoadThreshold: Long, filterDeletion: Boolean) = { |
| 48 | + val filterBulkLoad = (ds: DataFrame) => |
| 49 | + ds |
| 50 | + .filter( |
| 51 | + $"creationDate" < to_timestamp(lit(bulkLoadThreshold / 1000)) && |
| 52 | + (!lit(filterDeletion) || $"deletionDate" >= to_timestamp(lit(bulkLoadThreshold / 1000))) |
| 53 | + ) |
| 54 | + |
| 55 | + tpe match { |
| 56 | + case tpe if tpe.isStatic => df |
| 57 | + case tpe => filterBulkLoad(df).select(columns(tpe, df.columns).map(name => col(qualified(name))): _*) |
| 58 | + } |
| 59 | + } |
| 60 | +} |
0 commit comments