Skip to content

Commit b799983

Browse files
committed
implement a few more factor tables
1 parent 908cde7 commit b799983

File tree

5 files changed

+128
-57
lines changed

5 files changed

+128
-57
lines changed

src/main/scala/ldbc/snb/datagen/factors/Factor.scala

Lines changed: 0 additions & 34 deletions
This file was deleted.
Lines changed: 112 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,133 @@
11
package ldbc.snb.datagen.factors
22

3-
import ldbc.snb.datagen.SparkApp
4-
import ldbc.snb.datagen.factors.Factors.countryNumPersons
53
import ldbc.snb.datagen.factors.io.FactorTableSink
64
import ldbc.snb.datagen.io.graphs.GraphSource
7-
import ldbc.snb.datagen.model
5+
import ldbc.snb.datagen.model.{Graph, Mode, graphs}
6+
import ldbc.snb.datagen.{SparkApp, model}
87
import ldbc.snb.datagen.syntax._
98
import ldbc.snb.datagen.util.Logging
10-
import org.apache.spark.sql.SparkSession
9+
import org.apache.spark.sql.functions.{broadcast, count, date_trunc}
10+
import org.apache.spark.sql.{Column, DataFrame, SparkSession}
11+
12+
case class RawFactor(private val f: Graph[Mode.Raw.type] => DataFrame) extends (Graph[Mode.Raw.type] => DataFrame) {
13+
override def apply(v1: Graph[Mode.Raw.type]): DataFrame = f(v1)
14+
}
1115

1216
object FactorGenerationStage extends SparkApp with Logging {
1317
override def appName: String = "LDBC SNB Datagen for Spark: Factor Generation Stage"
1418

1519
case class Args(outputDir: String = "out")
1620

1721
def run(args: Args)(implicit spark: SparkSession): Unit = {
18-
import ldbc.snb.datagen.io.instances._
22+
import ldbc.snb.datagen.factors.io.instances._
1923
import ldbc.snb.datagen.io.Reader.ops._
2024
import ldbc.snb.datagen.io.Writer.ops._
21-
import ldbc.snb.datagen.factors.io.instances._
25+
import ldbc.snb.datagen.io.instances._
2226

2327
GraphSource(model.graphs.Raw.graphDef, args.outputDir, "csv")
2428
.read
25-
.pipe(countryNumPersons)
26-
.write(FactorTableSink(args.outputDir))
29+
.pipe(g => rawFactors.map { case (name, calc) => FactorTable(name, calc(g), g) })
30+
.foreach(_.write(FactorTableSink(args.outputDir)))
2731
}
2832

33+
private def frequency(df: DataFrame, value: Column, by: Seq[Column]) =
34+
df
35+
.groupBy(by: _*).agg(count(value).as("count"))
36+
.select(by :+ $"count": _*)
37+
.orderBy($"count".desc +: by.map(_.asc): _*)
38+
39+
private val rawFactors = Map(
40+
"countryNumPersons" -> RawFactor { graph =>
41+
val places = graph.entities(graphs.Raw.entities.Place).cache()
42+
val cities = places.where($"type" === "City")
43+
val countries = places.where($"type" === "Country")
44+
45+
val persons = graph.entities(graphs.Raw.entities.Person)
46+
frequency(
47+
persons.as("Person")
48+
.join(broadcast(cities.as("City")), $"City.id" === $"Person.LocationCityId")
49+
.join(broadcast(countries.as("Country")), $"Country.id" === $"City.PartOfPlaceId"),
50+
value = $"Person.id",
51+
by = Seq($"Country.id", $"Country.name")
52+
)
53+
},
54+
"countryNumMessages" -> RawFactor { graph =>
55+
val comments = graph.entities(graphs.Raw.entities.Comment)
56+
val posts = graph.entities(graphs.Raw.entities.Post)
57+
frequency(
58+
comments.select($"id", $"LocationCountryId") |+| posts.select($"id", $"LocationCountryId"),
59+
value = $"id",
60+
by = Seq($"LocationCountryId")
61+
)
62+
},
63+
"cityPairsNumFriends" -> RawFactor { graph =>
64+
val personKnowsPerson = graph.entities(graphs.Raw.entities.PersonKnowsPerson)
65+
val persons = graph.entities(graphs.Raw.entities.Person).cache()
66+
67+
val places = graph.entities(graphs.Raw.entities.Place)
68+
val cities = places.where($"type" === "City").cache()
69+
70+
frequency(
71+
personKnowsPerson.alias("Knows")
72+
.join(persons.as("Person1"), $"Person1.id" === $"Knows.Person1Id")
73+
.join(cities.as("City1"), $"City1.id" === "Person1.LocationCityId")
74+
.join(persons.as("Person2"), $"Person2.id" === $"Knows.Person2Id")
75+
.join(cities.as("City2"), $"City2.id" === "Person2.LocationCityId")
76+
.where($"City1.id" < $"City2.id"),
77+
value = $"*",
78+
by = Seq($"City1.id", $"City2.id", $"City1.name", $"City2.name")
79+
).select(
80+
$"City1.id".alias("city1Id"),
81+
$"City2.id".alias("city2Id"),
82+
$"City1.name".alias("city1Name"),
83+
$"City2.name".alias("city2Name"),
84+
$"count"
85+
)
86+
},
87+
"countryPairsNumFriends" -> RawFactor { graph =>
88+
val personKnowsPerson = graph.entities(graphs.Raw.entities.PersonKnowsPerson)
89+
val persons = graph.entities(graphs.Raw.entities.Person).cache()
90+
91+
val places = graph.entities(graphs.Raw.entities.Place)
92+
val cities = places.where($"type" === "City").cache()
93+
val countries = places.where($"type" === "Country").cache()
94+
95+
frequency(
96+
personKnowsPerson.alias("Knows")
97+
.join(persons.as("Person1"), $"Person1.id" === $"Knows.Person1Id")
98+
.join(cities.as("City1"), $"City1.id" === "Person1.LocationCityId")
99+
.join(countries.as("Country1"), $"Country1.id" === "City1.PartOfPlaceId")
100+
.join(persons.as("Person2"), $"Person2.id" === $"Knows.Person2Id")
101+
.join(cities.as("City2"), $"City2.id" === "Person2.LocationCityId")
102+
.join(countries.as("Country2"), $"Country2.id" === "City2.PartOfPlaceId")
103+
.where($"Country1.id" < $"Country2.id"),
104+
value = $"*",
105+
by = Seq($"Country1.id", $"Country2.id", $"Country1.name", $"Country2.name")
106+
).select(
107+
$"Country1.id".alias("country1Id"),
108+
$"Country2.id".alias("country2Id"),
109+
$"Country1.name".alias("country1Name"),
110+
$"Country2.name".alias("country2Name"),
111+
$"count"
112+
)
113+
},
114+
"messageCreationDays" -> RawFactor { graph =>
115+
val comments = graph.entities(graphs.Raw.entities.Comment)
116+
val posts = graph.entities(graphs.Raw.entities.Post)
117+
(comments.select($"creationDate") |+| posts.select($"creationDate"))
118+
.select(date_trunc("day", $"creationDate").as("creationDay"))
119+
.distinct()
120+
},
121+
"messageLengths" -> RawFactor { graph =>
122+
val comments = graph.entities(graphs.Raw.entities.Comment)
123+
val posts = graph.entities(graphs.Raw.entities.Post)
124+
frequency(
125+
comments.select($"id", $"length") |+| posts.select($"id", $"length"),
126+
value = $"id",
127+
by = Seq($"length")
128+
)
129+
}
130+
)
29131
}
132+
133+

