@@ -8,8 +8,9 @@ import ldbc.snb.datagen.model.Mode.Raw
8
8
import ldbc .snb .datagen .syntax ._
9
9
import ldbc .snb .datagen .transformation .transform .ConvertDates
10
10
import ldbc .snb .datagen .util .{DatagenStage , Logging }
11
+ import org .apache .spark .graphx
11
12
import org .apache .spark .sql .functions .{broadcast , col , count , date_trunc , expr , floor , lit , sum }
12
- import org .apache .spark .sql .{Column , DataFrame , functions }
13
+ import org .apache .spark .sql .{Column , DataFrame , Row , functions }
13
14
import shapeless ._
14
15
15
16
import scala .util .matching .Regex
@@ -146,6 +147,20 @@ object FactorGenerationStage extends DatagenStage with Logging {
146
147
).select($" Tag.id" .as(" tagId" ), $" Tag.name" .as(" tagName" ), $" frequency" )
147
148
}
148
149
150
+ private def sameUniversityKnows (personKnowsPerson : DataFrame , studyAt : DataFrame ) = {
151
+ undirectedKnowsTemporal(personKnowsPerson)
152
+ .join(studyAt.as(" studyAt1" ), $" studyAt1.personId" === $" knows.person1Id" )
153
+ .join(studyAt.as(" studyAt2" ), $" studyAt2.personId" === $" knows.person2Id" )
154
+ .where($" studyAt1.universityId" === $" studyAt2.universityId" )
155
+ .select(
156
+ $" knows.person1Id" .as(" person1Id" ),
157
+ $" knows.person2Id" .as(" person2Id" ),
158
+ functions.greatest($" knows.creationDate" , $" studyAt1.creationDate" , $" studyAt2.creationDate" ).alias(" creationDate" ),
159
+ functions.least($" knows.deletionDate" , $" studyAt1.deletionDate" , $" studyAt2.deletionDate" ).alias(" deletionDate" )
160
+ )
161
+ .where($" creationDate" < $" deletionDate" )
162
+ }
163
+
149
164
import model .raw ._
150
165
151
166
private val rawFactors = Map (
@@ -451,21 +466,6 @@ object FactorGenerationStage extends DatagenStage with Logging {
451
466
val sampleFractionPersonPairs = Math .min(10000.0 / personPairs.count(), 1.0 )
452
467
personPairs.sample(sampleFractionPersonPairs, 42 )
453
468
},
454
- " sameUniversityKnows" -> LargeFactor (PersonKnowsPersonType , PersonStudyAtUniversityType ) { case Seq (personKnowsPerson, studyAt) =>
455
- val size = Math .max(Math .ceil(personKnowsPerson.rdd.getNumPartitions / 10 ).toInt, 1 )
456
- undirectedKnowsTemporal(personKnowsPerson)
457
- .join(studyAt.as(" studyAt1" ), $" studyAt1.personId" === $" knows.person1Id" )
458
- .join(studyAt.as(" studyAt2" ), $" studyAt2.personId" === $" knows.person2Id" )
459
- .where($" studyAt1.universityId" === $" studyAt2.universityId" )
460
- .select(
461
- $" knows.person1Id" .as(" person1Id" ),
462
- $" knows.person2Id" .as(" person2Id" ),
463
- functions.greatest($" knows.creationDate" , $" studyAt1.creationDate" , $" studyAt2.creationDate" ).alias(" creationDate" ),
464
- functions.least($" knows.deletionDate" , $" studyAt1.deletionDate" , $" studyAt2.deletionDate" ).alias(" deletionDate" )
465
- )
466
- .where($" creationDate" < $" deletionDate" )
467
- .coalesce(size)
468
- },
469
469
// -- interactive --
470
470
// first names
471
471
" personFirstNames" -> Factor (PersonType ) { case Seq (person) =>
@@ -691,5 +691,39 @@ object FactorGenerationStage extends DatagenStage with Logging {
691
691
692
692
numFriendOfFriendCompanies
693
693
},
694
+ " sameUniversityConnected" -> LargeFactor (PersonType , PersonKnowsPersonType , PersonStudyAtUniversityType ) { case Seq (person, personKnowsPerson, studyAt) =>
695
+ val s = spark
696
+ import s .implicits ._
697
+ val vertices = person.select(" id" ).rdd.map(row => (row.getAs[Long ](" id" ), ()))
698
+
699
+ val edges = sameUniversityKnows(personKnowsPerson, studyAt).rdd.map(row =>
700
+ graphx.Edge (row.getAs[Long ](" person1Id" ), row.getAs[Long ](" person2Id" ), ())
701
+ )
702
+ val graph = graphx.Graph (vertices, edges, ())
703
+ val cc = graph.connectedComponents().vertices
704
+ .toDF(" PersonId" , " Component" )
705
+
706
+ val counts = cc.groupBy(" Component" ).agg(count(" *" ).as(" count" ))
707
+
708
+ cc.join(counts, Seq (" Component" )).select(" PersonId" , " Component" , " count" )
709
+
710
+ },
711
+ " personKnowsPersonConnected" -> LargeFactor (PersonType , PersonKnowsPersonType ) { case Seq (person, personKnowsPerson) =>
712
+ val s = spark
713
+ import s .implicits ._
714
+ val vertices = person.select(" id" ).rdd.map(row => (row.getAs[Long ](" id" ), ()))
715
+
716
+ val edges = personKnowsPerson.rdd.map(row =>
717
+ graphx.Edge (row.getAs[Long ](" Person1Id" ), row.getAs[Long ](" Person2Id" ), ())
718
+ )
719
+ val graph = graphx.Graph (vertices, edges, ())
720
+ val cc = graph.connectedComponents().vertices
721
+ .toDF(" PersonId" , " Component" )
722
+
723
+ val counts = cc.groupBy(" Component" ).agg(count(" *" ).as(" count" ))
724
+
725
+ cc.join(counts, Seq (" Component" )).select(" PersonId" , " Component" , " count" )
726
+
727
+ }
694
728
)
695
729
}
0 commit comments