Skip to content

Commit fadf300

Browse files
committed
implement the rest of the factor tables
1 parent b799983 commit fadf300

File tree

2 files changed

+70
-35
lines changed

2 files changed

+70
-35
lines changed

src/main/scala/ldbc/snb/datagen/factors/FactorGenerationStage.scala

Lines changed: 68 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@ package ldbc.snb.datagen.factors
22

33
import ldbc.snb.datagen.factors.io.FactorTableSink
44
import ldbc.snb.datagen.io.graphs.GraphSource
5-
import ldbc.snb.datagen.model.{Graph, Mode, graphs}
5+
import ldbc.snb.datagen.model.{EntityType, graphs}
66
import ldbc.snb.datagen.{SparkApp, model}
77
import ldbc.snb.datagen.syntax._
88
import ldbc.snb.datagen.util.Logging
9-
import org.apache.spark.sql.functions.{broadcast, count, date_trunc}
9+
import org.apache.spark.sql.functions.{broadcast, count, date_trunc, sum}
1010
import org.apache.spark.sql.{Column, DataFrame, SparkSession}
1111

12-
case class RawFactor(private val f: Graph[Mode.Raw.type] => DataFrame) extends (Graph[Mode.Raw.type] => DataFrame) {
13-
override def apply(v1: Graph[Mode.Raw.type]): DataFrame = f(v1)
12+
13+
case class Factor(requiredEntities: EntityType*)(f: Seq[DataFrame] => DataFrame) extends (Seq[DataFrame] => DataFrame) {
14+
override def apply(v1: Seq[DataFrame]): DataFrame = f(v1)
1415
}
1516

1617
object FactorGenerationStage extends SparkApp with Logging {
@@ -26,23 +27,37 @@ object FactorGenerationStage extends SparkApp with Logging {
2627

2728
GraphSource(model.graphs.Raw.graphDef, args.outputDir, "csv")
2829
.read
29-
.pipe(g => rawFactors.map { case (name, calc) => FactorTable(name, calc(g), g) })
30+
.pipe(g => rawFactors.map { case (name, calc) =>
31+
val resolvedEntities = calc.requiredEntities.foldLeft(Seq.empty[DataFrame])((args, et) => args :+ g.entities(et))
32+
FactorTable(name, calc(resolvedEntities), g)
33+
})
3034
.foreach(_.write(FactorTableSink(args.outputDir)))
3135
}
3236

33-
private def frequency(df: DataFrame, value: Column, by: Seq[Column]) =
37+
private def frequency(df: DataFrame, value: Column, by: Seq[Column], agg: Column => Column = count) =
3438
df
35-
.groupBy(by: _*).agg(count(value).as("count"))
36-
.select(by :+ $"count": _*)
37-
.orderBy($"count".desc +: by.map(_.asc): _*)
39+
.groupBy(by: _*).agg(agg(value).as("frequency"))
40+
.select(by :+ $"frequency": _*)
41+
.orderBy($"frequency".desc +: by.map(_.asc): _*)
42+
43+
44+
private def messageTags(commentHasTag: DataFrame, postHasTag: DataFrame, tag: DataFrame) = {
45+
val messageHasTag = commentHasTag.select($"CommentId".as("id"), $"TagId") |+| postHasTag.select($"PostId".as("id"), $"TagId")
46+
47+
frequency(
48+
messageHasTag.as("MessageHasTag").join(tag.as("Tag"), $"Tag.id" === $"MessageHasTag.TagId"),
49+
value = $"MessageHasTag.TagId",
50+
by = Seq($"Tag.id", $"Tag.name")
51+
).select($"Tag.id".as("tagId"), $"Tag.name".as("tagName"), $"frequency")
52+
}
53+
54+
import graphs.Raw.entities._
3855

3956
private val rawFactors = Map(
40-
"countryNumPersons" -> RawFactor { graph =>
41-
val places = graph.entities(graphs.Raw.entities.Place).cache()
57+
"countryNumPersons" -> Factor(Place, Person) { case Seq(places, persons) =>
4258
val cities = places.where($"type" === "City")
4359
val countries = places.where($"type" === "Country")
4460

45-
val persons = graph.entities(graphs.Raw.entities.Person)
4661
frequency(
4762
persons.as("Person")
4863
.join(broadcast(cities.as("City")), $"City.id" === $"Person.LocationCityId")
@@ -51,20 +66,14 @@ object FactorGenerationStage extends SparkApp with Logging {
5166
by = Seq($"Country.id", $"Country.name")
5267
)
5368
},
54-
"countryNumMessages" -> RawFactor { graph =>
55-
val comments = graph.entities(graphs.Raw.entities.Comment)
56-
val posts = graph.entities(graphs.Raw.entities.Post)
69+
"countryNumMessages" -> Factor(Comment, Post) { case Seq(comments, posts) =>
5770
frequency(
5871
comments.select($"id", $"LocationCountryId") |+| posts.select($"id", $"LocationCountryId"),
5972
value = $"id",
6073
by = Seq($"LocationCountryId")
6174
)
6275
},
63-
"cityPairsNumFriends" -> RawFactor { graph =>
64-
val personKnowsPerson = graph.entities(graphs.Raw.entities.PersonKnowsPerson)
65-
val persons = graph.entities(graphs.Raw.entities.Person).cache()
66-
67-
val places = graph.entities(graphs.Raw.entities.Place)
76+
"cityPairsNumFriends" -> Factor(PersonKnowsPerson, Person, Place) { case Seq(personKnowsPerson, persons, places) =>
6877
val cities = places.where($"type" === "City").cache()
6978

7079
frequency(
@@ -81,14 +90,10 @@ object FactorGenerationStage extends SparkApp with Logging {
8190
$"City2.id".alias("city2Id"),
8291
$"City1.name".alias("city1Name"),
8392
$"City2.name".alias("city2Name"),
84-
$"count"
93+
$"frequency"
8594
)
8695
},
87-
"countryPairsNumFriends" -> RawFactor { graph =>
88-
val personKnowsPerson = graph.entities(graphs.Raw.entities.PersonKnowsPerson)
89-
val persons = graph.entities(graphs.Raw.entities.Person).cache()
90-
91-
val places = graph.entities(graphs.Raw.entities.Place)
96+
"countryPairsNumFriends" -> Factor(PersonKnowsPerson, Person, Place) { case Seq(personKnowsPerson, persons, places) =>
9297
val cities = places.where($"type" === "City").cache()
9398
val countries = places.where($"type" === "Country").cache()
9499

@@ -108,24 +113,54 @@ object FactorGenerationStage extends SparkApp with Logging {
108113
$"Country2.id".alias("country2Id"),
109114
$"Country1.name".alias("country1Name"),
110115
$"Country2.name".alias("country2Name"),
111-
$"count"
116+
$"frequency"
112117
)
113118
},
114-
"messageCreationDays" -> RawFactor { graph =>
115-
val comments = graph.entities(graphs.Raw.entities.Comment)
116-
val posts = graph.entities(graphs.Raw.entities.Post)
119+
"messageCreationDays" -> Factor(Comment, Post) { case Seq(comments, posts) =>
117120
(comments.select($"creationDate") |+| posts.select($"creationDate"))
118121
.select(date_trunc("day", $"creationDate").as("creationDay"))
119122
.distinct()
120123
},
121-
"messageLengths" -> RawFactor { graph =>
122-
val comments = graph.entities(graphs.Raw.entities.Comment)
123-
val posts = graph.entities(graphs.Raw.entities.Post)
124+
"messageLengths" -> Factor(Comment, Post) { case Seq(comments, posts) =>
124125
frequency(
125126
comments.select($"id", $"length") |+| posts.select($"id", $"length"),
126127
value = $"id",
127128
by = Seq($"length")
128129
)
130+
},
131+
"messageTags" -> Factor(CommentHasTag, PostHasTag, Tag) { case Seq(commentHasTag, postHasTag, tag) =>
132+
messageTags(commentHasTag, postHasTag, tag).cache()
133+
},
134+
"messageTagClasses" -> Factor(CommentHasTag, PostHasTag, Tag, TagClass) { case Seq(commentHasTag, postHasTag, tag, tagClass) =>
135+
frequency(
136+
messageTags(commentHasTag, postHasTag, tag).as("MessageTags")
137+
.join(tag.as("Tag"), $"MessageTags.tagId" === $"Tag.id")
138+
.join(tagClass.as("TagClass"), $"Tag.TypeTagClassId" === $"TagClass.id"),
139+
value = $"frequency",
140+
by = Seq($"TagClass.id", $"TagClass.name"),
141+
agg = sum
142+
)
143+
},
144+
"personNumFriends" -> Factor(PersonKnowsPerson) { case Seq(knows) =>
145+
frequency(knows, value=$"Person2Id", by=Seq($"Person1Id"))
146+
},
147+
"postLanguages" -> Factor(Post) { case Seq(post) =>
148+
frequency(post.where($"language".isNotNull), value=$"id", by=Seq($"language"))
149+
},
150+
"tagClassNumTags" -> Factor(TagClass, Tag) { case Seq(tagClass, tag) =>
151+
frequency(
152+
tag.as("Tag").join(tagClass.as("TagClass"), $"Tag.TypeTagClassId" === $"TagClass.id"),
153+
value = $"Tag.id",
154+
by = Seq($"TagClass.id", $"TagClass.name")
155+
)
156+
},
157+
"companiesNumEmployees" -> Factor(Organisation, PersonWorkAtCompany) { case Seq(organisation, workAt) =>
158+
val company = organisation.where($"Type" === "Company")
159+
frequency(
160+
company.as("Company").join(workAt.as("WorkAt"), $"WorkAt.CompanyId" === $"Company.id"),
161+
value = $"WorkAt.PersonId",
162+
by = Seq($"Company.id", $"Company.name")
163+
)
129164
}
130165
)
131166
}

src/main/scala/ldbc/snb/datagen/io/utils.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ object utils {
99

1010
def fileExists(path: String)(implicit spark: SparkSession): Boolean = {
1111
val hadoopPath = new Path(path)
12-
FileSystem.get(new URI(path), spark.sparkContext.hadoopConfiguration).exists(hadoopPath)
12+
val fs = FileSystem.get(URI.create(path), spark.sparkContext.hadoopConfiguration)
13+
fs.exists(hadoopPath)
1314
}
14-
1515
}

0 commit comments

Comments
 (0)