|
1 | 1 | package ldbc.snb.datagen.factors
|
2 | 2 |
|
3 |
| -import ldbc.snb.datagen.SparkApp |
4 |
| -import ldbc.snb.datagen.factors.Factors.countryNumPersons |
5 | 3 | import ldbc.snb.datagen.factors.io.FactorTableSink
|
6 | 4 | import ldbc.snb.datagen.io.graphs.GraphSource
|
7 |
| -import ldbc.snb.datagen.model |
| 5 | +import ldbc.snb.datagen.model.{Graph, Mode, graphs} |
| 6 | +import ldbc.snb.datagen.{SparkApp, model} |
8 | 7 | import ldbc.snb.datagen.syntax._
|
9 | 8 | import ldbc.snb.datagen.util.Logging
|
10 |
| -import org.apache.spark.sql.SparkSession |
| 9 | +import org.apache.spark.sql.functions.{broadcast, count, date_trunc} |
| 10 | +import org.apache.spark.sql.{Column, DataFrame, SparkSession} |
| 11 | + |
| 12 | +case class RawFactor(private val f: Graph[Mode.Raw.type] => DataFrame) extends (Graph[Mode.Raw.type] => DataFrame) { |
| 13 | + override def apply(v1: Graph[Mode.Raw.type]): DataFrame = f(v1) |
| 14 | +} |
11 | 15 |
|
12 | 16 | object FactorGenerationStage extends SparkApp with Logging {
|
13 | 17 | override def appName: String = "LDBC SNB Datagen for Spark: Factor Generation Stage"
|
14 | 18 |
|
15 | 19 | case class Args(outputDir: String = "out")
|
16 | 20 |
|
17 | 21 | def run(args: Args)(implicit spark: SparkSession): Unit = {
|
18 |
| - import ldbc.snb.datagen.io.instances._ |
| 22 | + import ldbc.snb.datagen.factors.io.instances._ |
19 | 23 | import ldbc.snb.datagen.io.Reader.ops._
|
20 | 24 | import ldbc.snb.datagen.io.Writer.ops._
|
21 |
| - import ldbc.snb.datagen.factors.io.instances._ |
| 25 | + import ldbc.snb.datagen.io.instances._ |
22 | 26 |
|
23 | 27 | GraphSource(model.graphs.Raw.graphDef, args.outputDir, "csv")
|
24 | 28 | .read
|
25 |
| - .pipe(countryNumPersons) |
26 |
| - .write(FactorTableSink(args.outputDir)) |
| 29 | + .pipe(g => rawFactors.map { case (name, calc) => FactorTable(name, calc(g), g) }) |
| 30 | + .foreach(_.write(FactorTableSink(args.outputDir))) |
27 | 31 | }
|
28 | 32 |
|
| 33 | + private def frequency(df: DataFrame, value: Column, by: Seq[Column]) = |
| 34 | + df |
| 35 | + .groupBy(by: _*).agg(count(value).as("count")) |
| 36 | + .select(by :+ $"count": _*) |
| 37 | + .orderBy($"count".desc +: by.map(_.asc): _*) |
| 38 | + |
| 39 | + private val rawFactors = Map( |
| 40 | + "countryNumPersons" -> RawFactor { graph => |
| 41 | + val places = graph.entities(graphs.Raw.entities.Place).cache() |
| 42 | + val cities = places.where($"type" === "City") |
| 43 | + val countries = places.where($"type" === "Country") |
| 44 | + |
| 45 | + val persons = graph.entities(graphs.Raw.entities.Person) |
| 46 | + frequency( |
| 47 | + persons.as("Person") |
| 48 | + .join(broadcast(cities.as("City")), $"City.id" === $"Person.LocationCityId") |
| 49 | + .join(broadcast(countries.as("Country")), $"Country.id" === $"City.PartOfPlaceId"), |
| 50 | + value = $"Person.id", |
| 51 | + by = Seq($"Country.id", $"Country.name") |
| 52 | + ) |
| 53 | + }, |
| 54 | + "countryNumMessages" -> RawFactor { graph => |
| 55 | + val comments = graph.entities(graphs.Raw.entities.Comment) |
| 56 | + val posts = graph.entities(graphs.Raw.entities.Post) |
| 57 | + frequency( |
| 58 | + comments.select($"id", $"LocationCountryId") |+| posts.select($"id", $"LocationCountryId"), |
| 59 | + value = $"id", |
| 60 | + by = Seq($"LocationCountryId") |
| 61 | + ) |
| 62 | + }, |
| 63 | + "cityPairsNumFriends" -> RawFactor { graph => |
| 64 | + val personKnowsPerson = graph.entities(graphs.Raw.entities.PersonKnowsPerson) |
| 65 | + val persons = graph.entities(graphs.Raw.entities.Person).cache() |
| 66 | + |
| 67 | + val places = graph.entities(graphs.Raw.entities.Place) |
| 68 | + val cities = places.where($"type" === "City").cache() |
| 69 | + |
| 70 | + frequency( |
| 71 | + personKnowsPerson.alias("Knows") |
| 72 | + .join(persons.as("Person1"), $"Person1.id" === $"Knows.Person1Id") |
| 73 | + .join(cities.as("City1"), $"City1.id" === "Person1.LocationCityId") |
| 74 | + .join(persons.as("Person2"), $"Person2.id" === $"Knows.Person2Id") |
| 75 | + .join(cities.as("City2"), $"City2.id" === "Person2.LocationCityId") |
| 76 | + .where($"City1.id" < $"City2.id"), |
| 77 | + value = $"*", |
| 78 | + by = Seq($"City1.id", $"City2.id", $"City1.name", $"City2.name") |
| 79 | + ).select( |
| 80 | + $"City1.id".alias("city1Id"), |
| 81 | + $"City2.id".alias("city2Id"), |
| 82 | + $"City1.name".alias("city1Name"), |
| 83 | + $"City2.name".alias("city2Name"), |
| 84 | + $"count" |
| 85 | + ) |
| 86 | + }, |
| 87 | + "countryPairsNumFriends" -> RawFactor { graph => |
| 88 | + val personKnowsPerson = graph.entities(graphs.Raw.entities.PersonKnowsPerson) |
| 89 | + val persons = graph.entities(graphs.Raw.entities.Person).cache() |
| 90 | + |
| 91 | + val places = graph.entities(graphs.Raw.entities.Place) |
| 92 | + val cities = places.where($"type" === "City").cache() |
| 93 | + val countries = places.where($"type" === "Country").cache() |
| 94 | + |
| 95 | + frequency( |
| 96 | + personKnowsPerson.alias("Knows") |
| 97 | + .join(persons.as("Person1"), $"Person1.id" === $"Knows.Person1Id") |
| 98 | + .join(cities.as("City1"), $"City1.id" === "Person1.LocationCityId") |
| 99 | + .join(countries.as("Country1"), $"Country1.id" === "City1.PartOfPlaceId") |
| 100 | + .join(persons.as("Person2"), $"Person2.id" === $"Knows.Person2Id") |
| 101 | + .join(cities.as("City2"), $"City2.id" === "Person2.LocationCityId") |
| 102 | + .join(countries.as("Country2"), $"Country2.id" === "City2.PartOfPlaceId") |
| 103 | + .where($"Country1.id" < $"Country2.id"), |
| 104 | + value = $"*", |
| 105 | + by = Seq($"Country1.id", $"Country2.id", $"Country1.name", $"Country2.name") |
| 106 | + ).select( |
| 107 | + $"Country1.id".alias("country1Id"), |
| 108 | + $"Country2.id".alias("country2Id"), |
| 109 | + $"Country1.name".alias("country1Name"), |
| 110 | + $"Country2.name".alias("country2Name"), |
| 111 | + $"count" |
| 112 | + ) |
| 113 | + }, |
| 114 | + "messageCreationDays" -> RawFactor { graph => |
| 115 | + val comments = graph.entities(graphs.Raw.entities.Comment) |
| 116 | + val posts = graph.entities(graphs.Raw.entities.Post) |
| 117 | + (comments.select($"creationDate") |+| posts.select($"creationDate")) |
| 118 | + .select(date_trunc("day", $"creationDate").as("creationDay")) |
| 119 | + .distinct() |
| 120 | + }, |
| 121 | + "messageLengths" -> RawFactor { graph => |
| 122 | + val comments = graph.entities(graphs.Raw.entities.Comment) |
| 123 | + val posts = graph.entities(graphs.Raw.entities.Post) |
| 124 | + frequency( |
| 125 | + comments.select($"id", $"length") |+| posts.select($"id", $"length"), |
| 126 | + value = $"id", |
| 127 | + by = Seq($"length") |
| 128 | + ) |
| 129 | + } |
| 130 | + ) |
29 | 131 | }
|
| 132 | + |
| 133 | + |
0 commit comments