src/main/scala/ldbc/snb/datagen/factors/FactorTable.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ import org.apache.spark.sql.DataFrame
55

66

77
case class FactorTableDef[M <: Mode](
8-
name: String,
9-
sourceDef: GraphDef[M]
10-
)
8+
name: String,
9+
sourceDef: GraphDef[M]
10+
)
1111

1212
case class FactorTable[M <: Mode](
13-
name: String,
14-
data: DataFrame,
15-
source: Graph[M]
16-
)
13+
name: String,
14+
data: DataFrame,
15+
source: Graph[M]
16+
)

src/main/scala/ldbc/snb/datagen/factors/io/package.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@ package object io {
1818
override def write(self: FactorTable[M], sink: FactorTableSink): Unit = {
1919
val p = (sink.path / "factors" / sink.format / PathComponent[GraphLike[M]].path(self.source) / self.name).toString()
2020
self.data.coalesce(1).write(DataFrameSink(p, sink.format))
21+
log.info(s"Factor table ${self.name} written")
2122
}
2223
}
24+
2325
trait WriterInstances {
2426
implicit def factorTableWriter[M <: Mode]: Writer.Aux[FactorTableSink, FactorTable[M]] = new FactorTableWriter[M]
2527
}

src/main/scala/ldbc/snb/datagen/io/graphs.scala

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
11
package ldbc.snb.datagen.io
22

3-
import ldbc.snb.datagen.model.EntityType.{Attr, Edge, Node}
4-
import ldbc.snb.datagen.model.{Batched, BatchedEntity, EntityType, Graph, GraphDef, GraphLike, Mode}
5-
import ldbc.snb.datagen.util.{Logging, SparkUI}
6-
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
7-
import shapeless.{Generic, Poly1}
83
import better.files._
94
import ldbc.snb.datagen.io.dataframes.{DataFrameSink, DataFrameSource}
5+
import ldbc.snb.datagen.model.EntityType.{Attr, Edge, Node}
106
import ldbc.snb.datagen.model.Mode.Raw
7+
import ldbc.snb.datagen.model._
8+
import ldbc.snb.datagen.util.{Logging, SparkUI}
119
import org.apache.spark.sql.types.StructType
12-
import shapeless._
10+
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
11+
import shapeless.{Generic, Poly1}
1312

1413
import scala.collection.immutable.TreeMap
1514

1615
object graphs {
1716

18-
import dataframes.instances._
19-
import Writer.ops._
2017
import Reader.ops._
18+
import Writer.ops._
19+
import dataframes.instances._
2120

2221
case class GraphSink(
2322
path: String,

0 commit comments

Comments
 (0)