diff --git a/src/main/java/ldbc/finbench/datagen/generation/events/CompanyInvestEvent.java b/src/main/java/ldbc/finbench/datagen/generation/events/CompanyInvestEvent.java deleted file mode 100644 index 75e5fbdf..00000000 --- a/src/main/java/ldbc/finbench/datagen/generation/events/CompanyInvestEvent.java +++ /dev/null @@ -1,47 +0,0 @@ -package ldbc.finbench.datagen.generation.events; - -import java.io.Serializable; -import java.util.List; -import java.util.Random; -import ldbc.finbench.datagen.entities.edges.CompanyInvestCompany; -import ldbc.finbench.datagen.entities.nodes.Company; -import ldbc.finbench.datagen.generation.DatagenParams; -import ldbc.finbench.datagen.util.RandomGeneratorFarm; - -public class CompanyInvestEvent implements Serializable { - private final RandomGeneratorFarm randomFarm; - private final Random randIndex; - - public CompanyInvestEvent() { - randomFarm = new RandomGeneratorFarm(); - randIndex = new Random(DatagenParams.defaultSeed); - } - - public void resetState(int seed) { - randomFarm.resetRandomGenerators(seed); - randIndex.setSeed(seed); - } - - public List companyInvestPartition(List investors, List targets) { - Random numInvestorsRand = randomFarm.get(RandomGeneratorFarm.Aspect.NUMS_COMPANY_INVEST); - Random chooseInvestorRand = randomFarm.get(RandomGeneratorFarm.Aspect.CHOOSE_COMPANY_INVESTOR); - for (Company target : targets) { - int numInvestors = numInvestorsRand.nextInt( - DatagenParams.maxInvestors - DatagenParams.minInvestors + 1 - ) + DatagenParams.minInvestors; - for (int i = 0; i < numInvestors; i++) { - int index = chooseInvestorRand.nextInt(investors.size()); - Company investor = investors.get(index); - if (cannotInvest(investor, target)) { - continue; - } - CompanyInvestCompany.createCompanyInvestCompany(randomFarm, investor, target); - } - } - return targets; - } - - public boolean cannotInvest(Company investor, Company target) { - return (investor == target) || investor.hasInvestedBy(target) || target.hasInvestedBy(investor); - } -} diff --git a/src/main/java/ldbc/finbench/datagen/generation/events/InvestActivitesEvent.java b/src/main/java/ldbc/finbench/datagen/generation/events/InvestActivitesEvent.java new file mode 100644 index 00000000..1e26065b --- /dev/null +++ b/src/main/java/ldbc/finbench/datagen/generation/events/InvestActivitesEvent.java @@ -0,0 +1,67 @@ +package ldbc.finbench.datagen.generation.events; + +import java.io.Serializable; +import java.util.List; +import java.util.Random; +import ldbc.finbench.datagen.entities.edges.CompanyInvestCompany; +import ldbc.finbench.datagen.entities.edges.PersonInvestCompany; +import ldbc.finbench.datagen.entities.nodes.Company; +import ldbc.finbench.datagen.entities.nodes.Person; +import ldbc.finbench.datagen.generation.DatagenParams; +import ldbc.finbench.datagen.util.RandomGeneratorFarm; + +public class InvestActivitesEvent implements Serializable { + private final RandomGeneratorFarm randomFarm; + + public InvestActivitesEvent() { + randomFarm = new RandomGeneratorFarm(); + } + + public void resetState(int seed) { + randomFarm.resetRandomGenerators(seed); + } + + public List investPartition(List personinvestors, List companyInvestors, + List targets) { + Random numPersonInvestorsRand = randomFarm.get(RandomGeneratorFarm.Aspect.NUMS_PERSON_INVEST); + Random choosePersonInvestorRand = randomFarm.get(RandomGeneratorFarm.Aspect.CHOOSE_PERSON_INVESTOR); + Random numCompanyInvestorsRand = randomFarm.get(RandomGeneratorFarm.Aspect.NUMS_COMPANY_INVEST); + Random chooseCompanyInvestorRand = randomFarm.get(RandomGeneratorFarm.Aspect.CHOOSE_COMPANY_INVESTOR); + for (Company target : targets) { + // Person investors + int numPersonInvestors = numPersonInvestorsRand.nextInt( + DatagenParams.maxInvestors - DatagenParams.minInvestors + 1 + ) + DatagenParams.minInvestors; + for (int i = 0; i < numPersonInvestors; i++) { + int index = choosePersonInvestorRand.nextInt(personinvestors.size()); + Person investor = personinvestors.get(index); + if (cannotInvest(investor, target)) { + continue; + } + PersonInvestCompany.createPersonInvestCompany(randomFarm, investor, target); + } + + // Company investors + int numCompanyInvestors = numCompanyInvestorsRand.nextInt( + DatagenParams.maxInvestors - DatagenParams.minInvestors + 1 + ) + DatagenParams.minInvestors; + for (int i = 0; i < numCompanyInvestors; i++) { + int index = chooseCompanyInvestorRand.nextInt(companyInvestors.size()); + Company investor = companyInvestors.get(index); + if (cannotInvest(investor, target)) { + continue; + } + CompanyInvestCompany.createCompanyInvestCompany(randomFarm, investor, target); + } + } + return targets; + } + + public boolean cannotInvest(Person investor, Company target) { + return target.hasInvestedBy(investor); + } + + public boolean cannotInvest(Company investor, Company target) { + return (investor == target) || investor.hasInvestedBy(target) || target.hasInvestedBy(investor); + } +} diff --git a/src/main/java/ldbc/finbench/datagen/generation/events/PersonInvestEvent.java b/src/main/java/ldbc/finbench/datagen/generation/events/PersonInvestEvent.java deleted file mode 100644 index 297a70f5..00000000 --- a/src/main/java/ldbc/finbench/datagen/generation/events/PersonInvestEvent.java +++ /dev/null @@ -1,48 +0,0 @@ -package ldbc.finbench.datagen.generation.events; - -import java.io.Serializable; -import java.util.List; -import java.util.Random; -import ldbc.finbench.datagen.entities.edges.PersonInvestCompany; -import ldbc.finbench.datagen.entities.nodes.Company; -import ldbc.finbench.datagen.entities.nodes.Person; -import ldbc.finbench.datagen.generation.DatagenParams; -import ldbc.finbench.datagen.util.RandomGeneratorFarm; - -public class PersonInvestEvent implements Serializable { - private final RandomGeneratorFarm randomFarm; - private final Random randIndex; - - public PersonInvestEvent() { - randomFarm = new RandomGeneratorFarm(); - randIndex = new Random(DatagenParams.defaultSeed); - } - - public void resetState(int seed) { - randomFarm.resetRandomGenerators(seed); - randIndex.setSeed(seed); - } - - public List personInvestPartition(List investors, List targets) { - Random numInvestorsRand = randomFarm.get(RandomGeneratorFarm.Aspect.NUMS_PERSON_INVEST); - Random chooseInvestorRand = randomFarm.get(RandomGeneratorFarm.Aspect.CHOOSE_PERSON_INVESTOR); - for (Company target : targets) { - int numInvestors = numInvestorsRand.nextInt( - DatagenParams.maxInvestors - DatagenParams.minInvestors + 1 - ) + DatagenParams.minInvestors; - for (int i = 0; i < numInvestors; i++) { - int index = chooseInvestorRand.nextInt(investors.size()); - Person investor = investors.get(index); - if (cannotInvest(investor, target)) { - continue; - } - PersonInvestCompany.createPersonInvestCompany(randomFarm, investor, target); - } - } - return targets; - } - - public boolean cannotInvest(Person investor, Company target) { - return target.hasInvestedBy(investor); - } -} diff --git a/src/main/scala/ldbc/finbench/datagen/factors/FactorGenerationStage.scala b/src/main/scala/ldbc/finbench/datagen/factors/FactorGenerationStage.scala index 82dc85fb..b178d0cc 100644 --- a/src/main/scala/ldbc/finbench/datagen/factors/FactorGenerationStage.scala +++ b/src/main/scala/ldbc/finbench/datagen/factors/FactorGenerationStage.scala @@ -1,6 +1,5 @@ package ldbc.finbench.datagen.factors -import ldbc.finbench.datagen.LdbcDatagen.log import ldbc.finbench.datagen.util.DatagenStage import org.apache.spark.sql.functions._ import org.apache.spark.sql.{DataFrame, SparkSession, functions => F} @@ -64,196 +63,152 @@ object FactorGenerationStage extends DatagenStage { run(parsedArgs) } - override def run(args: Args) = { factortables(args) } - def factortables(args: Args)(implicit spark: SparkSession) = { - import spark.implicits._ - log.info("[Main] Starting factoring stage") + def readCSV(path: String): DataFrame = { + val csvOptions = Map("header" -> "true", "delimiter" -> "|") + spark.read + .format("csv") + .options(csvOptions) + .load(path) + } - val transferRDD = spark.read - .format("org.apache.spark.sql.execution.datasources.csv.CSVFileFormat") - .option("header", "true") - .option("delimiter", "|") - .load(s"${args.outputDir}/snapshot/AccountTransferAccount.csv") - .select( - $"fromId", - $"toId", - $"amount".cast("double"), - (unix_timestamp( - coalesce( - to_timestamp($"createTime", "yyyy-MM-dd HH:mm:ss.SSS"), - to_timestamp($"createTime", "yyyy-MM-dd HH:mm:ss") - ) - ) * 1000).alias("createTime") + def transformItems( + df: DataFrame, + groupByCol: String, + selectCol: String + ): DataFrame = { + val itemAmountRDD = df + .groupBy(groupByCol, selectCol) + .agg( + max(col("amount")).alias("maxAmount"), + max(col("createTime")).alias("maxCreateTime") ) - val withdrawRDD = spark.read - .format("org.apache.spark.sql.execution.datasources.csv.CSVFileFormat") - .option("header", "true") - .option("delimiter", "|") - .load(s"${args.outputDir}/snapshot/AccountWithdrawAccount.csv") - .select( - $"fromId", - $"toId", - $"amount".cast("double"), - (unix_timestamp( - coalesce( - to_timestamp($"createTime", "yyyy-MM-dd HH:mm:ss.SSS"), - to_timestamp($"createTime", "yyyy-MM-dd HH:mm:ss") - ) - ) * 1000).alias("createTime") + val itemsRDD = itemAmountRDD + .groupBy(groupByCol) + .agg( + collect_list( + array(col(selectCol), col("maxAmount"), col("maxCreateTime")) + ).alias("items") ) - val depositRDD = spark.read - .format("org.apache.spark.sql.execution.datasources.csv.CSVFileFormat") - .option("header", "true") - .option("delimiter", "|") - .load(s"${args.outputDir}/snapshot/LoanDepositAccount.csv") - .select($"accountId", $"loanId") - - val personInvestRDD = spark.read - .format("org.apache.spark.sql.execution.datasources.csv.CSVFileFormat") - .option("header", "true") - .option("delimiter", "|") - .load(s"${args.outputDir}/snapshot/PersonInvestCompany.csv") - .select( - $"investorId", - $"companyId", - (unix_timestamp( - coalesce( - to_timestamp($"createTime", "yyyy-MM-dd HH:mm:ss.SSS"), - to_timestamp($"createTime", "yyyy-MM-dd HH:mm:ss") - ) - ) * 1000).alias("createTime") + itemsRDD + .withColumn( + "items", + F.expr( + "transform(items, array -> concat('[', concat_ws(',', array), ']'))" + ) + ) + .withColumn( + "items", + F.concat(lit("["), F.concat_ws(",", col("items")), lit("]")) ) + } - val OwnRDD = spark.read - .format("org.apache.spark.sql.execution.datasources.csv.CSVFileFormat") - .option("header", "true") - .option("delimiter", "|") - .load(s"${args.outputDir}/snapshot/PersonOwnAccount.csv") - .select($"personId", $"accountId") + def bucketAndCount( + df: DataFrame, + idCol: String, + amountCol: String, + groupCol: String + ): DataFrame = { + + val buckets = Array(10000, 30000, 100000, 300000, 1000000, 2000000, 3000000, + 4000000, 5000000, 6000000, 7000000, 8000000, 9000000, 10000000) + + val bucketedRDD = df.withColumn( + "bucket", + when(col(amountCol) <= buckets(0), buckets(0)) + .when(col(amountCol) <= buckets(1), buckets(1)) + .when(col(amountCol) <= buckets(2), buckets(2)) + .when(col(amountCol) <= buckets(3), buckets(3)) + .when(col(amountCol) <= buckets(4), buckets(4)) + .when(col(amountCol) <= buckets(5), buckets(5)) + .when(col(amountCol) <= buckets(6), buckets(6)) + .when(col(amountCol) <= buckets(7), buckets(7)) + .when(col(amountCol) <= buckets(8), buckets(8)) + .when(col(amountCol) <= buckets(9), buckets(9)) + .when(col(amountCol) <= buckets(10), buckets(10)) + .when(col(amountCol) <= buckets(11), buckets(11)) + .when(col(amountCol) <= buckets(12), buckets(12)) + .otherwise(buckets(13)) + ) - val personGuaranteeRDD = spark.read - .format("org.apache.spark.sql.execution.datasources.csv.CSVFileFormat") - .option("header", "true") - .option("delimiter", "|") - .load(s"${args.outputDir}/snapshot/PersonGuaranteePerson.csv") - .select( - $"fromId", - $"toId", - (unix_timestamp( - coalesce( - to_timestamp($"createTime", "yyyy-MM-dd HH:mm:ss.SSS"), - to_timestamp($"createTime", "yyyy-MM-dd HH:mm:ss") - ) - ) * 1000).alias("createTime") + bucketedRDD + .groupBy(idCol) + .pivot("bucket", buckets.map(_.toString)) + .count() + .na + .fill(0) + .withColumnRenamed(idCol, groupCol) + } + + def processByMonth( + df: DataFrame, + idCol: String, + timeCol: String, + newIdColName: String + ): DataFrame = { + val byMonthRDD = df + .withColumn( + "year_month", + date_format((col(timeCol) / 1000).cast("timestamp"), "yyyy-MM") ) + .groupBy(idCol, "year_month") + .count() - def transformItems( - df: DataFrame, - groupByCol: String, - selectCol: String - ): DataFrame = { - val itemAmountRDD = df - .groupBy(groupByCol, selectCol) - .agg( - max($"amount").alias("maxAmount"), - max($"createTime").alias("maxCreateTime") - ) + val timestampedByMonthRDD = byMonthRDD + .withColumn( + "year_month_ts", + unix_timestamp(col("year_month"), "yyyy-MM") * 1000 + ) + .drop("year_month") - val itemsRDD = itemAmountRDD - .groupBy(groupByCol) - .agg(collect_list(array(col(selectCol), $"maxAmount", $"maxCreateTime")).alias("items")) + val pivotRDD = timestampedByMonthRDD + .groupBy(idCol) + .pivot("year_month_ts") + .agg(first("count")) + .na + .fill(0) + .withColumnRenamed(idCol, newIdColName) - val accountItemsRDD = itemsRDD - .withColumn( - "items", - F.expr( - "transform(items, array -> concat('[', concat_ws(',', array), ']'))" - ) - ) - .withColumn( - "items", - F.concat(lit("["), F.concat_ws(",", $"items"), lit("]")) - ) + pivotRDD + } - accountItemsRDD - } + def factortables(args: Args)(implicit spark: SparkSession) = { +// import spark.implicits._ + log.info("[Main] Starting factoring stage") - def bucketAndCount( - df: DataFrame, - idCol: String, - amountCol: String, - groupCol: String - ): DataFrame = { - - val buckets = Array(10000, 30000, 100000, 300000, 1000000, 2000000, - 3000000, 4000000, 5000000, 6000000, 7000000, 8000000, 9000000, 10000000) - - val bucketedRDD = df.withColumn( - "bucket", - when(col(amountCol) <= buckets(0), buckets(0)) - .when(col(amountCol) <= buckets(1), buckets(1)) - .when(col(amountCol) <= buckets(2), buckets(2)) - .when(col(amountCol) <= buckets(3), buckets(3)) - .when(col(amountCol) <= buckets(4), buckets(4)) - .when(col(amountCol) <= buckets(5), buckets(5)) - .when(col(amountCol) <= buckets(6), buckets(6)) - .when(col(amountCol) <= buckets(7), buckets(7)) - .when(col(amountCol) <= buckets(8), buckets(8)) - .when(col(amountCol) <= buckets(9), buckets(9)) - .when(col(amountCol) <= buckets(10), buckets(10)) - .when(col(amountCol) <= buckets(11), buckets(11)) - .when(col(amountCol) <= buckets(12), buckets(12)) - .otherwise(buckets(13)) + val transferRDD = readCSV(s"${args.outputDir}/raw/transfer/*.csv") + .select( + col("fromId"), + col("toId"), + col("amount").cast("double"), + col("createTime") ) - val bucketCountsRDD = bucketedRDD - .groupBy(idCol) - .pivot("bucket", buckets.map(_.toString)) - .count() - .na - .fill(0) - .withColumnRenamed(idCol, groupCol) - - bucketCountsRDD - } + val withdrawRDD = readCSV(s"${args.outputDir}/raw/withdraw/*.csv") + .select( + col("fromId"), + col("toId"), + col("amount").cast("double"), + col("createTime") + ) - def processByMonth( - df: DataFrame, - idCol: String, - timeCol: String, - newIdColName: String - ): DataFrame = { - val byMonthRDD = df - .withColumn( - "year_month", - date_format((col(timeCol) / 1000).cast("timestamp"), "yyyy-MM") - ) - .groupBy(idCol, "year_month") - .count() + val depositRDD = readCSV(s"${args.outputDir}/raw/deposit/*.csv") + .select(col("accountId"), col("loanId")) - val timestampedByMonthRDD = byMonthRDD - .withColumn( - "year_month_ts", - unix_timestamp(col("year_month"), "yyyy-MM") * 1000 - ) - .drop("year_month") + val personInvestRDD = readCSV(s"${args.outputDir}/raw/personInvest/*.csv") + .select(col("investorId"), col("companyIdstorId"), col("createTime")) - val pivotRDD = timestampedByMonthRDD - .groupBy(idCol) - .pivot("year_month_ts") - .agg(first("count")) - .na - .fill(0) - .withColumnRenamed(idCol, newIdColName) + val ownRDD = readCSV(s"${args.outputDir}/raw/personOwnAccount/*.csv") + .select(col("personId"), col("accountId")) - pivotRDD - } + val personGuaranteeRDD = + readCSV(s"${args.outputDir}/raw/personGuarantee/*.csv") + .select(col("fromId"), col("toId"), col("createTime")) val PersonInvestCompanyRDD = personInvestRDD .groupBy("investorId") @@ -267,7 +222,12 @@ object FactorGenerationStage extends DatagenStage { .save(s"${args.outputDir}/factor_table/person_invest_company") val transferItRDD = - transferRDD.select($"fromId", $"toId", $"amount".cast("double"), $"createTime") + transferRDD.select( + col("fromId"), + col("toId"), + col("amount").cast("double"), + col("createTime") + ) val transferOutAccountItemsRDD = transformItems(transferItRDD, "fromId", "toId") @@ -292,14 +252,14 @@ object FactorGenerationStage extends DatagenStage { .save(s"${args.outputDir}/factor_table/account_transfer_in_items") val transferOutLRDD = transferItRDD - .select($"fromId", $"toId") - .groupBy($"fromId") - .agg(F.collect_list($"toId").alias("transfer_out_list")) - .select($"fromId".alias("account_id"), $"transfer_out_list") + .select(col("fromId"), col("toId")) + .groupBy(col("fromId")) + .agg(F.collect_list(col("toId")).alias("transfer_out_list")) + .select(col("fromId").alias("account_id"), col("transfer_out_list")) val transferOutListRDD = transferOutLRDD.withColumn( "transfer_out_list", - F.concat(lit("["), F.concat_ws(",", $"transfer_out_list"), lit("]")) + F.concat(lit("["), F.concat_ws(",", col("transfer_out_list")), lit("]")) ) transferOutListRDD @@ -311,14 +271,14 @@ object FactorGenerationStage extends DatagenStage { .save(s"${args.outputDir}/factor_table/account_transfer_out_list") val fromAccounts = transferRDD.select( - $"fromId".alias("account_id"), - $"toId".alias("corresponding_account_id"), - $"createTime" + col("fromId").alias("account_id"), + col("toId").alias("corresponding_account_id"), + col("createTime") ) val toAccounts = transferRDD.select( - $"toId".alias("account_id"), - $"fromId".alias("corresponding_account_id"), - $"createTime" + col("toId").alias("account_id"), + col("fromId").alias("corresponding_account_id"), + col("createTime") ) val allAccounts = fromAccounts.union(toAccounts) @@ -328,7 +288,7 @@ object FactorGenerationStage extends DatagenStage { .agg(collect_set("corresponding_account_id").alias("account_list")) val accountCountDF = accountListRDD - .select($"account_id", F.size($"account_list").alias("sum")) + .select(col("account_id"), F.size(col("account_list")).alias("sum")) accountCountDF.write .option("header", "true") @@ -338,7 +298,7 @@ object FactorGenerationStage extends DatagenStage { val transferAccountListRDD = accountListRDD.withColumn( "account_list", - F.concat(lit("["), F.concat_ws(",", $"account_list"), lit("]")) + F.concat(lit("["), F.concat_ws(",", col("account_list")), lit("]")) ) transferAccountListRDD @@ -428,7 +388,12 @@ object FactorGenerationStage extends DatagenStage { .save(s"${args.outputDir}/factor_table/trans_withdraw_month") val withdrawInRDD = - withdrawRDD.select($"fromId", $"toId", $"amount".cast("double"), $"createTime") + withdrawRDD.select( + col("fromId"), + col("toId"), + col("amount").cast("double"), + col("createTime") + ) val combinedRDD = transferItRDD.union(withdrawInRDD) val transformedAccountItemsRDD = @@ -453,15 +418,15 @@ object FactorGenerationStage extends DatagenStage { .save(s"${args.outputDir}/factor_table/account_withdraw_in_items") val transferOutAmountRDD = - transferItRDD.select($"fromId", $"amount".cast("double")) + transferItRDD.select(col("fromId"), col("amount").cast("double")) val transferInAmountRDD = - transferItRDD.select($"toId", $"amount".cast("double")) + transferItRDD.select(col("toId"), col("amount").cast("double")) val transactionsAmountRDD = transferOutAmountRDD .union(withdrawRDD.select(col("fromId"), col("amount").cast("double"))) val withdrawInBucketAmountRDD = - withdrawInRDD.select($"toId", $"amount".cast("double")) + withdrawInRDD.select(col("toId"), col("amount").cast("double")) val bucketCountsRDD = bucketAndCount(transactionsAmountRDD, "fromId", "amount", "account_id") @@ -498,8 +463,8 @@ object FactorGenerationStage extends DatagenStage { .format("org.apache.spark.sql.execution.datasources.csv.CSVFileFormat") .save(s"${args.outputDir}/factor_table/transfer_out_bucket") - val PersonOwnAccountRDD = OwnRDD - .select($"personId", $"accountId") + val PersonOwnAccountRDD = ownRDD + .select(col("personId"), col("accountId")) .groupBy("personId") .agg(coalesce(collect_set("accountId"), array()).alias("account_list")) .select( @@ -515,17 +480,17 @@ object FactorGenerationStage extends DatagenStage { .save(s"${args.outputDir}/factor_table/person_account_list") val PersonGuaranteeListRDD = personGuaranteeRDD - .select($"fromId", $"toId") + .select(col("fromId"), col("toId")) .groupBy("fromId") .agg(coalesce(collect_set("toId"), array()).alias("guaranteee_list")) val PersonGuaranteePersonRDD = PersonGuaranteeListRDD.withColumn( "guaranteee_list", - F.concat(lit("["), F.concat_ws(",", $"guaranteee_list"), lit("]")) + F.concat(lit("["), F.concat_ws(",", col("guaranteee_list")), lit("]")) ) val PersonGuaranteeCount = PersonGuaranteeListRDD - .select($"fromId", F.size($"guaranteee_list").alias("sum")) + .select(col("fromId"), F.size(col("guaranteee_list")).alias("sum")) PersonGuaranteePersonRDD.write .option("header", "true") diff --git a/src/main/scala/ldbc/finbench/datagen/generation/serializers/ActivitySerializer.scala b/src/main/scala/ldbc/finbench/datagen/generation/ActivitySerializer.scala similarity index 99% rename from src/main/scala/ldbc/finbench/datagen/generation/serializers/ActivitySerializer.scala rename to src/main/scala/ldbc/finbench/datagen/generation/ActivitySerializer.scala index d69b32ed..1416967e 100644 --- a/src/main/scala/ldbc/finbench/datagen/generation/serializers/ActivitySerializer.scala +++ b/src/main/scala/ldbc/finbench/datagen/generation/ActivitySerializer.scala @@ -1,4 +1,4 @@ -package ldbc.finbench.datagen.generation.serializers +package ldbc.finbench.datagen.generation import ldbc.finbench.datagen.entities.edges._ import ldbc.finbench.datagen.entities.nodes._ diff --git a/src/main/scala/ldbc/finbench/datagen/generation/ActivitySimulator.scala b/src/main/scala/ldbc/finbench/datagen/generation/ActivitySimulator.scala index c3b4fc7f..bbcea2b9 100644 --- a/src/main/scala/ldbc/finbench/datagen/generation/ActivitySimulator.scala +++ b/src/main/scala/ldbc/finbench/datagen/generation/ActivitySimulator.scala @@ -3,7 +3,6 @@ package ldbc.finbench.datagen.generation import ldbc.finbench.datagen.config.DatagenConfiguration import ldbc.finbench.datagen.entities.nodes._ import ldbc.finbench.datagen.generation.generators.{ActivityGenerator, SparkCompanyGenerator, SparkMediumGenerator, SparkPersonGenerator} -import ldbc.finbench.datagen.generation.serializers.ActivitySerializer import ldbc.finbench.datagen.io.Writer import ldbc.finbench.datagen.io.raw.RawSink import ldbc.finbench.datagen.util.Logging diff --git a/src/main/scala/ldbc/finbench/datagen/generation/generators/ActivityGenerator.scala b/src/main/scala/ldbc/finbench/datagen/generation/generators/ActivityGenerator.scala index 2e0810b0..185bf5ac 100644 --- a/src/main/scala/ldbc/finbench/datagen/generation/generators/ActivityGenerator.scala +++ b/src/main/scala/ldbc/finbench/datagen/generation/generators/ActivityGenerator.scala @@ -82,12 +82,9 @@ class ActivityGenerator()(implicit spark: SparkSession) personRDD: RDD[Person], companyRDD: RDD[Company] ): RDD[Company] = { - val persons = spark.sparkContext.broadcast(personRDD.collect().toList) - val companies = spark.sparkContext.broadcast(companyRDD.collect().toList) - - val personInvestEvent = new PersonInvestEvent() - val companyInvestEvent = new CompanyInvestEvent() + val investActivitesEvent = new InvestActivitesEvent + val numPartitions = companyRDD.getNumPartitions companyRDD .sample( withReplacement = false, @@ -95,17 +92,30 @@ class ActivityGenerator()(implicit spark: SparkSession) sampleRandom.nextLong() ) .mapPartitionsWithIndex { (index, targets) => - personInvestEvent.resetState(index) - personInvestEvent - .personInvestPartition(persons.value.asJava, targets.toList.asJava) - .iterator() - .asScala - } - .mapPartitionsWithIndex { (index, targets) => - companyInvestEvent.resetState(index) - companyInvestEvent - .companyInvestPartition( - companies.value.asJava, + val persons = personRDD + .sample( + withReplacement = false, + 1.0 / numPartitions, + sampleRandom.nextLong() + ) + .collect() + .toList + .asJava + val companies = companyRDD + .sample( + withReplacement = false, + 1.0 / numPartitions, + sampleRandom.nextLong() + ) + .collect() + .toList + .asJava + + investActivitesEvent.resetState(index) + investActivitesEvent + .investPartition( + persons, + companies, targets.toList.asJava ) .iterator() @@ -118,23 +128,22 @@ class ActivityGenerator()(implicit spark: SparkSession) mediumRDD: RDD[Medium], accountRDD: RDD[Account] ): RDD[Medium] = { - val accountSampleList = spark.sparkContext.broadcast( - accountRDD + val signInEvent = new SignInEvent + val numPartitions = mediumRDD.getNumPartitions + mediumRDD.mapPartitionsWithIndex((index, mediums) => { + val accountSampleList = accountRDD .sample( withReplacement = false, - DatagenParams.accountSignedInFraction, + DatagenParams.accountSignedInFraction / numPartitions, sampleRandom.nextLong() ) .collect() .toList - ) - - val signInEvent = new SignInEvent - mediumRDD.mapPartitionsWithIndex((index, mediums) => { + .asJava signInEvent .signIn( mediums.toList.asJava, - accountSampleList.value.asJava, + accountSampleList, index ) .iterator() @@ -164,23 +173,23 @@ class ActivityGenerator()(implicit spark: SparkSession) loanRDD: RDD[Loan], accountRDD: RDD[Account] ): (RDD[Loan]) = { - val sampledAccounts = spark.sparkContext.broadcast( - accountRDD + val numPartition = loanRDD.getNumPartitions + loanRDD.mapPartitionsWithIndex((index, loans) => { + val sampledAccounts = accountRDD .sample( withReplacement = false, - DatagenParams.loanInvolvedAccountsFraction, + DatagenParams.loanInvolvedAccountsFraction / numPartition, sampleRandom.nextLong() ) .collect() .toList - ) + .asJava - loanRDD.mapPartitionsWithIndex((index, loans) => { val loanSubEvents = new LoanActivitiesEvents loanSubEvents .afterLoanApplied( loans.toList.asJava, - sampledAccounts.value.asJava, + sampledAccounts, index ) .iterator() diff --git a/src/main/scala/ldbc/finbench/datagen/transformation/TransformationStage.scala b/src/main/scala/ldbc/finbench/datagen/transformation/TransformationStage.scala deleted file mode 100644 index d58dc6e8..00000000 --- a/src/main/scala/ldbc/finbench/datagen/transformation/TransformationStage.scala +++ /dev/null @@ -1,204 +0,0 @@ -package ldbc.finbench.datagen.transformation - -import ldbc.finbench.datagen.generation.DatagenParams -import ldbc.finbench.datagen.generation.dictionary.Dictionaries -import ldbc.finbench.datagen.util.sql.qcol -import ldbc.finbench.datagen.util.{DatagenStage, Logging} -import ldbc.finbench.datagen.syntax._ -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, SparkSession} -import org.apache.spark.sql.functions.{ - col, - date_format, - date_trunc, - from_unixtime, - lit, - to_timestamp -} -import scopt.OptionParser -import shapeless.lens - -// Note: transformation is not used now. Data conversion is done by python scripts. -object TransformationStage extends DatagenStage with Logging { - private val options: Map[String, String] = - Map("header" -> "true", "delimiter" -> "|") - - case class Args( - outputDir: String = "out", - bulkloadPortion: Double = 0.0, - keepImplicitDeletes: Boolean = false, - simulationStart: Long = 0, - simulationEnd: Long = 0, - irFormat: String = "csv", - format: String = "csv", - formatOptions: Map[String, String] = Map.empty, - epochMillis: Boolean = false, - batchPeriod: String = "day" - ) - - override type ArgsType = Args - - def main(args: Array[String]): Unit = { - val parser = new OptionParser[Args](getClass.getName.dropRight(1)) { - head(appName) - - val args = lens[Args] - - opt[String]('o', "output-dir") - .action((x, c) => args.outputDir.set(c)(x)) - .text( - "path on the cluster filesystem, where Datagen outputs. Can be a URI (e.g S3, ADLS, HDFS) or a " + - "path in which case the default cluster file system is used." - ) - - opt[String]("ir-format") - .action((x, c) => args.irFormat.set(c)(x)) - .text("Format of the raw input") - - opt[String]("format") - .action((x, c) => args.format.set(c)(x)) - .text("Output format") - - help('h', "help").text("prints this usage text") - } - val parsedArgs = - parser - .parse(args, Args()) - .getOrElse(throw new RuntimeException("Invalid arguments")) - - run(parsedArgs) - } - - // execute the transform process - override def run(args: Args): Unit = { - - val rawPathPrefix = args.outputDir / "raw" - val outputPathPrefix = args.outputDir / "history_data" - - val filterDeletion = false - - val simulationStart = Dictionaries.dates.getSimulationStart - val simulationEnd = Dictionaries.dates.getSimulationEnd - val bulkLoadThreshold = calculateBulkLoadThreshold( - args.bulkloadPortion, - simulationStart, - simulationEnd - ) - - // val batch_id = (col: Column) => date_format(date_trunc(args.batchPeriod, to_timestamp(col / lit(1000L))), batchPeriodFormat(args.batchPeriod)) - // - // def inBatch(col: Column, batchStart: Long, batchEnd: Long) = - // col >= lit(batchStart) && col < lit(batchEnd) - // - // val batched = (df: DataFrame) => - // df - // .select( - // df.columns.map(qcol) ++ Seq( - // batch_id($"creationDate").as("insert_batch_id"), - // batch_id($"deletionDate").as("delete_batch_id") - // ): _* - // ) - // - // val insertBatchPart = (tpe: EntityType, df: DataFrame, batchStart: Long, batchEnd: Long) => { - // df - // .filter(inBatch($"creationDate", batchStart, batchEnd)) - // .pipe(batched) - // .select( - // Seq($"insert_batch_id".as("batch_id")) ++ columns(tpe, df.columns).map(qcol): _* - // ) - // } - // - // val deleteBatchPart = (tpe: EntityType, df: DataFrame, batchStart: Long, batchEnd: Long) => { - // val idColumns = tpe.primaryKey.map(qcol) - // df - // .filter(inBatch($"deletionDate", batchStart, batchEnd)) - // .filter(if (df.columns.contains("explicitlyDeleted")) col("explicitlyDeleted") else lit(true)) - // .pipe(batched) - // .select(Seq($"delete_batch_id".as("batch_id"), $"deletionDate") ++ idColumns: _*) - // } - - val readRaw = (target: String) => { - spark.read - .format(args.irFormat) - .options(options) - .option("inferSchema", "true") - .load(s"$rawPathPrefix/$target/*.csv") - } - - val extractSnapshot = (df: DataFrame) => { - df.filter( - $"creationDate" < lit(bulkLoadThreshold) - && (!lit(filterDeletion) || $"deletionDate" >= lit(bulkLoadThreshold)) - ) - // .select(_: _*) - } - - val transferSnapshot = extractSnapshot(readRaw("transfer")) - // .select("fromId", "toId", "multiplicityId", "createTime", "deleteTime", "amount", "isExplicitDeleted") - // .map(extractSnapshot) - .withColumn( - "createTime", - from_unixtime( - col("createTime") / 1000, - batchPeriodFormat(args.batchPeriod) - ) - ) - .withColumn( - "deleteTime", - from_unixtime( - col("deleteTime") / 1000, - batchPeriodFormat(args.batchPeriod) - ) - ) - .orderBy("createTime", "deleteTime") - write(transferSnapshot, (outputPathPrefix / "transfer").toString) - - // val accountDf = spark.read.format("csv") - // .option("header", "true") - // .option("delimiter", "|") - // .load("./out/account/part-00000-4b0e57cb-23bb-447f-89f1-e7e71a4ee017-c000.csv") - // - // transferDf.join(accountDf, transferDf("fromId") === accountDf("id"), "left") - // .select() - } - - // def columns(tpe: EntityType, cols: Seq[String]) = tpe match { - // case tpe if tpe.isStatic => cols - // case Edge("Knows", PersonType, PersonType, NN, false, _, _) => - // val rawCols = Set("deletionDate", "explicitlyDeleted", "weight") - // cols.filter(!rawCols.contains(_)) - // case _ => - // val rawCols = Set("deletionDate", "explicitlyDeleted") - // cols.filter(!rawCols.contains(_)) - // } - private def write(data: DataFrame, path: String): Unit = { - data - .toDF() - .coalesce(1) - .write - .format("csv") - .options(options) - .option("encoding", "UTF-8") - .mode("overwrite") - .save(path) - } - - private def calculateBulkLoadThreshold( - bulkLoadPortion: Double, - simulationStart: Long, - simulationEnd: Long - ) = { - (simulationEnd - ((simulationEnd - simulationStart) * (1 - bulkLoadPortion)).toLong) - } - - private def batchPeriodFormat(batchPeriod: String) = batchPeriod match { - case "year" => "yyyy" - case "month" => "yyyy-MM" - case "day" => "yyyy-MM-dd" - case "hour" => "yyyy-MM-dd'T'hh" - case "minute" => "yyyy-MM-dd'T'hh:mm" - case "second" => "yyyy-MM-dd'T'hh:mm:ss" - case "millisecond" => "yyyy-MM-dd'T'hh:mm:ss.SSS" - case _ => throw new IllegalArgumentException("Unrecognized partition key") - } -}