@@ -2,15 +2,16 @@ package ldbc.snb.datagen.factors
2
2
3
3
import ldbc .snb .datagen .factors .io .FactorTableSink
4
4
import ldbc .snb .datagen .io .graphs .GraphSource
5
- import ldbc .snb .datagen .model .{Graph , Mode , graphs }
5
+ import ldbc .snb .datagen .model .{EntityType , graphs }
6
6
import ldbc .snb .datagen .{SparkApp , model }
7
7
import ldbc .snb .datagen .syntax ._
8
8
import ldbc .snb .datagen .util .Logging
9
- import org .apache .spark .sql .functions .{broadcast , count , date_trunc }
9
+ import org .apache .spark .sql .functions .{broadcast , count , date_trunc , sum }
10
10
import org .apache .spark .sql .{Column , DataFrame , SparkSession }
11
11
12
- case class RawFactor (private val f : Graph [Mode .Raw .type ] => DataFrame ) extends (Graph [Mode .Raw .type ] => DataFrame ) {
13
- override def apply (v1 : Graph [Mode .Raw .type ]): DataFrame = f(v1)
12
+
13
+ case class Factor (requiredEntities : EntityType * )(f : Seq [DataFrame ] => DataFrame ) extends (Seq [DataFrame ] => DataFrame ) {
14
+ override def apply (v1 : Seq [DataFrame ]): DataFrame = f(v1)
14
15
}
15
16
16
17
object FactorGenerationStage extends SparkApp with Logging {
@@ -26,23 +27,37 @@ object FactorGenerationStage extends SparkApp with Logging {
26
27
27
28
GraphSource (model.graphs.Raw .graphDef, args.outputDir, " csv" )
28
29
.read
29
- .pipe(g => rawFactors.map { case (name, calc) => FactorTable (name, calc(g), g) })
30
+ .pipe(g => rawFactors.map { case (name, calc) =>
31
+ val resolvedEntities = calc.requiredEntities.foldLeft(Seq .empty[DataFrame ])((args, et) => args :+ g.entities(et))
32
+ FactorTable (name, calc(resolvedEntities), g)
33
+ })
30
34
.foreach(_.write(FactorTableSink (args.outputDir)))
31
35
}
32
36
33
- private def frequency (df : DataFrame , value : Column , by : Seq [Column ]) =
37
+ private def frequency (df : DataFrame , value : Column , by : Seq [Column ], agg : Column => Column = count ) =
34
38
df
35
- .groupBy(by : _* ).agg(count(value).as(" count" ))
36
- .select(by :+ $" count" : _* )
37
- .orderBy($" count" .desc +: by.map(_.asc): _* )
39
+ .groupBy(by : _* ).agg(agg(value).as(" frequency" ))
40
+ .select(by :+ $" frequency" : _* )
41
+ .orderBy($" frequency" .desc +: by.map(_.asc): _* )
42
+
43
+
44
+ private def messageTags (commentHasTag : DataFrame , postHasTag : DataFrame , tag : DataFrame ) = {
45
+ val messageHasTag = commentHasTag.select($" CommentId" .as(" id" ), $" TagId" ) |+| postHasTag.select($" PostId" .as(" id" ), $" TagId" )
46
+
47
+ frequency(
48
+ messageHasTag.as(" MessageHasTag" ).join(tag.as(" Tag" ), $" Tag.id" === $" MessageHasTag.TagId" ),
49
+ value = $" MessageHasTag.TagId" ,
50
+ by = Seq ($" Tag.id" , $" Tag.name" )
51
+ ).select($" Tag.id" .as(" tagId" ), $" Tag.name" .as(" tagName" ), $" frequency" )
52
+ }
53
+
54
+ import graphs .Raw .entities ._
38
55
39
56
private val rawFactors = Map (
40
- " countryNumPersons" -> RawFactor { graph =>
41
- val places = graph.entities(graphs.Raw .entities.Place ).cache()
57
+ " countryNumPersons" -> Factor (Place , Person ) { case Seq (places, persons) =>
42
58
val cities = places.where($" type" === " City" )
43
59
val countries = places.where($" type" === " Country" )
44
60
45
- val persons = graph.entities(graphs.Raw .entities.Person )
46
61
frequency(
47
62
persons.as(" Person" )
48
63
.join(broadcast(cities.as(" City" )), $" City.id" === $" Person.LocationCityId" )
@@ -51,20 +66,14 @@ object FactorGenerationStage extends SparkApp with Logging {
51
66
by = Seq ($" Country.id" , $" Country.name" )
52
67
)
53
68
},
54
- " countryNumMessages" -> RawFactor { graph =>
55
- val comments = graph.entities(graphs.Raw .entities.Comment )
56
- val posts = graph.entities(graphs.Raw .entities.Post )
69
+ " countryNumMessages" -> Factor (Comment , Post ) { case Seq (comments, posts) =>
57
70
frequency(
58
71
comments.select($" id" , $" LocationCountryId" ) |+| posts.select($" id" , $" LocationCountryId" ),
59
72
value = $" id" ,
60
73
by = Seq ($" LocationCountryId" )
61
74
)
62
75
},
63
- " cityPairsNumFriends" -> RawFactor { graph =>
64
- val personKnowsPerson = graph.entities(graphs.Raw .entities.PersonKnowsPerson )
65
- val persons = graph.entities(graphs.Raw .entities.Person ).cache()
66
-
67
- val places = graph.entities(graphs.Raw .entities.Place )
76
+ " cityPairsNumFriends" -> Factor (PersonKnowsPerson , Person , Place ) { case Seq (personKnowsPerson, persons, places) =>
68
77
val cities = places.where($" type" === " City" ).cache()
69
78
70
79
frequency(
@@ -81,14 +90,10 @@ object FactorGenerationStage extends SparkApp with Logging {
81
90
$" City2.id" .alias(" city2Id" ),
82
91
$" City1.name" .alias(" city1Name" ),
83
92
$" City2.name" .alias(" city2Name" ),
84
- $" count "
93
+ $" frequency "
85
94
)
86
95
},
87
- " countryPairsNumFriends" -> RawFactor { graph =>
88
- val personKnowsPerson = graph.entities(graphs.Raw .entities.PersonKnowsPerson )
89
- val persons = graph.entities(graphs.Raw .entities.Person ).cache()
90
-
91
- val places = graph.entities(graphs.Raw .entities.Place )
96
+ " countryPairsNumFriends" -> Factor (PersonKnowsPerson , Person , Place ) { case Seq (personKnowsPerson, persons, places) =>
92
97
val cities = places.where($" type" === " City" ).cache()
93
98
val countries = places.where($" type" === " Country" ).cache()
94
99
@@ -108,24 +113,54 @@ object FactorGenerationStage extends SparkApp with Logging {
108
113
$" Country2.id" .alias(" country2Id" ),
109
114
$" Country1.name" .alias(" country1Name" ),
110
115
$" Country2.name" .alias(" country2Name" ),
111
- $" count "
116
+ $" frequency "
112
117
)
113
118
},
114
- " messageCreationDays" -> RawFactor { graph =>
115
- val comments = graph.entities(graphs.Raw .entities.Comment )
116
- val posts = graph.entities(graphs.Raw .entities.Post )
119
+ " messageCreationDays" -> Factor (Comment , Post ) { case Seq (comments, posts) =>
117
120
(comments.select($" creationDate" ) |+| posts.select($" creationDate" ))
118
121
.select(date_trunc(" day" , $" creationDate" ).as(" creationDay" ))
119
122
.distinct()
120
123
},
121
- " messageLengths" -> RawFactor { graph =>
122
- val comments = graph.entities(graphs.Raw .entities.Comment )
123
- val posts = graph.entities(graphs.Raw .entities.Post )
124
+ " messageLengths" -> Factor (Comment , Post ) { case Seq (comments, posts) =>
124
125
frequency(
125
126
comments.select($" id" , $" length" ) |+| posts.select($" id" , $" length" ),
126
127
value = $" id" ,
127
128
by = Seq ($" length" )
128
129
)
130
+ },
131
+ " messageTags" -> Factor (CommentHasTag , PostHasTag , Tag ) { case Seq (commentHasTag, postHasTag, tag) =>
132
+ messageTags(commentHasTag, postHasTag, tag).cache()
133
+ },
134
+ " messageTagClasses" -> Factor (CommentHasTag , PostHasTag , Tag , TagClass ) { case Seq (commentHasTag, postHasTag, tag, tagClass) =>
135
+ frequency(
136
+ messageTags(commentHasTag, postHasTag, tag).as(" MessageTags" )
137
+ .join(tag.as(" Tag" ), $" MessageTags.tagId" === $" Tag.id" )
138
+ .join(tagClass.as(" TagClass" ), $" Tag.TypeTagClassId" === $" TagClass.id" ),
139
+ value = $" frequency" ,
140
+ by = Seq ($" TagClass.id" , $" TagClass.name" ),
141
+ agg = sum
142
+ )
143
+ },
144
+ " personNumFriends" -> Factor (PersonKnowsPerson ) { case Seq (knows) =>
145
+ frequency(knows, value= $" Person2Id" , by= Seq ($" Person1Id" ))
146
+ },
147
+ " postLanguages" -> Factor (Post ) { case Seq (post) =>
148
+ frequency(post.where($" language" .isNotNull), value= $" id" , by= Seq ($" language" ))
149
+ },
150
+ " tagClassNumTags" -> Factor (TagClass , Tag ) { case Seq (tagClass, tag) =>
151
+ frequency(
152
+ tag.as(" Tag" ).join(tagClass.as(" TagClass" ), $" Tag.TypeTagClassId" === $" TagClass.id" ),
153
+ value = $" Tag.id" ,
154
+ by = Seq ($" TagClass.id" , $" TagClass.name" )
155
+ )
156
+ },
157
+ " companiesNumEmployees" -> Factor (Organisation , PersonWorkAtCompany ) { case Seq (organisation, workAt) =>
158
+ val company = organisation.where($" Type" === " Company" )
159
+ frequency(
160
+ company.as(" Company" ).join(workAt.as(" WorkAt" ), $" WorkAt.CompanyId" === $" Company.id" ),
161
+ value = $" WorkAt.PersonId" ,
162
+ by = Seq ($" Company.id" , $" Company.name" )
163
+ )
129
164
}
130
165
)
131
166
}
0 commit comments