@@ -13,7 +13,15 @@ import shapeless._
13
13
14
14
import scala .util .matching .Regex
15
15
16
- case class Factor (requiredEntities : EntityType * )(f : Seq [DataFrame ] => DataFrame ) extends (Seq [DataFrame ] => DataFrame ) {
16
+ trait FactorTrait extends (Seq [DataFrame ] => DataFrame ) {
17
+ def requiredEntities : Seq [EntityType ]
18
+ }
19
+
20
+ case class Factor (override val requiredEntities : EntityType * )(f : Seq [DataFrame ] => DataFrame ) extends FactorTrait {
21
+ override def apply (v1 : Seq [DataFrame ]): DataFrame = f(v1).coalesce(1 )
22
+ }
23
+
24
+ case class LargeFactor (override val requiredEntities : EntityType * )(f : Seq [DataFrame ] => DataFrame ) extends FactorTrait {
17
25
override def apply (v1 : Seq [DataFrame ]): DataFrame = f(v1)
18
26
}
19
27
@@ -160,8 +168,8 @@ object FactorGenerationStage extends DatagenStage with Logging {
160
168
)
161
169
},
162
170
" cityPairsNumFriends" -> Factor (PersonKnowsPersonType , PersonType , PlaceType ) { case Seq (personKnowsPerson, persons, places) =>
163
- val cities = places.where($" type" === " City" ).cache()
164
- val knows = undirectedKnows(personKnowsPerson)
171
+ val cities = places.where($" type" === " City" ).cache()
172
+ val knows = undirectedKnows(personKnowsPerson)
165
173
val countries = places.where($" type" === " Country" ).cache()
166
174
167
175
frequency(
@@ -279,7 +287,8 @@ object FactorGenerationStage extends DatagenStage with Logging {
279
287
)
280
288
},
281
289
" personNumFriends" -> Factor (PersonKnowsPersonType , PersonType ) { case Seq (personKnowsPerson, person1) =>
282
- val knows = person1.as(" Person1" )
290
+ val knows = person1
291
+ .as(" Person1" )
283
292
.join(undirectedKnows(personKnowsPerson).as(" knows" ), $" Person1.id" === $" knows.Person1Id" , " leftouter" )
284
293
frequency(knows, value = $" knows.Person2Id" , by = Seq ($" Person1.id" , $" Person1.creationDate" , $" Person1.deletionDate" ))
285
294
},
@@ -318,7 +327,7 @@ object FactorGenerationStage extends DatagenStage with Logging {
318
327
$" Company.name" .alias(" companyName" ),
319
328
$" Company.id" .alias(" companyId" ),
320
329
$" Person2.creationDate" .alias(" person2creationDate" ),
321
- $" Person2.deletionDate" .alias(" person2deletionDate" ),
330
+ $" Person2.deletionDate" .alias(" person2deletionDate" )
322
331
)
323
332
.distinct()
324
333
},
@@ -331,8 +340,8 @@ object FactorGenerationStage extends DatagenStage with Logging {
331
340
)
332
341
},
333
342
" people4Hops" -> Factor (PersonType , PlaceType , PersonKnowsPersonType ) { case Seq (person, place, knows) =>
334
- val cities = place.where($" type" === " City" ).cache()
335
- val allKnows = undirectedKnows(knows).cache()
343
+ val cities = place.where($" type" === " City" ).cache()
344
+ val allKnows = undirectedKnows(knows).cache()
336
345
val minSampleSize = 100.0
337
346
338
347
val chinesePeopleSample = (relations : DataFrame ) => {
@@ -377,8 +386,8 @@ object FactorGenerationStage extends DatagenStage with Logging {
377
386
.limit(10000 )
378
387
},
379
388
" people2Hops" -> Factor (PersonType , PlaceType , PersonKnowsPersonType ) { case Seq (person, place, knows) =>
380
- val cities = place.where($" type" === " City" ).cache()
381
- val allKnows = undirectedKnows(knows).cache()
389
+ val cities = place.where($" type" === " City" ).cache()
390
+ val allKnows = undirectedKnows(knows).cache()
382
391
val minSampleSize = 100.0
383
392
384
393
val chinesePeopleSample = (relations : DataFrame ) => {
@@ -422,16 +431,17 @@ object FactorGenerationStage extends DatagenStage with Logging {
422
431
.sort($" Person1Id" , $" Person2Id" )
423
432
.limit(10000 )
424
433
},
425
- " sameUniversityKnows" -> Factor (PersonKnowsPersonType , PersonStudyAtUniversityType ) {
426
- case Seq (personKnowsPerson, studyAt) =>
427
- undirectedKnows(personKnowsPerson)
428
- .join(studyAt.as(" studyAt1" ), $" studyAt1.personId" === $" knows.person1Id" )
429
- .join(studyAt.as(" studyAt2" ), $" studyAt2.personId" === $" knows.person2Id" )
430
- .where($" studyAt1.universityId" === $" studyAt2.universityId" )
431
- .select(
432
- $" knows.person1Id" .as(" person1Id" ),
433
- $" knows.person2Id" .as(" person2Id" )
434
- )
434
+ " sameUniversityKnows" -> LargeFactor (PersonKnowsPersonType , PersonStudyAtUniversityType ) { case Seq (personKnowsPerson, studyAt) =>
435
+ val size = Math .max(Math .ceil(personKnowsPerson.rdd.getNumPartitions / 10 ).toInt, 1 )
436
+ undirectedKnows(personKnowsPerson)
437
+ .join(studyAt.as(" studyAt1" ), $" studyAt1.personId" === $" knows.person1Id" )
438
+ .join(studyAt.as(" studyAt2" ), $" studyAt2.personId" === $" knows.person2Id" )
439
+ .where($" studyAt1.universityId" === $" studyAt2.universityId" )
440
+ .select(
441
+ $" knows.person1Id" .as(" person1Id" ),
442
+ $" knows.person2Id" .as(" person2Id" )
443
+ )
444
+ .coalesce(size)
435
445
}
436
446
)
437
447
}
0 commit comments