From 39d2fdb51233ed9b1aaf3adaa3267853f5e58c0f Mon Sep 17 00:00:00 2001 From: frreiss Date: Tue, 1 Nov 2016 23:00:17 -0700 Subject: [PATCH 001/534] [SPARK-17475][STREAMING] Delete CRC files if the filesystem doesn't use checksum files ## What changes were proposed in this pull request? When the metadata logs for various parts of Structured Streaming are stored on non-HDFS filesystems such as NFS or ext4, the HDFSMetadataLog class leaves hidden HDFS-style checksum (CRC) files in the log directory, one file per batch. This PR modifies HDFSMetadataLog so that it detects the use of a filesystem that doesn't use CRC files and removes the CRC files. ## How was this patch tested? Modified an existing test case in HDFSMetadataLogSuite to check whether HDFSMetadataLog correctly removes CRC files on the local POSIX filesystem. Ran the entire regression suite. Author: frreiss Closes #15027 from frreiss/fred-17475. (cherry picked from commit 620da3b4828b3580c7ed7339b2a07938e6be1bb1) Signed-off-by: Reynold Xin --- .../spark/sql/execution/streaming/HDFSMetadataLog.scala | 5 +++++ .../sql/execution/streaming/HDFSMetadataLogSuite.scala | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index c7235320fd6bd..9a0f87cf0498c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -148,6 +148,11 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) // It will fail if there is an existing file (someone has committed the batch) logDebug(s"Attempting to write log #${batchIdToPath(batchId)}") fileManager.rename(tempPath, batchIdToPath(batchId)) + + // SPARK-17475: HDFSMetadataLog should not leak CRC files + // If the underlying filesystem didn't rename the CRC file, delete it. + val crcPath = new Path(tempPath.getParent(), s".${tempPath.getName()}.crc") + if (fileManager.exists(crcPath)) fileManager.delete(crcPath) return } catch { case e: IOException if isFileAlreadyExistsException(e) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index 9c1d26dcb2241..d03e08d9a576c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -119,6 +119,12 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { assert(metadataLog.get(1).isEmpty) assert(metadataLog.get(2).isDefined) assert(metadataLog.getLatest().get._1 == 2) + + // There should be exactly one file, called "2", in the metadata directory. + // This check also tests for regressions of SPARK-17475 + val allFiles = new File(metadataLog.metadataPath.toString).listFiles().toSeq + assert(allFiles.size == 1) + assert(allFiles(0).getName() == "2") } } From e6509c2459e7ece3c3c6bcd143b8cc71f8f4d5c8 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 2 Nov 2016 14:15:10 +0800 Subject: [PATCH 002/534] [SPARK-18183][SPARK-18184] Fix INSERT [INTO|OVERWRITE] TABLE ... PARTITION for Datasource tables There are a couple issues with the current 2.1 behavior when inserting into Datasource tables with partitions managed by Hive. (1) OVERWRITE TABLE ... PARTITION will actually overwrite the entire table instead of just the specified partition. (2) INSERT|OVERWRITE does not work with partitions that have custom locations. This PR fixes both of these issues for Datasource tables managed by Hive. The behavior for legacy tables or when `manageFilesourcePartitions = false` is unchanged. There is one other issue in that INSERT OVERWRITE with dynamic partitions will overwrite the entire table instead of just the updated partitions, but this behavior is pretty complicated to implement for Datasource tables. We should address that in a future release. Unit tests. Author: Eric Liang Closes #15705 from ericl/sc-4942. (cherry picked from commit abefe2ec428dc24a4112c623fb6fbe4b2ca60a2b) Signed-off-by: Reynold Xin --- .../spark/sql/catalyst/dsl/package.scala | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 9 +++- .../plans/logical/basicLogicalOperators.scala | 19 ++++++- .../sql/catalyst/parser/PlanParserSuite.scala | 15 ++++-- .../apache/spark/sql/DataFrameWriter.scala | 4 +- .../datasources/CatalogFileIndex.scala | 5 +- .../datasources/DataSourceStrategy.scala | 30 +++++++++-- .../InsertIntoDataSourceCommand.scala | 6 +-- .../spark/sql/hive/HiveStrategies.scala | 3 +- .../CreateHiveTableAsSelectCommand.scala | 5 +- .../PartitionProviderCompatibilitySuite.scala | 52 +++++++++++++++++++ 11 files changed, 129 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 66e52ca68af19..e901683be6854 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -367,7 +367,7 @@ package object dsl { def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( analysis.UnresolvedRelation(TableIdentifier(tableName)), - Map.empty, logicalPlan, overwrite, false) + Map.empty, logicalPlan, OverwriteOptions(overwrite), false) def as(alias: String): LogicalPlan = logicalPlan match { case UnresolvedRelation(tbl, _) => UnresolvedRelation(tbl, Option(alias)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 38e9bb6c162ad..ac1577b3abb4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -177,12 +177,19 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { throw new ParseException(s"Dynamic partitions do not support IF NOT EXISTS. Specified " + "partitions with value: " + dynamicPartitionKeys.keys.mkString("[", ",", "]"), ctx) } + val overwrite = ctx.OVERWRITE != null + val overwritePartition = + if (overwrite && partitionKeys.nonEmpty && dynamicPartitionKeys.isEmpty) { + Some(partitionKeys.map(t => (t._1, t._2.get))) + } else { + None + } InsertIntoTable( UnresolvedRelation(tableIdent, None), partitionKeys, query, - ctx.OVERWRITE != null, + OverwriteOptions(overwrite, overwritePartition), ctx.EXISTS != null) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index a48974c6322ad..7a15c2285d584 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.catalyst.catalog.CatalogTypes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ @@ -345,18 +346,32 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { override lazy val statistics: Statistics = super.statistics.copy(isBroadcastable = true) } +/** + * Options for writing new data into a table. + * + * @param enabled whether to overwrite existing data in the table. + * @param specificPartition only data in the specified partition will be overwritten. + */ +case class OverwriteOptions( + enabled: Boolean, + specificPartition: Option[CatalogTypes.TablePartitionSpec] = None) { + if (specificPartition.isDefined) { + assert(enabled, "Overwrite must be enabled when specifying a partition to overwrite.") + } +} + case class InsertIntoTable( table: LogicalPlan, partition: Map[String, Option[String]], child: LogicalPlan, - overwrite: Boolean, + overwrite: OverwriteOptions, ifNotExists: Boolean) extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil override def output: Seq[Attribute] = Seq.empty - assert(overwrite || !ifNotExists) + assert(overwrite.enabled || !ifNotExists) assert(partition.values.forall(_.nonEmpty) || !ifNotExists) override lazy val resolved: Boolean = childrenResolved && table.resolved diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index ca86304d4d400..7400f3430e99c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -180,7 +180,16 @@ class PlanParserSuite extends PlanTest { partition: Map[String, Option[String]], overwrite: Boolean = false, ifNotExists: Boolean = false): LogicalPlan = - InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists) + InsertIntoTable( + table("s"), partition, plan, + OverwriteOptions( + overwrite, + if (overwrite && partition.nonEmpty) { + Some(partition.map(kv => (kv._1, kv._2.get))) + } else { + None + }), + ifNotExists) // Single inserts assertEqual(s"insert overwrite table s $sql", @@ -196,9 +205,9 @@ class PlanParserSuite extends PlanTest { val plan2 = table("t").where('x > 5).select(star()) assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5", InsertIntoTable( - table("s"), Map.empty, plan.limit(1), overwrite = false, ifNotExists = false).union( + table("s"), Map.empty, plan.limit(1), OverwriteOptions(false), ifNotExists = false).union( InsertIntoTable( - table("u"), Map.empty, plan2, overwrite = false, ifNotExists = false))) + table("u"), Map.empty, plan2, OverwriteOptions(false), ifNotExists = false))) } test ("insert with if not exists") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 11dd1df909938..700f4835ac89a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -25,7 +25,7 @@ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Union} +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, OverwriteOptions, Union} import org.apache.spark.sql.execution.command.AlterTableRecoverPartitionsCommand import org.apache.spark.sql.execution.datasources.{CaseInsensitiveMap, CreateTable, DataSource, HadoopFsRelation} import org.apache.spark.sql.types.StructType @@ -259,7 +259,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { table = UnresolvedRelation(tableIdent), partition = Map.empty[String, Option[String]], child = df.logicalPlan, - overwrite = mode == SaveMode.Overwrite, + overwrite = OverwriteOptions(mode == SaveMode.Overwrite), ifNotExists = false)).toRdd } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala index 092aabc89a36c..443a2ec033a98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala @@ -67,7 +67,10 @@ class CatalogFileIndex( val selectedPartitions = sparkSession.sessionState.catalog.listPartitionsByFilter( table.identifier, filters) val partitions = selectedPartitions.map { p => - PartitionPath(p.toRow(partitionSchema), p.storage.locationUri.get) + val path = new Path(p.storage.locationUri.get) + val fs = path.getFileSystem(hadoopConf) + PartitionPath( + p.toRow(partitionSchema), path.makeQualified(fs.getUri, fs.getWorkingDirectory)) } val partitionSpec = PartitionSpec(partitionSchema, partitions) new PrunedInMemoryFileIndex( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 34b77cab65def..47c1f9d3fac1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources import scala.collection.mutable.ArrayBuffer +import org.apache.hadoop.fs.Path + import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ @@ -174,14 +176,32 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { case LogicalRelation(r: HadoopFsRelation, _, _) => r.location.rootPaths }.flatten - val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append - if (overwrite && inputPaths.contains(outputPath)) { + val mode = if (overwrite.enabled) SaveMode.Overwrite else SaveMode.Append + if (overwrite.enabled && inputPaths.contains(outputPath)) { throw new AnalysisException( "Cannot overwrite a path that is also being read from.") } + val overwritingSinglePartition = (overwrite.specificPartition.isDefined && + t.sparkSession.sessionState.conf.manageFilesourcePartitions && + l.catalogTable.get.partitionProviderIsHive) + + val effectiveOutputPath = if (overwritingSinglePartition) { + val partition = t.sparkSession.sessionState.catalog.getPartition( + l.catalogTable.get.identifier, overwrite.specificPartition.get) + new Path(partition.storage.locationUri.get) + } else { + outputPath + } + + val effectivePartitionSchema = if (overwritingSinglePartition) { + Nil + } else { + query.resolve(t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver) + } + def refreshPartitionsCallback(updatedPartitions: Seq[TablePartitionSpec]): Unit = { - if (l.catalogTable.isDefined && + if (l.catalogTable.isDefined && updatedPartitions.nonEmpty && l.catalogTable.get.partitionColumnNames.nonEmpty && l.catalogTable.get.partitionProviderIsHive) { val metastoreUpdater = AlterTableAddPartitionCommand( @@ -194,8 +214,8 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { } val insertCmd = InsertIntoHadoopFsRelationCommand( - outputPath, - query.resolve(t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver), + effectiveOutputPath, + effectivePartitionSchema, t.bucketSpec, t.fileFormat, refreshPartitionsCallback, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index b2ff68a833fea..2eba1e9986acd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OverwriteOptions} import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.sources.InsertableRelation @@ -30,7 +30,7 @@ import org.apache.spark.sql.sources.InsertableRelation case class InsertIntoDataSourceCommand( logicalRelation: LogicalRelation, query: LogicalPlan, - overwrite: Boolean) + overwrite: OverwriteOptions) extends RunnableCommand { override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) @@ -40,7 +40,7 @@ case class InsertIntoDataSourceCommand( val data = Dataset.ofRows(sparkSession, query) // Apply the schema of the existing table to the new data. val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) - relation.insert(df, overwrite) + relation.insert(df, overwrite.enabled) // Invalidate the cache. sparkSession.sharedState.cacheManager.invalidateCache(logicalRelation) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 9d2930948d6ba..ce1e3eb1a5bc9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -46,7 +46,8 @@ private[hive] trait HiveStrategies { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.InsertIntoTable( table: MetastoreRelation, partition, child, overwrite, ifNotExists) => - InsertIntoHiveTable(table, partition, planLater(child), overwrite, ifNotExists) :: Nil + InsertIntoHiveTable( + table, partition, planLater(child), overwrite.enabled, ifNotExists) :: Nil case CreateTable(tableDesc, mode, Some(query)) if tableDesc.provider.get == "hive" => val newTableDesc = if (tableDesc.storage.serde.isEmpty) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index ef5a5a001fb6f..cac43597aef21 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -21,7 +21,7 @@ import scala.util.control.NonFatal import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.catalog.CatalogTable -import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan, OverwriteOptions} import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.hive.MetastoreRelation @@ -88,7 +88,8 @@ case class CreateHiveTableAsSelectCommand( } else { try { sparkSession.sessionState.executePlan(InsertIntoTable( - metastoreRelation, Map(), query, overwrite = true, ifNotExists = false)).toRdd + metastoreRelation, Map(), query, overwrite = OverwriteOptions(true), + ifNotExists = false)).toRdd } catch { case NonFatal(e) => // drop the created table. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala index 5f16960fb1496..ac435bf6195b0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala @@ -134,4 +134,56 @@ class PartitionProviderCompatibilitySuite } } } + + test("insert overwrite partition of legacy datasource table overwrites entire table") { + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "false") { + withTable("test") { + withTempDir { dir => + setupPartitionedDatasourceTable("test", dir) + spark.sql( + """insert overwrite table test + |partition (partCol=1) + |select * from range(100)""".stripMargin) + assert(spark.sql("select * from test").count() == 100) + + // Dynamic partitions case + spark.sql("insert overwrite table test select id, id from range(10)".stripMargin) + assert(spark.sql("select * from test").count() == 10) + } + } + } + } + + test("insert overwrite partition of new datasource table overwrites just partition") { + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "true") { + withTable("test") { + withTempDir { dir => + setupPartitionedDatasourceTable("test", dir) + sql("msck repair table test") + spark.sql( + """insert overwrite table test + |partition (partCol=1) + |select * from range(100)""".stripMargin) + assert(spark.sql("select * from test").count() == 104) + + // Test overwriting a partition that has a custom location + withTempDir { dir2 => + sql( + s"""alter table test partition (partCol=1) + |set location '${dir2.getAbsolutePath}'""".stripMargin) + assert(sql("select * from test").count() == 4) + sql( + """insert overwrite table test + |partition (partCol=1) + |select * from range(30)""".stripMargin) + sql( + """insert overwrite table test + |partition (partCol=1) + |select * from range(20)""".stripMargin) + assert(sql("select * from test").count() == 24) + } + } + } + } + } } From 85dd073743946383438aabb9f1281e6075f25cc5 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 1 Nov 2016 23:37:03 -0700 Subject: [PATCH 003/534] [SPARK-18192] Support all file formats in structured streaming ## What changes were proposed in this pull request? This patch adds support for all file formats in structured streaming sinks. This is actually a very small change thanks to all the previous refactoring done using the new internal commit protocol API. ## How was this patch tested? Updated FileStreamSinkSuite to add test cases for json, text, and parquet. Author: Reynold Xin Closes #15711 from rxin/SPARK-18192. (cherry picked from commit a36653c5b7b2719f8bfddf4ddfc6e1b828ac9af1) Signed-off-by: Reynold Xin --- .../execution/datasources/DataSource.scala | 8 +-- .../sql/streaming/FileStreamSinkSuite.scala | 62 +++++++++---------- 2 files changed, 32 insertions(+), 38 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index d980e6a15aabe..3f956c427655e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -29,7 +29,6 @@ import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat @@ -37,7 +36,6 @@ import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{CalendarIntervalType, StructType} @@ -292,7 +290,7 @@ case class DataSource( case s: StreamSinkProvider => s.createSink(sparkSession.sqlContext, options, partitionColumns, outputMode) - case parquet: parquet.ParquetFileFormat => + case fileFormat: FileFormat => val caseInsensitiveOptions = new CaseInsensitiveMap(options) val path = caseInsensitiveOptions.getOrElse("path", { throw new IllegalArgumentException("'path' is not specified") @@ -301,7 +299,7 @@ case class DataSource( throw new IllegalArgumentException( s"Data source $className does not support $outputMode output mode") } - new FileStreamSink(sparkSession, path, parquet, partitionColumns, options) + new FileStreamSink(sparkSession, path, fileFormat, partitionColumns, options) case _ => throw new UnsupportedOperationException( @@ -516,7 +514,7 @@ case class DataSource( val plan = data.logicalPlan plan.resolve(name :: Nil, data.sparkSession.sessionState.analyzer.resolver).getOrElse { throw new AnalysisException( - s"Unable to resolve ${name} given [${plan.output.map(_.name).mkString(", ")}]") + s"Unable to resolve $name given [${plan.output.map(_.name).mkString(", ")}]") }.asInstanceOf[Attribute] } // For partitioned relation r, r.schema's column ordering can be different from the column diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 902cf05344716..0f140f94f630e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming -import org.apache.spark.sql._ +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.streaming.{MemoryStream, MetadataLogFileIndex} @@ -142,42 +142,38 @@ class FileStreamSinkSuite extends StreamTest { } } - test("FileStreamSink - supported formats") { - def testFormat(format: Option[String]): Unit = { - val inputData = MemoryStream[Int] - val ds = inputData.toDS() + test("FileStreamSink - parquet") { + testFormat(None) // should not throw error as default format parquet when not specified + testFormat(Some("parquet")) + } - val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath - val checkpointDir = Utils.createTempDir(namePrefix = "stream.checkpoint").getCanonicalPath + test("FileStreamSink - text") { + testFormat(Some("text")) + } - var query: StreamingQuery = null + test("FileStreamSink - json") { + testFormat(Some("text")) + } - try { - val writer = - ds.map(i => (i, i * 1000)) - .toDF("id", "value") - .writeStream - if (format.nonEmpty) { - writer.format(format.get) - } - query = writer - .option("checkpointLocation", checkpointDir) - .start(outputDir) - } finally { - if (query != null) { - query.stop() - } - } - } + def testFormat(format: Option[String]): Unit = { + val inputData = MemoryStream[Int] + val ds = inputData.toDS() - testFormat(None) // should not throw error as default format parquet when not specified - testFormat(Some("parquet")) - val e = intercept[UnsupportedOperationException] { - testFormat(Some("text")) - } - Seq("text", "not support", "stream").foreach { s => - assert(e.getMessage.contains(s)) + val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath + val checkpointDir = Utils.createTempDir(namePrefix = "stream.checkpoint").getCanonicalPath + + var query: StreamingQuery = null + + try { + val writer = ds.map(i => (i, i * 1000)).toDF("id", "value").writeStream + if (format.nonEmpty) { + writer.format(format.get) + } + query = writer.option("checkpointLocation", checkpointDir).start(outputDir) + } finally { + if (query != null) { + query.stop() + } } } - } From 4c4bf87acf2516a72b59f4e760413f80640dca1e Mon Sep 17 00:00:00 2001 From: CodingCat Date: Tue, 1 Nov 2016 23:39:53 -0700 Subject: [PATCH 004/534] [SPARK-18144][SQL] logging StreamingQueryListener$QueryStartedEvent ## What changes were proposed in this pull request? The PR fixes the bug that the QueryStartedEvent is not logged the postToAll() in the original code is actually calling StreamingQueryListenerBus.postToAll() which has no listener at all....we shall post by sparkListenerBus.postToAll(s) and this.postToAll() to trigger local listeners as well as the listeners registered in LiveListenerBus zsxwing ## How was this patch tested? The following snapshot shows that QueryStartedEvent has been logged correctly ![image](https://cloud.githubusercontent.com/assets/678008/19821553/007a7d28-9d2d-11e6-9f13-49851559cdaa.png) Author: CodingCat Closes #15675 from CodingCat/SPARK-18144. (cherry picked from commit 85c5424d466f4a5765c825e0e2ab30da97611285) Signed-off-by: Shixiong Zhu --- .../streaming/StreamingQueryListenerBus.scala | 10 +++++++++- .../spark/sql/streaming/StreamingQuerySuite.scala | 7 ++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala index fc2190d39da4f..22e4c6380fcd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala @@ -41,6 +41,8 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus) def post(event: StreamingQueryListener.Event) { event match { case s: QueryStartedEvent => + sparkListenerBus.post(s) + // post to local listeners to trigger callbacks postToAll(s) case _ => sparkListenerBus.post(event) @@ -50,7 +52,13 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus) override def onOtherEvent(event: SparkListenerEvent): Unit = { event match { case e: StreamingQueryListener.Event => - postToAll(e) + // SPARK-18144: we broadcast QueryStartedEvent to all listeners attached to this bus + // synchronously and the ones attached to LiveListenerBus asynchronously. Therefore, + // we need to ignore QueryStartedEvent if this method is called within SparkListenerBus + // thread + if (!LiveListenerBus.withinListenerThread.value || !e.isInstanceOf[QueryStartedEvent]) { + postToAll(e) + } case _ => } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 464c443beb6e7..31b7fe0b04da9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -290,7 +290,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { // A StreamingQueryListener that gets the query status after the first completed trigger val listener = new StreamingQueryListener { @volatile var firstStatus: StreamingQueryStatus = null - override def onQueryStarted(queryStarted: QueryStartedEvent): Unit = { } + @volatile var queryStartedEvent = 0 + override def onQueryStarted(queryStarted: QueryStartedEvent): Unit = { + queryStartedEvent += 1 + } override def onQueryProgress(queryProgress: QueryProgressEvent): Unit = { if (firstStatus == null) firstStatus = queryProgress.queryStatus } @@ -303,6 +306,8 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { q.processAllAvailable() eventually(timeout(streamingTimeout)) { assert(listener.firstStatus != null) + // test if QueryStartedEvent callback is called for only once + assert(listener.queryStartedEvent === 1) } listener.firstStatus } finally { From 3b624bedf0f0ecd5dcfcc262a3ca8b4e33662533 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 2 Nov 2016 00:08:30 -0700 Subject: [PATCH 005/534] [SPARK-17532] Add lock debugging info to thread dumps. ## What changes were proposed in this pull request? This adds information to the web UI thread dump page about the JVM locks held by threads and the locks that threads are blocked waiting to acquire. This should help find cases where lock contention is causing Spark applications to run slowly. ## How was this patch tested? Tested by applying this patch and viewing the change in the web UI. ![thread-lock-info](https://cloud.githubusercontent.com/assets/87915/18493057/6e5da870-79c3-11e6-8c20-f54c18a37544.png) Additions: - A "Thread Locking" column with the locks held by the thread or that are blocking the thread - Links from the a blocked thread to the thread holding the lock - Stack frames show where threads are inside `synchronized` blocks, "holding Monitor(...)" Author: Ryan Blue Closes #15088 from rdblue/SPARK-17532-add-thread-lock-info. (cherry picked from commit 2dc048081668665f85623839d5f663b402e42555) Signed-off-by: Reynold Xin --- .../org/apache/spark/ui/static/table.js | 3 +- .../ui/exec/ExecutorThreadDumpPage.scala | 12 +++++++ .../apache/spark/util/ThreadStackTrace.scala | 6 +++- .../scala/org/apache/spark/util/Utils.scala | 34 ++++++++++++++++--- 4 files changed, 49 insertions(+), 6 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/table.js b/core/src/main/resources/org/apache/spark/ui/static/table.js index 14b06bfe860ed..0315ebf5c48a9 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/table.js +++ b/core/src/main/resources/org/apache/spark/ui/static/table.js @@ -36,7 +36,7 @@ function toggleThreadStackTrace(threadId, forceAdd) { if (stackTrace.length == 0) { var stackTraceText = $('#' + threadId + "_td_stacktrace").html() var threadCell = $("#thread_" + threadId + "_tr") - threadCell.after("
" +
+        threadCell.after("
" +
             stackTraceText +  "
") } else { if (!forceAdd) { @@ -73,6 +73,7 @@ function onMouseOverAndOut(threadId) { $("#" + threadId + "_td_id").toggleClass("threaddump-td-mouseover"); $("#" + threadId + "_td_name").toggleClass("threaddump-td-mouseover"); $("#" + threadId + "_td_state").toggleClass("threaddump-td-mouseover"); + $("#" + threadId + "_td_locking").toggleClass("threaddump-td-mouseover"); } function onSearchStringChange() { diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index a0ef80d9bdae0..c6a07445f2a35 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -48,6 +48,16 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage } }.map { thread => val threadId = thread.threadId + val blockedBy = thread.blockedByThreadId match { + case Some(blockedByThreadId) => + + case None => Text("") + } + val heldLocks = thread.holdingLocks.mkString(", ") + {threadId} {thread.threadName} {thread.threadState} + {blockedBy}{heldLocks} {thread.stackTrace} } @@ -86,6 +97,7 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage Thread ID Thread Name Thread State + Thread Locks {dumpRows} diff --git a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala index d4e0ad93b966a..b1217980faf1f 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala @@ -24,4 +24,8 @@ private[spark] case class ThreadStackTrace( threadId: Long, threadName: String, threadState: Thread.State, - stackTrace: String) + stackTrace: String, + blockedByThreadId: Option[Long], + blockedByLock: String, + holdingLocks: Seq[String]) + diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 6027b07c0fee8..22c28fba2087e 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -18,7 +18,7 @@ package org.apache.spark.util import java.io._ -import java.lang.management.ManagementFactory +import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo} import java.net._ import java.nio.ByteBuffer import java.nio.channels.Channels @@ -2096,15 +2096,41 @@ private[spark] object Utils extends Logging { } } + private implicit class Lock(lock: LockInfo) { + def lockString: String = { + lock match { + case monitor: MonitorInfo => + s"Monitor(${lock.getClassName}@${lock.getIdentityHashCode}})" + case _ => + s"Lock(${lock.getClassName}@${lock.getIdentityHashCode}})" + } + } + } + /** Return a thread dump of all threads' stacktraces. Used to capture dumps for the web UI */ def getThreadDump(): Array[ThreadStackTrace] = { // We need to filter out null values here because dumpAllThreads() may return null array // elements for threads that are dead / don't exist. val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null) threadInfos.sortBy(_.getThreadId).map { case threadInfo => - val stackTrace = threadInfo.getStackTrace.map(_.toString).mkString("\n") - ThreadStackTrace(threadInfo.getThreadId, threadInfo.getThreadName, - threadInfo.getThreadState, stackTrace) + val monitors = threadInfo.getLockedMonitors.map(m => m.getLockedStackFrame -> m).toMap + val stackTrace = threadInfo.getStackTrace.map { frame => + monitors.get(frame) match { + case Some(monitor) => + monitor.getLockedStackFrame.toString + s" => holding ${monitor.lockString}" + case None => + frame.toString + } + }.mkString("\n") + + // use a set to dedup re-entrant locks that are held at multiple places + val heldLocks = (threadInfo.getLockedSynchronizers.map(_.lockString) + ++ threadInfo.getLockedMonitors.map(_.lockString) + ).toSet + + ThreadStackTrace(threadInfo.getThreadId, threadInfo.getThreadName, threadInfo.getThreadState, + stackTrace, if (threadInfo.getLockOwnerId < 0) None else Some(threadInfo.getLockOwnerId), + Option(threadInfo.getLockInfo).map(_.lockString).getOrElse(""), heldLocks.toSeq) } } From ab8da1413836591fecbc75a2515875bf3e50527f Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Wed, 2 Nov 2016 09:10:34 +0000 Subject: [PATCH 006/534] [SPARK-18198][DOC][STREAMING] Highlight code snippets ## What changes were proposed in this pull request? This patch uses `{% highlight lang %}...{% endhighlight %}` to highlight code snippets in the `Structured Streaming Kafka010 integration doc` and the `Spark Streaming Kafka010 integration doc`. This patch consists of two commits: - the first commit fixes only the leading spaces -- this is large - the second commit adds the highlight instructions -- this is much simpler and easier to review ## How was this patch tested? SKIP_API=1 jekyll build ## Screenshots **Before** ![snip20161101_3](https://cloud.githubusercontent.com/assets/15843379/19894258/47746524-a087-11e6-9a2a-7bff2d428d44.png) **After** ![snip20161101_1](https://cloud.githubusercontent.com/assets/15843379/19894324/8bebcd1e-a087-11e6-835b-88c4d2979cfa.png) Author: Liwei Lin Closes #15715 from lw-lin/doc-highlight-code-snippet. (cherry picked from commit 98ede49496d0d7b4724085083d4f24436b92a7bf) Signed-off-by: Sean Owen --- docs/streaming-kafka-0-10-integration.md | 391 +++++++++--------- .../structured-streaming-kafka-integration.md | 156 +++---- 2 files changed, 287 insertions(+), 260 deletions(-) diff --git a/docs/streaming-kafka-0-10-integration.md b/docs/streaming-kafka-0-10-integration.md index c1ef396907db7..b645d3c3a4b53 100644 --- a/docs/streaming-kafka-0-10-integration.md +++ b/docs/streaming-kafka-0-10-integration.md @@ -17,69 +17,72 @@ For Scala/Java applications using SBT/Maven project definitions, link your strea
- import org.apache.kafka.clients.consumer.ConsumerRecord - import org.apache.kafka.common.serialization.StringDeserializer - import org.apache.spark.streaming.kafka010._ - import org.apache.spark.streaming.kafka010.LocationStrategies.PreferConsistent - import org.apache.spark.streaming.kafka010.ConsumerStrategies.Subscribe - - val kafkaParams = Map[String, Object]( - "bootstrap.servers" -> "localhost:9092,anotherhost:9092", - "key.deserializer" -> classOf[StringDeserializer], - "value.deserializer" -> classOf[StringDeserializer], - "group.id" -> "use_a_separate_group_id_for_each_stream", - "auto.offset.reset" -> "latest", - "enable.auto.commit" -> (false: java.lang.Boolean) - ) - - val topics = Array("topicA", "topicB") - val stream = KafkaUtils.createDirectStream[String, String]( - streamingContext, - PreferConsistent, - Subscribe[String, String](topics, kafkaParams) - ) - - stream.map(record => (record.key, record.value)) - +{% highlight scala %} +import org.apache.kafka.clients.consumer.ConsumerRecord +import org.apache.kafka.common.serialization.StringDeserializer +import org.apache.spark.streaming.kafka010._ +import org.apache.spark.streaming.kafka010.LocationStrategies.PreferConsistent +import org.apache.spark.streaming.kafka010.ConsumerStrategies.Subscribe + +val kafkaParams = Map[String, Object]( + "bootstrap.servers" -> "localhost:9092,anotherhost:9092", + "key.deserializer" -> classOf[StringDeserializer], + "value.deserializer" -> classOf[StringDeserializer], + "group.id" -> "use_a_separate_group_id_for_each_stream", + "auto.offset.reset" -> "latest", + "enable.auto.commit" -> (false: java.lang.Boolean) +) + +val topics = Array("topicA", "topicB") +val stream = KafkaUtils.createDirectStream[String, String]( + streamingContext, + PreferConsistent, + Subscribe[String, String](topics, kafkaParams) +) + +stream.map(record => (record.key, record.value)) +{% endhighlight %} Each item in the stream is a [ConsumerRecord](http://kafka.apache.org/0100/javadoc/org/apache/kafka/clients/consumer/ConsumerRecord.html)
- import java.util.*; - import org.apache.spark.SparkConf; - import org.apache.spark.TaskContext; - import org.apache.spark.api.java.*; - import org.apache.spark.api.java.function.*; - import org.apache.spark.streaming.api.java.*; - import org.apache.spark.streaming.kafka010.*; - import org.apache.kafka.clients.consumer.ConsumerRecord; - import org.apache.kafka.common.TopicPartition; - import org.apache.kafka.common.serialization.StringDeserializer; - import scala.Tuple2; - - Map kafkaParams = new HashMap<>(); - kafkaParams.put("bootstrap.servers", "localhost:9092,anotherhost:9092"); - kafkaParams.put("key.deserializer", StringDeserializer.class); - kafkaParams.put("value.deserializer", StringDeserializer.class); - kafkaParams.put("group.id", "use_a_separate_group_id_for_each_stream"); - kafkaParams.put("auto.offset.reset", "latest"); - kafkaParams.put("enable.auto.commit", false); - - Collection topics = Arrays.asList("topicA", "topicB"); - - final JavaInputDStream> stream = - KafkaUtils.createDirectStream( - streamingContext, - LocationStrategies.PreferConsistent(), - ConsumerStrategies.Subscribe(topics, kafkaParams) - ); - - stream.mapToPair( - new PairFunction, String, String>() { - @Override - public Tuple2 call(ConsumerRecord record) { - return new Tuple2<>(record.key(), record.value()); - } - }) +{% highlight java %} +import java.util.*; +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.*; +import org.apache.spark.streaming.api.java.*; +import org.apache.spark.streaming.kafka010.*; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.StringDeserializer; +import scala.Tuple2; + +Map kafkaParams = new HashMap<>(); +kafkaParams.put("bootstrap.servers", "localhost:9092,anotherhost:9092"); +kafkaParams.put("key.deserializer", StringDeserializer.class); +kafkaParams.put("value.deserializer", StringDeserializer.class); +kafkaParams.put("group.id", "use_a_separate_group_id_for_each_stream"); +kafkaParams.put("auto.offset.reset", "latest"); +kafkaParams.put("enable.auto.commit", false); + +Collection topics = Arrays.asList("topicA", "topicB"); + +final JavaInputDStream> stream = + KafkaUtils.createDirectStream( + streamingContext, + LocationStrategies.PreferConsistent(), + ConsumerStrategies.Subscribe(topics, kafkaParams) + ); + +stream.mapToPair( + new PairFunction, String, String>() { + @Override + public Tuple2 call(ConsumerRecord record) { + return new Tuple2<>(record.key(), record.value()); + } + }) +{% endhighlight %}
@@ -109,32 +112,35 @@ If you have a use case that is better suited to batch processing, you can create
- // Import dependencies and create kafka params as in Create Direct Stream above - - val offsetRanges = Array( - // topic, partition, inclusive starting offset, exclusive ending offset - OffsetRange("test", 0, 0, 100), - OffsetRange("test", 1, 0, 100) - ) +{% highlight scala %} +// Import dependencies and create kafka params as in Create Direct Stream above - val rdd = KafkaUtils.createRDD[String, String](sparkContext, kafkaParams, offsetRanges, PreferConsistent) +val offsetRanges = Array( + // topic, partition, inclusive starting offset, exclusive ending offset + OffsetRange("test", 0, 0, 100), + OffsetRange("test", 1, 0, 100) +) +val rdd = KafkaUtils.createRDD[String, String](sparkContext, kafkaParams, offsetRanges, PreferConsistent) +{% endhighlight %}
- // Import dependencies and create kafka params as in Create Direct Stream above - - OffsetRange[] offsetRanges = { - // topic, partition, inclusive starting offset, exclusive ending offset - OffsetRange.create("test", 0, 0, 100), - OffsetRange.create("test", 1, 0, 100) - }; - - JavaRDD> rdd = KafkaUtils.createRDD( - sparkContext, - kafkaParams, - offsetRanges, - LocationStrategies.PreferConsistent() - ); +{% highlight java %} +// Import dependencies and create kafka params as in Create Direct Stream above + +OffsetRange[] offsetRanges = { + // topic, partition, inclusive starting offset, exclusive ending offset + OffsetRange.create("test", 0, 0, 100), + OffsetRange.create("test", 1, 0, 100) +}; + +JavaRDD> rdd = KafkaUtils.createRDD( + sparkContext, + kafkaParams, + offsetRanges, + LocationStrategies.PreferConsistent() +); +{% endhighlight %}
@@ -144,29 +150,33 @@ Note that you cannot use `PreferBrokers`, because without the stream there is no
- stream.foreachRDD { rdd => - val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges - rdd.foreachPartition { iter => - val o: OffsetRange = offsetRanges(TaskContext.get.partitionId) - println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") - } - } +{% highlight scala %} +stream.foreachRDD { rdd => + val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + rdd.foreachPartition { iter => + val o: OffsetRange = offsetRanges(TaskContext.get.partitionId) + println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") + } +} +{% endhighlight %}
- stream.foreachRDD(new VoidFunction>>() { - @Override - public void call(JavaRDD> rdd) { - final OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); - rdd.foreachPartition(new VoidFunction>>() { - @Override - public void call(Iterator> consumerRecords) { - OffsetRange o = offsetRanges[TaskContext.get().partitionId()]; - System.out.println( - o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset()); - } - }); - } - }); +{% highlight java %} +stream.foreachRDD(new VoidFunction>>() { + @Override + public void call(JavaRDD> rdd) { + final OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + rdd.foreachPartition(new VoidFunction>>() { + @Override + public void call(Iterator> consumerRecords) { + OffsetRange o = offsetRanges[TaskContext.get().partitionId()]; + System.out.println( + o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset()); + } + }); + } +}); +{% endhighlight %}
@@ -183,25 +193,28 @@ Kafka has an offset commit API that stores offsets in a special Kafka topic. By
- stream.foreachRDD { rdd => - val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges - - // some time later, after outputs have completed - stream.asInstanceOf[CanCommitOffsets].commitAsync(offsetRanges) - } - +{% highlight scala %} +stream.foreachRDD { rdd => + val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + + // some time later, after outputs have completed + stream.asInstanceOf[CanCommitOffsets].commitAsync(offsetRanges) +} +{% endhighlight %} As with HasOffsetRanges, the cast to CanCommitOffsets will only succeed if called on the result of createDirectStream, not after transformations. The commitAsync call is threadsafe, but must occur after outputs if you want meaningful semantics.
- stream.foreachRDD(new VoidFunction>>() { - @Override - public void call(JavaRDD> rdd) { - OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); - - // some time later, after outputs have completed - ((CanCommitOffsets) stream.inputDStream()).commitAsync(offsetRanges); - } - }); +{% highlight java %} +stream.foreachRDD(new VoidFunction>>() { + @Override + public void call(JavaRDD> rdd) { + OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + + // some time later, after outputs have completed + ((CanCommitOffsets) stream.inputDStream()).commitAsync(offsetRanges); + } +}); +{% endhighlight %}
@@ -210,64 +223,68 @@ For data stores that support transactions, saving offsets in the same transactio
- // The details depend on your data store, but the general idea looks like this +{% highlight scala %} +// The details depend on your data store, but the general idea looks like this - // begin from the the offsets committed to the database - val fromOffsets = selectOffsetsFromYourDatabase.map { resultSet => - new TopicPartition(resultSet.string("topic"), resultSet.int("partition")) -> resultSet.long("offset") - }.toMap +// begin from the the offsets committed to the database +val fromOffsets = selectOffsetsFromYourDatabase.map { resultSet => + new TopicPartition(resultSet.string("topic"), resultSet.int("partition")) -> resultSet.long("offset") +}.toMap - val stream = KafkaUtils.createDirectStream[String, String]( - streamingContext, - PreferConsistent, - Assign[String, String](fromOffsets.keys.toList, kafkaParams, fromOffsets) - ) +val stream = KafkaUtils.createDirectStream[String, String]( + streamingContext, + PreferConsistent, + Assign[String, String](fromOffsets.keys.toList, kafkaParams, fromOffsets) +) - stream.foreachRDD { rdd => - val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges +stream.foreachRDD { rdd => + val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges - val results = yourCalculation(rdd) + val results = yourCalculation(rdd) - // begin your transaction + // begin your transaction - // update results - // update offsets where the end of existing offsets matches the beginning of this batch of offsets - // assert that offsets were updated correctly + // update results + // update offsets where the end of existing offsets matches the beginning of this batch of offsets + // assert that offsets were updated correctly - // end your transaction - } + // end your transaction +} +{% endhighlight %}
- // The details depend on your data store, but the general idea looks like this - - // begin from the the offsets committed to the database - Map fromOffsets = new HashMap<>(); - for (resultSet : selectOffsetsFromYourDatabase) - fromOffsets.put(new TopicPartition(resultSet.string("topic"), resultSet.int("partition")), resultSet.long("offset")); - } - - JavaInputDStream> stream = KafkaUtils.createDirectStream( - streamingContext, - LocationStrategies.PreferConsistent(), - ConsumerStrategies.Assign(fromOffsets.keySet(), kafkaParams, fromOffsets) - ); - - stream.foreachRDD(new VoidFunction>>() { - @Override - public void call(JavaRDD> rdd) { - OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); - - Object results = yourCalculation(rdd); - - // begin your transaction - - // update results - // update offsets where the end of existing offsets matches the beginning of this batch of offsets - // assert that offsets were updated correctly - - // end your transaction - } - }); +{% highlight java %} +// The details depend on your data store, but the general idea looks like this + +// begin from the the offsets committed to the database +Map fromOffsets = new HashMap<>(); +for (resultSet : selectOffsetsFromYourDatabase) + fromOffsets.put(new TopicPartition(resultSet.string("topic"), resultSet.int("partition")), resultSet.long("offset")); +} + +JavaInputDStream> stream = KafkaUtils.createDirectStream( + streamingContext, + LocationStrategies.PreferConsistent(), + ConsumerStrategies.Assign(fromOffsets.keySet(), kafkaParams, fromOffsets) +); + +stream.foreachRDD(new VoidFunction>>() { + @Override + public void call(JavaRDD> rdd) { + OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + + Object results = yourCalculation(rdd); + + // begin your transaction + + // update results + // update offsets where the end of existing offsets matches the beginning of this batch of offsets + // assert that offsets were updated correctly + + // end your transaction + } +}); +{% endhighlight %}
@@ -277,25 +294,29 @@ The new Kafka consumer [supports SSL](http://kafka.apache.org/documentation.html
- val kafkaParams = Map[String, Object]( - // the usual params, make sure to change the port in bootstrap.servers if 9092 is not TLS - "security.protocol" -> "SSL", - "ssl.truststore.location" -> "/some-directory/kafka.client.truststore.jks", - "ssl.truststore.password" -> "test1234", - "ssl.keystore.location" -> "/some-directory/kafka.client.keystore.jks", - "ssl.keystore.password" -> "test1234", - "ssl.key.password" -> "test1234" - ) +{% highlight scala %} +val kafkaParams = Map[String, Object]( + // the usual params, make sure to change the port in bootstrap.servers if 9092 is not TLS + "security.protocol" -> "SSL", + "ssl.truststore.location" -> "/some-directory/kafka.client.truststore.jks", + "ssl.truststore.password" -> "test1234", + "ssl.keystore.location" -> "/some-directory/kafka.client.keystore.jks", + "ssl.keystore.password" -> "test1234", + "ssl.key.password" -> "test1234" +) +{% endhighlight %}
- Map kafkaParams = new HashMap(); - // the usual params, make sure to change the port in bootstrap.servers if 9092 is not TLS - kafkaParams.put("security.protocol", "SSL"); - kafkaParams.put("ssl.truststore.location", "/some-directory/kafka.client.truststore.jks"); - kafkaParams.put("ssl.truststore.password", "test1234"); - kafkaParams.put("ssl.keystore.location", "/some-directory/kafka.client.keystore.jks"); - kafkaParams.put("ssl.keystore.password", "test1234"); - kafkaParams.put("ssl.key.password", "test1234"); +{% highlight java %} +Map kafkaParams = new HashMap(); +// the usual params, make sure to change the port in bootstrap.servers if 9092 is not TLS +kafkaParams.put("security.protocol", "SSL"); +kafkaParams.put("ssl.truststore.location", "/some-directory/kafka.client.truststore.jks"); +kafkaParams.put("ssl.truststore.password", "test1234"); +kafkaParams.put("ssl.keystore.location", "/some-directory/kafka.client.keystore.jks"); +kafkaParams.put("ssl.keystore.password", "test1234"); +kafkaParams.put("ssl.key.password", "test1234"); +{% endhighlight %}
diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index a6c3b3a9024d8..c4c9fb3f7d3db 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -19,97 +19,103 @@ application. See the [Deploying](#deploying) subsection below.
+{% highlight scala %} - // Subscribe to 1 topic - val ds1 = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1") - .load() - ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] +// Subscribe to 1 topic +val ds1 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() +ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] - // Subscribe to multiple topics - val ds2 = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1,topic2") - .load() - ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] +// Subscribe to multiple topics +val ds2 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1,topic2") + .load() +ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] - // Subscribe to a pattern - val ds3 = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribePattern", "topic.*") - .load() - ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] +// Subscribe to a pattern +val ds3 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribePattern", "topic.*") + .load() +ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] +{% endhighlight %}
+{% highlight java %} - // Subscribe to 1 topic - Dataset ds1 = spark - .readStream() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1") - .load() - ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +// Subscribe to 1 topic +Dataset ds1 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() +ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - // Subscribe to multiple topics - Dataset ds2 = spark - .readStream() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1,topic2") - .load() - ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +// Subscribe to multiple topics +Dataset ds2 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1,topic2") + .load() +ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - // Subscribe to a pattern - Dataset ds3 = spark - .readStream() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribePattern", "topic.*") - .load() - ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +// Subscribe to a pattern +Dataset ds3 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribePattern", "topic.*") + .load() +ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +{% endhighlight %}
+{% highlight python %} - # Subscribe to 1 topic - ds1 = spark - .readStream() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1") - .load() - ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +# Subscribe to 1 topic +ds1 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() +ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - # Subscribe to multiple topics - ds2 = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1,topic2") - .load() - ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +# Subscribe to multiple topics +ds2 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1,topic2") + .load() +ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - # Subscribe to a pattern - ds3 = spark - .readStream() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribePattern", "topic.*") - .load() - ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +# Subscribe to a pattern +ds3 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribePattern", "topic.*") + .load() +ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +{% endhighlight %}
From 176afa5e8b207e28a16e1b22280ed05c10b7b486 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 2 Nov 2016 09:39:15 +0000 Subject: [PATCH 007/534] [SPARK-18076][CORE][SQL] Fix default Locale used in DateFormat, NumberFormat to Locale.US ## What changes were proposed in this pull request? Fix `Locale.US` for all usages of `DateFormat`, `NumberFormat` ## How was this patch tested? Existing tests. Author: Sean Owen Closes #15610 from srowen/SPARK-18076. (cherry picked from commit 9c8deef64efee20a0ddc9b612f90e77c80aede60) Signed-off-by: Sean Owen --- .../org/apache/spark/SparkHadoopWriter.scala | 8 +++---- .../apache/spark/deploy/SparkHadoopUtil.scala | 4 ++-- .../apache/spark/deploy/master/Master.scala | 5 ++-- .../apache/spark/deploy/worker/Worker.scala | 4 ++-- .../org/apache/spark/rdd/HadoopRDD.scala | 5 ++-- .../org/apache/spark/rdd/NewHadoopRDD.scala | 4 ++-- .../apache/spark/rdd/PairRDDFunctions.scala | 4 ++-- .../status/api/v1/JacksonMessageWriter.scala | 4 ++-- .../spark/status/api/v1/SimpleDateParam.scala | 6 ++--- .../scala/org/apache/spark/ui/UIUtils.scala | 3 ++- .../spark/util/logging/RollingPolicy.scala | 6 ++--- .../org/apache/spark/util/UtilsSuite.scala | 2 +- .../deploy/rest/mesos/MesosRestServer.scala | 11 ++++----- .../mllib/pmml/export/PMMLModelExport.scala | 4 ++-- .../expressions/datetimeExpressions.scala | 17 ++++++------- .../expressions/stringExpressions.scala | 2 +- .../spark/sql/catalyst/json/JSONOptions.scala | 6 +++-- .../sql/catalyst/util/DateTimeUtils.scala | 6 ++--- .../expressions/DateExpressionsSuite.scala | 24 +++++++++---------- .../catalyst/util/DateTimeUtilsSuite.scala | 6 ++--- .../datasources/csv/CSVInferSchema.scala | 4 ++-- .../datasources/csv/CSVOptions.scala | 5 ++-- .../sql/execution/metric/SQLMetrics.scala | 2 +- .../sql/execution/streaming/socket.scala | 4 ++-- .../apache/spark/sql/DateFunctionsSuite.scala | 11 +++++---- .../execution/datasources/csv/CSVSuite.scala | 9 +++---- .../datasources/csv/CSVTypeCastSuite.scala | 9 ++++--- .../hive/execution/InsertIntoHiveTable.scala | 9 +++---- .../spark/sql/hive/hiveWriterContainers.scala | 4 ++-- .../sql/sources/SimpleTextRelation.scala | 3 ++- .../apache/spark/streaming/ui/UIUtils.scala | 8 ++++--- 31 files changed, 103 insertions(+), 96 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 6550d703bc860..7f75a393bf8ff 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -20,7 +20,7 @@ package org.apache.spark import java.io.IOException import java.text.NumberFormat import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.fs.Path @@ -67,12 +67,12 @@ class SparkHadoopWriter(jobConf: JobConf) extends Logging with Serializable { def setup(jobid: Int, splitid: Int, attemptid: Int) { setIDs(jobid, splitid, attemptid) - HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss").format(now), + HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(now), jobid, splitID, attemptID, conf.value) } def open() { - val numfmt = NumberFormat.getInstance() + val numfmt = NumberFormat.getInstance(Locale.US) numfmt.setMinimumIntegerDigits(5) numfmt.setGroupingUsed(false) @@ -162,7 +162,7 @@ class SparkHadoopWriter(jobConf: JobConf) extends Logging with Serializable { private[spark] object SparkHadoopWriter { def createJobID(time: Date, id: Int): JobID = { - val formatter = new SimpleDateFormat("yyyyMMddHHmmss") + val formatter = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) val jobtrackerID = formatter.format(time) new JobID(jobtrackerID, id) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 3f54ecc17ac33..23156072c3ebe 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -21,7 +21,7 @@ import java.io.IOException import java.lang.reflect.Method import java.security.PrivilegedExceptionAction import java.text.DateFormat -import java.util.{Arrays, Comparator, Date} +import java.util.{Arrays, Comparator, Date, Locale} import scala.collection.JavaConverters._ import scala.util.control.NonFatal @@ -357,7 +357,7 @@ class SparkHadoopUtil extends Logging { * @return a printable string value. */ private[spark] def tokenToString(token: Token[_ <: TokenIdentifier]): String = { - val df = DateFormat.getDateTimeInstance(DateFormat.SHORT, DateFormat.SHORT) + val df = DateFormat.getDateTimeInstance(DateFormat.SHORT, DateFormat.SHORT, Locale.US) val buffer = new StringBuilder(128) buffer.append(token.toString) try { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 8c91aa15167c4..4618e6117a4fb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.master import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import java.util.concurrent.{ScheduledFuture, TimeUnit} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -51,7 +51,8 @@ private[deploy] class Master( private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs + // For application IDs + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) private val WORKER_TIMEOUT_MS = conf.getLong("spark.worker.timeout", 60) * 1000 private val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 0bedd9a20a969..8b1c6bf2e5fd5 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -20,7 +20,7 @@ package org.apache.spark.deploy.worker import java.io.File import java.io.IOException import java.text.SimpleDateFormat -import java.util.{Date, UUID} +import java.util.{Date, Locale, UUID} import java.util.concurrent._ import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} @@ -68,7 +68,7 @@ private[deploy] class Worker( ThreadUtils.newDaemonSingleThreadExecutor("worker-cleanup-thread")) // For worker and executor IDs - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) // Send a heartbeat every (heartbeat timeout) / 4 milliseconds private val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index e1cf3938de098..36a2f5c87e372 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.io.IOException import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import scala.collection.immutable.Map import scala.reflect.ClassTag @@ -243,7 +243,8 @@ class HadoopRDD[K, V]( var reader: RecordReader[K, V] = null val inputFormat = getInputFormat(jobConf) - HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss").format(createTime), + HadoopRDD.addLocalConfiguration( + new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(createTime), context.stageId, theSplit.index, context.attemptNumber, jobConf) reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index baf31fb658870..488e777fea371 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.io.IOException import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import scala.reflect.ClassTag @@ -79,7 +79,7 @@ class NewHadoopRDD[K, V]( // private val serializableConf = new SerializableWritable(_conf) private val jobTrackerId: String = { - val formatter = new SimpleDateFormat("yyyyMMddHHmmss") + val formatter = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) formatter.format(new Date()) } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 068f4ed8ad745..67baad1c51bca 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.nio.ByteBuffer import java.text.SimpleDateFormat -import java.util.{Date, HashMap => JHashMap} +import java.util.{Date, HashMap => JHashMap, Locale} import scala.collection.{mutable, Map} import scala.collection.JavaConverters._ @@ -1079,7 +1079,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf val job = NewAPIHadoopJob.getInstance(hadoopConf) - val formatter = new SimpleDateFormat("yyyyMMddHHmmss") + val formatter = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) val jobtrackerID = formatter.format(new Date()) val stageId = self.id val jobConfiguration = job.getConfiguration diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala index f6a9f9c5573db..76af33c1a18db 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala @@ -21,7 +21,7 @@ import java.lang.annotation.Annotation import java.lang.reflect.Type import java.nio.charset.StandardCharsets import java.text.SimpleDateFormat -import java.util.{Calendar, SimpleTimeZone} +import java.util.{Calendar, Locale, SimpleTimeZone} import javax.ws.rs.Produces import javax.ws.rs.core.{MediaType, MultivaluedMap} import javax.ws.rs.ext.{MessageBodyWriter, Provider} @@ -86,7 +86,7 @@ private[v1] class JacksonMessageWriter extends MessageBodyWriter[Object]{ private[spark] object JacksonMessageWriter { def makeISODateFormat: SimpleDateFormat = { - val iso8601 = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'GMT'") + val iso8601 = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'GMT'", Locale.US) val cal = Calendar.getInstance(new SimpleTimeZone(0, "GMT")) iso8601.setCalendar(cal) iso8601 diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala index 0c71cd2382225..d8d5e8958b23c 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala @@ -17,7 +17,7 @@ package org.apache.spark.status.api.v1 import java.text.{ParseException, SimpleDateFormat} -import java.util.TimeZone +import java.util.{Locale, TimeZone} import javax.ws.rs.WebApplicationException import javax.ws.rs.core.Response import javax.ws.rs.core.Response.Status @@ -25,12 +25,12 @@ import javax.ws.rs.core.Response.Status private[v1] class SimpleDateParam(val originalValue: String) { val timestamp: Long = { - val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz") + val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz", Locale.US) try { format.parse(originalValue).getTime() } catch { case _: ParseException => - val gmtDay = new SimpleDateFormat("yyyy-MM-dd") + val gmtDay = new SimpleDateFormat("yyyy-MM-dd", Locale.US) gmtDay.setTimeZone(TimeZone.getTimeZone("GMT")) try { gmtDay.parse(originalValue).getTime() diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index c0d1a2220f62a..66b097aa8166d 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -36,7 +36,8 @@ private[spark] object UIUtils extends Logging { // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. private val dateFormat = new ThreadLocal[SimpleDateFormat]() { - override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + override def initialValue(): SimpleDateFormat = + new SimpleDateFormat("yyyy/MM/dd HH:mm:ss", Locale.US) } def formatDate(date: Date): String = dateFormat.get.format(date) diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala index 5c4238c0381a1..1f263df57c857 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala @@ -18,7 +18,7 @@ package org.apache.spark.util.logging import java.text.SimpleDateFormat -import java.util.Calendar +import java.util.{Calendar, Locale} import org.apache.spark.internal.Logging @@ -59,7 +59,7 @@ private[spark] class TimeBasedRollingPolicy( } @volatile private var nextRolloverTime = calculateNextRolloverTime() - private val formatter = new SimpleDateFormat(rollingFileSuffixPattern) + private val formatter = new SimpleDateFormat(rollingFileSuffixPattern, Locale.US) /** Should rollover if current time has exceeded next rollover time */ def shouldRollover(bytesToBeWritten: Long): Boolean = { @@ -109,7 +109,7 @@ private[spark] class SizeBasedRollingPolicy( } @volatile private var bytesWrittenSinceRollover = 0L - val formatter = new SimpleDateFormat("--yyyy-MM-dd--HH-mm-ss--SSSS") + val formatter = new SimpleDateFormat("--yyyy-MM-dd--HH-mm-ss--SSSS", Locale.US) /** Should rollover if the next set of bytes is going to exceed the size limit */ def shouldRollover(bytesToBeWritten: Long): Boolean = { diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 15ef32f21d90c..feacfb7642f27 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -264,7 +264,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { val hour = minute * 60 def str: (Long) => String = Utils.msDurationToString(_) - val sep = new DecimalFormatSymbols(Locale.getDefault()).getDecimalSeparator() + val sep = new DecimalFormatSymbols(Locale.US).getDecimalSeparator assert(str(123) === "123 ms") assert(str(second) === "1" + sep + "0 s") diff --git a/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala index 3b96488a129a9..ff60b88c6d533 100644 --- a/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala +++ b/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.rest.mesos import java.io.File import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import java.util.concurrent.atomic.AtomicLong import javax.servlet.http.HttpServletResponse @@ -62,11 +62,10 @@ private[mesos] class MesosSubmitRequestServlet( private val DEFAULT_CORES = 1.0 private val nextDriverNumber = new AtomicLong(0) - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs - private def newDriverId(submitDate: Date): String = { - "driver-%s-%04d".format( - createDateFormat.format(submitDate), nextDriverNumber.incrementAndGet()) - } + // For application IDs + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) + private def newDriverId(submitDate: Date): String = + f"driver-${createDateFormat.format(submitDate)}-${nextDriverNumber.incrementAndGet()}%04d" /** * Build a driver description from the fields specified in the submit request. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala index 426bb818c9266..f5ca1c221d66b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.pmml.export import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import scala.beans.BeanProperty @@ -34,7 +34,7 @@ private[mllib] trait PMMLModelExport { val version = getClass.getPackage.getImplementationVersion val app = new Application("Apache Spark MLlib").setVersion(version) val timestamp = new Timestamp() - .addContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss").format(new Date())) + .addContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss", Locale.US).format(new Date())) val header = new Header() .setApplication(app) .setTimestamp(timestamp) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 7ab68a13e09cf..67c078ae5e264 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.text.SimpleDateFormat -import java.util.{Calendar, TimeZone} +import java.util.{Calendar, Locale, TimeZone} import scala.util.Try @@ -331,7 +331,7 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) override protected def nullSafeEval(timestamp: Any, format: Any): Any = { - val sdf = new SimpleDateFormat(format.toString) + val sdf = new SimpleDateFormat(format.toString, Locale.US) UTF8String.fromString(sdf.format(new java.util.Date(timestamp.asInstanceOf[Long] / 1000))) } @@ -400,7 +400,7 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] private lazy val formatter: SimpleDateFormat = - Try(new SimpleDateFormat(constFormat.toString)).getOrElse(null) + Try(new SimpleDateFormat(constFormat.toString, Locale.US)).getOrElse(null) override def eval(input: InternalRow): Any = { val t = left.eval(input) @@ -425,7 +425,7 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { null } else { val formatString = f.asInstanceOf[UTF8String].toString - Try(new SimpleDateFormat(formatString).parse( + Try(new SimpleDateFormat(formatString, Locale.US).parse( t.asInstanceOf[UTF8String].toString).getTime / 1000L).getOrElse(null) } } @@ -520,7 +520,7 @@ case class FromUnixTime(sec: Expression, format: Expression) private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] private lazy val formatter: SimpleDateFormat = - Try(new SimpleDateFormat(constFormat.toString)).getOrElse(null) + Try(new SimpleDateFormat(constFormat.toString, Locale.US)).getOrElse(null) override def eval(input: InternalRow): Any = { val time = left.eval(input) @@ -539,9 +539,10 @@ case class FromUnixTime(sec: Expression, format: Expression) if (f == null) { null } else { - Try(UTF8String.fromString(new SimpleDateFormat( - f.asInstanceOf[UTF8String].toString).format(new java.util.Date( - time.asInstanceOf[Long] * 1000L)))).getOrElse(null) + Try( + UTF8String.fromString(new SimpleDateFormat(f.toString, Locale.US). + format(new java.util.Date(time.asInstanceOf[Long] * 1000L))) + ).getOrElse(null) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 1bcbb6cfc9246..25a5e3fd7da73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1415,7 +1415,7 @@ case class Sentences( val locale = if (languageStr != null && countryStr != null) { new Locale(languageStr.toString, countryStr.toString) } else { - Locale.getDefault + Locale.US } getSentences(string.asInstanceOf[UTF8String].toString, locale) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index aec18922ea6c8..c45970658cf07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.json +import java.util.Locale + import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import org.apache.commons.lang3.time.FastDateFormat @@ -56,11 +58,11 @@ private[sql] class JSONOptions( // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. val dateFormat: FastDateFormat = - FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd")) + FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US) val timestampFormat: FastDateFormat = FastDateFormat.getInstance( - parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ")) + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), Locale.US) // Parse mode flags if (!ParseModes.isValidMode(parseMode)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 0b643a5b84268..235ca8d2633a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} -import java.util.{Calendar, TimeZone} +import java.util.{Calendar, Locale, TimeZone} import javax.xml.bind.DatatypeConverter import scala.annotation.tailrec @@ -79,14 +79,14 @@ object DateTimeUtils { // `SimpleDateFormat` is not thread-safe. val threadLocalTimestampFormat = new ThreadLocal[DateFormat] { override def initialValue(): SimpleDateFormat = { - new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) } } // `SimpleDateFormat` is not thread-safe. private val threadLocalDateFormat = new ThreadLocal[DateFormat] { override def initialValue(): SimpleDateFormat = { - new SimpleDateFormat("yyyy-MM-dd") + new SimpleDateFormat("yyyy-MM-dd", Locale.US) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 6118a34d29eaa..35cea25ba0b7d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat -import java.util.Calendar +import java.util.{Calendar, Locale} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -30,8 +30,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { import IntegralLiteralTestUtils._ - val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") - val sdfDate = new SimpleDateFormat("yyyy-MM-dd") + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) + val sdfDate = new SimpleDateFormat("yyyy-MM-dd", Locale.US) val d = new Date(sdf.parse("2015-04-08 13:10:15").getTime) val ts = new Timestamp(sdf.parse("2013-11-08 13:10:15").getTime) @@ -49,7 +49,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("DayOfYear") { - val sdfDay = new SimpleDateFormat("D") + val sdfDay = new SimpleDateFormat("D", Locale.US) (0 to 3).foreach { m => (0 to 5).foreach { i => val c = Calendar.getInstance() @@ -411,9 +411,9 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("from_unixtime") { - val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" - val sdf2 = new SimpleDateFormat(fmt2) + val sdf2 = new SimpleDateFormat(fmt2, Locale.US) checkEvaluation( FromUnixTime(Literal(0L), Literal("yyyy-MM-dd HH:mm:ss")), sdf1.format(new Timestamp(0))) checkEvaluation(FromUnixTime( @@ -430,11 +430,11 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("unix_timestamp") { - val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" - val sdf2 = new SimpleDateFormat(fmt2) + val sdf2 = new SimpleDateFormat(fmt2, Locale.US) val fmt3 = "yy-MM-dd" - val sdf3 = new SimpleDateFormat(fmt3) + val sdf3 = new SimpleDateFormat(fmt3, Locale.US) val date1 = Date.valueOf("2015-07-24") checkEvaluation( UnixTimestamp(Literal(sdf1.format(new Timestamp(0))), Literal("yyyy-MM-dd HH:mm:ss")), 0L) @@ -466,11 +466,11 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("to_unix_timestamp") { - val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" - val sdf2 = new SimpleDateFormat(fmt2) + val sdf2 = new SimpleDateFormat(fmt2, Locale.US) val fmt3 = "yy-MM-dd" - val sdf3 = new SimpleDateFormat(fmt3) + val sdf3 = new SimpleDateFormat(fmt3, Locale.US) val date1 = Date.valueOf("2015-07-24") checkEvaluation( ToUnixTimestamp(Literal(sdf1.format(new Timestamp(0))), Literal("yyyy-MM-dd HH:mm:ss")), 0L) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 4f516d006458e..e0a9a0c3d5c00 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat -import java.util.{Calendar, TimeZone} +import java.util.{Calendar, Locale, TimeZone} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.DateTimeUtils._ @@ -68,8 +68,8 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(d2.toString === d1.toString) } - val df1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") - val df2 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss z") + val df1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) + val df2 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss z", Locale.US) checkFromToJavaDate(new Date(100)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index 3ab775c909238..1981d8607c0c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -247,7 +247,7 @@ private[csv] object CSVTypeCast { case options.positiveInf => Float.PositiveInfinity case _ => Try(datum.toFloat) - .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue()) + .getOrElse(NumberFormat.getInstance(Locale.US).parse(datum).floatValue()) } case _: DoubleType => datum match { @@ -256,7 +256,7 @@ private[csv] object CSVTypeCast { case options.positiveInf => Double.PositiveInfinity case _ => Try(datum.toDouble) - .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue()) + .getOrElse(NumberFormat.getInstance(Locale.US).parse(datum).doubleValue()) } case _: BooleanType => datum.toBoolean case dt: DecimalType => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 014614eb997a5..5903729c11fc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.csv import java.nio.charset.StandardCharsets +import java.util.Locale import org.apache.commons.lang3.time.FastDateFormat @@ -104,11 +105,11 @@ private[csv] class CSVOptions(@transient private val parameters: Map[String, Str // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. val dateFormat: FastDateFormat = - FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd")) + FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US) val timestampFormat: FastDateFormat = FastDateFormat.getInstance( - parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ")) + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), Locale.US) val maxColumns = getInt("maxColumns", 20480) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 0cc1edd196bc8..dbc27d8b237f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -102,7 +102,7 @@ object SQLMetrics { */ def stringValue(metricsType: String, values: Seq[Long]): String = { if (metricsType == SUM_METRIC) { - val numberFormat = NumberFormat.getIntegerInstance(Locale.ENGLISH) + val numberFormat = NumberFormat.getIntegerInstance(Locale.US) numberFormat.format(values.sum) } else { val strFormat: Long => String = if (metricsType == SIZE_METRIC) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala index c662e7c6bc775..042977f870b8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala @@ -21,7 +21,7 @@ import java.io.{BufferedReader, InputStreamReader, IOException} import java.net.Socket import java.sql.Timestamp import java.text.SimpleDateFormat -import java.util.Calendar +import java.util.{Calendar, Locale} import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.ListBuffer @@ -37,7 +37,7 @@ object TextSocketSource { val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil) val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) :: StructField("timestamp", TimestampType) :: Nil) - val DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index f7aa3b747ae5d..e05b2252ee346 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat +import java.util.Locale import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions._ @@ -55,8 +56,8 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(sql("""SELECT CURRENT_TIMESTAMP() = NOW()"""), Row(true)) } - val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") - val sdfDate = new SimpleDateFormat("yyyy-MM-dd") + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) + val sdfDate = new SimpleDateFormat("yyyy-MM-dd", Locale.US) val d = new Date(sdf.parse("2015-04-08 13:10:15").getTime) val ts = new Timestamp(sdf.parse("2013-04-08 13:10:15").getTime) @@ -395,11 +396,11 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { } test("from_unixtime") { - val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" - val sdf2 = new SimpleDateFormat(fmt2) + val sdf2 = new SimpleDateFormat(fmt2, Locale.US) val fmt3 = "yy-MM-dd HH-mm-ss" - val sdf3 = new SimpleDateFormat(fmt3) + val sdf3 = new SimpleDateFormat(fmt3, Locale.US) val df = Seq((1000, "yyyy-MM-dd HH:mm:ss.SSS"), (-1000, "yy-MM-dd HH-mm-ss")).toDF("a", "b") checkAnswer( df.select(from_unixtime(col("a"))), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index f7c22c6c93f7a..8209b5bd7f9de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -21,6 +21,7 @@ import java.io.File import java.nio.charset.UnsupportedCharsetException import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat +import java.util.Locale import org.apache.commons.lang3.time.FastDateFormat import org.apache.hadoop.io.SequenceFile.CompressionType @@ -487,7 +488,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .select("date") .collect() - val dateFormat = new SimpleDateFormat("dd/MM/yyyy HH:mm") + val dateFormat = new SimpleDateFormat("dd/MM/yyyy HH:mm", Locale.US) val expected = Seq(Seq(new Timestamp(dateFormat.parse("26/08/2015 18:00").getTime)), Seq(new Timestamp(dateFormat.parse("27/10/2014 18:30").getTime)), @@ -509,7 +510,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .select("date") .collect() - val dateFormat = new SimpleDateFormat("dd/MM/yyyy hh:mm") + val dateFormat = new SimpleDateFormat("dd/MM/yyyy hh:mm", Locale.US) val expected = Seq( new Date(dateFormat.parse("26/08/2015 18:00").getTime), new Date(dateFormat.parse("27/10/2014 18:30").getTime), @@ -728,7 +729,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .option("inferSchema", "false") .load(iso8601timestampsPath) - val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSZZ") + val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSZZ", Locale.US) val expectedTimestamps = timestamps.collect().map { r => // This should be ISO8601 formatted string. Row(iso8501.format(r.toSeq.head)) @@ -761,7 +762,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .option("inferSchema", "false") .load(iso8601datesPath) - val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd") + val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd", Locale.US) val expectedDates = dates.collect().map { r => // This should be ISO8601 formatted string. Row(iso8501.format(r.toSeq.head)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala index 51832a13cfe0b..c74406b9cbfbb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala @@ -144,13 +144,12 @@ class CSVTypeCastSuite extends SparkFunSuite { DateTimeUtils.millisToDays(DateTimeUtils.stringToTime("2015-01-01").getTime)) } - test("Float and Double Types are cast correctly with Locale") { + test("Float and Double Types are cast without respect to platform default Locale") { val originalLocale = Locale.getDefault try { - val locale : Locale = new Locale("fr", "FR") - Locale.setDefault(locale) - assert(CSVTypeCast.castTo("1,00", FloatType) == 1.0) - assert(CSVTypeCast.castTo("1,00", DoubleType) == 1.0) + Locale.setDefault(new Locale("fr", "FR")) + assert(CSVTypeCast.castTo("1,00", FloatType) == 100.0) // Would parse as 1.0 in fr-FR + assert(CSVTypeCast.castTo("1,00", DoubleType) == 100.0) } finally { Locale.setDefault(originalLocale) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 2843100fb3b36..05164d774ccaf 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -20,9 +20,7 @@ package org.apache.spark.sql.hive.execution import java.io.IOException import java.net.URI import java.text.SimpleDateFormat -import java.util.{Date, Random} - -import scala.collection.JavaConverters._ +import java.util.{Date, Locale, Random} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} @@ -60,9 +58,8 @@ case class InsertIntoHiveTable( private def executionId: String = { val rand: Random = new Random - val format: SimpleDateFormat = new SimpleDateFormat("yyyy-MM-dd_HH-mm-ss_SSS") - val executionId: String = "hive_" + format.format(new Date) + "_" + Math.abs(rand.nextLong) - return executionId + val format = new SimpleDateFormat("yyyy-MM-dd_HH-mm-ss_SSS", Locale.US) + "hive_" + format.format(new Date) + "_" + Math.abs(rand.nextLong) } private def getStagingDir(inputPath: Path, hadoopConf: Configuration): Path = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index ea88276bb96c0..e53c3e4d4833b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive import java.text.NumberFormat -import java.util.Date +import java.util.{Date, Locale} import scala.collection.JavaConverters._ @@ -95,7 +95,7 @@ private[hive] class SparkHiveWriterContainer( } protected def getOutputName: String = { - val numberFormat = NumberFormat.getInstance() + val numberFormat = NumberFormat.getInstance(Locale.US) numberFormat.setMinimumIntegerDigits(5) numberFormat.setGroupingUsed(false) val extension = Utilities.getFileExtension(conf.value, fileSinkConf.getCompressed, outputFormat) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 64d0ecbeefc98..cecfd99098659 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources import java.text.NumberFormat +import java.util.Locale import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} @@ -141,7 +142,7 @@ class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) class AppendingTextOutputFormat(path: String) extends TextOutputFormat[NullWritable, Text] { - val numberFormat = NumberFormat.getInstance() + val numberFormat = NumberFormat.getInstance(Locale.US) numberFormat.setMinimumIntegerDigits(5) numberFormat.setGroupingUsed(false) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala index 9b1c939e9329f..84ecf81abfbf1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.ui import java.text.SimpleDateFormat -import java.util.TimeZone +import java.util.{Locale, TimeZone} import java.util.concurrent.TimeUnit import scala.xml.Node @@ -80,11 +80,13 @@ private[streaming] object UIUtils { // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. private val batchTimeFormat = new ThreadLocal[SimpleDateFormat]() { - override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + override def initialValue(): SimpleDateFormat = + new SimpleDateFormat("yyyy/MM/dd HH:mm:ss", Locale.US) } private val batchTimeFormatWithMilliseconds = new ThreadLocal[SimpleDateFormat]() { - override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss.SSS") + override def initialValue(): SimpleDateFormat = + new SimpleDateFormat("yyyy/MM/dd HH:mm:ss.SSS", Locale.US) } /** From 41491e54080742f6e4a1e80a72cd9f46a9336e31 Mon Sep 17 00:00:00 2001 From: eyal farago Date: Wed, 2 Nov 2016 11:12:20 +0100 Subject: [PATCH 008/534] [SPARK-16839][SQL] Simplify Struct creation code path ## What changes were proposed in this pull request? Simplify struct creation, especially the aspect of `CleanupAliases` which missed some aliases when handling trees created by `CreateStruct`. This PR includes: 1. A failing test (create struct with nested aliases, some of the aliases survive `CleanupAliases`). 2. A fix that transforms `CreateStruct` into a `CreateNamedStruct` constructor, effectively eliminating `CreateStruct` from all expression trees. 3. A `NamePlaceHolder` used by `CreateStruct` when column names cannot be extracted from unresolved `NamedExpression`. 4. A new Analyzer rule that resolves `NamePlaceHolder` into a string literal once the `NamedExpression` is resolved. 5. `CleanupAliases` code was simplified as it no longer has to deal with `CreateStruct`'s top level columns. ## How was this patch tested? Running all tests-suits in package org.apache.spark.sql, especially including the analysis suite, making sure added test initially fails, after applying suggested fix rerun the entire analysis package successfully. Modified few tests that expected `CreateStruct` which is now transformed into `CreateNamedStruct`. Author: eyal farago Author: Herman van Hovell Author: eyal farago Author: Eyal Farago Author: Hyukjin Kwon Author: eyalfa Closes #15718 from hvanhovell/SPARK-16839-2. (cherry picked from commit f151bd1af8a05d4b6c901ebe6ac0b51a4a1a20df) Signed-off-by: Herman van Hovell --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 12 +- .../sql/catalyst/analysis/Analyzer.scala | 53 ++--- .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../sql/catalyst/expressions/Projection.scala | 2 - .../expressions/complexTypeCreator.scala | 212 ++++++------------ .../sql/catalyst/parser/AstBuilder.scala | 4 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 38 +++- .../expressions/ComplexTypeSuite.scala | 1 - .../scala/org/apache/spark/sql/Column.scala | 3 + .../command/AnalyzeColumnCommand.scala | 4 +- .../sql-tests/results/group-by.sql.out | 2 +- .../apache/spark/sql/hive/test/TestHive.scala | 20 +- .../resources/sqlgen/subquery_in_having_2.sql | 2 +- .../sql/catalyst/LogicalPlanToSQLSuite.scala | 12 +- 14 files changed, 169 insertions(+), 198 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 806019d7524ff..d7fe6b32822a7 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1222,16 +1222,16 @@ test_that("column functions", { # Test struct() df <- createDataFrame(list(list(1L, 2L, 3L), list(4L, 5L, 6L)), schema = c("a", "b", "c")) - result <- collect(select(df, struct("a", "c"))) + result <- collect(select(df, alias(struct("a", "c"), "d"))) expected <- data.frame(row.names = 1:2) - expected$"struct(a, c)" <- list(listToStruct(list(a = 1L, c = 3L)), - listToStruct(list(a = 4L, c = 6L))) + expected$"d" <- list(listToStruct(list(a = 1L, c = 3L)), + listToStruct(list(a = 4L, c = 6L))) expect_equal(result, expected) - result <- collect(select(df, struct(df$a, df$b))) + result <- collect(select(df, alias(struct(df$a, df$b), "d"))) expected <- data.frame(row.names = 1:2) - expected$"struct(a, b)" <- list(listToStruct(list(a = 1L, b = 2L)), - listToStruct(list(a = 4L, b = 5L))) + expected$"d" <- list(listToStruct(list(a = 1L, b = 2L)), + listToStruct(list(a = 4L, b = 5L))) expect_equal(result, expected) # Test encode(), decode() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f8f4799322b3b..5011f2fdbf9b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.trees.{TreeNodeRef} +import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.types._ @@ -83,6 +83,7 @@ class Analyzer( ResolveTableValuedFunctions :: ResolveRelations :: ResolveReferences :: + ResolveCreateNamedStruct :: ResolveDeserializer :: ResolveNewInstance :: ResolveUpCast :: @@ -653,11 +654,12 @@ class Analyzer( case s: Star => s.expand(child, resolver) case o => o :: Nil }) - case c: CreateStruct if containsStar(c.children) => - c.copy(children = c.children.flatMap { - case s: Star => s.expand(child, resolver) - case o => o :: Nil - }) + case c: CreateNamedStruct if containsStar(c.valExprs) => + val newChildren = c.children.grouped(2).flatMap { + case Seq(k, s : Star) => CreateStruct(s.expand(child, resolver)).children + case kv => kv + } + c.copy(children = newChildren.toList ) case c: CreateArray if containsStar(c.children) => c.copy(children = c.children.flatMap { case s: Star => s.expand(child, resolver) @@ -1141,7 +1143,7 @@ class Analyzer( case In(e, Seq(l @ ListQuery(_, exprId))) if e.resolved => // Get the left hand side expressions. val expressions = e match { - case CreateStruct(exprs) => exprs + case cns : CreateNamedStruct => cns.valExprs case expr => Seq(expr) } resolveSubQuery(l, plans, expressions.size) { (rewrite, conditions) => @@ -2072,18 +2074,8 @@ object EliminateUnions extends Rule[LogicalPlan] { */ object CleanupAliases extends Rule[LogicalPlan] { private def trimAliases(e: Expression): Expression = { - var stop = false e.transformDown { - // CreateStruct is a special case, we need to retain its top level Aliases as they decide the - // name of StructField. We also need to stop transform down this expression, or the Aliases - // under CreateStruct will be mistakenly trimmed. - case c: CreateStruct if !stop => - stop = true - c.copy(children = c.children.map(trimNonTopLevelAliases)) - case c: CreateStructUnsafe if !stop => - stop = true - c.copy(children = c.children.map(trimNonTopLevelAliases)) - case Alias(child, _) if !stop => child + case Alias(child, _) => child } } @@ -2116,15 +2108,8 @@ object CleanupAliases extends Rule[LogicalPlan] { case a: AppendColumns => a case other => - var stop = false other transformExpressionsDown { - case c: CreateStruct if !stop => - stop = true - c.copy(children = c.children.map(trimNonTopLevelAliases)) - case c: CreateStructUnsafe if !stop => - stop = true - c.copy(children = c.children.map(trimNonTopLevelAliases)) - case Alias(child, _) if !stop => child + case Alias(child, _) => child } } } @@ -2217,3 +2202,19 @@ object TimeWindowing extends Rule[LogicalPlan] { } } } + +/** + * Resolve a [[CreateNamedStruct]] if it contains [[NamePlaceholder]]s. + */ +object ResolveCreateNamedStruct extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { + case e: CreateNamedStruct if !e.resolved => + val children = e.children.grouped(2).flatMap { + case Seq(NamePlaceholder, e: NamedExpression) if e.resolved => + Seq(Literal(e.name), e) + case kv => + kv + } + CreateNamedStruct(children.toList) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 3e836ca375e2e..b028d07fb8d0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -357,7 +357,7 @@ object FunctionRegistry { expression[MapValues]("map_values"), expression[Size]("size"), expression[SortArray]("sort_array"), - expression[CreateStruct]("struct"), + CreateStruct.registryEntry, // misc functions expression[AssertTrue]("assert_true"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index a81fa1ce3adcc..03e054d098511 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -119,7 +119,6 @@ object UnsafeProjection { */ def create(exprs: Seq[Expression]): UnsafeProjection = { val unsafeExprs = exprs.map(_ transform { - case CreateStruct(children) => CreateStructUnsafe(children) case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) }) GenerateUnsafeProjection.generate(unsafeExprs) @@ -145,7 +144,6 @@ object UnsafeProjection { subexpressionEliminationEnabled: Boolean): UnsafeProjection = { val e = exprs.map(BindReferences.bindReference(_, inputSchema)) .map(_ transform { - case CreateStruct(children) => CreateStructUnsafe(children) case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) }) GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 917aa0873130b..dbfb2996ec9d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -18,9 +18,11 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.analysis.Star import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData, TypeUtils} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -172,101 +174,71 @@ case class CreateMap(children: Seq[Expression]) extends Expression { } /** - * Returns a Row containing the evaluation of all children expressions. + * An expression representing a not yet available attribute name. This expression is unevaluable + * and as its name suggests it is a temporary place holder until we're able to determine the + * actual attribute name. */ -@ExpressionDescription( - usage = "_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.") -case class CreateStruct(children: Seq[Expression]) extends Expression { - - override def foldable: Boolean = children.forall(_.foldable) - - override lazy val dataType: StructType = { - val fields = children.zipWithIndex.map { case (child, idx) => - child match { - case ne: NamedExpression => - StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) - case _ => - StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) - } - } - StructType(fields) - } - +case object NamePlaceholder extends LeafExpression with Unevaluable { + override lazy val resolved: Boolean = false + override def foldable: Boolean = false override def nullable: Boolean = false + override def dataType: DataType = StringType + override def prettyName: String = "NamePlaceholder" + override def toString: String = prettyName +} - override def eval(input: InternalRow): Any = { - InternalRow(children.map(_.eval(input)): _*) +/** + * Returns a Row containing the evaluation of all children expressions. + */ +object CreateStruct extends FunctionBuilder { + def apply(children: Seq[Expression]): CreateNamedStruct = { + CreateNamedStruct(children.zipWithIndex.flatMap { + case (e: NamedExpression, _) if e.resolved => Seq(Literal(e.name), e) + case (e: NamedExpression, _) => Seq(NamePlaceholder, e) + case (e, index) => Seq(Literal(s"col${index + 1}"), e) + }) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val rowClass = classOf[GenericInternalRow].getName - val values = ctx.freshName("values") - ctx.addMutableState("Object[]", values, s"this.$values = null;") - - ev.copy(code = s""" - boolean ${ev.isNull} = false; - this.$values = new Object[${children.size}];""" + - ctx.splitExpressions( - ctx.INPUT_ROW, - children.zipWithIndex.map { case (e, i) => - val eval = e.genCode(ctx) - eval.code + s""" - if (${eval.isNull}) { - $values[$i] = null; - } else { - $values[$i] = ${eval.value}; - }""" - }) + - s""" - final InternalRow ${ev.value} = new $rowClass($values); - this.$values = null; - """) + /** + * Entry to use in the function registry. + */ + val registryEntry: (String, (ExpressionInfo, FunctionBuilder)) = { + val info: ExpressionInfo = new ExpressionInfo( + "org.apache.spark.sql.catalyst.expressions.NamedStruct", + null, + "struct", + "_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.", + "") + ("struct", (info, this)) } - - override def prettyName: String = "struct" } - /** - * Creates a struct with the given field names and values - * - * @param children Seq(name1, val1, name2, val2, ...) + * Common base class for both [[CreateNamedStruct]] and [[CreateNamedStructUnsafe]]. */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.") -// scalastyle:on line.size.limit -case class CreateNamedStruct(children: Seq[Expression]) extends Expression { +trait CreateNamedStructLike extends Expression { + lazy val (nameExprs, valExprs) = children.grouped(2).map { + case Seq(name, value) => (name, value) + }.toList.unzip - /** - * Returns Aliased [[Expression]]s that could be used to construct a flattened version of this - * StructType. - */ - def flatten: Seq[NamedExpression] = valExprs.zip(names).map { - case (v, n) => Alias(v, n.toString)() - } + lazy val names = nameExprs.map(_.eval(EmptyRow)) - private lazy val (nameExprs, valExprs) = - children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip + override def nullable: Boolean = false - private lazy val names = nameExprs.map(_.eval(EmptyRow)) + override def foldable: Boolean = valExprs.forall(_.foldable) override lazy val dataType: StructType = { val fields = names.zip(valExprs).map { - case (name, valExpr: NamedExpression) => - StructField(name.asInstanceOf[UTF8String].toString, - valExpr.dataType, valExpr.nullable, valExpr.metadata) - case (name, valExpr) => - StructField(name.asInstanceOf[UTF8String].toString, - valExpr.dataType, valExpr.nullable, Metadata.empty) + case (name, expr) => + val metadata = expr match { + case ne: NamedExpression => ne.metadata + case _ => Metadata.empty + } + StructField(name.toString, expr.dataType, expr.nullable, metadata) } StructType(fields) } - override def foldable: Boolean = valExprs.forall(_.foldable) - - override def nullable: Boolean = false - override def checkInputDataTypes(): TypeCheckResult = { if (children.size % 2 != 0) { TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.") @@ -274,8 +246,8 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) if (invalidNames.nonEmpty) { TypeCheckResult.TypeCheckFailure( - s"Only foldable StringType expressions are allowed to appear at odd position , got :" + - s" ${invalidNames.mkString(",")}") + "Only foldable StringType expressions are allowed to appear at odd position, got:" + + s" ${invalidNames.mkString(",")}") } else if (!names.contains(null)) { TypeCheckResult.TypeCheckSuccess } else { @@ -284,9 +256,29 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { } } + /** + * Returns Aliased [[Expression]]s that could be used to construct a flattened version of this + * StructType. + */ + def flatten: Seq[NamedExpression] = valExprs.zip(names).map { + case (v, n) => Alias(v, n.toString)() + } + override def eval(input: InternalRow): Any = { InternalRow(valExprs.map(_.eval(input)): _*) } +} + +/** + * Creates a struct with the given field names and values + * + * @param children Seq(name1, val1, name2, val2, ...) + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.") +// scalastyle:on line.size.limit +case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStructLike { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericInternalRow].getName @@ -316,44 +308,6 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { override def prettyName: String = "named_struct" } -/** - * Returns a Row containing the evaluation of all children expressions. This is a variant that - * returns UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with - * this expression automatically at runtime. - */ -case class CreateStructUnsafe(children: Seq[Expression]) extends Expression { - - override def foldable: Boolean = children.forall(_.foldable) - - override lazy val resolved: Boolean = childrenResolved - - override lazy val dataType: StructType = { - val fields = children.zipWithIndex.map { case (child, idx) => - child match { - case ne: NamedExpression => - StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) - case _ => - StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) - } - } - StructType(fields) - } - - override def nullable: Boolean = false - - override def eval(input: InternalRow): Any = { - InternalRow(children.map(_.eval(input)): _*) - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val eval = GenerateUnsafeProjection.createCode(ctx, children) - ExprCode(code = eval.code, isNull = eval.isNull, value = eval.value) - } - - override def prettyName: String = "struct_unsafe" -} - - /** * Creates a struct with the given field names and values. This is a variant that returns * UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with @@ -361,31 +315,7 @@ case class CreateStructUnsafe(children: Seq[Expression]) extends Expression { * * @param children Seq(name1, val1, name2, val2, ...) */ -case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression { - - private lazy val (nameExprs, valExprs) = - children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip - - private lazy val names = nameExprs.map(_.eval(EmptyRow).toString) - - override lazy val dataType: StructType = { - val fields = names.zip(valExprs).map { - case (name, valExpr: NamedExpression) => - StructField(name, valExpr.dataType, valExpr.nullable, valExpr.metadata) - case (name, valExpr) => - StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) - } - StructType(fields) - } - - override def foldable: Boolean = valExprs.forall(_.foldable) - - override def nullable: Boolean = false - - override def eval(input: InternalRow): Any = { - InternalRow(valExprs.map(_.eval(input)): _*) - } - +case class CreateNamedStructUnsafe(children: Seq[Expression]) extends CreateNamedStructLike { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) ExprCode(code = eval.code, isNull = eval.isNull, value = eval.value) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index ac1577b3abb4d..4b151c81d8f8b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -688,8 +688,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { // inline table comes in two styles: // style 1: values (1), (2), (3) -- multiple columns are supported // style 2: values 1, 2, 3 -- only a single column is supported here - case CreateStruct(children) => children // style 1 - case child => Seq(child) // style 2 + case struct: CreateNamedStruct => struct.valExprs // style 1 + case child => Seq(child) // style 2 } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 590774c043040..817de48de2798 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import org.scalatest.ShouldMatchers + import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -25,7 +27,8 @@ import org.apache.spark.sql.catalyst.plans.{Cross, Inner} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ -class AnalysisSuite extends AnalysisTest { + +class AnalysisSuite extends AnalysisTest with ShouldMatchers { import org.apache.spark.sql.catalyst.analysis.TestRelations._ test("union project *") { @@ -218,9 +221,36 @@ class AnalysisSuite extends AnalysisTest { // CreateStruct is a special case that we should not trim Alias for it. plan = testRelation.select(CreateStruct(Seq(a, (a + 1).as("a+1"))).as("col")) - checkAnalysis(plan, plan) - plan = testRelation.select(CreateStructUnsafe(Seq(a, (a + 1).as("a+1"))).as("col")) - checkAnalysis(plan, plan) + expected = testRelation.select(CreateNamedStruct(Seq( + Literal(a.name), a, + Literal("a+1"), (a + 1))).as("col")) + checkAnalysis(plan, expected) + } + + test("Analysis may leave unnecassary aliases") { + val att1 = testRelation.output.head + var plan = testRelation.select( + CreateStruct(Seq(att1, ((att1.as("aa")) + 1).as("a_plus_1"))).as("col"), + att1 + ) + val prevPlan = getAnalyzer(true).execute(plan) + plan = prevPlan.select(CreateArray(Seq( + CreateStruct(Seq(att1, (att1 + 1).as("a_plus_1"))).as("col1"), + /** alias should be eliminated by [[CleanupAliases]] */ + "col".attr.as("col2") + )).as("arr")) + plan = getAnalyzer(true).execute(plan) + + val expectedPlan = prevPlan.select( + CreateArray(Seq( + CreateNamedStruct(Seq( + Literal(att1.name), att1, + Literal("a_plus_1"), (att1 + 1))), + 'col.struct(prevPlan.output(0).dataType.asInstanceOf[StructType]).notNull + )).as("arr") + ) + + checkAnalysis(plan, expectedPlan) } test("SPARK-10534: resolve attribute references in order by clause") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 0c307b2b8576b..c21c6de32c0ba 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -243,7 +243,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val b = AttributeReference("b", IntegerType)() checkMetadata(CreateStruct(Seq(a, b))) checkMetadata(CreateNamedStruct(Seq("a", a, "b", b))) - checkMetadata(CreateStructUnsafe(Seq(a, b))) checkMetadata(CreateNamedStructUnsafe(Seq("a", a, "b", b))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 249408e0fbce4..7a131b30eafd7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -186,6 +186,9 @@ class Column(val expr: Expression) extends Logging { case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => UnresolvedAlias(a, Some(Column.generateAlias)) + // Wait until the struct is resolved. This will generate a nicer looking alias. + case struct: CreateNamedStructLike => UnresolvedAlias(struct) + case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index f873f34a845ef..6141fab4aff0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -137,7 +137,7 @@ object ColumnStatStruct { private def numTrues(e: Expression): Expression = Sum(If(e, one, zero)) private def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero)) - private def getStruct(exprs: Seq[Expression]): CreateStruct = { + private def getStruct(exprs: Seq[Expression]): CreateNamedStruct = { CreateStruct(exprs.map { expr: Expression => expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() @@ -168,7 +168,7 @@ object ColumnStatStruct { } } - def apply(attr: Attribute, relativeSD: Double): CreateStruct = attr.dataType match { + def apply(attr: Attribute, relativeSD: Double): CreateNamedStruct = attr.dataType match { // Use aggregate functions to compute statistics we need. case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(attr, relativeSD)) case StringType => getStruct(stringColumnStat(attr, relativeSD)) diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index a91f04e098b18..af6c930d64b76 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -87,7 +87,7 @@ struct -- !query 9 SELECT 'foo', MAX(STRUCT(a)) FROM testData WHERE a = 0 GROUP BY 1 -- !query 9 schema -struct> +struct> -- !query 9 output diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 6eb571b91ffab..90000445dffb2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -190,6 +190,12 @@ private[hive] class TestHiveSparkSession( new File(Thread.currentThread().getContextClassLoader.getResource(path).getFile) } + private def quoteHiveFile(path : String) = if (Utils.isWindows) { + getHiveFile(path).getPath.replace('\\', '/') + } else { + getHiveFile(path).getPath + } + def getWarehousePath(): String = { val tempConf = new SQLConf sc.conf.getAll.foreach { case (k, v) => tempConf.setConfString(k, v) } @@ -225,16 +231,16 @@ private[hive] class TestHiveSparkSession( val hiveQTestUtilTables: Seq[TestTable] = Seq( TestTable("src", "CREATE TABLE src (key INT, value STRING)".cmd, - s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' INTO TABLE src".cmd), + s"LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' INTO TABLE src".cmd), TestTable("src1", "CREATE TABLE src1 (key INT, value STRING)".cmd, - s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd), + s"LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd), TestTable("srcpart", () => { sql( "CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)") for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) { sql( - s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' + s"""LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' |OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr') """.stripMargin) } @@ -244,7 +250,7 @@ private[hive] class TestHiveSparkSession( "CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)") for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- 11 to 12) { sql( - s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' + s"""LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' |OVERWRITE INTO TABLE srcpart1 PARTITION (ds='$ds',hr='$hr') """.stripMargin) } @@ -269,7 +275,7 @@ private[hive] class TestHiveSparkSession( sql( s""" - |LOAD DATA LOCAL INPATH '${getHiveFile("data/files/complex.seq")}' + |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/complex.seq")}' |INTO TABLE src_thrift """.stripMargin) }), @@ -308,7 +314,7 @@ private[hive] class TestHiveSparkSession( |) """.stripMargin.cmd, s""" - |LOAD DATA LOCAL INPATH '${getHiveFile("data/files/episodes.avro")}' + |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/episodes.avro")}' |INTO TABLE episodes """.stripMargin.cmd ), @@ -379,7 +385,7 @@ private[hive] class TestHiveSparkSession( TestTable("src_json", s"""CREATE TABLE src_json (json STRING) STORED AS TEXTFILE """.stripMargin.cmd, - s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/json.txt")}' INTO TABLE src_json".cmd) + s"LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/json.txt")}' INTO TABLE src_json".cmd) ) hiveQTestUtilTables.foreach(registerTestTable) diff --git a/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql b/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql index de0116a4dcbaf..cdda29af50e37 100644 --- a/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql +++ b/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql @@ -7,4 +7,4 @@ having b.key in (select a.key where a.value > 'val_9' and a.value = min(b.value)) order by b.key -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `min(value)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, min(`gen_attr_5`) AS `gen_attr_1`, min(`gen_attr_5`) AS `gen_attr_4` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (struct(`gen_attr_0`, `gen_attr_4`) IN (SELECT `gen_attr_6` AS `_c0`, `gen_attr_7` AS `_c1` FROM (SELECT `gen_attr_2` AS `gen_attr_6`, `gen_attr_3` AS `gen_attr_7` FROM (SELECT `gen_attr_2`, `gen_attr_3` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_3 WHERE (`gen_attr_3` > 'val_9')) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC NULLS FIRST) AS b +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `min(value)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, min(`gen_attr_5`) AS `gen_attr_1`, min(`gen_attr_5`) AS `gen_attr_4` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (named_struct('gen_attr_0', `gen_attr_0`, 'gen_attr_4', `gen_attr_4`) IN (SELECT `gen_attr_6` AS `_c0`, `gen_attr_7` AS `_c1` FROM (SELECT `gen_attr_2` AS `gen_attr_6`, `gen_attr_3` AS `gen_attr_7` FROM (SELECT `gen_attr_2`, `gen_attr_3` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_3 WHERE (`gen_attr_3` > 'val_9')) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC NULLS FIRST) AS b diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala index c7f10e569fa4d..12d18dc87ceb4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst import java.nio.charset.StandardCharsets import java.nio.file.{Files, NoSuchFileException, Paths} +import scala.io.Source import scala.util.control.NonFatal import org.apache.spark.sql.Column @@ -109,12 +110,15 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { Files.write(path, answerText.getBytes(StandardCharsets.UTF_8)) } else { val goldenFileName = s"sqlgen/$answerFile.sql" - val resourceFile = getClass.getClassLoader.getResource(goldenFileName) - if (resourceFile == null) { + val resourceStream = getClass.getClassLoader.getResourceAsStream(goldenFileName) + if (resourceStream == null) { throw new NoSuchFileException(goldenFileName) } - val path = resourceFile.getPath - val answerText = new String(Files.readAllBytes(Paths.get(path)), StandardCharsets.UTF_8) + val answerText = try { + Source.fromInputStream(resourceStream).mkString + } finally { + resourceStream.close + } val sqls = answerText.split(separator) assert(sqls.length == 2, "Golden sql files should have a separator.") val expectedSQL = sqls(1).trim() From 9be069125f7e94df9d862f307b87965baf9416e3 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Wed, 2 Nov 2016 11:29:26 -0700 Subject: [PATCH 009/534] [SPARK-17683][SQL] Support ArrayType in Literal.apply ## What changes were proposed in this pull request? This pr is to add pattern-matching entries for array data in `Literal.apply`. ## How was this patch tested? Added tests in `LiteralExpressionSuite`. Author: Takeshi YAMAMURO Closes #15257 from maropu/SPARK-17683. (cherry picked from commit 4af0ce2d96de3397c9bc05684cad290a52486577) Signed-off-by: Reynold Xin --- .../sql/catalyst/expressions/literals.scala | 57 ++++++++++++++++++- .../expressions/LiteralExpressionSuite.scala | 27 ++++++++- 2 files changed, 82 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index a597a17aadd99..1985e68c94e2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -17,14 +17,25 @@ package org.apache.spark.sql.catalyst.expressions +import java.lang.{Boolean => JavaBoolean} +import java.lang.{Byte => JavaByte} +import java.lang.{Double => JavaDouble} +import java.lang.{Float => JavaFloat} +import java.lang.{Integer => JavaInteger} +import java.lang.{Long => JavaLong} +import java.lang.{Short => JavaShort} +import java.math.{BigDecimal => JavaBigDecimal} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import java.util import java.util.Objects import javax.xml.bind.DatatypeConverter +import scala.math.{BigDecimal, BigInt} + import org.json4s.JsonAST._ +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -46,12 +57,17 @@ object Literal { case s: String => Literal(UTF8String.fromString(s), StringType) case b: Boolean => Literal(b, BooleanType) case d: BigDecimal => Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale)) - case d: java.math.BigDecimal => + case d: JavaBigDecimal => Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale())) case d: Decimal => Literal(d, DecimalType(Math.max(d.precision, d.scale), d.scale)) case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType) case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType) case a: Array[Byte] => Literal(a, BinaryType) + case a: Array[_] => + val elementType = componentTypeToDataType(a.getClass.getComponentType()) + val dataType = ArrayType(elementType) + val convert = CatalystTypeConverters.createToCatalystConverter(dataType) + Literal(convert(a), dataType) case i: CalendarInterval => Literal(i, CalendarIntervalType) case null => Literal(null, NullType) case v: Literal => v @@ -59,6 +75,45 @@ object Literal { throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) } + /** + * Returns the Spark SQL DataType for a given class object. Since this type needs to be resolved + * in runtime, we use match-case idioms for class objects here. However, there are similar + * functions in other files (e.g., HiveInspectors), so these functions need to merged into one. + */ + private[this] def componentTypeToDataType(clz: Class[_]): DataType = clz match { + // primitive types + case JavaShort.TYPE => ShortType + case JavaInteger.TYPE => IntegerType + case JavaLong.TYPE => LongType + case JavaDouble.TYPE => DoubleType + case JavaByte.TYPE => ByteType + case JavaFloat.TYPE => FloatType + case JavaBoolean.TYPE => BooleanType + + // java classes + case _ if clz == classOf[Date] => DateType + case _ if clz == classOf[Timestamp] => TimestampType + case _ if clz == classOf[JavaBigDecimal] => DecimalType.SYSTEM_DEFAULT + case _ if clz == classOf[Array[Byte]] => BinaryType + case _ if clz == classOf[JavaShort] => ShortType + case _ if clz == classOf[JavaInteger] => IntegerType + case _ if clz == classOf[JavaLong] => LongType + case _ if clz == classOf[JavaDouble] => DoubleType + case _ if clz == classOf[JavaByte] => ByteType + case _ if clz == classOf[JavaFloat] => FloatType + case _ if clz == classOf[JavaBoolean] => BooleanType + + // other scala classes + case _ if clz == classOf[String] => StringType + case _ if clz == classOf[BigInt] => DecimalType.SYSTEM_DEFAULT + case _ if clz == classOf[BigDecimal] => DecimalType.SYSTEM_DEFAULT + case _ if clz == classOf[CalendarInterval] => CalendarIntervalType + + case _ if clz.isArray => ArrayType(componentTypeToDataType(clz.getComponentType)) + + case _ => throw new AnalysisException(s"Unsupported component type $clz in arrays") + } + /** * Constructs a [[Literal]] of [[ObjectType]], for example when you need to pass an object * into code generation. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index 450222d8cbba3..4af4da8a9f0c2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -21,6 +21,7 @@ import java.nio.charset.StandardCharsets import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -43,6 +44,7 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, TimestampType), null) checkEvaluation(Literal.create(null, CalendarIntervalType), null) checkEvaluation(Literal.create(null, ArrayType(ByteType, true)), null) + checkEvaluation(Literal.create(null, ArrayType(StringType, true)), null) checkEvaluation(Literal.create(null, MapType(StringType, IntegerType)), null) checkEvaluation(Literal.create(null, StructType(Seq.empty)), null) } @@ -122,5 +124,28 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { } } - // TODO(davies): add tests for ArrayType, MapType and StructType + test("array") { + def checkArrayLiteral(a: Array[_], elementType: DataType): Unit = { + val toCatalyst = (a: Array[_], elementType: DataType) => { + CatalystTypeConverters.createToCatalystConverter(ArrayType(elementType))(a) + } + checkEvaluation(Literal(a), toCatalyst(a, elementType)) + } + checkArrayLiteral(Array(1, 2, 3), IntegerType) + checkArrayLiteral(Array("a", "b", "c"), StringType) + checkArrayLiteral(Array(1.0, 4.0), DoubleType) + checkArrayLiteral(Array(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR), + CalendarIntervalType) + } + + test("unsupported types (map and struct) in literals") { + def checkUnsupportedTypeInLiteral(v: Any): Unit = { + val errMsgMap = intercept[RuntimeException] { + Literal(v) + } + assert(errMsgMap.getMessage.startsWith("Unsupported literal type")) + } + checkUnsupportedTypeInLiteral(Map("key1" -> 1, "key2" -> 2)) + checkUnsupportedTypeInLiteral(("mike", 29, 1.0)) + } } From a885d5bbce9dba66b394850b3aac51ae97cb18dd Mon Sep 17 00:00:00 2001 From: buzhihuojie Date: Wed, 2 Nov 2016 11:36:20 -0700 Subject: [PATCH 010/534] [SPARK-17895] Improve doc for rangeBetween and rowsBetween ## What changes were proposed in this pull request? Copied description for row and range based frame boundary from https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala#L56 Added examples to show different behavior of rangeBetween and rowsBetween when involving duplicate values. Please review https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark before opening a pull request. Author: buzhihuojie Closes #15727 from david-weiluo-ren/improveDocForRangeAndRowsBetween. (cherry picked from commit 742e0fea5391857964e90d396641ecf95cac4248) Signed-off-by: Reynold Xin --- .../apache/spark/sql/expressions/Window.scala | 55 +++++++++++++++++++ .../spark/sql/expressions/WindowSpec.scala | 55 +++++++++++++++++++ 2 files changed, 110 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index 0b26d863cac5d..327bc379d4132 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -121,6 +121,32 @@ object Window { * and [[Window.currentRow]] to specify special boundary values, rather than using integral * values directly. * + * A row based boundary is based on the position of the row within the partition. + * An offset indicates the number of rows above or below the current row, the frame for the + * current row starts or ends. For instance, given a row based sliding frame with a lower bound + * offset of -1 and a upper bound offset of +2. The frame for row with index 5 would range from + * index 4 to index 6. + * + * {{{ + * import org.apache.spark.sql.expressions.Window + * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) + * .toDF("id", "category") + * df.withColumn("sum", + * sum('id) over Window.partitionBy('category).orderBy('id).rowsBetween(0,1)) + * .show() + * + * +---+--------+---+ + * | id|category|sum| + * +---+--------+---+ + * | 1| b| 3| + * | 2| b| 5| + * | 3| b| 3| + * | 1| a| 2| + * | 1| a| 3| + * | 2| a| 2| + * +---+--------+---+ + * }}} + * * @param start boundary start, inclusive. The frame is unbounded if this is * the minimum long value ([[Window.unboundedPreceding]]). * @param end boundary end, inclusive. The frame is unbounded if this is the @@ -144,6 +170,35 @@ object Window { * and [[Window.currentRow]] to specify special boundary values, rather than using integral * values directly. * + * A range based boundary is based on the actual value of the ORDER BY + * expression(s). An offset is used to alter the value of the ORDER BY expression, for + * instance if the current order by expression has a value of 10 and the lower bound offset + * is -3, the resulting lower bound for the current row will be 10 - 3 = 7. This however puts a + * number of constraints on the ORDER BY expressions: there can be only one expression and this + * expression must have a numerical data type. An exception can be made when the offset is 0, + * because no value modification is needed, in this case multiple and non-numeric ORDER BY + * expression are allowed. + * + * {{{ + * import org.apache.spark.sql.expressions.Window + * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) + * .toDF("id", "category") + * df.withColumn("sum", + * sum('id) over Window.partitionBy('category).orderBy('id).rangeBetween(0,1)) + * .show() + * + * +---+--------+---+ + * | id|category|sum| + * +---+--------+---+ + * | 1| b| 3| + * | 2| b| 5| + * | 3| b| 3| + * | 1| a| 4| + * | 1| a| 4| + * | 2| a| 2| + * +---+--------+---+ + * }}} + * * @param start boundary start, inclusive. The frame is unbounded if this is * the minimum long value ([[Window.unboundedPreceding]]). * @param end boundary end, inclusive. The frame is unbounded if this is the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 1e85b6e7881ad..4a8ce695bd4da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -89,6 +89,32 @@ class WindowSpec private[sql]( * and [[Window.currentRow]] to specify special boundary values, rather than using integral * values directly. * + * A row based boundary is based on the position of the row within the partition. + * An offset indicates the number of rows above or below the current row, the frame for the + * current row starts or ends. For instance, given a row based sliding frame with a lower bound + * offset of -1 and a upper bound offset of +2. The frame for row with index 5 would range from + * index 4 to index 6. + * + * {{{ + * import org.apache.spark.sql.expressions.Window + * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) + * .toDF("id", "category") + * df.withColumn("sum", + * sum('id) over Window.partitionBy('category).orderBy('id).rowsBetween(0,1)) + * .show() + * + * +---+--------+---+ + * | id|category|sum| + * +---+--------+---+ + * | 1| b| 3| + * | 2| b| 5| + * | 3| b| 3| + * | 1| a| 2| + * | 1| a| 3| + * | 2| a| 2| + * +---+--------+---+ + * }}} + * * @param start boundary start, inclusive. The frame is unbounded if this is * the minimum long value ([[Window.unboundedPreceding]]). * @param end boundary end, inclusive. The frame is unbounded if this is the @@ -111,6 +137,35 @@ class WindowSpec private[sql]( * and [[Window.currentRow]] to specify special boundary values, rather than using integral * values directly. * + * A range based boundary is based on the actual value of the ORDER BY + * expression(s). An offset is used to alter the value of the ORDER BY expression, for + * instance if the current order by expression has a value of 10 and the lower bound offset + * is -3, the resulting lower bound for the current row will be 10 - 3 = 7. This however puts a + * number of constraints on the ORDER BY expressions: there can be only one expression and this + * expression must have a numerical data type. An exception can be made when the offset is 0, + * because no value modification is needed, in this case multiple and non-numeric ORDER BY + * expression are allowed. + * + * {{{ + * import org.apache.spark.sql.expressions.Window + * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) + * .toDF("id", "category") + * df.withColumn("sum", + * sum('id) over Window.partitionBy('category).orderBy('id).rangeBetween(0,1)) + * .show() + * + * +---+--------+---+ + * | id|category|sum| + * +---+--------+---+ + * | 1| b| 3| + * | 2| b| 5| + * | 3| b| 3| + * | 1| a| 4| + * | 1| a| 4| + * | 2| a| 2| + * +---+--------+---+ + * }}} + * * @param start boundary start, inclusive. The frame is unbounded if this is * the minimum long value ([[Window.unboundedPreceding]]). * @param end boundary end, inclusive. The frame is unbounded if this is the From 0093257ea94d3a197ca061b54c04685d7c1f616a Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 2 Nov 2016 11:41:49 -0700 Subject: [PATCH 011/534] [SPARK-14393][SQL] values generated by non-deterministic functions shouldn't change after coalesce or union ## What changes were proposed in this pull request? When a user appended a column using a "nondeterministic" function to a DataFrame, e.g., `rand`, `randn`, and `monotonically_increasing_id`, the expected semantic is the following: - The value in each row should remain unchanged, as if we materialize the column immediately, regardless of later DataFrame operations. However, since we use `TaskContext.getPartitionId` to get the partition index from the current thread, the values from nondeterministic columns might change if we call `union` or `coalesce` after. `TaskContext.getPartitionId` returns the partition index of the current Spark task, which might not be the corresponding partition index of the DataFrame where we defined the column. See the unit tests below or JIRA for examples. This PR uses the partition index from `RDD.mapPartitionWithIndex` instead of `TaskContext` and fixes the partition initialization logic in whole-stage codegen, normal codegen, and codegen fallback. `initializeStatesForPartition(partitionIndex: Int)` was added to `Projection`, `Nondeterministic`, and `Predicate` (codegen) and initialized right after object creation in `mapPartitionWithIndex`. `newPredicate` now returns a `Predicate` instance rather than a function for proper initialization. ## How was this patch tested? Unit tests. (Actually I'm not very confident that this PR fixed all issues without introducing new ones ...) cc: rxin davies Author: Xiangrui Meng Closes #15567 from mengxr/SPARK-14393. (cherry picked from commit 02f203107b8eda1f1576e36c4f12b0e3bc5e910e) Signed-off-by: Reynold Xin --- .../main/scala/org/apache/spark/rdd/RDD.scala | 16 +++++- .../sql/catalyst/expressions/Expression.scala | 19 +++++-- .../catalyst/expressions/InputFileName.scala | 2 +- .../MonotonicallyIncreasingID.scala | 11 ++-- .../sql/catalyst/expressions/Projection.scala | 22 +++++--- .../expressions/SparkPartitionID.scala | 13 +++-- .../expressions/codegen/CodeGenerator.scala | 14 +++++ .../expressions/codegen/CodegenFallback.scala | 18 +++++-- .../codegen/GenerateMutableProjection.scala | 4 ++ .../codegen/GeneratePredicate.scala | 18 +++++-- .../codegen/GenerateSafeProjection.scala | 4 ++ .../codegen/GenerateUnsafeProjection.scala | 4 ++ .../sql/catalyst/expressions/package.scala | 10 +++- .../sql/catalyst/expressions/predicates.scala | 4 -- .../expressions/randomExpressions.scala | 14 ++--- .../sql/catalyst/optimizer/Optimizer.scala | 1 + .../expressions/ExpressionEvalHelper.scala | 5 +- .../CodegenExpressionCachingSuite.scala | 13 +++-- .../sql/execution/DataSourceScanExec.scala | 6 ++- .../spark/sql/execution/ExistingRDD.scala | 3 +- .../spark/sql/execution/GenerateExec.scala | 3 +- .../spark/sql/execution/SparkPlan.scala | 4 +- .../sql/execution/WholeStageCodegenExec.scala | 8 ++- .../execution/basicPhysicalOperators.scala | 8 +-- .../columnar/InMemoryTableScanExec.scala | 5 +- .../joins/BroadcastNestedLoopJoinExec.scala | 7 +-- .../joins/CartesianProductExec.scala | 8 +-- .../spark/sql/execution/joins/HashJoin.scala | 2 +- .../execution/joins/SortMergeJoinExec.scala | 2 +- .../apache/spark/sql/execution/objects.scala | 6 ++- .../spark/sql/DataFrameFunctionsSuite.scala | 52 +++++++++++++++++++ .../hive/execution/HiveTableScanExec.scala | 3 +- 32 files changed, 231 insertions(+), 78 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index db535de9e9bb3..e018af35cb18d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -788,14 +788,26 @@ abstract class RDD[T: ClassTag]( } /** - * [performance] Spark's internal mapPartitions method which skips closure cleaning. It is a - * performance API to be used carefully only if we are sure that the RDD elements are + * [performance] Spark's internal mapPartitionsWithIndex method that skips closure cleaning. + * It is a performance API to be used carefully only if we are sure that the RDD elements are * serializable and don't require closure cleaning. * * @param preservesPartitioning indicates whether the input function preserves the partitioner, * which should be `false` unless this is a pair RDD and the input function doesn't modify * the keys. */ + private[spark] def mapPartitionsWithIndexInternal[U: ClassTag]( + f: (Int, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = withScope { + new MapPartitionsRDD( + this, + (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter), + preservesPartitioning) + } + + /** + * [performance] Spark's internal mapPartitions method that skips closure cleaning. + */ private[spark] def mapPartitionsInternal[U: ClassTag]( f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = withScope { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 9edc1ceff26a7..726a231fd814e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -272,17 +272,28 @@ trait Nondeterministic extends Expression { final override def deterministic: Boolean = false final override def foldable: Boolean = false + @transient private[this] var initialized = false - final def setInitialValues(): Unit = { - initInternal() + /** + * Initializes internal states given the current partition index and mark this as initialized. + * Subclasses should override [[initializeInternal()]]. + */ + final def initialize(partitionIndex: Int): Unit = { + initializeInternal(partitionIndex) initialized = true } - protected def initInternal(): Unit + protected def initializeInternal(partitionIndex: Int): Unit + /** + * @inheritdoc + * Throws an exception if [[initialize()]] is not called yet. + * Subclasses should override [[evalInternal()]]. + */ final override def eval(input: InternalRow = null): Any = { - require(initialized, "nondeterministic expression should be initialized before evaluate") + require(initialized, + s"Nondeterministic expression ${this.getClass.getName} should be initialized before eval.") evalInternal(input) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala index 96929ecf56375..b6c12c5351119 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -37,7 +37,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def prettyName: String = "input_file_name" - override protected def initInternal(): Unit = {} + override protected def initializeInternal(partitionIndex: Int): Unit = {} override protected def evalInternal(input: InternalRow): UTF8String = { InputFileNameHolder.getInputFileName() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 5b4922e0cf2b7..72b8dcca26e2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -50,9 +50,9 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis @transient private[this] var partitionMask: Long = _ - override protected def initInternal(): Unit = { + override protected def initializeInternal(partitionIndex: Int): Unit = { count = 0L - partitionMask = TaskContext.getPartitionId().toLong << 33 + partitionMask = partitionIndex.toLong << 33 } override def nullable: Boolean = false @@ -68,9 +68,10 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val countTerm = ctx.freshName("count") val partitionMaskTerm = ctx.freshName("partitionMask") - ctx.addMutableState(ctx.JAVA_LONG, countTerm, s"$countTerm = 0L;") - ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, - s"$partitionMaskTerm = ((long) org.apache.spark.TaskContext.getPartitionId()) << 33;") + ctx.addMutableState(ctx.JAVA_LONG, countTerm, "") + ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, "") + ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") + ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") ev.copy(code = s""" final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 03e054d098511..476e37e6a9bac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.types.{DataType, StructType} /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. + * * @param expressions a sequence of expressions that determine the value of each column of the * output row. */ @@ -30,10 +31,12 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) - expressions.foreach(_.foreach { - case n: Nondeterministic => n.setInitialValues() - case _ => - }) + override def initialize(partitionIndex: Int): Unit = { + expressions.foreach(_.foreach { + case n: Nondeterministic => n.initialize(partitionIndex) + case _ => + }) + } // null check is required for when Kryo invokes the no-arg constructor. protected val exprArray = if (expressions != null) expressions.toArray else null @@ -54,6 +57,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { /** * A [[MutableProjection]] that is calculated by calling `eval` on each of the specified * expressions. + * * @param expressions a sequence of expressions that determine the value of each column of the * output row. */ @@ -63,10 +67,12 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu private[this] val buffer = new Array[Any](expressions.size) - expressions.foreach(_.foreach { - case n: Nondeterministic => n.setInitialValues() - case _ => - }) + override def initialize(partitionIndex: Int): Unit = { + expressions.foreach(_.foreach { + case n: Nondeterministic => n.initialize(partitionIndex) + case _ => + }) + } private[this] val exprArray = expressions.toArray private[this] var mutableRow: InternalRow = new GenericInternalRow(exprArray.length) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 1f675d5b07270..6bef473cac060 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -17,16 +17,15 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{DataType, IntegerType} /** - * Expression that returns the current partition id of the Spark task. + * Expression that returns the current partition id. */ @ExpressionDescription( - usage = "_FUNC_() - Returns the current partition id of the Spark task", + usage = "_FUNC_() - Returns the current partition id", extended = "> SELECT _FUNC_();\n 0") case class SparkPartitionID() extends LeafExpression with Nondeterministic { @@ -38,16 +37,16 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { override val prettyName = "SPARK_PARTITION_ID" - override protected def initInternal(): Unit = { - partitionId = TaskContext.getPartitionId() + override protected def initializeInternal(partitionIndex: Int): Unit = { + partitionId = partitionIndex } override protected def evalInternal(input: InternalRow): Int = partitionId override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val idTerm = ctx.freshName("partitionId") - ctx.addMutableState(ctx.JAVA_INT, idTerm, - s"$idTerm = org.apache.spark.TaskContext.getPartitionId();") + ctx.addMutableState(ctx.JAVA_INT, idTerm, "") + ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 6cab50ae1bf8d..9c3c6d3b2a7f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -184,6 +184,20 @@ class CodegenContext { splitExpressions(initCodes, "init", Nil) } + /** + * Code statements to initialize states that depend on the partition index. + * An integer `partitionIndex` will be made available within the scope. + */ + val partitionInitializationStatements: mutable.ArrayBuffer[String] = mutable.ArrayBuffer.empty + + def addPartitionInitializationStatement(statement: String): Unit = { + partitionInitializationStatements += statement + } + + def initPartition(): String = { + partitionInitializationStatements.mkString("\n") + } + /** * Holding all the functions those will be added into generated class. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 6a5a3e7933eea..0322d1dd6a9ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -25,15 +25,23 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, No trait CodegenFallback extends Expression { protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - foreach { - case n: Nondeterministic => n.setInitialValues() - case _ => - } - // LeafNode does not need `input` val input = if (this.isInstanceOf[LeafExpression]) "null" else ctx.INPUT_ROW val idx = ctx.references.length ctx.references += this + var childIndex = idx + this.foreach { + case n: Nondeterministic => + // This might add the current expression twice, but it won't hurt. + ctx.references += n + childIndex += 1 + ctx.addPartitionInitializationStatement( + s""" + |((Nondeterministic) references[$childIndex]) + | .initialize(partitionIndex); + """.stripMargin) + case _ => + } val objectTerm = ctx.freshName("obj") val placeHolder = ctx.registerComment(this.toString) if (nullable) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 5c4b56b0b224c..4d732445544a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -111,6 +111,10 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP ${ctx.initMutableStates()} } + public void initialize(int partitionIndex) { + ${ctx.initPartition()} + } + ${ctx.declareAddedFunctions()} public ${classOf[BaseMutableProjection].getName} target(InternalRow row) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 39aa7b17de6c9..dcd1ed96a298e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -25,19 +25,26 @@ import org.apache.spark.sql.catalyst.expressions._ */ abstract class Predicate { def eval(r: InternalRow): Boolean + + /** + * Initializes internal states given the current partition index. + * This is used by nondeterministic expressions to set initial states. + * The default implementation does nothing. + */ + def initialize(partitionIndex: Int): Unit = {} } /** * Generates bytecode that evaluates a boolean [[Expression]] on a given input [[InternalRow]]. */ -object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Boolean] { +object GeneratePredicate extends CodeGenerator[Expression, Predicate] { protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer.execute(in) protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression = BindReferences.bindReference(in, inputSchema) - protected def create(predicate: Expression): ((InternalRow) => Boolean) = { + protected def create(predicate: Expression): Predicate = { val ctx = newCodeGenContext() val eval = predicate.genCode(ctx) @@ -55,6 +62,10 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool ${ctx.initMutableStates()} } + public void initialize(int partitionIndex) { + ${ctx.initPartition()} + } + ${ctx.declareAddedFunctions()} public boolean eval(InternalRow ${ctx.INPUT_ROW}) { @@ -67,7 +78,6 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}") - val p = CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate] - (r: InternalRow) => p.eval(r) + CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 2773e1a666212..b1cb6edefb852 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -173,6 +173,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ${ctx.initMutableStates()} } + public void initialize(int partitionIndex) { + ${ctx.initPartition()} + } + ${ctx.declareAddedFunctions()} public java.lang.Object apply(java.lang.Object _i) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 7cc45372daa5a..7e4c9089a2cb9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -380,6 +380,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ${ctx.initMutableStates()} } + public void initialize(int partitionIndex) { + ${ctx.initPartition()} + } + ${ctx.declareAddedFunctions()} // Scala.Function1 need this diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 1510a4796683c..1b00c9e79da22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -64,7 +64,15 @@ package object expressions { * column of the new row. If the schema of the input row is specified, then the given expression * will be bound to that schema. */ - abstract class Projection extends (InternalRow => InternalRow) + abstract class Projection extends (InternalRow => InternalRow) { + + /** + * Initializes internal states given the current partition index. + * This is used by nondeterministic expressions to set initial states. + * The default implementation does nothing. + */ + def initialize(partitionIndex: Int): Unit = {} + } /** * Converts a [[InternalRow]] to another Row given a sequence of expression that define each diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 9394e39aadd9d..c941a576d00d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -31,10 +31,6 @@ object InterpretedPredicate { create(BindReferences.bindReference(expression, inputSchema)) def create(expression: Expression): (InternalRow => Boolean) = { - expression.foreach { - case n: Nondeterministic => n.setInitialValues() - case _ => - } (r: InternalRow) => expression.eval(r).asInstanceOf[Boolean] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index ca200768b2286..e09029f5aab9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -42,8 +42,8 @@ abstract class RDG extends LeafExpression with Nondeterministic { */ @transient protected var rng: XORShiftRandom = _ - override protected def initInternal(): Unit = { - rng = new XORShiftRandom(seed + TaskContext.getPartitionId) + override protected def initializeInternal(partitionIndex: Int): Unit = { + rng = new XORShiftRandom(seed + partitionIndex) } override def nullable: Boolean = false @@ -70,8 +70,9 @@ case class Rand(seed: Long) extends RDG { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName - ctx.addMutableState(className, rngTerm, - s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());") + ctx.addMutableState(className, rngTerm, "") + ctx.addPartitionInitializationStatement( + s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = "false") } @@ -93,8 +94,9 @@ case class Randn(seed: Long) extends RDG { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName - ctx.addMutableState(className, rngTerm, - s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());") + ctx.addMutableState(className, rngTerm, "") + ctx.addPartitionInitializationStatement( + s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e5e2cd7d27d15..b6ad5db74e3c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1060,6 +1060,7 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] { case Project(projectList, LocalRelation(output, data)) if !projectList.exists(hasUnevaluableExpr) => val projection = new InterpretedProjection(projectList, output) + projection.initialize(0) LocalRelation(projectList.map(_.toAttribute), data.map(projection)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index f0c149c02b9aa..9ceb709185417 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -75,7 +75,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = { expression.foreach { - case n: Nondeterministic => n.setInitialValues() + case n: Nondeterministic => n.initialize(0) case _ => } expression.eval(inputRow) @@ -121,6 +121,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { val plan = generateProject( GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) + plan.initialize(0) val actual = plan(inputRow).get(0, expression.dataType) if (!checkResult(actual, expected)) { @@ -182,12 +183,14 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { var plan = generateProject( GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) + plan.initialize(0) var actual = plan(inputRow).get(0, expression.dataType) assert(checkResult(actual, expected)) plan = generateProject( GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) + plan.initialize(0) actual = FromUnsafeProjection(expression.dataType :: Nil)( plan(inputRow)).get(0, expression.dataType) assert(checkResult(actual, expected)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala index 06dc3bd33b90e..fe5cb8eda824f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala @@ -31,19 +31,22 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { // Use an Add to wrap two of them together in case we only initialize the top level expressions. val expr = And(NondeterministicExpression(), NondeterministicExpression()) val instance = UnsafeProjection.create(Seq(expr)) + instance.initialize(0) assert(instance.apply(null).getBoolean(0) === false) } test("GenerateMutableProjection should initialize expressions") { val expr = And(NondeterministicExpression(), NondeterministicExpression()) val instance = GenerateMutableProjection.generate(Seq(expr)) + instance.initialize(0) assert(instance.apply(null).getBoolean(0) === false) } test("GeneratePredicate should initialize expressions") { val expr = And(NondeterministicExpression(), NondeterministicExpression()) val instance = GeneratePredicate.generate(expr) - assert(instance.apply(null) === false) + instance.initialize(0) + assert(instance.eval(null) === false) } test("GenerateUnsafeProjection should not share expression instances") { @@ -73,13 +76,13 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { test("GeneratePredicate should not share expression instances") { val expr1 = MutableExpression() val instance1 = GeneratePredicate.generate(expr1) - assert(instance1.apply(null) === false) + assert(instance1.eval(null) === false) val expr2 = MutableExpression() expr2.mutableState = true val instance2 = GeneratePredicate.generate(expr2) - assert(instance1.apply(null) === false) - assert(instance2.apply(null) === true) + assert(instance1.eval(null) === false) + assert(instance2.eval(null) === true) } } @@ -89,7 +92,7 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { */ case class NondeterministicExpression() extends LeafExpression with Nondeterministic with CodegenFallback { - override protected def initInternal(): Unit = { } + override protected def initializeInternal(partitionIndex: Int): Unit = {} override protected def evalInternal(input: InternalRow): Any = false override def nullable: Boolean = false override def dataType: DataType = BooleanType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index fdd1fa3648251..e485b52b43f76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -71,8 +71,9 @@ case class RowDataSourceScanExec( val unsafeRow = if (outputUnsafeRows) { rdd } else { - rdd.mapPartitionsInternal { iter => + rdd.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(schema) + proj.initialize(index) iter.map(proj) } } @@ -284,8 +285,9 @@ case class FileSourceScanExec( val unsafeRows = { val scan = inputRDD if (needsUnsafeRowConversion) { - scan.mapPartitionsInternal { iter => + scan.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(schema) + proj.initialize(index) iter.map(proj) } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 455fb5bfbb6f7..aab087cd98716 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -190,8 +190,9 @@ case class RDDScanExec( protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - rdd.mapPartitionsInternal { iter => + rdd.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(schema) + proj.initialize(index) iter.map { r => numOutputRows += 1 proj(r) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 2663129562660..19fbf0c162048 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -94,8 +94,9 @@ case class GenerateExec( } val numOutputRows = longMetric("numOutputRows") - rows.mapPartitionsInternal { iter => + rows.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(output, output) + proj.initialize(index) iter.map { r => numOutputRows += 1 proj(r) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 24d0cffef82a2..cadab37a449aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -29,7 +29,7 @@ import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredicate, _} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetric @@ -354,7 +354,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } protected def newPredicate( - expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { + expression: Expression, inputSchema: Seq[Attribute]): GenPredicate = { GeneratePredicate.generate(expression, inputSchema) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 6303483f22fd3..516b9d5444d31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -331,6 +331,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co partitionIndex = index; this.inputs = inputs; ${ctx.initMutableStates()} + ${ctx.initPartition()} } ${ctx.declareAddedFunctions()} @@ -383,10 +384,13 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co } else { // Right now, we support up to two input RDDs. rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) => - val partitionIndex = TaskContext.getPartitionId() + Iterator((leftIter, rightIter)) + // a small hack to obtain the correct partition index + }.mapPartitionsWithIndex { (index, zippedIter) => + val (leftIter, rightIter) = zippedIter.next() val clazz = CodeGenerator.compile(cleanedSource) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] - buffer.init(partitionIndex, Array(leftIter, rightIter)) + buffer.init(index, Array(leftIter, rightIter)) new Iterator[InternalRow] { override def hasNext: Boolean = { val v = buffer.hasNext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index a5291e0c12f88..32133f52630cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -70,9 +70,10 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) } protected override def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitionsInternal { iter => + child.execute().mapPartitionsWithIndexInternal { (index, iter) => val project = UnsafeProjection.create(projectList, child.output, subexpressionEliminationEnabled) + project.initialize(index) iter.map(project) } } @@ -205,10 +206,11 @@ case class FilterExec(condition: Expression, child: SparkPlan) protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - child.execute().mapPartitionsInternal { iter => + child.execute().mapPartitionsWithIndexInternal { (index, iter) => val predicate = newPredicate(condition, child.output) + predicate.initialize(0) iter.filter { row => - val r = predicate(row) + val r = predicate.eval(row) if (r) numOutputRows += 1 r } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index b87016d5a5696..9028caa446e8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -132,10 +132,11 @@ case class InMemoryTableScanExec( val relOutput: AttributeSeq = relation.output val buffers = relation.cachedColumnBuffers - buffers.mapPartitionsInternal { cachedBatchIterator => + buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) => val partitionFilter = newPredicate( partitionFilters.reduceOption(And).getOrElse(Literal(true)), schema) + partitionFilter.initialize(index) // Find the ordinals and data types of the requested columns. val (requestedColumnIndices, requestedColumnDataTypes) = @@ -147,7 +148,7 @@ case class InMemoryTableScanExec( val cachedBatchesToScan = if (inMemoryPartitionPruningEnabled) { cachedBatchIterator.filter { cachedBatch => - if (!partitionFilter(cachedBatch.stats)) { + if (!partitionFilter.eval(cachedBatch.stats)) { def statsString: String = schemaIndex.map { case (a, i) => val value = cachedBatch.stats.get(i, a.dataType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index bfe7e3dea45df..f526a19876670 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -52,7 +52,7 @@ case class BroadcastNestedLoopJoinExec( UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil } - private[this] def genResultProjection: InternalRow => InternalRow = joinType match { + private[this] def genResultProjection: UnsafeProjection = joinType match { case LeftExistence(j) => UnsafeProjection.create(output, output) case other => @@ -84,7 +84,7 @@ case class BroadcastNestedLoopJoinExec( @transient private lazy val boundCondition = { if (condition.isDefined) { - newPredicate(condition.get, streamed.output ++ broadcast.output) + newPredicate(condition.get, streamed.output ++ broadcast.output).eval _ } else { (r: InternalRow) => true } @@ -366,8 +366,9 @@ case class BroadcastNestedLoopJoinExec( } val numOutputRows = longMetric("numOutputRows") - resultRdd.mapPartitionsInternal { iter => + resultRdd.mapPartitionsWithIndexInternal { (index, iter) => val resultProj = genResultProjection + resultProj.initialize(index) iter.map { r => numOutputRows += 1 resultProj(r) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala index 15dc9b40662e2..8341fe2ffd078 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -98,15 +98,15 @@ case class CartesianProductExec( val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]] val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size) - pair.mapPartitionsInternal { iter => + pair.mapPartitionsWithIndexInternal { (index, iter) => val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema) val filtered = if (condition.isDefined) { - val boundCondition: (InternalRow) => Boolean = - newPredicate(condition.get, left.output ++ right.output) + val boundCondition = newPredicate(condition.get, left.output ++ right.output) + boundCondition.initialize(index) val joined = new JoinedRow iter.filter { r => - boundCondition(joined(r._1, r._2)) + boundCondition.eval(joined(r._1, r._2)) } } else { iter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 05c5e2f4cd77b..1aef5f6864263 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -81,7 +81,7 @@ trait HashJoin { UnsafeProjection.create(streamedKeys) @transient private[this] lazy val boundCondition = if (condition.isDefined) { - newPredicate(condition.get, streamedPlan.output ++ buildPlan.output) + newPredicate(condition.get, streamedPlan.output ++ buildPlan.output).eval _ } else { (r: InternalRow) => true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index ecf7cf289f034..ca9c0ed8cec32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -101,7 +101,7 @@ case class SortMergeJoinExec( left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => val boundCondition: (InternalRow) => Boolean = { condition.map { cond => - newPredicate(cond, left.output ++ right.output) + newPredicate(cond, left.output ++ right.output).eval _ }.getOrElse { (r: InternalRow) => true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 9df56bbf1ef87..fde3b2a528994 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -87,8 +87,9 @@ case class DeserializeToObjectExec( } override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitionsInternal { iter => + child.execute().mapPartitionsWithIndexInternal { (index, iter) => val projection = GenerateSafeProjection.generate(deserializer :: Nil, child.output) + projection.initialize(index) iter.map(projection) } } @@ -124,8 +125,9 @@ case class SerializeFromObjectExec( } override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitionsInternal { iter => + child.execute().mapPartitionsWithIndexInternal { (index, iter) => val projection = UnsafeProjection.create(serializer) + projection.initialize(index) iter.map(projection) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 586a0fffeb7a1..0e9a2c6cf7dec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -19,7 +19,13 @@ package org.apache.spark.sql import java.nio.charset.StandardCharsets +import scala.util.Random + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -406,4 +412,50 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Seq(Row(true), Row(true)) ) } + + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { + import DataFrameFunctionsSuite.CodegenFallbackExpr + for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { + val c = if (codegenFallback) { + Column(CodegenFallbackExpr(v.expr)) + } else { + v + } + withSQLConf( + (SQLConf.WHOLESTAGE_FALLBACK.key, codegenFallback.toString), + (SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString)) { + val df = spark.range(0, 4, 1, 4).withColumn("c", c) + val rows = df.collect() + val rowsAfterCoalesce = df.coalesce(2).collect() + assert(rows === rowsAfterCoalesce, "Values changed after coalesce when " + + s"codegenFallback=$codegenFallback and wholeStage=$wholeStage.") + + val df1 = spark.range(0, 2, 1, 2).withColumn("c", c) + val rows1 = df1.collect() + val df2 = spark.range(2, 4, 1, 2).withColumn("c", c) + val rows2 = df2.collect() + val rowsAfterUnion = df1.union(df2).collect() + assert(rowsAfterUnion === rows1 ++ rows2, "Values changed after union when " + + s"codegenFallback=$codegenFallback and wholeStage=$wholeStage.") + } + } + } + + test("SPARK-14393: values generated by non-deterministic functions shouldn't change after " + + "coalesce or union") { + Seq( + monotonically_increasing_id(), spark_partition_id(), + rand(Random.nextLong()), randn(Random.nextLong()) + ).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_)) + } +} + +object DataFrameFunctionsSuite { + case class CodegenFallbackExpr(child: Expression) extends Expression with CodegenFallback { + override def children: Seq[Expression] = Seq(child) + override def nullable: Boolean = child.nullable + override def dataType: DataType = child.dataType + override lazy val resolved = true + override def eval(input: InternalRow): Any = child.eval(input) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 231f204b12b47..c80695bd3e0fe 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -154,8 +154,9 @@ case class HiveTableScanExec( val numOutputRows = longMetric("numOutputRows") // Avoid to serialize MetastoreRelation because schema is lazy. (see SPARK-15649) val outputSchema = schema - rdd.mapPartitionsInternal { iter => + rdd.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(outputSchema) + proj.initialize(index) iter.map { r => numOutputRows += 1 proj(r) From bd3ea6595788a4fe5399e6c6c666618d8cb6872c Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 2 Nov 2016 11:47:45 -0700 Subject: [PATCH 012/534] [SPARK-18160][CORE][YARN] spark.files & spark.jars should not be passed to driver in yarn mode ## What changes were proposed in this pull request? spark.files is still passed to driver in yarn mode, so SparkContext will still handle it which cause the error in the jira desc. ## How was this patch tested? Tested manually in a 5 node cluster. As this issue only happens in multiple node cluster, so I didn't write test for it. Author: Jeff Zhang Closes #15669 from zjffdu/SPARK-18160. (cherry picked from commit 3c24299b71e23e159edbb972347b13430f92a465) Signed-off-by: Marcelo Vanzin --- .../scala/org/apache/spark/SparkContext.scala | 29 ++++--------------- .../org/apache/spark/deploy/yarn/Client.scala | 5 +++- 2 files changed, 10 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 4694790c72cd8..63478c88b057b 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1716,29 +1716,12 @@ class SparkContext(config: SparkConf) extends Logging { key = uri.getScheme match { // A JAR file which exists only on the driver node case null | "file" => - if (master == "yarn" && deployMode == "cluster") { - // In order for this to work in yarn cluster mode the user must specify the - // --addJars option to the client to upload the file into the distributed cache - // of the AM to make it show up in the current working directory. - val fileName = new Path(uri.getPath).getName() - try { - env.rpcEnv.fileServer.addJar(new File(fileName)) - } catch { - case e: Exception => - // For now just log an error but allow to go through so spark examples work. - // The spark examples don't really need the jar distributed since its also - // the app jar. - logError("Error adding jar (" + e + "), was the --addJars option used?") - null - } - } else { - try { - env.rpcEnv.fileServer.addJar(new File(uri.getPath)) - } catch { - case exc: FileNotFoundException => - logError(s"Jar not found at $path") - null - } + try { + env.rpcEnv.fileServer.addJar(new File(uri.getPath)) + } catch { + case exc: FileNotFoundException => + logError(s"Jar not found at $path") + null } // A JAR file which exists locally on every worker node case "local" => diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 55e4a833b6707..053a78617d4e0 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1202,7 +1202,10 @@ private object Client extends Logging { // Note that any env variable with the SPARK_ prefix gets propagated to all (remote) processes System.setProperty("SPARK_YARN_MODE", "true") val sparkConf = new SparkConf - + // SparkSubmit would use yarn cache to distribute files & jars in yarn mode, + // so remove them from sparkConf here for yarn mode. + sparkConf.remove("spark.jars") + sparkConf.remove("spark.files") val args = new ClientArguments(argStrings) new Client(args, sparkConf).run() } From 1eef8e5cd09dfb8b77044ef9864321618e8ea8c8 Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Wed, 2 Nov 2016 11:52:29 -0700 Subject: [PATCH 013/534] [SPARK-17058][BUILD] Add maven snapshots-and-staging profile to build/test against staging artifacts ## What changes were proposed in this pull request? Adds a `snapshots-and-staging profile` so that RCs of projects like Hadoop and HBase can be used in developer-only build and test runs. There's a comment above the profile telling people not to use this in production. There's no attempt to do the same for SBT, as Ivy is different. ## How was this patch tested? Tested by building against the Hadoop 2.7.3 RC 1 JARs without the profile (and without any local copy of the 2.7.3 artifacts), the build failed ``` mvn install -DskipTests -Pyarn,hadoop-2.7,hive -Dhadoop.version=2.7.3 ... [INFO] ------------------------------------------------------------------------ [INFO] Building Spark Project Launcher 2.1.0-SNAPSHOT [INFO] ------------------------------------------------------------------------ Downloading: https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-client/2.7.3/hadoop-client-2.7.3.pom [WARNING] The POM for org.apache.hadoop:hadoop-client:jar:2.7.3 is missing, no dependency information available Downloading: https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-client/2.7.3/hadoop-client-2.7.3.jar [INFO] ------------------------------------------------------------------------ [INFO] Reactor Summary: [INFO] [INFO] Spark Project Parent POM ........................... SUCCESS [ 4.482 s] [INFO] Spark Project Tags ................................. SUCCESS [ 17.402 s] [INFO] Spark Project Sketch ............................... SUCCESS [ 11.252 s] [INFO] Spark Project Networking ........................... SUCCESS [ 13.458 s] [INFO] Spark Project Shuffle Streaming Service ............ SUCCESS [ 9.043 s] [INFO] Spark Project Unsafe ............................... SUCCESS [ 16.027 s] [INFO] Spark Project Launcher ............................. FAILURE [ 1.653 s] [INFO] Spark Project Core ................................. SKIPPED ... ``` With the profile, the build completed ``` mvn install -DskipTests -Pyarn,hadoop-2.7,hive,snapshots-and-staging -Dhadoop.version=2.7.3 ``` Author: Steve Loughran Closes #14646 from steveloughran/stevel/SPARK-17058-support-asf-snapshots. (cherry picked from commit 37d95227a21de602b939dae84943ba007f434513) Signed-off-by: Reynold Xin --- pom.xml | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/pom.xml b/pom.xml index aaf7cfa7eb2ad..04d2eaa1d3bac 100644 --- a/pom.xml +++ b/pom.xml @@ -2693,6 +2693,54 @@ + + + snapshots-and-staging + + + https://repository.apache.org/content/groups/staging/ + https://repository.apache.org/content/repositories/snapshots/ + + + + + ASF Staging + ${asf.staging} + + + ASF Snapshots + ${asf.snapshots} + + true + + + false + + + + + + + ASF Staging + ${asf.staging} + + + ASF Snapshots + ${asf.snapshots} + + true + + + false + + + + + + + org.json + json + From be3933ddfa3b6b6cf458c0fc4865a61fef40e76a Mon Sep 17 00:00:00 2001 From: Michael Allman Date: Thu, 10 Nov 2016 13:41:13 -0800 Subject: [PATCH 084/534] [SPARK-17993][SQL] Fix Parquet log output redirection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (Link to Jira issue: https://issues.apache.org/jira/browse/SPARK-17993) ## What changes were proposed in this pull request? PR #14690 broke parquet log output redirection for converted partitioned Hive tables. For example, when querying parquet files written by Parquet-mr 1.6.0 Spark prints a torrent of (harmless) warning messages from the Parquet reader: ``` Oct 18, 2016 7:42:18 PM WARNING: org.apache.parquet.CorruptStatistics: Ignoring statistics because created_by could not be parsed (see PARQUET-251): parquet-mr version 1.6.0 org.apache.parquet.VersionParser$VersionParseException: Could not parse created_by: parquet-mr version 1.6.0 using format: (.+) version ((.*) )?\(build ?(.*)\) at org.apache.parquet.VersionParser.parse(VersionParser.java:112) at org.apache.parquet.CorruptStatistics.shouldIgnoreStatistics(CorruptStatistics.java:60) at org.apache.parquet.format.converter.ParquetMetadataConverter.fromParquetStatistics(ParquetMetadataConverter.java:263) at org.apache.parquet.hadoop.ParquetFileReader$Chunk.readAllPages(ParquetFileReader.java:583) at org.apache.parquet.hadoop.ParquetFileReader.readNextRowGroup(ParquetFileReader.java:513) at org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecordReader.checkEndOfRowGroup(VectorizedParquetRecordReader.java:270) at org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecordReader.nextBatch(VectorizedParquetRecordReader.java:225) at org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecordReader.nextKeyValue(VectorizedParquetRecordReader.java:137) at org.apache.spark.sql.execution.datasources.RecordReaderIterator.hasNext(RecordReaderIterator.scala:39) at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.hasNext(FileScanRDD.scala:102) at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.nextIterator(FileScanRDD.scala:162) at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.hasNext(FileScanRDD.scala:102) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.scan_nextBatch$(Unknown Source) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:372) at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:231) at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:225) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:803) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:803) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:319) at org.apache.spark.rdd.RDD.iterator(RDD.scala:283) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87) at org.apache.spark.scheduler.Task.run(Task.scala:99) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) ``` This only happens during execution, not planning, and it doesn't matter what log level the `SparkContext` is set to. That's because Parquet (versions < 1.9) doesn't use slf4j for logging. Note, you can tell that log redirection is not working here because the log message format does not conform to the default Spark log message format. This is a regression I noted as something we needed to fix as a follow up. It appears that the problem arose because we removed the call to `inferSchema` during Hive table conversion. That call is what triggered the output redirection. ## How was this patch tested? I tested this manually in four ways: 1. Executing `spark.sqlContext.range(10).selectExpr("id as a").write.mode("overwrite").parquet("test")`. 2. Executing `spark.read.format("parquet").load(legacyParquetFile).show` for a Parquet file `legacyParquetFile` written using Parquet-mr 1.6.0. 3. Executing `select * from legacy_parquet_table limit 1` for some unpartitioned Parquet-based Hive table written using Parquet-mr 1.6.0. 4. Executing `select * from legacy_partitioned_parquet_table where partcol=x limit 1` for some partitioned Parquet-based Hive table written using Parquet-mr 1.6.0. I ran each test with a new instance of `spark-shell` or `spark-sql`. Incidentally, I found that test case 3 was not a regression—redirection was not occurring in the master codebase prior to #14690. I spent some time working on a unit test, but based on my experience working on this ticket I feel that automated testing here is far from feasible. cc ericl dongjoon-hyun Author: Michael Allman Closes #15538 from mallman/spark-17993-fix_parquet_log_redirection. (cherry picked from commit b533fa2b205544b42dcebe0a6fee9d8275f6da7d) Signed-off-by: Reynold Xin --- .../parquet/ParquetLogRedirector.java | 72 +++++++++++++++++++ .../parquet/ParquetFileFormat.scala | 58 ++++----------- sql/core/src/test/resources/log4j.properties | 4 +- sql/hive/src/test/resources/log4j.properties | 4 ++ 4 files changed, 90 insertions(+), 48 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetLogRedirector.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetLogRedirector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetLogRedirector.java new file mode 100644 index 0000000000000..7a7f32ee1e87b --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetLogRedirector.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet; + +import java.io.Serializable; +import java.util.logging.Handler; +import java.util.logging.Logger; + +import org.apache.parquet.Log; +import org.slf4j.bridge.SLF4JBridgeHandler; + +// Redirects the JUL logging for parquet-mr versions <= 1.8 to SLF4J logging using +// SLF4JBridgeHandler. Parquet-mr versions >= 1.9 use SLF4J directly +final class ParquetLogRedirector implements Serializable { + // Client classes should hold a reference to INSTANCE to ensure redirection occurs. This is + // especially important for Serializable classes where fields are set but constructors are + // ignored + static final ParquetLogRedirector INSTANCE = new ParquetLogRedirector(); + + // JUL loggers must be held by a strong reference, otherwise they may get destroyed by GC. + // However, the root JUL logger used by Parquet isn't properly referenced. Here we keep + // references to loggers in both parquet-mr <= 1.6 and 1.7/1.8 + private static final Logger apacheParquetLogger = + Logger.getLogger(Log.class.getPackage().getName()); + private static final Logger parquetLogger = Logger.getLogger("parquet"); + + static { + // For parquet-mr 1.7 and 1.8, which are under `org.apache.parquet` namespace. + try { + Class.forName(Log.class.getName()); + redirect(Logger.getLogger(Log.class.getPackage().getName())); + } catch (ClassNotFoundException ex) { + throw new RuntimeException(ex); + } + + // For parquet-mr 1.6.0 and lower versions bundled with Hive, which are under `parquet` + // namespace. + try { + Class.forName("parquet.Log"); + redirect(Logger.getLogger("parquet")); + } catch (Throwable t) { + // SPARK-9974: com.twitter:parquet-hadoop-bundle:1.6.0 is not packaged into the assembly + // when Spark is built with SBT. So `parquet.Log` may not be found. This try/catch block + // should be removed after this issue is fixed. + } + } + + private ParquetLogRedirector() { + } + + private static void redirect(Logger logger) { + for (Handler handler : logger.getHandlers()) { + logger.removeHandler(handler); + } + logger.setUseParentHandlers(false); + logger.addHandler(new SLF4JBridgeHandler()); + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index b8ea7f40c4ab3..031a0fe57893f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.datasources.parquet import java.net.URI -import java.util.logging.{Logger => JLogger} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -29,14 +28,12 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.FileSplit import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl -import org.apache.parquet.{Log => ApacheParquetLog} import org.apache.parquet.filter2.compat.FilterCompat import org.apache.parquet.filter2.predicate.FilterApi import org.apache.parquet.hadoop._ import org.apache.parquet.hadoop.codec.CodecConfig import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.schema.MessageType -import org.slf4j.bridge.SLF4JBridgeHandler import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging @@ -56,6 +53,11 @@ class ParquetFileFormat with DataSourceRegister with Logging with Serializable { + // Hold a reference to the (serializable) singleton instance of ParquetLogRedirector. This + // ensures the ParquetLogRedirector class is initialized whether an instance of ParquetFileFormat + // is constructed or deserialized. Do not heed the Scala compiler's warning about an unused field + // here. + private val parquetLogRedirector = ParquetLogRedirector.INSTANCE override def shortName(): String = "parquet" @@ -129,10 +131,14 @@ class ParquetFileFormat conf.setBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, false) } - ParquetFileFormat.redirectParquetLogs() - new OutputWriterFactory { - override def newInstance( + // This OutputWriterFactory instance is deserialized when writing Parquet files on the + // executor side without constructing or deserializing ParquetFileFormat. Therefore, we hold + // another reference to ParquetLogRedirector.INSTANCE here to ensure the latter class is + // initialized. + private val parquetLogRedirector = ParquetLogRedirector.INSTANCE + + override def newInstance( path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { @@ -673,44 +679,4 @@ object ParquetFileFormat extends Logging { Failure(cause) }.toOption } - - // JUL loggers must be held by a strong reference, otherwise they may get destroyed by GC. - // However, the root JUL logger used by Parquet isn't properly referenced. Here we keep - // references to loggers in both parquet-mr <= 1.6 and >= 1.7 - val apacheParquetLogger: JLogger = JLogger.getLogger(classOf[ApacheParquetLog].getPackage.getName) - val parquetLogger: JLogger = JLogger.getLogger("parquet") - - // Parquet initializes its own JUL logger in a static block which always prints to stdout. Here - // we redirect the JUL logger via SLF4J JUL bridge handler. - val redirectParquetLogsViaSLF4J: Unit = { - def redirect(logger: JLogger): Unit = { - logger.getHandlers.foreach(logger.removeHandler) - logger.setUseParentHandlers(false) - logger.addHandler(new SLF4JBridgeHandler) - } - - // For parquet-mr 1.7.0 and above versions, which are under `org.apache.parquet` namespace. - // scalastyle:off classforname - Class.forName(classOf[ApacheParquetLog].getName) - // scalastyle:on classforname - redirect(JLogger.getLogger(classOf[ApacheParquetLog].getPackage.getName)) - - // For parquet-mr 1.6.0 and lower versions bundled with Hive, which are under `parquet` - // namespace. - try { - // scalastyle:off classforname - Class.forName("parquet.Log") - // scalastyle:on classforname - redirect(JLogger.getLogger("parquet")) - } catch { case _: Throwable => - // SPARK-9974: com.twitter:parquet-hadoop-bundle:1.6.0 is not packaged into the assembly - // when Spark is built with SBT. So `parquet.Log` may not be found. This try/catch block - // should be removed after this issue is fixed. - } - } - - /** - * ParquetFileFormat.prepareWrite calls this function to initialize `redirectParquetLogsViaSLF4J`. - */ - def redirectParquetLogs(): Unit = {} } diff --git a/sql/core/src/test/resources/log4j.properties b/sql/core/src/test/resources/log4j.properties index 33b9ecf1e2826..25b817382195a 100644 --- a/sql/core/src/test/resources/log4j.properties +++ b/sql/core/src/test/resources/log4j.properties @@ -53,5 +53,5 @@ log4j.additivity.hive.ql.metadata.Hive=false log4j.logger.hive.ql.metadata.Hive=OFF # Parquet related logging -log4j.logger.org.apache.parquet.hadoop=WARN -log4j.logger.org.apache.spark.sql.parquet=INFO +log4j.logger.org.apache.parquet=ERROR +log4j.logger.parquet=ERROR diff --git a/sql/hive/src/test/resources/log4j.properties b/sql/hive/src/test/resources/log4j.properties index fea3404769d9d..072bb25d30a87 100644 --- a/sql/hive/src/test/resources/log4j.properties +++ b/sql/hive/src/test/resources/log4j.properties @@ -59,3 +59,7 @@ log4j.logger.hive.ql.metadata.Hive=OFF log4j.additivity.org.apache.hadoop.hive.ql.io.RCFile=false log4j.logger.org.apache.hadoop.hive.ql.io.RCFile=ERROR + +# Parquet related logging +log4j.logger.org.apache.parquet=ERROR +log4j.logger.parquet=ERROR From c602894f25bf9e61b759815674008471858cc71e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 10 Nov 2016 13:42:48 -0800 Subject: [PATCH 085/534] [SPARK-17990][SPARK-18302][SQL] correct several partition related behaviours of ExternalCatalog ## What changes were proposed in this pull request? This PR corrects several partition related behaviors of `ExternalCatalog`: 1. default partition location should not always lower case the partition column names in path string(fix `HiveExternalCatalog`) 2. rename partition should not always lower case the partition column names in updated partition path string(fix `HiveExternalCatalog`) 3. rename partition should update the partition location only for managed table(fix `InMemoryCatalog`) 4. create partition with existing directory should be fine(fix `InMemoryCatalog`) 5. create partition with non-existing directory should create that directory(fix `InMemoryCatalog`) 6. drop partition from external table should not delete the directory(fix `InMemoryCatalog`) ## How was this patch tested? new tests in `ExternalCatalogSuite` Author: Wenchen Fan Closes #15797 from cloud-fan/partition. (cherry picked from commit 2f7461f31331cfc37f6cfa3586b7bbefb3af5547) Signed-off-by: Reynold Xin --- .../catalog/ExternalCatalogUtils.scala | 121 ++++++++++++++ .../catalyst/catalog/InMemoryCatalog.scala | 92 +++++------ .../sql/catalyst/catalog/interface.scala | 11 ++ .../catalog/ExternalCatalogSuite.scala | 150 ++++++++++++++---- .../catalog/SessionCatalogSuite.scala | 24 ++- .../spark/sql/execution/command/ddl.scala | 8 +- .../spark/sql/execution/command/tables.scala | 3 +- .../datasources/CatalogFileIndex.scala | 2 +- .../datasources/DataSourceStrategy.scala | 2 +- .../datasources/FileFormatWriter.scala | 6 +- .../PartitioningAwareFileIndex.scala | 2 - .../datasources/PartitioningUtils.scala | 94 +---------- .../sql/execution/command/DDLSuite.scala | 8 +- .../ParquetPartitionDiscoverySuite.scala | 21 +-- .../spark/sql/hive/HiveExternalCatalog.scala | 51 +++++- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 4 +- .../spark/sql/hive/MultiDatabaseSuite.scala | 2 +- .../sql/hive/execution/HiveDDLSuite.scala | 2 +- .../sql/hive/execution/SQLQuerySuite.scala | 2 +- 19 files changed, 397 insertions(+), 208 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala new file mode 100644 index 0000000000000..b1442eec164d8 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.util.Shell + +import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec + +object ExternalCatalogUtils { + // This duplicates default value of Hive `ConfVars.DEFAULTPARTITIONNAME`, since catalyst doesn't + // depend on Hive. + val DEFAULT_PARTITION_NAME = "__HIVE_DEFAULT_PARTITION__" + + ////////////////////////////////////////////////////////////////////////////////////////////////// + // The following string escaping code is mainly copied from Hive (o.a.h.h.common.FileUtils). + ////////////////////////////////////////////////////////////////////////////////////////////////// + + val charToEscape = { + val bitSet = new java.util.BitSet(128) + + /** + * ASCII 01-1F are HTTP control characters that need to be escaped. + * \u000A and \u000D are \n and \r, respectively. + */ + val clist = Array( + '\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u0009', + '\n', '\u000B', '\u000C', '\r', '\u000E', '\u000F', '\u0010', '\u0011', '\u0012', '\u0013', + '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001A', '\u001B', '\u001C', + '\u001D', '\u001E', '\u001F', '"', '#', '%', '\'', '*', '/', ':', '=', '?', '\\', '\u007F', + '{', '[', ']', '^') + + clist.foreach(bitSet.set(_)) + + if (Shell.WINDOWS) { + Array(' ', '<', '>', '|').foreach(bitSet.set(_)) + } + + bitSet + } + + def needsEscaping(c: Char): Boolean = { + c >= 0 && c < charToEscape.size() && charToEscape.get(c) + } + + def escapePathName(path: String): String = { + val builder = new StringBuilder() + path.foreach { c => + if (needsEscaping(c)) { + builder.append('%') + builder.append(f"${c.asInstanceOf[Int]}%02X") + } else { + builder.append(c) + } + } + + builder.toString() + } + + + def unescapePathName(path: String): String = { + val sb = new StringBuilder + var i = 0 + + while (i < path.length) { + val c = path.charAt(i) + if (c == '%' && i + 2 < path.length) { + val code: Int = try { + Integer.parseInt(path.substring(i + 1, i + 3), 16) + } catch { + case _: Exception => -1 + } + if (code >= 0) { + sb.append(code.asInstanceOf[Char]) + i += 3 + } else { + sb.append(c) + i += 1 + } + } else { + sb.append(c) + i += 1 + } + } + + sb.toString() + } + + def generatePartitionPath( + spec: TablePartitionSpec, + partitionColumnNames: Seq[String], + tablePath: Path): Path = { + val partitionPathStrings = partitionColumnNames.map { col => + val partitionValue = spec(col) + val partitionString = if (partitionValue == null) { + DEFAULT_PARTITION_NAME + } else { + escapePathName(partitionValue) + } + escapePathName(col) + "=" + partitionString + } + partitionPathStrings.foldLeft(tablePath) { (totalPath, nextPartPath) => + new Path(totalPath, nextPartPath) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 20db81e6f9060..a3ffeaa63f690 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -231,7 +231,7 @@ class InMemoryCatalog( assert(tableMeta.storage.locationUri.isDefined, "Managed table should always have table location, as we will assign a default location " + "to it if it doesn't have one.") - val dir = new Path(tableMeta.storage.locationUri.get) + val dir = new Path(tableMeta.location) try { val fs = dir.getFileSystem(hadoopConfig) fs.delete(dir, true) @@ -259,7 +259,7 @@ class InMemoryCatalog( assert(oldDesc.table.storage.locationUri.isDefined, "Managed table should always have table location, as we will assign a default location " + "to it if it doesn't have one.") - val oldDir = new Path(oldDesc.table.storage.locationUri.get) + val oldDir = new Path(oldDesc.table.location) val newDir = new Path(catalog(db).db.locationUri, newName) try { val fs = oldDir.getFileSystem(hadoopConfig) @@ -355,25 +355,28 @@ class InMemoryCatalog( } } - val tableDir = new Path(catalog(db).db.locationUri, table) - val partitionColumnNames = getTable(db, table).partitionColumnNames + val tableMeta = getTable(db, table) + val partitionColumnNames = tableMeta.partitionColumnNames + val tablePath = new Path(tableMeta.location) // TODO: we should follow hive to roll back if one partition path failed to create. parts.foreach { p => - // If location is set, the partition is using an external partition location and we don't - // need to handle its directory. - if (p.storage.locationUri.isEmpty) { - val partitionPath = partitionColumnNames.flatMap { col => - p.spec.get(col).map(col + "=" + _) - }.mkString("/") - try { - val fs = tableDir.getFileSystem(hadoopConfig) - fs.mkdirs(new Path(tableDir, partitionPath)) - } catch { - case e: IOException => - throw new SparkException(s"Unable to create partition path $partitionPath", e) + val partitionPath = p.storage.locationUri.map(new Path(_)).getOrElse { + ExternalCatalogUtils.generatePartitionPath(p.spec, partitionColumnNames, tablePath) + } + + try { + val fs = tablePath.getFileSystem(hadoopConfig) + if (!fs.exists(partitionPath)) { + fs.mkdirs(partitionPath) } + } catch { + case e: IOException => + throw new SparkException(s"Unable to create partition path $partitionPath", e) } - existingParts.put(p.spec, p) + + existingParts.put( + p.spec, + p.copy(storage = p.storage.copy(locationUri = Some(partitionPath.toString)))) } } @@ -392,19 +395,15 @@ class InMemoryCatalog( } } - val tableDir = new Path(catalog(db).db.locationUri, table) - val partitionColumnNames = getTable(db, table).partitionColumnNames - // TODO: we should follow hive to roll back if one partition path failed to delete. + val shouldRemovePartitionLocation = getTable(db, table).tableType == CatalogTableType.MANAGED + // TODO: we should follow hive to roll back if one partition path failed to delete, and support + // partial partition spec. partSpecs.foreach { p => - // If location is set, the partition is using an external partition location and we don't - // need to handle its directory. - if (existingParts.contains(p) && existingParts(p).storage.locationUri.isEmpty) { - val partitionPath = partitionColumnNames.flatMap { col => - p.get(col).map(col + "=" + _) - }.mkString("/") + if (existingParts.contains(p) && shouldRemovePartitionLocation) { + val partitionPath = new Path(existingParts(p).location) try { - val fs = tableDir.getFileSystem(hadoopConfig) - fs.delete(new Path(tableDir, partitionPath), true) + val fs = partitionPath.getFileSystem(hadoopConfig) + fs.delete(partitionPath, true) } catch { case e: IOException => throw new SparkException(s"Unable to delete partition path $partitionPath", e) @@ -423,33 +422,34 @@ class InMemoryCatalog( requirePartitionsExist(db, table, specs) requirePartitionsNotExist(db, table, newSpecs) - val tableDir = new Path(catalog(db).db.locationUri, table) - val partitionColumnNames = getTable(db, table).partitionColumnNames + val tableMeta = getTable(db, table) + val partitionColumnNames = tableMeta.partitionColumnNames + val tablePath = new Path(tableMeta.location) + val shouldUpdatePartitionLocation = getTable(db, table).tableType == CatalogTableType.MANAGED + val existingParts = catalog(db).tables(table).partitions // TODO: we should follow hive to roll back if one partition path failed to rename. specs.zip(newSpecs).foreach { case (oldSpec, newSpec) => - val newPart = getPartition(db, table, oldSpec).copy(spec = newSpec) - val existingParts = catalog(db).tables(table).partitions - - // If location is set, the partition is using an external partition location and we don't - // need to handle its directory. - if (newPart.storage.locationUri.isEmpty) { - val oldPath = partitionColumnNames.flatMap { col => - oldSpec.get(col).map(col + "=" + _) - }.mkString("/") - val newPath = partitionColumnNames.flatMap { col => - newSpec.get(col).map(col + "=" + _) - }.mkString("/") + val oldPartition = getPartition(db, table, oldSpec) + val newPartition = if (shouldUpdatePartitionLocation) { + val oldPartPath = new Path(oldPartition.location) + val newPartPath = ExternalCatalogUtils.generatePartitionPath( + newSpec, partitionColumnNames, tablePath) try { - val fs = tableDir.getFileSystem(hadoopConfig) - fs.rename(new Path(tableDir, oldPath), new Path(tableDir, newPath)) + val fs = tablePath.getFileSystem(hadoopConfig) + fs.rename(oldPartPath, newPartPath) } catch { case e: IOException => - throw new SparkException(s"Unable to rename partition path $oldPath", e) + throw new SparkException(s"Unable to rename partition path $oldPartPath", e) } + oldPartition.copy( + spec = newSpec, + storage = oldPartition.storage.copy(locationUri = Some(newPartPath.toString))) + } else { + oldPartition.copy(spec = newSpec) } existingParts.remove(oldSpec) - existingParts.put(newSpec, newPart) + existingParts.put(newSpec, newPartition) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 34748a04859ad..93c70de18ae7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -99,6 +99,12 @@ case class CatalogTablePartition( output.filter(_.nonEmpty).mkString("CatalogPartition(\n\t", "\n\t", ")") } + /** Return the partition location, assuming it is specified. */ + def location: String = storage.locationUri.getOrElse { + val specString = spec.map { case (k, v) => s"$k=$v" }.mkString(", ") + throw new AnalysisException(s"Partition [$specString] did not specify locationUri") + } + /** * Given the partition schema, returns a row with that schema holding the partition values. */ @@ -171,6 +177,11 @@ case class CatalogTable( throw new AnalysisException(s"table $identifier did not specify database") } + /** Return the table location, assuming it is specified. */ + def location: String = storage.locationUri.getOrElse { + throw new AnalysisException(s"table $identifier did not specify locationUri") + } + /** Return the fully qualified name of this table, assuming the database was specified. */ def qualifiedName: String = identifier.unquotedString diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index 34bdfc8a98710..303a8662d3f4d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -17,9 +17,8 @@ package org.apache.spark.sql.catalyst.catalog -import java.io.File -import java.net.URI - +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite @@ -320,6 +319,33 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac catalog.createPartitions("db2", "tbl2", Seq(part1), ignoreIfExists = true) } + test("create partitions without location") { + val catalog = newBasicCatalog() + val table = CatalogTable( + identifier = TableIdentifier("tbl", Some("db1")), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat(None, None, None, None, false, Map.empty), + schema = new StructType() + .add("col1", "int") + .add("col2", "string") + .add("partCol1", "int") + .add("partCol2", "string"), + provider = Some("hive"), + partitionColumnNames = Seq("partCol1", "partCol2")) + catalog.createTable(table, ignoreIfExists = false) + + val partition = CatalogTablePartition(Map("partCol1" -> "1", "partCol2" -> "2"), storageFormat) + catalog.createPartitions("db1", "tbl", Seq(partition), ignoreIfExists = false) + + val partitionLocation = catalog.getPartition( + "db1", + "tbl", + Map("partCol1" -> "1", "partCol2" -> "2")).location + val tableLocation = catalog.getTable("db1", "tbl").location + val defaultPartitionLocation = new Path(new Path(tableLocation, "partCol1=1"), "partCol2=2") + assert(new Path(partitionLocation) == defaultPartitionLocation) + } + test("list partitions with partial partition spec") { val catalog = newBasicCatalog() val parts = catalog.listPartitions("db2", "tbl2", Some(Map("a" -> "1"))) @@ -399,6 +425,46 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac intercept[AnalysisException] { catalog.getPartition("db2", "tbl2", part2.spec) } } + test("rename partitions should update the location for managed table") { + val catalog = newBasicCatalog() + val table = CatalogTable( + identifier = TableIdentifier("tbl", Some("db1")), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat(None, None, None, None, false, Map.empty), + schema = new StructType() + .add("col1", "int") + .add("col2", "string") + .add("partCol1", "int") + .add("partCol2", "string"), + provider = Some("hive"), + partitionColumnNames = Seq("partCol1", "partCol2")) + catalog.createTable(table, ignoreIfExists = false) + + val tableLocation = catalog.getTable("db1", "tbl").location + + val mixedCasePart1 = CatalogTablePartition( + Map("partCol1" -> "1", "partCol2" -> "2"), storageFormat) + val mixedCasePart2 = CatalogTablePartition( + Map("partCol1" -> "3", "partCol2" -> "4"), storageFormat) + + catalog.createPartitions("db1", "tbl", Seq(mixedCasePart1), ignoreIfExists = false) + assert( + new Path(catalog.getPartition("db1", "tbl", mixedCasePart1.spec).location) == + new Path(new Path(tableLocation, "partCol1=1"), "partCol2=2")) + + catalog.renamePartitions("db1", "tbl", Seq(mixedCasePart1.spec), Seq(mixedCasePart2.spec)) + assert( + new Path(catalog.getPartition("db1", "tbl", mixedCasePart2.spec).location) == + new Path(new Path(tableLocation, "partCol1=3"), "partCol2=4")) + + // For external tables, RENAME PARTITION should not update the partition location. + val existingPartLoc = catalog.getPartition("db2", "tbl2", part1.spec).location + catalog.renamePartitions("db2", "tbl2", Seq(part1.spec), Seq(part3.spec)) + assert( + new Path(catalog.getPartition("db2", "tbl2", part3.spec).location) == + new Path(existingPartLoc)) + } + test("rename partitions when database/table does not exist") { val catalog = newBasicCatalog() intercept[AnalysisException] { @@ -419,11 +485,6 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac test("alter partitions") { val catalog = newBasicCatalog() try { - // Note: Before altering table partitions in Hive, you *must* set the current database - // to the one that contains the table of interest. Otherwise you will end up with the - // most helpful error message ever: "Unable to alter partition. alter is not possible." - // See HIVE-2742 for more detail. - catalog.setCurrentDatabase("db2") val newLocation = newUriForDatabase() val newSerde = "com.sparkbricks.text.EasySerde" val newSerdeProps = Map("spark" -> "bricks", "compressed" -> "false") @@ -571,10 +632,11 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac // -------------------------------------------------------------------------- private def exists(uri: String, children: String*): Boolean = { - val base = new File(new URI(uri)) - children.foldLeft(base) { - case (parent, child) => new File(parent, child) - }.exists() + val base = new Path(uri) + val finalPath = children.foldLeft(base) { + case (parent, child) => new Path(parent, child) + } + base.getFileSystem(new Configuration()).exists(finalPath) } test("create/drop database should create/delete the directory") { @@ -623,7 +685,6 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac test("create/drop/rename partitions should create/delete/rename the directory") { val catalog = newBasicCatalog() - val databaseDir = catalog.getDatabase("db1").locationUri val table = CatalogTable( identifier = TableIdentifier("tbl", Some("db1")), tableType = CatalogTableType.MANAGED, @@ -631,34 +692,61 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac schema = new StructType() .add("col1", "int") .add("col2", "string") - .add("a", "int") - .add("b", "string"), + .add("partCol1", "int") + .add("partCol2", "string"), provider = Some("hive"), - partitionColumnNames = Seq("a", "b") - ) + partitionColumnNames = Seq("partCol1", "partCol2")) catalog.createTable(table, ignoreIfExists = false) + val tableLocation = catalog.getTable("db1", "tbl").location + + val part1 = CatalogTablePartition(Map("partCol1" -> "1", "partCol2" -> "2"), storageFormat) + val part2 = CatalogTablePartition(Map("partCol1" -> "3", "partCol2" -> "4"), storageFormat) + val part3 = CatalogTablePartition(Map("partCol1" -> "5", "partCol2" -> "6"), storageFormat) + catalog.createPartitions("db1", "tbl", Seq(part1, part2), ignoreIfExists = false) - assert(exists(databaseDir, "tbl", "a=1", "b=2")) - assert(exists(databaseDir, "tbl", "a=3", "b=4")) + assert(exists(tableLocation, "partCol1=1", "partCol2=2")) + assert(exists(tableLocation, "partCol1=3", "partCol2=4")) catalog.renamePartitions("db1", "tbl", Seq(part1.spec), Seq(part3.spec)) - assert(!exists(databaseDir, "tbl", "a=1", "b=2")) - assert(exists(databaseDir, "tbl", "a=5", "b=6")) + assert(!exists(tableLocation, "partCol1=1", "partCol2=2")) + assert(exists(tableLocation, "partCol1=5", "partCol2=6")) catalog.dropPartitions("db1", "tbl", Seq(part2.spec, part3.spec), ignoreIfNotExists = false, purge = false) - assert(!exists(databaseDir, "tbl", "a=3", "b=4")) - assert(!exists(databaseDir, "tbl", "a=5", "b=6")) + assert(!exists(tableLocation, "partCol1=3", "partCol2=4")) + assert(!exists(tableLocation, "partCol1=5", "partCol2=6")) - val externalPartition = CatalogTablePartition( - Map("a" -> "7", "b" -> "8"), + val tempPath = Utils.createTempDir() + // create partition with existing directory is OK. + val partWithExistingDir = CatalogTablePartition( + Map("partCol1" -> "7", "partCol2" -> "8"), CatalogStorageFormat( - Some(Utils.createTempDir().getAbsolutePath), - None, None, None, false, Map.empty) - ) - catalog.createPartitions("db1", "tbl", Seq(externalPartition), ignoreIfExists = false) - assert(!exists(databaseDir, "tbl", "a=7", "b=8")) + Some(tempPath.getAbsolutePath), + None, None, None, false, Map.empty)) + catalog.createPartitions("db1", "tbl", Seq(partWithExistingDir), ignoreIfExists = false) + + tempPath.delete() + // create partition with non-existing directory will create that directory. + val partWithNonExistingDir = CatalogTablePartition( + Map("partCol1" -> "9", "partCol2" -> "10"), + CatalogStorageFormat( + Some(tempPath.getAbsolutePath), + None, None, None, false, Map.empty)) + catalog.createPartitions("db1", "tbl", Seq(partWithNonExistingDir), ignoreIfExists = false) + assert(tempPath.exists()) + } + + test("drop partition from external table should not delete the directory") { + val catalog = newBasicCatalog() + catalog.createPartitions("db2", "tbl1", Seq(part1), ignoreIfExists = false) + + val partPath = new Path(catalog.getPartition("db2", "tbl1", part1.spec).location) + val fs = partPath.getFileSystem(new Configuration) + assert(fs.exists(partPath)) + + catalog.dropPartitions("db2", "tbl1", Seq(part1.spec), ignoreIfNotExists = false, purge = false) + assert(fs.exists(partPath)) } } @@ -731,7 +819,7 @@ abstract class CatalogTestUtils { CatalogTable( identifier = TableIdentifier(name, database), tableType = CatalogTableType.EXTERNAL, - storage = storageFormat, + storage = storageFormat.copy(locationUri = Some(Utils.createTempDir().getAbsolutePath)), schema = new StructType() .add("col1", "int") .add("col2", "string") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 001d9c47785d2..52385de50db6b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -527,13 +527,13 @@ class SessionCatalogSuite extends SparkFunSuite { sessionCatalog.createTable(newTable("tbl", "mydb"), ignoreIfExists = false) sessionCatalog.createPartitions( TableIdentifier("tbl", Some("mydb")), Seq(part1, part2), ignoreIfExists = false) - assert(catalogPartitionsEqual(externalCatalog, "mydb", "tbl", Seq(part1, part2))) + assert(catalogPartitionsEqual(externalCatalog.listPartitions("mydb", "tbl"), part1, part2)) // Create partitions without explicitly specifying database sessionCatalog.setCurrentDatabase("mydb") sessionCatalog.createPartitions( TableIdentifier("tbl"), Seq(partWithMixedOrder), ignoreIfExists = false) assert(catalogPartitionsEqual( - externalCatalog, "mydb", "tbl", Seq(part1, part2, partWithMixedOrder))) + externalCatalog.listPartitions("mydb", "tbl"), part1, part2, partWithMixedOrder)) } test("create partitions when database/table does not exist") { @@ -586,13 +586,13 @@ class SessionCatalogSuite extends SparkFunSuite { test("drop partitions") { val externalCatalog = newBasicCatalog() val sessionCatalog = new SessionCatalog(externalCatalog) - assert(catalogPartitionsEqual(externalCatalog, "db2", "tbl2", Seq(part1, part2))) + assert(catalogPartitionsEqual(externalCatalog.listPartitions("db2", "tbl2"), part1, part2)) sessionCatalog.dropPartitions( TableIdentifier("tbl2", Some("db2")), Seq(part1.spec), ignoreIfNotExists = false, purge = false) - assert(catalogPartitionsEqual(externalCatalog, "db2", "tbl2", Seq(part2))) + assert(catalogPartitionsEqual(externalCatalog.listPartitions("db2", "tbl2"), part2)) // Drop partitions without explicitly specifying database sessionCatalog.setCurrentDatabase("db2") sessionCatalog.dropPartitions( @@ -604,7 +604,7 @@ class SessionCatalogSuite extends SparkFunSuite { // Drop multiple partitions at once sessionCatalog.createPartitions( TableIdentifier("tbl2", Some("db2")), Seq(part1, part2), ignoreIfExists = false) - assert(catalogPartitionsEqual(externalCatalog, "db2", "tbl2", Seq(part1, part2))) + assert(catalogPartitionsEqual(externalCatalog.listPartitions("db2", "tbl2"), part1, part2)) sessionCatalog.dropPartitions( TableIdentifier("tbl2", Some("db2")), Seq(part1.spec, part2.spec), @@ -844,10 +844,11 @@ class SessionCatalogSuite extends SparkFunSuite { test("list partitions") { val catalog = new SessionCatalog(newBasicCatalog()) - assert(catalog.listPartitions(TableIdentifier("tbl2", Some("db2"))).toSet == Set(part1, part2)) + assert(catalogPartitionsEqual( + catalog.listPartitions(TableIdentifier("tbl2", Some("db2"))), part1, part2)) // List partitions without explicitly specifying database catalog.setCurrentDatabase("db2") - assert(catalog.listPartitions(TableIdentifier("tbl2")).toSet == Set(part1, part2)) + assert(catalogPartitionsEqual(catalog.listPartitions(TableIdentifier("tbl2")), part1, part2)) } test("list partitions when database/table does not exist") { @@ -860,6 +861,15 @@ class SessionCatalogSuite extends SparkFunSuite { } } + private def catalogPartitionsEqual( + actualParts: Seq[CatalogTablePartition], + expectedParts: CatalogTablePartition*): Boolean = { + // ExternalCatalog may set a default location for partitions, here we ignore the partition + // location when comparing them. + actualParts.map(p => p.copy(storage = p.storage.copy(locationUri = None))).toSet == + expectedParts.map(p => p.copy(storage = p.storage.copy(locationUri = None))).toSet + } + // -------------------------------------------------------------------------- // Functions // -------------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 8500ab460a1b6..84a63fdb9f36f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.Resolver -import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTablePartition, CatalogTableType, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.execution.datasources.{CaseInsensitiveMap, PartitioningUtils} @@ -500,7 +500,7 @@ case class AlterTableRecoverPartitionsCommand( s"location provided: $tableIdentWithDB") } - val root = new Path(table.storage.locationUri.get) + val root = new Path(table.location) logInfo(s"Recover all the partitions in $root") val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) @@ -558,9 +558,9 @@ case class AlterTableRecoverPartitionsCommand( val name = st.getPath.getName if (st.isDirectory && name.contains("=")) { val ps = name.split("=", 2) - val columnName = PartitioningUtils.unescapePathName(ps(0)) + val columnName = ExternalCatalogUtils.unescapePathName(ps(0)) // TODO: Validate the value - val value = PartitioningUtils.unescapePathName(ps(1)) + val value = ExternalCatalogUtils.unescapePathName(ps(1)) if (resolver(columnName, partitionNames.head)) { scanPartitions(spark, fs, filter, st.getPath, spec ++ Map(partitionNames.head -> value), partitionNames.drop(1), threshold, resolver) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index e49a1f5acd0c9..119e732d0202c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -710,7 +710,8 @@ case class ShowPartitionsCommand( private def getPartName(spec: TablePartitionSpec, partColNames: Seq[String]): String = { partColNames.map { name => - PartitioningUtils.escapePathName(name) + "=" + PartitioningUtils.escapePathName(spec(name)) + ExternalCatalogUtils.escapePathName(name) + "=" + + ExternalCatalogUtils.escapePathName(spec(name)) }.mkString(File.separator) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala index 443a2ec033a98..4ad91dcceb432 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala @@ -67,7 +67,7 @@ class CatalogFileIndex( val selectedPartitions = sparkSession.sessionState.catalog.listPartitionsByFilter( table.identifier, filters) val partitions = selectedPartitions.map { p => - val path = new Path(p.storage.locationUri.get) + val path = new Path(p.location) val fs = path.getFileSystem(hadoopConf) PartitionPath( p.toRow(partitionSchema), path.makeQualified(fs.getUri, fs.getWorkingDirectory)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 2d43a6ad098ed..739aeac877b99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -190,7 +190,7 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { val effectiveOutputPath = if (overwritingSinglePartition) { val partition = t.sparkSession.sessionState.catalog.getPartition( l.catalogTable.get.identifier, overwrite.specificPartition.get) - new Path(partition.storage.locationUri.get) + new Path(partition.location) } else { outputPath } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index e404dcd5452b9..0f8ed9e23fe3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -32,7 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.sql.{Dataset, SparkSession} -import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning @@ -281,11 +281,11 @@ object FileFormatWriter extends Logging { private def partitionStringExpression: Seq[Expression] = { description.partitionColumns.zipWithIndex.flatMap { case (c, i) => val escaped = ScalaUDF( - PartitioningUtils.escapePathName _, + ExternalCatalogUtils.escapePathName _, StringType, Seq(Cast(c, StringType)), Seq(StringType)) - val str = If(IsNull(c), Literal(PartitioningUtils.DEFAULT_PARTITION_NAME), escaped) + val str = If(IsNull(c), Literal(ExternalCatalogUtils.DEFAULT_PARTITION_NAME), escaped) val partitionName = Literal(c.name + "=") :: str :: Nil if (i == 0) partitionName else Literal(Path.SEPARATOR) :: partitionName } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index a8a722dd3c620..3740caa22c37e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -128,7 +128,6 @@ abstract class PartitioningAwareFileIndex( case Some(userProvidedSchema) if userProvidedSchema.nonEmpty => val spec = PartitioningUtils.parsePartitions( leafDirs, - PartitioningUtils.DEFAULT_PARTITION_NAME, typeInference = false, basePaths = basePaths) @@ -148,7 +147,6 @@ abstract class PartitioningAwareFileIndex( case _ => PartitioningUtils.parsePartitions( leafDirs, - PartitioningUtils.DEFAULT_PARTITION_NAME, typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled, basePaths = basePaths) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index b51b41869bf06..a28b04ca3fb5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -25,7 +25,6 @@ import scala.collection.mutable.ArrayBuffer import scala.util.Try import org.apache.hadoop.fs.Path -import org.apache.hadoop.util.Shell import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow @@ -56,15 +55,15 @@ object PartitionSpec { } object PartitioningUtils { - // This duplicates default value of Hive `ConfVars.DEFAULTPARTITIONNAME`, since sql/core doesn't - // depend on Hive. - val DEFAULT_PARTITION_NAME = "__HIVE_DEFAULT_PARTITION__" private[datasources] case class PartitionValues(columnNames: Seq[String], literals: Seq[Literal]) { require(columnNames.size == literals.size) } + import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.DEFAULT_PARTITION_NAME + import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.unescapePathName + /** * Given a group of qualified paths, tries to parse them and returns a partition specification. * For example, given: @@ -90,12 +89,11 @@ object PartitioningUtils { */ private[datasources] def parsePartitions( paths: Seq[Path], - defaultPartitionName: String, typeInference: Boolean, basePaths: Set[Path]): PartitionSpec = { // First, we need to parse every partition's path and see if we can find partition values. val (partitionValues, optDiscoveredBasePaths) = paths.map { path => - parsePartition(path, defaultPartitionName, typeInference, basePaths) + parsePartition(path, typeInference, basePaths) }.unzip // We create pairs of (path -> path's partition value) here @@ -173,7 +171,6 @@ object PartitioningUtils { */ private[datasources] def parsePartition( path: Path, - defaultPartitionName: String, typeInference: Boolean, basePaths: Set[Path]): (Option[PartitionValues], Option[Path]) = { val columns = ArrayBuffer.empty[(String, Literal)] @@ -196,7 +193,7 @@ object PartitioningUtils { // Let's say currentPath is a path of "/table/a=1/", currentPath.getName will give us a=1. // Once we get the string, we try to parse it and find the partition column and value. val maybeColumn = - parsePartitionColumn(currentPath.getName, defaultPartitionName, typeInference) + parsePartitionColumn(currentPath.getName, typeInference) maybeColumn.foreach(columns += _) // Now, we determine if we should stop. @@ -228,7 +225,6 @@ object PartitioningUtils { private def parsePartitionColumn( columnSpec: String, - defaultPartitionName: String, typeInference: Boolean): Option[(String, Literal)] = { val equalSignIndex = columnSpec.indexOf('=') if (equalSignIndex == -1) { @@ -240,7 +236,7 @@ object PartitioningUtils { val rawColumnValue = columnSpec.drop(equalSignIndex + 1) assert(rawColumnValue.nonEmpty, s"Empty partition column value in '$columnSpec'") - val literal = inferPartitionColumnValue(rawColumnValue, defaultPartitionName, typeInference) + val literal = inferPartitionColumnValue(rawColumnValue, typeInference) Some(columnName -> literal) } } @@ -355,7 +351,6 @@ object PartitioningUtils { */ private[datasources] def inferPartitionColumnValue( raw: String, - defaultPartitionName: String, typeInference: Boolean): Literal = { val decimalTry = Try { // `BigDecimal` conversion can fail when the `field` is not a form of number. @@ -380,14 +375,14 @@ object PartitioningUtils { .orElse(Try(Literal(JTimestamp.valueOf(unescapePathName(raw))))) // Then falls back to string .getOrElse { - if (raw == defaultPartitionName) { + if (raw == DEFAULT_PARTITION_NAME) { Literal.create(null, NullType) } else { Literal.create(unescapePathName(raw), StringType) } } } else { - if (raw == defaultPartitionName) { + if (raw == DEFAULT_PARTITION_NAME) { Literal.create(null, NullType) } else { Literal.create(unescapePathName(raw), StringType) @@ -450,77 +445,4 @@ object PartitioningUtils { Literal.create(Cast(l, desiredType).eval(), desiredType) } } - - ////////////////////////////////////////////////////////////////////////////////////////////////// - // The following string escaping code is mainly copied from Hive (o.a.h.h.common.FileUtils). - ////////////////////////////////////////////////////////////////////////////////////////////////// - - val charToEscape = { - val bitSet = new java.util.BitSet(128) - - /** - * ASCII 01-1F are HTTP control characters that need to be escaped. - * \u000A and \u000D are \n and \r, respectively. - */ - val clist = Array( - '\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u0009', - '\n', '\u000B', '\u000C', '\r', '\u000E', '\u000F', '\u0010', '\u0011', '\u0012', '\u0013', - '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001A', '\u001B', '\u001C', - '\u001D', '\u001E', '\u001F', '"', '#', '%', '\'', '*', '/', ':', '=', '?', '\\', '\u007F', - '{', '[', ']', '^') - - clist.foreach(bitSet.set(_)) - - if (Shell.WINDOWS) { - Array(' ', '<', '>', '|').foreach(bitSet.set(_)) - } - - bitSet - } - - def needsEscaping(c: Char): Boolean = { - c >= 0 && c < charToEscape.size() && charToEscape.get(c) - } - - def escapePathName(path: String): String = { - val builder = new StringBuilder() - path.foreach { c => - if (needsEscaping(c)) { - builder.append('%') - builder.append(f"${c.asInstanceOf[Int]}%02X") - } else { - builder.append(c) - } - } - - builder.toString() - } - - def unescapePathName(path: String): String = { - val sb = new StringBuilder - var i = 0 - - while (i < path.length) { - val c = path.charAt(i) - if (c == '%' && i + 2 < path.length) { - val code: Int = try { - Integer.parseInt(path.substring(i + 1, i + 3), 16) - } catch { - case _: Exception => -1 - } - if (code >= 0) { - sb.append(code.asInstanceOf[Char]) - i += 3 - } else { - sb.append(c) - i += 1 - } - } else { - sb.append(c) - i += 1 - } - } - - sb.toString() - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index df3a3c34c39a0..363715c6d2249 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -875,7 +875,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) val part2 = Map("a" -> "2", "b" -> "6") - val root = new Path(catalog.getTableMetadata(tableIdent).storage.locationUri.get) + val root = new Path(catalog.getTableMetadata(tableIdent).location) val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) // valid fs.mkdirs(new Path(new Path(root, "a=1"), "b=5")) @@ -1133,7 +1133,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } assert(catalog.getTableMetadata(tableIdent).storage.locationUri.isDefined) assert(catalog.getTableMetadata(tableIdent).storage.properties.isEmpty) - assert(catalog.getPartition(tableIdent, partSpec).storage.locationUri.isEmpty) + assert(catalog.getPartition(tableIdent, partSpec).storage.locationUri.isDefined) assert(catalog.getPartition(tableIdent, partSpec).storage.properties.isEmpty) // Verify that the location is set to the expected string def verifyLocation(expected: String, spec: Option[TablePartitionSpec] = None): Unit = { @@ -1296,9 +1296,9 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { sql("ALTER TABLE dbx.tab1 ADD IF NOT EXISTS " + "PARTITION (a='2', b='6') LOCATION 'paris' PARTITION (a='3', b='7')") assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3)) - assert(catalog.getPartition(tableIdent, part1).storage.locationUri.isEmpty) + assert(catalog.getPartition(tableIdent, part1).storage.locationUri.isDefined) assert(catalog.getPartition(tableIdent, part2).storage.locationUri == Option("paris")) - assert(catalog.getPartition(tableIdent, part3).storage.locationUri.isEmpty) + assert(catalog.getPartition(tableIdent, part3).storage.locationUri.isDefined) // add partitions without explicitly specifying database catalog.setCurrentDatabase("dbx") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 120a3a2ef33aa..22e35a1bc0b1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -29,6 +29,7 @@ import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.{PartitionPath => Partition} @@ -48,11 +49,11 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha import PartitioningUtils._ import testImplicits._ - val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__" + val defaultPartitionName = ExternalCatalogUtils.DEFAULT_PARTITION_NAME test("column type inference") { def check(raw: String, literal: Literal): Unit = { - assert(inferPartitionColumnValue(raw, defaultPartitionName, true) === literal) + assert(inferPartitionColumnValue(raw, true) === literal) } check("10", Literal.create(10, IntegerType)) @@ -76,7 +77,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha "hdfs://host:9000/path/a=10.5/b=hello") var exception = intercept[AssertionError] { - parsePartitions(paths.map(new Path(_)), defaultPartitionName, true, Set.empty[Path]) + parsePartitions(paths.map(new Path(_)), true, Set.empty[Path]) } assert(exception.getMessage().contains("Conflicting directory structures detected")) @@ -88,7 +89,6 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), - defaultPartitionName, true, Set(new Path("hdfs://host:9000/path/"))) @@ -101,7 +101,6 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), - defaultPartitionName, true, Set(new Path("hdfs://host:9000/path/something=true/table"))) @@ -114,7 +113,6 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), - defaultPartitionName, true, Set(new Path("hdfs://host:9000/path/table=true"))) @@ -127,7 +125,6 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha exception = intercept[AssertionError] { parsePartitions( paths.map(new Path(_)), - defaultPartitionName, true, Set(new Path("hdfs://host:9000/path/"))) } @@ -147,7 +144,6 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha exception = intercept[AssertionError] { parsePartitions( paths.map(new Path(_)), - defaultPartitionName, true, Set(new Path("hdfs://host:9000/tmp/tables/"))) } @@ -156,13 +152,13 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha test("parse partition") { def check(path: String, expected: Option[PartitionValues]): Unit = { - val actual = parsePartition(new Path(path), defaultPartitionName, true, Set.empty[Path])._1 + val actual = parsePartition(new Path(path), true, Set.empty[Path])._1 assert(expected === actual) } def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = { val message = intercept[T] { - parsePartition(new Path(path), defaultPartitionName, true, Set.empty[Path]) + parsePartition(new Path(path), true, Set.empty[Path]) }.getMessage assert(message.contains(expected)) @@ -204,7 +200,6 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha // when the basePaths is the same as the path to a leaf directory val partitionSpec1: Option[PartitionValues] = parsePartition( path = new Path("file://path/a=10"), - defaultPartitionName = defaultPartitionName, typeInference = true, basePaths = Set(new Path("file://path/a=10")))._1 @@ -213,7 +208,6 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha // when the basePaths is the path to a base directory of leaf directories val partitionSpec2: Option[PartitionValues] = parsePartition( path = new Path("file://path/a=10"), - defaultPartitionName = defaultPartitionName, typeInference = true, basePaths = Set(new Path("file://path")))._1 @@ -231,7 +225,6 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val actualSpec = parsePartitions( paths.map(new Path(_)), - defaultPartitionName, true, rootPaths) assert(actualSpec === spec) @@ -314,7 +307,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha test("parse partitions with type inference disabled") { def check(paths: Seq[String], spec: PartitionSpec): Unit = { val actualSpec = - parsePartitions(paths.map(new Path(_)), defaultPartitionName, false, Set.empty[Path]) + parsePartitions(paths.map(new Path(_)), false, Set.empty[Path]) assert(actualSpec === spec) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index b537061d0d221..42ce1a88a2b67 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.hive +import java.io.IOException import java.util import scala.util.control.NonFatal @@ -26,7 +27,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.ql.metadata.HiveException import org.apache.thrift.TException -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier @@ -255,7 +256,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // compatible format, which means the data source is file-based and must have a `path`. require(tableDefinition.storage.locationUri.isDefined, "External file-based data source table must have a `path` entry in storage properties.") - Some(new Path(tableDefinition.storage.locationUri.get).toUri.toString) + Some(new Path(tableDefinition.location).toUri.toString) } else { None } @@ -789,7 +790,21 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat parts: Seq[CatalogTablePartition], ignoreIfExists: Boolean): Unit = withClient { requireTableExists(db, table) - val lowerCasedParts = parts.map(p => p.copy(spec = lowerCasePartitionSpec(p.spec))) + + val tableMeta = getTable(db, table) + val partitionColumnNames = tableMeta.partitionColumnNames + val tablePath = new Path(tableMeta.location) + val partsWithLocation = parts.map { p => + // Ideally we can leave the partition location empty and let Hive metastore to set it. + // However, Hive metastore is not case preserving and will generate wrong partition location + // with lower cased partition column names. Here we set the default partition location + // manually to avoid this problem. + val partitionPath = p.storage.locationUri.map(new Path(_)).getOrElse { + ExternalCatalogUtils.generatePartitionPath(p.spec, partitionColumnNames, tablePath) + } + p.copy(storage = p.storage.copy(locationUri = Some(partitionPath.toString))) + } + val lowerCasedParts = partsWithLocation.map(p => p.copy(spec = lowerCasePartitionSpec(p.spec))) client.createPartitions(db, table, lowerCasedParts, ignoreIfExists) } @@ -810,6 +825,31 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat newSpecs: Seq[TablePartitionSpec]): Unit = withClient { client.renamePartitions( db, table, specs.map(lowerCasePartitionSpec), newSpecs.map(lowerCasePartitionSpec)) + + val tableMeta = getTable(db, table) + val partitionColumnNames = tableMeta.partitionColumnNames + // Hive metastore is not case preserving and keeps partition columns with lower cased names. + // When Hive rename partition for managed tables, it will create the partition location with + // a default path generate by the new spec with lower cased partition column names. This is + // unexpected and we need to rename them manually and alter the partition location. + val hasUpperCasePartitionColumn = partitionColumnNames.exists(col => col.toLowerCase != col) + if (tableMeta.tableType == MANAGED && hasUpperCasePartitionColumn) { + val tablePath = new Path(tableMeta.location) + val newParts = newSpecs.map { spec => + val partition = client.getPartition(db, table, lowerCasePartitionSpec(spec)) + val wrongPath = new Path(partition.location) + val rightPath = ExternalCatalogUtils.generatePartitionPath( + spec, partitionColumnNames, tablePath) + try { + tablePath.getFileSystem(hadoopConf).rename(wrongPath, rightPath) + } catch { + case e: IOException => throw new SparkException( + s"Unable to rename partition path from $wrongPath to $rightPath", e) + } + partition.copy(storage = partition.storage.copy(locationUri = Some(rightPath.toString))) + } + alterPartitions(db, table, newParts) + } } override def alterPartitions( @@ -817,6 +857,11 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat table: String, newParts: Seq[CatalogTablePartition]): Unit = withClient { val lowerCasedParts = newParts.map(p => p.copy(spec = lowerCasePartitionSpec(p.spec))) + // Note: Before altering table partitions in Hive, you *must* set the current database + // to the one that contains the table of interest. Otherwise you will end up with the + // most helpful error message ever: "Unable to alter partition. alter is not possible." + // See HIVE-2742 for more detail. + client.setCurrentDatabase(db) client.alterPartitions(db, table, lowerCasedParts) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index d3873cf6c8231..fbd705172cae6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -445,7 +445,7 @@ object SetWarehouseLocationTest extends Logging { catalog.getTableMetadata(TableIdentifier("testLocation", Some("default"))) val expectedLocation = "file:" + expectedWarehouseLocation.toString + "/testlocation" - val actualLocation = tableMetadata.storage.locationUri.get + val actualLocation = tableMetadata.location if (actualLocation != expectedLocation) { throw new Exception( s"Expected table location is $expectedLocation. But, it is actually $actualLocation") @@ -461,7 +461,7 @@ object SetWarehouseLocationTest extends Logging { catalog.getTableMetadata(TableIdentifier("testLocation", Some("testLocationDB"))) val expectedLocation = "file:" + expectedWarehouseLocation.toString + "/testlocationdb.db/testlocation" - val actualLocation = tableMetadata.storage.locationUri.get + val actualLocation = tableMetadata.location if (actualLocation != expectedLocation) { throw new Exception( s"Expected table location is $expectedLocation. But, it is actually $actualLocation") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index cfc1d81d544eb..9f4401ae22560 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -29,7 +29,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle val expectedPath = spark.sharedState.externalCatalog.getDatabase(dbName).locationUri + "/" + tableName - assert(metastoreTable.storage.locationUri.get === expectedPath) + assert(metastoreTable.location === expectedPath) } private def getTableNames(dbName: Option[String] = None): Array[String] = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 0076a778683ca..6efae13ddf69d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -425,7 +425,7 @@ class HiveDDLSuite sql("CREATE TABLE tab1 (height INT, length INT) PARTITIONED BY (a INT, b INT)") val part1 = Map("a" -> "1", "b" -> "5") val part2 = Map("a" -> "2", "b" -> "6") - val root = new Path(catalog.getTableMetadata(tableIdent).storage.locationUri.get) + val root = new Path(catalog.getTableMetadata(tableIdent).location) val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) // valid fs.mkdirs(new Path(new Path(root, "a=1"), "b=5")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index c21db3595fa19..e607af67f93e5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -542,7 +542,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } userSpecifiedLocation match { case Some(location) => - assert(r.catalogTable.storage.locationUri.get === location) + assert(r.catalogTable.location === location) case None => // OK. } // Also make sure that the format and serde are as desired. From 064d4315f246450043a52882fcf59e95d79701e8 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 10 Nov 2016 17:00:43 -0800 Subject: [PATCH 086/534] [SPARK-18185] Fix all forms of INSERT / OVERWRITE TABLE for Datasource tables ## What changes were proposed in this pull request? As of current 2.1, INSERT OVERWRITE with dynamic partitions against a Datasource table will overwrite the entire table instead of only the partitions matching the static keys, as in Hive. It also doesn't respect custom partition locations. This PR adds support for all these operations to Datasource tables managed by the Hive metastore. It is implemented as follows - During planning time, the full set of partitions affected by an INSERT or OVERWRITE command is read from the Hive metastore. - The planner identifies any partitions with custom locations and includes this in the write task metadata. - FileFormatWriter tasks refer to this custom locations map when determining where to write for dynamic partition output. - When the write job finishes, the set of written partitions is compared against the initial set of matched partitions, and the Hive metastore is updated to reflect the newly added / removed partitions. It was necessary to introduce a method for staging files with absolute output paths to `FileCommitProtocol`. These files are not handled by the Hadoop output committer but are moved to their final locations when the job commits. The overwrite behavior of legacy Datasource tables is also changed: no longer will the entire table be overwritten if a partial partition spec is present. cc cloud-fan yhuai ## How was this patch tested? Unit tests, existing tests. Author: Eric Liang Author: Wenchen Fan Closes #15814 from ericl/sc-5027. (cherry picked from commit a3356343cbf58b930326f45721fb4ecade6f8029) Signed-off-by: Reynold Xin --- .../internal/io/FileCommitProtocol.scala | 15 ++ .../io/HadoopMapReduceCommitProtocol.scala | 63 ++++++- .../sql/catalyst/parser/AstBuilder.scala | 12 +- .../plans/logical/basicLogicalOperators.scala | 10 +- .../sql/catalyst/parser/PlanParserSuite.scala | 4 +- .../execution/datasources/DataSource.scala | 20 ++- .../datasources/DataSourceStrategy.scala | 94 +++++++--- .../datasources/FileFormatWriter.scala | 26 ++- .../InsertIntoHadoopFsRelationCommand.scala | 61 ++++++- .../datasources/PartitioningUtils.scala | 10 ++ .../execution/streaming/FileStreamSink.scala | 2 +- .../ManifestFileCommitProtocol.scala | 6 + .../PartitionProviderCompatibilitySuite.scala | 161 +++++++++++++++++- 13 files changed, 411 insertions(+), 73 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala index fb8020585cf89..afd2250c93a8a 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -82,9 +82,24 @@ abstract class FileCommitProtocol { * * The "dir" parameter specifies 2, and "ext" parameter specifies both 4 and 5, and the rest * are left to the commit protocol implementation to decide. + * + * Important: it is the caller's responsibility to add uniquely identifying content to "ext" + * if a task is going to write out multiple files to the same dir. The file commit protocol only + * guarantees that files written by different tasks will not conflict. */ def newTaskTempFile(taskContext: TaskAttemptContext, dir: Option[String], ext: String): String + /** + * Similar to newTaskTempFile(), but allows files to committed to an absolute output location. + * Depending on the implementation, there may be weaker guarantees around adding files this way. + * + * Important: it is the caller's responsibility to add uniquely identifying content to "ext" + * if a task is going to write out multiple files to the same dir. The file commit protocol only + * guarantees that files written by different tasks will not conflict. + */ + def newTaskTempFileAbsPath( + taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String + /** * Commits a task after the writes succeed. Must be called on the executors when running tasks. */ diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 66ccb6d437708..c99b75e52325e 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -17,7 +17,9 @@ package org.apache.spark.internal.io -import java.util.Date +import java.util.{Date, UUID} + +import scala.collection.mutable import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ @@ -42,17 +44,26 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) /** OutputCommitter from Hadoop is not serializable so marking it transient. */ @transient private var committer: OutputCommitter = _ + /** + * Tracks files staged by this task for absolute output paths. These outputs are not managed by + * the Hadoop OutputCommitter, so we must move these to their final locations on job commit. + * + * The mapping is from the temp output path to the final desired output path of the file. + */ + @transient private var addedAbsPathFiles: mutable.Map[String, String] = null + + /** + * The staging directory for all files committed with absolute output paths. + */ + private def absPathStagingDir: Path = new Path(path, "_temporary-" + jobId) + protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { context.getOutputFormatClass.newInstance().getOutputCommitter(context) } override def newTaskTempFile( taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = { - // The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet - // Note that %05d does not truncate the split number, so if we have more than 100000 tasks, - // the file name is fine and won't overflow. - val split = taskContext.getTaskAttemptID.getTaskID.getId - val filename = f"part-$split%05d-$jobId$ext" + val filename = getFilename(taskContext, ext) val stagingDir: String = committer match { // For FileOutputCommitter it has its own staging path called "work path". @@ -67,6 +78,28 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) } } + override def newTaskTempFileAbsPath( + taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = { + val filename = getFilename(taskContext, ext) + val absOutputPath = new Path(absoluteDir, filename).toString + + // Include a UUID here to prevent file collisions for one task writing to different dirs. + // In principle we could include hash(absoluteDir) instead but this is simpler. + val tmpOutputPath = new Path( + absPathStagingDir, UUID.randomUUID().toString() + "-" + filename).toString + + addedAbsPathFiles(tmpOutputPath) = absOutputPath + tmpOutputPath + } + + private def getFilename(taskContext: TaskAttemptContext, ext: String): String = { + // The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet + // Note that %05d does not truncate the split number, so if we have more than 100000 tasks, + // the file name is fine and won't overflow. + val split = taskContext.getTaskAttemptID.getTaskID.getId + f"part-$split%05d-$jobId$ext" + } + override def setupJob(jobContext: JobContext): Unit = { // Setup IDs val jobId = SparkHadoopWriter.createJobID(new Date, 0) @@ -87,25 +120,41 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = { committer.commitJob(jobContext) + val filesToMove = taskCommits.map(_.obj.asInstanceOf[Map[String, String]]) + .foldLeft(Map[String, String]())(_ ++ _) + logDebug(s"Committing files staged for absolute locations $filesToMove") + val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + for ((src, dst) <- filesToMove) { + fs.rename(new Path(src), new Path(dst)) + } + fs.delete(absPathStagingDir, true) } override def abortJob(jobContext: JobContext): Unit = { committer.abortJob(jobContext, JobStatus.State.FAILED) + val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + fs.delete(absPathStagingDir, true) } override def setupTask(taskContext: TaskAttemptContext): Unit = { committer = setupCommitter(taskContext) committer.setupTask(taskContext) + addedAbsPathFiles = mutable.Map[String, String]() } override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = { val attemptId = taskContext.getTaskAttemptID SparkHadoopMapRedUtil.commitTask( committer, taskContext, attemptId.getJobID.getId, attemptId.getTaskID.getId) - EmptyTaskCommitMessage + new TaskCommitMessage(addedAbsPathFiles.toMap) } override def abortTask(taskContext: TaskAttemptContext): Unit = { committer.abortTask(taskContext) + // best effort cleanup of other staged files + for ((src, _) <- addedAbsPathFiles) { + val tmp = new Path(src) + tmp.getFileSystem(taskContext.getConfiguration).delete(tmp, false) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 4b151c81d8f8b..2006844923cf7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -172,24 +172,20 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { val tableIdent = visitTableIdentifier(ctx.tableIdentifier) val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) - val dynamicPartitionKeys = partitionKeys.filter(_._2.isEmpty) + val dynamicPartitionKeys: Map[String, Option[String]] = partitionKeys.filter(_._2.isEmpty) if (ctx.EXISTS != null && dynamicPartitionKeys.nonEmpty) { throw new ParseException(s"Dynamic partitions do not support IF NOT EXISTS. Specified " + "partitions with value: " + dynamicPartitionKeys.keys.mkString("[", ",", "]"), ctx) } val overwrite = ctx.OVERWRITE != null - val overwritePartition = - if (overwrite && partitionKeys.nonEmpty && dynamicPartitionKeys.isEmpty) { - Some(partitionKeys.map(t => (t._1, t._2.get))) - } else { - None - } + val staticPartitionKeys: Map[String, String] = + partitionKeys.filter(_._2.nonEmpty).map(t => (t._1, t._2.get)) InsertIntoTable( UnresolvedRelation(tableIdent, None), partitionKeys, query, - OverwriteOptions(overwrite, overwritePartition), + OverwriteOptions(overwrite, if (overwrite) staticPartitionKeys else Map.empty), ctx.EXISTS != null) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 65ceab2ce27b1..574caf039d3d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -350,13 +350,15 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { * Options for writing new data into a table. * * @param enabled whether to overwrite existing data in the table. - * @param specificPartition only data in the specified partition will be overwritten. + * @param staticPartitionKeys if non-empty, specifies that we only want to overwrite partitions + * that match this partial partition spec. If empty, all partitions + * will be overwritten. */ case class OverwriteOptions( enabled: Boolean, - specificPartition: Option[CatalogTypes.TablePartitionSpec] = None) { - if (specificPartition.isDefined) { - assert(enabled, "Overwrite must be enabled when specifying a partition to overwrite.") + staticPartitionKeys: CatalogTypes.TablePartitionSpec = Map.empty) { + if (staticPartitionKeys.nonEmpty) { + assert(enabled, "Overwrite must be enabled when specifying specific partitions.") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 7400f3430e99c..e5f1f7b3bd4cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -185,9 +185,9 @@ class PlanParserSuite extends PlanTest { OverwriteOptions( overwrite, if (overwrite && partition.nonEmpty) { - Some(partition.map(kv => (kv._1, kv._2.get))) + partition.map(kv => (kv._1, kv._2.get)) } else { - None + Map.empty }), ifNotExists) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 5d663949df6b5..65422f1495f03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -417,15 +417,17 @@ case class DataSource( // will be adjusted within InsertIntoHadoopFsRelation. val plan = InsertIntoHadoopFsRelationCommand( - outputPath, - columns, - bucketSpec, - format, - _ => Unit, // No existing table needs to be refreshed. - options, - data.logicalPlan, - mode, - catalogTable) + outputPath = outputPath, + staticPartitionKeys = Map.empty, + customPartitionLocations = Map.empty, + partitionColumns = columns, + bucketSpec = bucketSpec, + fileFormat = format, + refreshFunction = _ => Unit, // No existing table needs to be refreshed. + options = options, + query = data.logicalPlan, + mode = mode, + catalogTable = catalogTable) sparkSession.sessionState.executePlan(plan).toRdd // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring it. copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 739aeac877b99..4f19a2d00b0e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -24,10 +24,10 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SimpleCatalogRelation} +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTablePartition, SimpleCatalogRelation} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} -import org.apache.spark.sql.execution.command.{AlterTableAddPartitionCommand, DDLUtils, ExecutedCommandExec} +import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -182,41 +182,53 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { "Cannot overwrite a path that is also being read from.") } - val overwritingSinglePartition = - overwrite.specificPartition.isDefined && + val partitionSchema = query.resolve( + t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver) + val partitionsTrackedByCatalog = t.sparkSession.sessionState.conf.manageFilesourcePartitions && + l.catalogTable.isDefined && l.catalogTable.get.partitionColumnNames.nonEmpty && l.catalogTable.get.tracksPartitionsInCatalog - val effectiveOutputPath = if (overwritingSinglePartition) { - val partition = t.sparkSession.sessionState.catalog.getPartition( - l.catalogTable.get.identifier, overwrite.specificPartition.get) - new Path(partition.location) - } else { - outputPath - } - - val effectivePartitionSchema = if (overwritingSinglePartition) { - Nil - } else { - query.resolve(t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver) + var initialMatchingPartitions: Seq[TablePartitionSpec] = Nil + var customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty + + // When partitions are tracked by the catalog, compute all custom partition locations that + // may be relevant to the insertion job. + if (partitionsTrackedByCatalog) { + val matchingPartitions = t.sparkSession.sessionState.catalog.listPartitions( + l.catalogTable.get.identifier, Some(overwrite.staticPartitionKeys)) + initialMatchingPartitions = matchingPartitions.map(_.spec) + customPartitionLocations = getCustomPartitionLocations( + t.sparkSession, l.catalogTable.get, outputPath, matchingPartitions) } + // Callback for updating metastore partition metadata after the insertion job completes. + // TODO(ekl) consider moving this into InsertIntoHadoopFsRelationCommand def refreshPartitionsCallback(updatedPartitions: Seq[TablePartitionSpec]): Unit = { - if (l.catalogTable.isDefined && updatedPartitions.nonEmpty && - l.catalogTable.get.partitionColumnNames.nonEmpty && - l.catalogTable.get.tracksPartitionsInCatalog) { - val metastoreUpdater = AlterTableAddPartitionCommand( - l.catalogTable.get.identifier, - updatedPartitions.map(p => (p, None)), - ifNotExists = true) - metastoreUpdater.run(t.sparkSession) + if (partitionsTrackedByCatalog) { + val newPartitions = updatedPartitions.toSet -- initialMatchingPartitions + if (newPartitions.nonEmpty) { + AlterTableAddPartitionCommand( + l.catalogTable.get.identifier, newPartitions.toSeq.map(p => (p, None)), + ifNotExists = true).run(t.sparkSession) + } + if (overwrite.enabled) { + val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions + if (deletedPartitions.nonEmpty) { + AlterTableDropPartitionCommand( + l.catalogTable.get.identifier, deletedPartitions.toSeq, + ifExists = true, purge = true).run(t.sparkSession) + } + } } t.location.refresh() } val insertCmd = InsertIntoHadoopFsRelationCommand( - effectiveOutputPath, - effectivePartitionSchema, + outputPath, + if (overwrite.enabled) overwrite.staticPartitionKeys else Map.empty, + customPartitionLocations, + partitionSchema, t.bucketSpec, t.fileFormat, refreshPartitionsCallback, @@ -227,6 +239,34 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { insertCmd } + + /** + * Given a set of input partitions, returns those that have locations that differ from the + * Hive default (e.g. /k1=v1/k2=v2). These partitions were manually assigned locations by + * the user. + * + * @return a mapping from partition specs to their custom locations + */ + private def getCustomPartitionLocations( + spark: SparkSession, + table: CatalogTable, + basePath: Path, + partitions: Seq[CatalogTablePartition]): Map[TablePartitionSpec, String] = { + val hadoopConf = spark.sessionState.newHadoopConf + val fs = basePath.getFileSystem(hadoopConf) + val qualifiedBasePath = basePath.makeQualified(fs.getUri, fs.getWorkingDirectory) + partitions.flatMap { p => + val defaultLocation = qualifiedBasePath.suffix( + "/" + PartitioningUtils.getPathFragment(p.spec, table.partitionSchema)).toString + val catalogLocation = new Path(p.location).makeQualified( + fs.getUri, fs.getWorkingDirectory).toString + if (catalogLocation != defaultLocation) { + Some(p.spec -> catalogLocation) + } else { + None + } + }.toMap + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 0f8ed9e23fe3b..edcce103d0963 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -47,6 +47,10 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter /** A helper object for writing FileFormat data out to a location. */ object FileFormatWriter extends Logging { + /** Describes how output files should be placed in the filesystem. */ + case class OutputSpec( + outputPath: String, customPartitionLocations: Map[TablePartitionSpec, String]) + /** A shared job description for all the write tasks. */ private class WriteJobDescription( val uuid: String, // prevent collision between different (appending) write jobs @@ -56,7 +60,8 @@ object FileFormatWriter extends Logging { val partitionColumns: Seq[Attribute], val nonPartitionColumns: Seq[Attribute], val bucketSpec: Option[BucketSpec], - val path: String) + val path: String, + val customPartitionLocations: Map[TablePartitionSpec, String]) extends Serializable { assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ nonPartitionColumns), @@ -83,7 +88,7 @@ object FileFormatWriter extends Logging { plan: LogicalPlan, fileFormat: FileFormat, committer: FileCommitProtocol, - outputPath: String, + outputSpec: OutputSpec, hadoopConf: Configuration, partitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], @@ -93,7 +98,7 @@ object FileFormatWriter extends Logging { val job = Job.getInstance(hadoopConf) job.setOutputKeyClass(classOf[Void]) job.setOutputValueClass(classOf[InternalRow]) - FileOutputFormat.setOutputPath(job, new Path(outputPath)) + FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath)) val partitionSet = AttributeSet(partitionColumns) val dataColumns = plan.output.filterNot(partitionSet.contains) @@ -111,7 +116,8 @@ object FileFormatWriter extends Logging { partitionColumns = partitionColumns, nonPartitionColumns = dataColumns, bucketSpec = bucketSpec, - path = outputPath) + path = outputSpec.outputPath, + customPartitionLocations = outputSpec.customPartitionLocations) SQLExecution.withNewExecutionId(sparkSession, queryExecution) { // This call shouldn't be put into the `try` block below because it only initializes and @@ -308,7 +314,17 @@ object FileFormatWriter extends Logging { } val ext = bucketId + description.outputWriterFactory.getFileExtension(taskAttemptContext) - val path = committer.newTaskTempFile(taskAttemptContext, partDir, ext) + val customPath = partDir match { + case Some(dir) => + description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) + case _ => + None + } + val path = if (customPath.isDefined) { + committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext) + } else { + committer.newTaskTempFile(taskAttemptContext, partDir, ext) + } val newWriter = description.outputWriterFactory.newInstance( path = path, dataSchema = description.nonPartitionColumns.toStructType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index a0a8cb5024c33..28975e1546e79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import java.io.IOException -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql._ @@ -32,19 +32,32 @@ import org.apache.spark.sql.execution.command.RunnableCommand /** * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending. * Writing to dynamic partitions is also supported. + * + * @param staticPartitionKeys partial partitioning spec for write. This defines the scope of + * partition overwrites: when the spec is empty, all partitions are + * overwritten. When it covers a prefix of the partition keys, only + * partitions matching the prefix are overwritten. + * @param customPartitionLocations mapping of partition specs to their custom locations. The + * caller should guarantee that exactly those table partitions + * falling under the specified static partition keys are contained + * in this map, and that no other partitions are. */ case class InsertIntoHadoopFsRelationCommand( outputPath: Path, + staticPartitionKeys: TablePartitionSpec, + customPartitionLocations: Map[TablePartitionSpec, String], partitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], fileFormat: FileFormat, - refreshFunction: (Seq[TablePartitionSpec]) => Unit, + refreshFunction: Seq[TablePartitionSpec] => Unit, options: Map[String, String], @transient query: LogicalPlan, mode: SaveMode, catalogTable: Option[CatalogTable]) extends RunnableCommand { + import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName + override protected def innerChildren: Seq[LogicalPlan] = query :: Nil override def run(sparkSession: SparkSession): Seq[Row] = { @@ -66,10 +79,7 @@ case class InsertIntoHadoopFsRelationCommand( case (SaveMode.ErrorIfExists, true) => throw new AnalysisException(s"path $qualifiedOutputPath already exists.") case (SaveMode.Overwrite, true) => - if (!fs.delete(qualifiedOutputPath, true /* recursively */)) { - throw new IOException(s"Unable to clear output " + - s"directory $qualifiedOutputPath prior to writing to it") - } + deleteMatchingPartitions(fs, qualifiedOutputPath) true case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) => true @@ -93,7 +103,8 @@ case class InsertIntoHadoopFsRelationCommand( plan = query, fileFormat = fileFormat, committer = committer, - outputPath = qualifiedOutputPath.toString, + outputSpec = FileFormatWriter.OutputSpec( + qualifiedOutputPath.toString, customPartitionLocations), hadoopConf = hadoopConf, partitionColumns = partitionColumns, bucketSpec = bucketSpec, @@ -105,4 +116,40 @@ case class InsertIntoHadoopFsRelationCommand( Seq.empty[Row] } + + /** + * Deletes all partition files that match the specified static prefix. Partitions with custom + * locations are also cleared based on the custom locations map given to this class. + */ + private def deleteMatchingPartitions(fs: FileSystem, qualifiedOutputPath: Path): Unit = { + val staticPartitionPrefix = if (staticPartitionKeys.nonEmpty) { + "/" + partitionColumns.flatMap { p => + staticPartitionKeys.get(p.name) match { + case Some(value) => + Some(escapePathName(p.name) + "=" + escapePathName(value)) + case None => + None + } + }.mkString("/") + } else { + "" + } + // first clear the path determined by the static partition keys (e.g. /table/foo=1) + val staticPrefixPath = qualifiedOutputPath.suffix(staticPartitionPrefix) + if (fs.exists(staticPrefixPath) && !fs.delete(staticPrefixPath, true /* recursively */)) { + throw new IOException(s"Unable to clear output " + + s"directory $staticPrefixPath prior to writing to it") + } + // now clear all custom partition locations (e.g. /custom/dir/where/foo=2/bar=4) + for ((spec, customLoc) <- customPartitionLocations) { + assert( + (staticPartitionKeys.toSet -- spec).isEmpty, + "Custom partition location did not match static partitioning keys") + val path = new Path(customLoc) + if (fs.exists(path) && !fs.delete(path, true)) { + throw new IOException(s"Unable to clear partition " + + s"directory $path prior to writing to it") + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index a28b04ca3fb5a..bf9f318780ec2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -62,6 +62,7 @@ object PartitioningUtils { } import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.DEFAULT_PARTITION_NAME + import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.unescapePathName /** @@ -252,6 +253,15 @@ object PartitioningUtils { }.toMap } + /** + * This is the inverse of parsePathFragment(). + */ + def getPathFragment(spec: TablePartitionSpec, partitionSchema: StructType): String = { + partitionSchema.map { field => + escapePathName(field.name) + "=" + escapePathName(spec(field.name)) + }.mkString("/") + } + /** * Normalize the column names in partition specification, w.r.t. the real partition column names * and case sensitivity. e.g., if the partition spec has a column named `monTh`, and there is a diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index e849cafef4184..f1c5f9ab5067d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -80,7 +80,7 @@ class FileStreamSink( plan = data.logicalPlan, fileFormat = fileFormat, committer = committer, - outputPath = path, + outputSpec = FileFormatWriter.OutputSpec(path, Map.empty), hadoopConf = hadoopConf, partitionColumns = partitionColumns, bucketSpec = None, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala index 1fe13fa1623fc..92191c8b64b72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala @@ -96,6 +96,12 @@ class ManifestFileCommitProtocol(jobId: String, path: String) file } + override def newTaskTempFileAbsPath( + taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = { + throw new UnsupportedOperationException( + s"$this does not support adding files with an absolute path") + } + override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = { if (addedFiles.nonEmpty) { val fs = new Path(addedFiles.head).getFileSystem(taskContext.getConfiguration) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala index ac435bf6195b0..a1aa07456fd36 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.util.Utils class PartitionProviderCompatibilitySuite extends QueryTest with TestHiveSingleton with SQLTestUtils { @@ -135,7 +136,7 @@ class PartitionProviderCompatibilitySuite } } - test("insert overwrite partition of legacy datasource table overwrites entire table") { + test("insert overwrite partition of legacy datasource table") { withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "false") { withTable("test") { withTempDir { dir => @@ -144,9 +145,9 @@ class PartitionProviderCompatibilitySuite """insert overwrite table test |partition (partCol=1) |select * from range(100)""".stripMargin) - assert(spark.sql("select * from test").count() == 100) + assert(spark.sql("select * from test").count() == 104) - // Dynamic partitions case + // Overwriting entire table spark.sql("insert overwrite table test select id, id from range(10)".stripMargin) assert(spark.sql("select * from test").count() == 10) } @@ -186,4 +187,158 @@ class PartitionProviderCompatibilitySuite } } } + + /** + * Runs a test against a multi-level partitioned table, then validates that the custom locations + * were respected by the output writer. + * + * The initial partitioning structure is: + * /P1=0/P2=0 -- custom location a + * /P1=0/P2=1 -- custom location b + * /P1=1/P2=0 -- custom location c + * /P1=1/P2=1 -- default location + */ + private def testCustomLocations(testFn: => Unit): Unit = { + val base = Utils.createTempDir(namePrefix = "base") + val a = Utils.createTempDir(namePrefix = "a") + val b = Utils.createTempDir(namePrefix = "b") + val c = Utils.createTempDir(namePrefix = "c") + try { + spark.sql(s""" + |create table test (id long, P1 int, P2 int) + |using parquet + |options (path "${base.getAbsolutePath}") + |partitioned by (P1, P2)""".stripMargin) + spark.sql(s"alter table test add partition (P1=0, P2=0) location '${a.getAbsolutePath}'") + spark.sql(s"alter table test add partition (P1=0, P2=1) location '${b.getAbsolutePath}'") + spark.sql(s"alter table test add partition (P1=1, P2=0) location '${c.getAbsolutePath}'") + spark.sql(s"alter table test add partition (P1=1, P2=1)") + + testFn + + // Now validate the partition custom locations were respected + val initialCount = spark.sql("select * from test").count() + val numA = spark.sql("select * from test where P1=0 and P2=0").count() + val numB = spark.sql("select * from test where P1=0 and P2=1").count() + val numC = spark.sql("select * from test where P1=1 and P2=0").count() + Utils.deleteRecursively(a) + spark.sql("refresh table test") + assert(spark.sql("select * from test where P1=0 and P2=0").count() == 0) + assert(spark.sql("select * from test").count() == initialCount - numA) + Utils.deleteRecursively(b) + spark.sql("refresh table test") + assert(spark.sql("select * from test where P1=0 and P2=1").count() == 0) + assert(spark.sql("select * from test").count() == initialCount - numA - numB) + Utils.deleteRecursively(c) + spark.sql("refresh table test") + assert(spark.sql("select * from test where P1=1 and P2=0").count() == 0) + assert(spark.sql("select * from test").count() == initialCount - numA - numB - numC) + } finally { + Utils.deleteRecursively(base) + Utils.deleteRecursively(a) + Utils.deleteRecursively(b) + Utils.deleteRecursively(c) + spark.sql("drop table test") + } + } + + test("sanity check table setup") { + testCustomLocations { + assert(spark.sql("select * from test").count() == 0) + assert(spark.sql("show partitions test").count() == 4) + } + } + + test("insert into partial dynamic partitions") { + testCustomLocations { + spark.sql("insert into test partition (P1=0, P2) select id, id from range(10)") + assert(spark.sql("select * from test").count() == 10) + assert(spark.sql("show partitions test").count() == 12) + spark.sql("insert into test partition (P1=0, P2) select id, id from range(10)") + assert(spark.sql("select * from test").count() == 20) + assert(spark.sql("show partitions test").count() == 12) + spark.sql("insert into test partition (P1=1, P2) select id, id from range(10)") + assert(spark.sql("select * from test").count() == 30) + assert(spark.sql("show partitions test").count() == 20) + spark.sql("insert into test partition (P1=2, P2) select id, id from range(10)") + assert(spark.sql("select * from test").count() == 40) + assert(spark.sql("show partitions test").count() == 30) + } + } + + test("insert into fully dynamic partitions") { + testCustomLocations { + spark.sql("insert into test partition (P1, P2) select id, id, id from range(10)") + assert(spark.sql("select * from test").count() == 10) + assert(spark.sql("show partitions test").count() == 12) + spark.sql("insert into test partition (P1, P2) select id, id, id from range(10)") + assert(spark.sql("select * from test").count() == 20) + assert(spark.sql("show partitions test").count() == 12) + } + } + + test("insert into static partition") { + testCustomLocations { + spark.sql("insert into test partition (P1=0, P2=0) select id from range(10)") + assert(spark.sql("select * from test").count() == 10) + assert(spark.sql("show partitions test").count() == 4) + spark.sql("insert into test partition (P1=0, P2=0) select id from range(10)") + assert(spark.sql("select * from test").count() == 20) + assert(spark.sql("show partitions test").count() == 4) + spark.sql("insert into test partition (P1=1, P2=1) select id from range(10)") + assert(spark.sql("select * from test").count() == 30) + assert(spark.sql("show partitions test").count() == 4) + } + } + + test("overwrite partial dynamic partitions") { + testCustomLocations { + spark.sql("insert overwrite table test partition (P1=0, P2) select id, id from range(10)") + assert(spark.sql("select * from test").count() == 10) + assert(spark.sql("show partitions test").count() == 12) + spark.sql("insert overwrite table test partition (P1=0, P2) select id, id from range(5)") + assert(spark.sql("select * from test").count() == 5) + assert(spark.sql("show partitions test").count() == 7) + spark.sql("insert overwrite table test partition (P1=0, P2) select id, id from range(1)") + assert(spark.sql("select * from test").count() == 1) + assert(spark.sql("show partitions test").count() == 3) + spark.sql("insert overwrite table test partition (P1=1, P2) select id, id from range(10)") + assert(spark.sql("select * from test").count() == 11) + assert(spark.sql("show partitions test").count() == 11) + spark.sql("insert overwrite table test partition (P1=1, P2) select id, id from range(1)") + assert(spark.sql("select * from test").count() == 2) + assert(spark.sql("show partitions test").count() == 2) + spark.sql("insert overwrite table test partition (P1=3, P2) select id, id from range(100)") + assert(spark.sql("select * from test").count() == 102) + assert(spark.sql("show partitions test").count() == 102) + } + } + + test("overwrite fully dynamic partitions") { + testCustomLocations { + spark.sql("insert overwrite table test partition (P1, P2) select id, id, id from range(10)") + assert(spark.sql("select * from test").count() == 10) + assert(spark.sql("show partitions test").count() == 10) + spark.sql("insert overwrite table test partition (P1, P2) select id, id, id from range(5)") + assert(spark.sql("select * from test").count() == 5) + assert(spark.sql("show partitions test").count() == 5) + } + } + + test("overwrite static partition") { + testCustomLocations { + spark.sql("insert overwrite table test partition (P1=0, P2=0) select id from range(10)") + assert(spark.sql("select * from test").count() == 10) + assert(spark.sql("show partitions test").count() == 4) + spark.sql("insert overwrite table test partition (P1=0, P2=0) select id from range(5)") + assert(spark.sql("select * from test").count() == 5) + assert(spark.sql("show partitions test").count() == 4) + spark.sql("insert overwrite table test partition (P1=1, P2=1) select id from range(5)") + assert(spark.sql("select * from test").count() == 10) + assert(spark.sql("show partitions test").count() == 4) + spark.sql("insert overwrite table test partition (P1=1, P2=2) select id from range(5)") + assert(spark.sql("select * from test").count() == 15) + assert(spark.sql("show partitions test").count() == 5) + } + } } From 51dca6143670ec1c1cb090047c3941becaf41fa9 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 10 Nov 2016 17:13:10 -0800 Subject: [PATCH 087/534] [SPARK-18401][SPARKR][ML] SparkR random forest should support output original label. ## What changes were proposed in this pull request? SparkR ```spark.randomForest``` classification prediction should output original label rather than the indexed label. This issue is very similar with [SPARK-18291](https://issues.apache.org/jira/browse/SPARK-18291). ## How was this patch tested? Add unit tests. Author: Yanbo Liang Closes #15842 from yanboliang/spark-18401. (cherry picked from commit 5ddf69470b93c0b8a28bb4ac905e7670d9c50a95) Signed-off-by: Yanbo Liang --- R/pkg/inst/tests/testthat/test_mllib.R | 24 ++++++++++++++++ .../r/RandomForestClassificationWrapper.scala | 28 ++++++++++++++++--- 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 1e456ef5c6b16..33e85b78de4fe 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -935,6 +935,10 @@ test_that("spark.randomForest Classification", { expect_equal(stats$numTrees, 20) expect_error(capture.output(stats), NA) expect_true(length(capture.output(stats)) > 6) + # Test string prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("setosa", predictions)), 50) + expect_equal(length(grep("versicolor", predictions)), 50) modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") write.ml(model, modelPath) @@ -947,6 +951,26 @@ test_that("spark.randomForest Classification", { expect_equal(stats$numClasses, stats2$numClasses) unlink(modelPath) + + # Test numeric response variable + labelToIndex <- function(species) { + switch(as.character(species), + setosa = 0.0, + versicolor = 1.0, + virginica = 2.0 + ) + } + iris$NumericSpecies <- lapply(iris$Species, labelToIndex) + data <- suppressWarnings(createDataFrame(iris[-5])) + model <- spark.randomForest(data, NumericSpecies ~ Petal_Length + Petal_Width, "classification", + maxDepth = 5, maxBins = 16) + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + # Test numeric prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("1.0", predictions)), 50) + expect_equal(length(grep("2.0", predictions)), 50) }) test_that("spark.gbt", { diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala index 6947ba7e7597a..31f846dc6cfec 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala @@ -23,9 +23,9 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} -import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} -import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.feature.{IndexToString, RFormula} import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -35,6 +35,8 @@ private[r] class RandomForestClassifierWrapper private ( val formula: String, val features: Array[String]) extends MLWritable { + import RandomForestClassifierWrapper._ + private val rfcModel: RandomForestClassificationModel = pipeline.stages(1).asInstanceOf[RandomForestClassificationModel] @@ -46,7 +48,9 @@ private[r] class RandomForestClassifierWrapper private ( def summary: String = rfcModel.toDebugString def transform(dataset: Dataset[_]): DataFrame = { - pipeline.transform(dataset).drop(rfcModel.getFeaturesCol) + pipeline.transform(dataset) + .drop(PREDICTED_LABEL_INDEX_COL) + .drop(rfcModel.getFeaturesCol) } override def write: MLWriter = new @@ -54,6 +58,10 @@ private[r] class RandomForestClassifierWrapper private ( } private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestClassifierWrapper] { + + val PREDICTED_LABEL_INDEX_COL = "pred_label_idx" + val PREDICTED_LABEL_COL = "prediction" + def fit( // scalastyle:ignore data: DataFrame, formula: String, @@ -73,6 +81,7 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC val rFormula = new RFormula() .setFormula(formula) + .setForceIndexLabel(true) RWrapperUtils.checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) @@ -82,6 +91,11 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC .attributes.get val features = featureAttrs.map(_.name.get) + // get label names from output schema + val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol)) + .asInstanceOf[NominalAttribute] + val labels = labelAttr.values.get + // assemble and fit the pipeline val rfc = new RandomForestClassifier() .setMaxDepth(maxDepth) @@ -97,10 +111,16 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC .setCacheNodeIds(cacheNodeIds) .setProbabilityCol(probabilityCol) .setFeaturesCol(rFormula.getFeaturesCol) + .setPredictionCol(PREDICTED_LABEL_INDEX_COL) if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong) + val idxToStr = new IndexToString() + .setInputCol(PREDICTED_LABEL_INDEX_COL) + .setOutputCol(PREDICTED_LABEL_COL) + .setLabels(labels) + val pipeline = new Pipeline() - .setStages(Array(rFormulaModel, rfc)) + .setStages(Array(rFormulaModel, rfc, idxToStr)) .fit(data) new RandomForestClassifierWrapper(pipeline, formula, features) From 00c9c7d96489778dfe38a36675d3162bf8844880 Mon Sep 17 00:00:00 2001 From: Vinayak Date: Fri, 11 Nov 2016 12:54:16 -0600 Subject: [PATCH 088/534] [SPARK-17843][WEB UI] Indicate event logs pending for processing on history server UI ## What changes were proposed in this pull request? History Server UI's application listing to display information on currently under process event logs so a user knows that pending this processing an application may not list on the UI. When there are no event logs under process, the application list page has a "Last Updated" date-time at the top indicating the date-time of the last _completed_ scan of the event logs. The value is displayed to the user in his/her local time zone. ## How was this patch tested? All unit tests pass. Particularly all the suites under org.apache.spark.deploy.history.\* were run to test changes. - Very first startup - Pending logs - no logs processed yet: screen shot 2016-10-24 at 3 07 04 pm - Very first startup - Pending logs - some logs processed: screen shot 2016-10-24 at 3 18 42 pm - Last updated - No currently pending logs: screen shot 2016-10-17 at 8 34 37 pm - Last updated - With some currently pending logs: screen shot 2016-10-24 at 3 09 31 pm - No applications found and No currently pending logs: screen shot 2016-10-24 at 3 24 26 pm Author: Vinayak Closes #15410 from vijoshi/SAAS-608_master. (cherry picked from commit a531fe1a82ec515314f2db2e2305283fef24067f) Signed-off-by: Tom Graves --- .../spark/ui/static/historypage-common.js | 24 ++++++++ .../history/ApplicationHistoryProvider.scala | 24 ++++++++ .../deploy/history/FsHistoryProvider.scala | 59 +++++++++++++------ .../spark/deploy/history/HistoryPage.scala | 19 ++++++ .../spark/deploy/history/HistoryServer.scala | 8 +++ 5 files changed, 116 insertions(+), 18 deletions(-) create mode 100644 core/src/main/resources/org/apache/spark/ui/static/historypage-common.js diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-common.js b/core/src/main/resources/org/apache/spark/ui/static/historypage-common.js new file mode 100644 index 0000000000000..55d540d8317a0 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-common.js @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +$(document).ready(function() { + if ($('#last-updated').length) { + var lastUpdatedMillis = Number($('#last-updated').text()); + var updatedDate = new Date(lastUpdatedMillis); + $('#last-updated').text(updatedDate.toLocaleDateString()+", "+updatedDate.toLocaleTimeString()) + } +}); diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index 06530ff836466..d7d82800b8b55 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -74,6 +74,30 @@ private[history] case class LoadedAppUI( private[history] abstract class ApplicationHistoryProvider { + /** + * Returns the count of application event logs that the provider is currently still processing. + * History Server UI can use this to indicate to a user that the application listing on the UI + * can be expected to list additional known applications once the processing of these + * application event logs completes. + * + * A History Provider that does not have a notion of count of event logs that may be pending + * for processing need not override this method. + * + * @return Count of application event logs that are currently under process + */ + def getEventLogsUnderProcess(): Int = { + return 0; + } + + /** + * Returns the time the history provider last updated the application history information + * + * @return 0 if this is undefined or unsupported, otherwise the last updated time in millis + */ + def getLastUpdatedTime(): Long = { + return 0; + } + /** * Returns a list of applications available for the history server to show. * diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index dfc1aad64c818..ca38a47639422 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.history import java.io.{FileNotFoundException, IOException, OutputStream} import java.util.UUID -import java.util.concurrent.{Executors, ExecutorService, TimeUnit} +import java.util.concurrent.{Executors, ExecutorService, Future, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.mutable @@ -108,7 +108,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // The modification time of the newest log detected during the last scan. Currently only // used for logging msgs (logs are re-scanned based on file size, rather than modtime) - private var lastScanTime = -1L + private val lastScanTime = new java.util.concurrent.atomic.AtomicLong(-1) // Mapping of application IDs to their metadata, in descending end time order. Apps are inserted // into the map in order, so the LinkedHashMap maintains the correct ordering. @@ -120,6 +120,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // List of application logs to be deleted by event log cleaner. private var attemptsToClean = new mutable.ListBuffer[FsApplicationAttemptInfo] + private val pendingReplayTasksCount = new java.util.concurrent.atomic.AtomicInteger(0) + /** * Return a runnable that performs the given operation on the event logs. * This operation is expected to be executed periodically. @@ -226,6 +228,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) applications.get(appId) } + override def getEventLogsUnderProcess(): Int = pendingReplayTasksCount.get() + + override def getLastUpdatedTime(): Long = lastScanTime.get() + override def getAppUI(appId: String, attemptId: Option[String]): Option[LoadedAppUI] = { try { applications.get(appId).flatMap { appInfo => @@ -329,26 +335,43 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) if (logInfos.nonEmpty) { logDebug(s"New/updated attempts found: ${logInfos.size} ${logInfos.map(_.getPath)}") } - logInfos.map { file => - replayExecutor.submit(new Runnable { + + var tasks = mutable.ListBuffer[Future[_]]() + + try { + for (file <- logInfos) { + tasks += replayExecutor.submit(new Runnable { override def run(): Unit = mergeApplicationListing(file) }) } - .foreach { task => - try { - // Wait for all tasks to finish. This makes sure that checkForLogs - // is not scheduled again while some tasks are already running in - // the replayExecutor. - task.get() - } catch { - case e: InterruptedException => - throw e - case e: Exception => - logError("Exception while merging application listings", e) - } + } catch { + // let the iteration over logInfos break, since an exception on + // replayExecutor.submit (..) indicates the ExecutorService is unable + // to take any more submissions at this time + + case e: Exception => + logError(s"Exception while submitting event log for replay", e) + } + + pendingReplayTasksCount.addAndGet(tasks.size) + + tasks.foreach { task => + try { + // Wait for all tasks to finish. This makes sure that checkForLogs + // is not scheduled again while some tasks are already running in + // the replayExecutor. + task.get() + } catch { + case e: InterruptedException => + throw e + case e: Exception => + logError("Exception while merging application listings", e) + } finally { + pendingReplayTasksCount.decrementAndGet() } + } - lastScanTime = newLastScanTime + lastScanTime.set(newLastScanTime) } catch { case e: Exception => logError("Exception in checking for event log updates", e) } @@ -365,7 +388,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } catch { case e: Exception => logError("Exception encountered when attempting to update last scan time", e) - lastScanTime + lastScanTime.get() } finally { if (!fs.delete(path, true)) { logWarning(s"Error deleting ${path}") diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 96b9ecf43b14c..0e7a6c24d4fa5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -30,13 +30,30 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") Option(request.getParameter("showIncomplete")).getOrElse("false").toBoolean val allAppsSize = parent.getApplicationList().count(_.completed != requestedIncomplete) + val eventLogsUnderProcessCount = parent.getEventLogsUnderProcess() + val lastUpdatedTime = parent.getLastUpdatedTime() val providerConfig = parent.getProviderConfig() val content = +
    {providerConfig.map { case (k, v) =>
  • {k}: {v}
  • }}
+ { + if (eventLogsUnderProcessCount > 0) { +

There are {eventLogsUnderProcessCount} event log(s) currently being + processed which may result in additional applications getting listed on this page. + Refresh the page to view updates.

+ } + } + + { + if (lastUpdatedTime > 0) { +

Last updated: {lastUpdatedTime}

+ } + } + { if (allAppsSize > 0) { ++ @@ -46,6 +63,8 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") } else if (requestedIncomplete) {

No incomplete applications found!

+ } else if (eventLogsUnderProcessCount > 0) { +

No completed applications found!

} else {

No completed applications found!

++ parent.emptyListingHtml } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 3175b36b3e56f..7e21fa681aa1e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -179,6 +179,14 @@ class HistoryServer( provider.getListing() } + def getEventLogsUnderProcess(): Int = { + provider.getEventLogsUnderProcess() + } + + def getLastUpdatedTime(): Long = { + provider.getLastUpdatedTime() + } + def getApplicationInfoList: Iterator[ApplicationInfo] = { getApplicationList().map(ApplicationsListResource.appHistoryInfoToPublicAppInfo) } From 465e4b40b3b7760bfcd0f03a14b805029ed599f1 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 11 Nov 2016 13:28:18 -0800 Subject: [PATCH 089/534] [SPARK-17982][SQL] SQLBuilder should wrap the generated SQL with parenthesis for LIMIT ## What changes were proposed in this pull request? Currently, `SQLBuilder` handles `LIMIT` by always adding `LIMIT` at the end of the generated subSQL. It makes `RuntimeException`s like the following. This PR adds a parenthesis always except `SubqueryAlias` is used together with `LIMIT`. **Before** ``` scala scala> sql("CREATE TABLE tbl(id INT)") scala> sql("CREATE VIEW v1(id2) AS SELECT id FROM tbl LIMIT 2") java.lang.RuntimeException: Failed to analyze the canonicalized SQL: ... ``` **After** ``` scala scala> sql("CREATE TABLE tbl(id INT)") scala> sql("CREATE VIEW v1(id2) AS SELECT id FROM tbl LIMIT 2") scala> sql("SELECT id2 FROM v1") res4: org.apache.spark.sql.DataFrame = [id2: int] ``` **Fixed cases in this PR** The following two cases are the detail query plans having problematic SQL generations. 1. `SELECT * FROM (SELECT id FROM tbl LIMIT 2)` Please note that **FROM SELECT** part of the generated SQL in the below. When we don't use '()' for limit, this fails. ```scala # Original logical plan: Project [id#1] +- GlobalLimit 2 +- LocalLimit 2 +- Project [id#1] +- MetastoreRelation default, tbl # Canonicalized logical plan: Project [gen_attr_0#1 AS id#4] +- SubqueryAlias tbl +- Project [gen_attr_0#1] +- GlobalLimit 2 +- LocalLimit 2 +- Project [gen_attr_0#1] +- SubqueryAlias gen_subquery_0 +- Project [id#1 AS gen_attr_0#1] +- SQLTable default, tbl, [id#1] # Generated SQL: SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`tbl`) AS gen_subquery_0 LIMIT 2) AS tbl ``` 2. `SELECT * FROM (SELECT id FROM tbl TABLESAMPLE (2 ROWS))` Please note that **((~~~) AS gen_subquery_0 LIMIT 2)** in the below. When we use '()' for limit on `SubqueryAlias`, this fails. ```scala # Original logical plan: Project [id#1] +- Project [id#1] +- GlobalLimit 2 +- LocalLimit 2 +- MetastoreRelation default, tbl # Canonicalized logical plan: Project [gen_attr_0#1 AS id#4] +- SubqueryAlias tbl +- Project [gen_attr_0#1] +- GlobalLimit 2 +- LocalLimit 2 +- SubqueryAlias gen_subquery_0 +- Project [id#1 AS gen_attr_0#1] +- SQLTable default, tbl, [id#1] # Generated SQL: SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM ((SELECT `id` AS `gen_attr_0` FROM `default`.`tbl`) AS gen_subquery_0 LIMIT 2)) AS tbl ``` ## How was this patch tested? Pass the Jenkins test with a newly added test case. Author: Dongjoon Hyun Closes #15546 from dongjoon-hyun/SPARK-17982. (cherry picked from commit d42bb7cc4e32c173769bd7da5b9b5eafb510860c) Signed-off-by: gatorsmile --- .../org/apache/spark/sql/catalyst/SQLBuilder.scala | 7 ++++++- .../test/resources/sqlgen/generate_with_other_1.sql | 2 +- .../test/resources/sqlgen/generate_with_other_2.sql | 2 +- sql/hive/src/test/resources/sqlgen/limit.sql | 4 ++++ .../spark/sql/catalyst/LogicalPlanToSQLSuite.scala | 10 ++++++++++ 5 files changed, 22 insertions(+), 3 deletions(-) create mode 100644 sql/hive/src/test/resources/sqlgen/limit.sql diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala index 6f821f80cc4c5..380454267eaf4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala @@ -138,9 +138,14 @@ class SQLBuilder private ( case g: Generate => generateToSQL(g) - case Limit(limitExpr, child) => + // This prevents a pattern of `((...) AS gen_subquery_0 LIMIT 1)` which does not work. + // For example, `SELECT * FROM (SELECT id FROM tbl TABLESAMPLE (2 ROWS))` makes this plan. + case Limit(limitExpr, child: SubqueryAlias) => s"${toSQL(child)} LIMIT ${limitExpr.sql}" + case Limit(limitExpr, child) => + s"(${toSQL(child)} LIMIT ${limitExpr.sql})" + case Filter(condition, child) => val whereOrHaving = child match { case _: Aggregate => "HAVING" diff --git a/sql/hive/src/test/resources/sqlgen/generate_with_other_1.sql b/sql/hive/src/test/resources/sqlgen/generate_with_other_1.sql index ab444d0c70936..0739f8fff5467 100644 --- a/sql/hive/src/test/resources/sqlgen/generate_with_other_1.sql +++ b/sql/hive/src/test/resources/sqlgen/generate_with_other_1.sql @@ -5,4 +5,4 @@ WHERE id > 2 ORDER BY val, id LIMIT 5 -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `val`, `gen_attr_1` AS `id` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT gen_subquery_0.`gen_attr_2`, gen_subquery_0.`gen_attr_3`, gen_subquery_0.`gen_attr_4`, gen_subquery_0.`gen_attr_1` FROM (SELECT `arr` AS `gen_attr_2`, `arr2` AS `gen_attr_3`, `json` AS `gen_attr_4`, `id` AS `gen_attr_1` FROM `default`.`parquet_t3`) AS gen_subquery_0 WHERE (`gen_attr_1` > CAST(2 AS BIGINT))) AS gen_subquery_1 LATERAL VIEW explode(`gen_attr_2`) gen_subquery_2 AS `gen_attr_0` ORDER BY `gen_attr_0` ASC NULLS FIRST, `gen_attr_1` ASC NULLS FIRST LIMIT 5) AS parquet_t3 +SELECT `gen_attr_0` AS `val`, `gen_attr_1` AS `id` FROM ((SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT gen_subquery_0.`gen_attr_2`, gen_subquery_0.`gen_attr_3`, gen_subquery_0.`gen_attr_4`, gen_subquery_0.`gen_attr_1` FROM (SELECT `arr` AS `gen_attr_2`, `arr2` AS `gen_attr_3`, `json` AS `gen_attr_4`, `id` AS `gen_attr_1` FROM `default`.`parquet_t3`) AS gen_subquery_0 WHERE (`gen_attr_1` > CAST(2 AS BIGINT))) AS gen_subquery_1 LATERAL VIEW explode(`gen_attr_2`) gen_subquery_2 AS `gen_attr_0` ORDER BY `gen_attr_0` ASC NULLS FIRST, `gen_attr_1` ASC NULLS FIRST LIMIT 5)) AS parquet_t3 diff --git a/sql/hive/src/test/resources/sqlgen/generate_with_other_2.sql b/sql/hive/src/test/resources/sqlgen/generate_with_other_2.sql index 42a2369f34d1c..c4b344ee238a5 100644 --- a/sql/hive/src/test/resources/sqlgen/generate_with_other_2.sql +++ b/sql/hive/src/test/resources/sqlgen/generate_with_other_2.sql @@ -7,4 +7,4 @@ WHERE val > 2 ORDER BY val, id LIMIT 5 -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `val`, `gen_attr_1` AS `id` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `arr` AS `gen_attr_4`, `arr2` AS `gen_attr_3`, `json` AS `gen_attr_5`, `id` AS `gen_attr_1` FROM `default`.`parquet_t3`) AS gen_subquery_0 LATERAL VIEW explode(`gen_attr_3`) gen_subquery_2 AS `gen_attr_2` LATERAL VIEW explode(`gen_attr_2`) gen_subquery_3 AS `gen_attr_0` WHERE (`gen_attr_0` > CAST(2 AS BIGINT)) ORDER BY `gen_attr_0` ASC NULLS FIRST, `gen_attr_1` ASC NULLS FIRST LIMIT 5) AS gen_subquery_1 +SELECT `gen_attr_0` AS `val`, `gen_attr_1` AS `id` FROM ((SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `arr` AS `gen_attr_4`, `arr2` AS `gen_attr_3`, `json` AS `gen_attr_5`, `id` AS `gen_attr_1` FROM `default`.`parquet_t3`) AS gen_subquery_0 LATERAL VIEW explode(`gen_attr_3`) gen_subquery_2 AS `gen_attr_2` LATERAL VIEW explode(`gen_attr_2`) gen_subquery_3 AS `gen_attr_0` WHERE (`gen_attr_0` > CAST(2 AS BIGINT)) ORDER BY `gen_attr_0` ASC NULLS FIRST, `gen_attr_1` ASC NULLS FIRST LIMIT 5)) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/limit.sql b/sql/hive/src/test/resources/sqlgen/limit.sql new file mode 100644 index 0000000000000..7a6b060fbf505 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/limit.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT * FROM (SELECT id FROM tbl LIMIT 2) +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0`, `name` AS `gen_attr_1` FROM `default`.`tbl`) AS gen_subquery_0 LIMIT 2)) AS tbl diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala index 8696337b9dc8a..557ea44d1c80b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala @@ -1173,4 +1173,14 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { ) } } + + test("SPARK-17982 - limit") { + withTable("tbl") { + sql("CREATE TABLE tbl(id INT, name STRING)") + checkSQL( + "SELECT * FROM (SELECT id FROM tbl LIMIT 2)", + "limit" + ) + } + } } From 87820da782fd2d08078227a2ce5c363c3e1cb0f0 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Fri, 11 Nov 2016 13:52:10 -0800 Subject: [PATCH 090/534] [SPARK-18387][SQL] Add serialization to checkEvaluation. ## What changes were proposed in this pull request? This removes the serialization test from RegexpExpressionsSuite and replaces it by serializing all expressions in checkEvaluation. This also fixes math constant expressions by making LeafMathExpression Serializable and fixes NumberFormat values that are null or invalid after serialization. ## How was this patch tested? This patch is to tests. Author: Ryan Blue Closes #15847 from rdblue/SPARK-18387-fix-serializable-expressions. (cherry picked from commit 6e95325fc3726d260054bd6e7c0717b3c139917e) Signed-off-by: Reynold Xin --- .../expressions/mathExpressions.scala | 2 +- .../expressions/stringExpressions.scala | 44 +++++++++++-------- .../expressions/ExpressionEvalHelper.scala | 15 ++++--- .../expressions/RegexpExpressionsSuite.scala | 16 +------ 4 files changed, 36 insertions(+), 41 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index a60494a5bb69d..65273a77b1054 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -36,7 +36,7 @@ import org.apache.spark.unsafe.types.UTF8String * @param name The short name of the function */ abstract class LeafMathExpression(c: Double, name: String) - extends LeafExpression with CodegenFallback { + extends LeafExpression with CodegenFallback with Serializable { override def dataType: DataType = DoubleType override def foldable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 5f533fecf8d07..e74ef9a08750e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1431,18 +1431,20 @@ case class FormatNumber(x: Expression, d: Expression) // Associated with the pattern, for the last d value, and we will update the // pattern (DecimalFormat) once the new coming d value differ with the last one. + // This is an Option to distinguish between 0 (numberFormat is valid) and uninitialized after + // serialization (numberFormat has not been updated for dValue = 0). @transient - private var lastDValue: Int = -100 + private var lastDValue: Option[Int] = None // A cached DecimalFormat, for performance concern, we will change it // only if the d value changed. @transient - private val pattern: StringBuffer = new StringBuffer() + private lazy val pattern: StringBuffer = new StringBuffer() // SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.') // as a decimal separator. @transient - private val numberFormat = new DecimalFormat("", new DecimalFormatSymbols(Locale.US)) + private lazy val numberFormat = new DecimalFormat("", new DecimalFormatSymbols(Locale.US)) override protected def nullSafeEval(xObject: Any, dObject: Any): Any = { val dValue = dObject.asInstanceOf[Int] @@ -1450,24 +1452,28 @@ case class FormatNumber(x: Expression, d: Expression) return null } - if (dValue != lastDValue) { - // construct a new DecimalFormat only if a new dValue - pattern.delete(0, pattern.length) - pattern.append("#,###,###,###,###,###,##0") - - // decimal place - if (dValue > 0) { - pattern.append(".") - - var i = 0 - while (i < dValue) { - i += 1 - pattern.append("0") + lastDValue match { + case Some(last) if last == dValue => + // use the current pattern + case _ => + // construct a new DecimalFormat only if a new dValue + pattern.delete(0, pattern.length) + pattern.append("#,###,###,###,###,###,##0") + + // decimal place + if (dValue > 0) { + pattern.append(".") + + var i = 0 + while (i < dValue) { + i += 1 + pattern.append("0") + } } - } - lastDValue = dValue - numberFormat.applyLocalizedPattern(pattern.toString) + lastDValue = Some(dValue) + + numberFormat.applyLocalizedPattern(pattern.toString) } x.dataType match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 9ceb709185417..f83650424a964 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -22,7 +22,8 @@ import org.scalactic.TripleEqualsSupport.Spread import org.scalatest.exceptions.TestFailedException import org.scalatest.prop.GeneratorDrivenPropertyChecks -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.JavaSerializer import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer @@ -43,13 +44,15 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { protected def checkEvaluation( expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { + val serializer = new JavaSerializer(new SparkConf()).newInstance + val expr: Expression = serializer.deserialize(serializer.serialize(expression)) val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) - checkEvaluationWithoutCodegen(expression, catalystValue, inputRow) - checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow) - if (GenerateUnsafeProjection.canSupport(expression.dataType)) { - checkEvalutionWithUnsafeProjection(expression, catalystValue, inputRow) + checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) + checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) + if (GenerateUnsafeProjection.canSupport(expr.dataType)) { + checkEvalutionWithUnsafeProjection(expr, catalystValue, inputRow) } - checkEvaluationWithOptimization(expression, catalystValue, inputRow) + checkEvaluationWithOptimization(expr, catalystValue, inputRow) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index d0d1aaa9d299d..5299549e7b4da 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types.StringType @@ -192,17 +191,4 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringSplit(s1, s2), null, row3) } - test("RegExpReplace serialization") { - val serializer = new JavaSerializer(new SparkConf()).newInstance - - val row = create_row("abc", "b", "") - - val s = 's.string.at(0) - val p = 'p.string.at(1) - val r = 'r.string.at(2) - - val expr: RegExpReplace = serializer.deserialize(serializer.serialize(RegExpReplace(s, p, r))) - checkEvaluation(expr, "ac", row) - } - } From c2ebda443b2678e554d859d866af53e2e94822f2 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Fri, 11 Nov 2016 15:49:55 -0800 Subject: [PATCH 091/534] [SPARK-18264][SPARKR] build vignettes with package, update vignettes for CRAN release build and add info on release ## What changes were proposed in this pull request? Changes to DESCRIPTION to build vignettes. Changes the metadata for vignettes to generate the recommended format (which is about <10% of size before). Unfortunately it does not look as nice (before - left, after - right) ![image](https://cloud.githubusercontent.com/assets/8969467/20040492/b75883e6-a40d-11e6-9534-25cdd5d59a8b.png) ![image](https://cloud.githubusercontent.com/assets/8969467/20040490/a40f4d42-a40d-11e6-8c91-af00ddcbdad9.png) Also add information on how to run build/release to CRAN later. ## How was this patch tested? manually, unit tests shivaram We need this for branch-2.1 Author: Felix Cheung Closes #15790 from felixcheung/rpkgvignettes. (cherry picked from commit ba23f768f7419039df85530b84258ec31f0c22b4) Signed-off-by: Shivaram Venkataraman --- R/CRAN_RELEASE.md | 91 ++++++++++++++++++++++++++++ R/README.md | 8 +-- R/check-cran.sh | 33 ++++++++-- R/create-docs.sh | 19 +----- R/pkg/DESCRIPTION | 9 ++- R/pkg/vignettes/sparkr-vignettes.Rmd | 9 +-- 6 files changed, 134 insertions(+), 35 deletions(-) create mode 100644 R/CRAN_RELEASE.md diff --git a/R/CRAN_RELEASE.md b/R/CRAN_RELEASE.md new file mode 100644 index 0000000000000..bea8f9fbe4eec --- /dev/null +++ b/R/CRAN_RELEASE.md @@ -0,0 +1,91 @@ +# SparkR CRAN Release + +To release SparkR as a package to CRAN, we would use the `devtools` package. Please work with the +`dev@spark.apache.org` community and R package maintainer on this. + +### Release + +First, check that the `Version:` field in the `pkg/DESCRIPTION` file is updated. Also, check for stale files not under source control. + +Note that while `check-cran.sh` is running `R CMD check`, it is doing so with `--no-manual --no-vignettes`, which skips a few vignettes or PDF checks - therefore it will be preferred to run `R CMD check` on the source package built manually before uploading a release. + +To upload a release, we would need to update the `cran-comments.md`. This should generally contain the results from running the `check-cran.sh` script along with comments on status of all `WARNING` (should not be any) or `NOTE`. As a part of `check-cran.sh` and the release process, the vignettes is build - make sure `SPARK_HOME` is set and Spark jars are accessible. + +Once everything is in place, run in R under the `SPARK_HOME/R` directory: + +```R +paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); devtools::release(); .libPaths(paths) +``` + +For more information please refer to http://r-pkgs.had.co.nz/release.html#release-check + +### Testing: build package manually + +To build package manually such as to inspect the resulting `.tar.gz` file content, we would also use the `devtools` package. + +Source package is what get released to CRAN. CRAN would then build platform-specific binary packages from the source package. + +#### Build source package + +To build source package locally without releasing to CRAN, run in R under the `SPARK_HOME/R` directory: + +```R +paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); devtools::build("pkg"); .libPaths(paths) +``` + +(http://r-pkgs.had.co.nz/vignettes.html#vignette-workflow-2) + +Similarly, the source package is also created by `check-cran.sh` with `R CMD build pkg`. + +For example, this should be the content of the source package: + +```sh +DESCRIPTION R inst tests +NAMESPACE build man vignettes + +inst/doc/ +sparkr-vignettes.html +sparkr-vignettes.Rmd +sparkr-vignettes.Rman + +build/ +vignette.rds + +man/ + *.Rd files... + +vignettes/ +sparkr-vignettes.Rmd +``` + +#### Test source package + +To install, run this: + +```sh +R CMD INSTALL SparkR_2.1.0.tar.gz +``` + +With "2.1.0" replaced with the version of SparkR. + +This command installs SparkR to the default libPaths. Once that is done, you should be able to start R and run: + +```R +library(SparkR) +vignette("sparkr-vignettes", package="SparkR") +``` + +#### Build binary package + +To build binary package locally, run in R under the `SPARK_HOME/R` directory: + +```R +paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); devtools::build("pkg", binary = TRUE); .libPaths(paths) +``` + +For example, this should be the content of the binary package: + +```sh +DESCRIPTION Meta R html tests +INDEX NAMESPACE help profile worker +``` diff --git a/R/README.md b/R/README.md index 932d5272d0b4f..47f9a86dfde11 100644 --- a/R/README.md +++ b/R/README.md @@ -6,7 +6,7 @@ SparkR is an R package that provides a light-weight frontend to use Spark from R Libraries of sparkR need to be created in `$SPARK_HOME/R/lib`. This can be done by running the script `$SPARK_HOME/R/install-dev.sh`. By default the above script uses the system wide installation of R. However, this can be changed to any user installed location of R by setting the environment variable `R_HOME` the full path of the base directory where R is installed, before running install-dev.sh script. -Example: +Example: ```bash # where /home/username/R is where R is installed and /home/username/R/bin contains the files R and RScript export R_HOME=/home/username/R @@ -46,7 +46,7 @@ Sys.setenv(SPARK_HOME="/Users/username/spark") # This line loads SparkR from the installed directory .libPaths(c(file.path(Sys.getenv("SPARK_HOME"), "R", "lib"), .libPaths())) library(SparkR) -sc <- sparkR.init(master="local") +sparkR.session() ``` #### Making changes to SparkR @@ -54,11 +54,11 @@ sc <- sparkR.init(master="local") The [instructions](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) for making contributions to Spark also apply to SparkR. If you only make R file changes (i.e. no Scala changes) then you can just re-install the R package using `R/install-dev.sh` and test your changes. Once you have made your changes, please include unit tests for them and run existing unit tests using the `R/run-tests.sh` script as described below. - + #### Generating documentation The SparkR documentation (Rd files and HTML files) are not a part of the source repository. To generate them you can run the script `R/create-docs.sh`. This script uses `devtools` and `knitr` to generate the docs and these packages need to be installed on the machine before using the script. Also, you may need to install these [prerequisites](https://github.com/apache/spark/tree/master/docs#prerequisites). See also, `R/DOCUMENTATION.md` - + ### Examples, Unit tests SparkR comes with several sample programs in the `examples/src/main/r` directory. diff --git a/R/check-cran.sh b/R/check-cran.sh index bb331466ae931..c5f042848c90c 100755 --- a/R/check-cran.sh +++ b/R/check-cran.sh @@ -36,11 +36,27 @@ if [ ! -z "$R_HOME" ] fi echo "USING R_HOME = $R_HOME" -# Build the latest docs +# Build the latest docs, but not vignettes, which is built with the package next $FWDIR/create-docs.sh -# Build a zip file containing the source package -"$R_SCRIPT_PATH/"R CMD build $FWDIR/pkg +# Build source package with vignettes +SPARK_HOME="$(cd "${FWDIR}"/..; pwd)" +. "${SPARK_HOME}"/bin/load-spark-env.sh +if [ -f "${SPARK_HOME}/RELEASE" ]; then + SPARK_JARS_DIR="${SPARK_HOME}/jars" +else + SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars" +fi + +if [ -d "$SPARK_JARS_DIR" ]; then + # Build a zip file containing the source package with vignettes + SPARK_HOME="${SPARK_HOME}" "$R_SCRIPT_PATH/"R CMD build $FWDIR/pkg + + find pkg/vignettes/. -not -name '.' -not -name '*.Rmd' -not -name '*.md' -not -name '*.pdf' -not -name '*.html' -delete +else + echo "Error Spark JARs not found in $SPARK_HOME" + exit 1 +fi # Run check as-cran. VERSION=`grep Version $FWDIR/pkg/DESCRIPTION | awk '{print $NF}'` @@ -54,11 +70,16 @@ fi if [ -n "$NO_MANUAL" ] then - CRAN_CHECK_OPTIONS=$CRAN_CHECK_OPTIONS" --no-manual" + CRAN_CHECK_OPTIONS=$CRAN_CHECK_OPTIONS" --no-manual --no-vignettes" fi echo "Running CRAN check with $CRAN_CHECK_OPTIONS options" -"$R_SCRIPT_PATH/"R CMD check $CRAN_CHECK_OPTIONS SparkR_"$VERSION".tar.gz - +if [ -n "$NO_TESTS" ] && [ -n "$NO_MANUAL" ] +then + "$R_SCRIPT_PATH/"R CMD check $CRAN_CHECK_OPTIONS SparkR_"$VERSION".tar.gz +else + # This will run tests and/or build vignettes, and require SPARK_HOME + SPARK_HOME="${SPARK_HOME}" "$R_SCRIPT_PATH/"R CMD check $CRAN_CHECK_OPTIONS SparkR_"$VERSION".tar.gz +fi popd > /dev/null diff --git a/R/create-docs.sh b/R/create-docs.sh index 69ffc5f678c36..84e6aa928cb0f 100755 --- a/R/create-docs.sh +++ b/R/create-docs.sh @@ -20,7 +20,7 @@ # Script to create API docs and vignettes for SparkR # This requires `devtools`, `knitr` and `rmarkdown` to be installed on the machine. -# After running this script the html docs can be found in +# After running this script the html docs can be found in # $SPARK_HOME/R/pkg/html # The vignettes can be found in # $SPARK_HOME/R/pkg/vignettes/sparkr_vignettes.html @@ -52,21 +52,4 @@ Rscript -e 'libDir <- "../../lib"; library(SparkR, lib.loc=libDir); library(knit popd -# Find Spark jars. -if [ -f "${SPARK_HOME}/RELEASE" ]; then - SPARK_JARS_DIR="${SPARK_HOME}/jars" -else - SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars" -fi - -# Only create vignettes if Spark JARs exist -if [ -d "$SPARK_JARS_DIR" ]; then - # render creates SparkR vignettes - Rscript -e 'library(rmarkdown); paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); render("pkg/vignettes/sparkr-vignettes.Rmd"); .libPaths(paths)' - - find pkg/vignettes/. -not -name '.' -not -name '*.Rmd' -not -name '*.md' -not -name '*.pdf' -not -name '*.html' -delete -else - echo "Skipping R vignettes as Spark JARs not found in $SPARK_HOME" -fi - popd diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 5a83883089e0e..fe41a9e7dabbd 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,8 +1,8 @@ Package: SparkR Type: Package Title: R Frontend for Apache Spark -Version: 2.0.0 -Date: 2016-08-27 +Version: 2.1.0 +Date: 2016-11-06 Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), email = "shivaram@cs.berkeley.edu"), person("Xiangrui", "Meng", role = "aut", @@ -18,7 +18,9 @@ Depends: Suggests: testthat, e1071, - survival + survival, + knitr, + rmarkdown Description: The SparkR package provides an R frontend for Apache Spark. License: Apache License (== 2.0) Collate: @@ -48,3 +50,4 @@ Collate: 'utils.R' 'window.R' RoxygenNote: 5.0.1 +VignetteBuilder: knitr diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 80e876027bddb..73a5e26a3ba9c 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -1,12 +1,13 @@ --- title: "SparkR - Practical Guide" output: - html_document: - theme: united + rmarkdown::html_vignette: toc: true toc_depth: 4 - toc_float: true - highlight: textmate +vignette: > + %\VignetteIndexEntry{SparkR - Practical Guide} + %\VignetteEngine{knitr::rmarkdown} + \usepackage[utf8]{inputenc} --- ## Overview From 56859c029476bc41b2d2e05043c119146b287bce Mon Sep 17 00:00:00 2001 From: sethah Date: Sat, 12 Nov 2016 01:38:26 +0000 Subject: [PATCH 092/534] [SPARK-18060][ML] Avoid unnecessary computation for MLOR ## What changes were proposed in this pull request? Before this patch, the gradient updates for multinomial logistic regression were computed by an outer loop over the number of classes and an inner loop over the number of features. Inside the inner loop, we standardized the feature value (`value / featuresStd(index)`), which means we performed the computation `numFeatures * numClasses` times. We only need to perform that computation `numFeatures` times, however. If we re-order the inner and outer loop, we can avoid this, but then we lose sequential memory access. In this patch, we instead lay out the coefficients in column major order while we train, so that we can avoid the extra computation and retain sequential memory access. We convert back to row-major order when we create the model. ## How was this patch tested? This is an implementation detail only, so the original behavior should be maintained. All tests pass. I ran some performance tests to verify speedups. The results are below, and show significant speedups. ## Performance Tests **Setup** 3 node bare-metal cluster 120 cores total 384 gb RAM total **Results** NOTE: The `currentMasterTime` and `thisPatchTime` are times in seconds for a single iteration of L-BFGS or OWL-QN. | | numPoints | numFeatures | numClasses | regParam | elasticNetParam | currentMasterTime (sec) | thisPatchTime (sec) | pctSpeedup | |----|-------------|---------------|--------------|------------|-------------------|---------------------------|-----------------------|--------------| | 0 | 1e+07 | 100 | 500 | 0.5 | 0 | 90 | 18 | 80 | | 1 | 1e+08 | 100 | 50 | 0.5 | 0 | 90 | 19 | 78 | | 2 | 1e+08 | 100 | 50 | 0.05 | 1 | 72 | 19 | 73 | | 3 | 1e+06 | 100 | 5000 | 0.5 | 0 | 93 | 53 | 43 | | 4 | 1e+07 | 100 | 5000 | 0.5 | 0 | 900 | 390 | 56 | | 5 | 1e+08 | 100 | 500 | 0.5 | 0 | 840 | 174 | 79 | | 6 | 1e+08 | 100 | 200 | 0.5 | 0 | 360 | 72 | 80 | | 7 | 1e+08 | 1000 | 5 | 0.5 | 0 | 9 | 3 | 66 | Author: sethah Closes #15593 from sethah/MLOR_PERF_COL_MAJOR_COEF. (cherry picked from commit 46b2550bcd3690a260b995fd4d024a73b92a0299) Signed-off-by: DB Tsai --- .../classification/LogisticRegression.scala | 125 +++++++++++------- 1 file changed, 74 insertions(+), 51 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index c4651054fd765..18b9b3043db8a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -438,18 +438,14 @@ class LogisticRegression @Since("1.2.0") ( val standardizationParam = $(standardization) def regParamL1Fun = (index: Int) => { // Remove the L1 penalization on the intercept - val isIntercept = $(fitIntercept) && ((index + 1) % numFeaturesPlusIntercept == 0) + val isIntercept = $(fitIntercept) && index >= numFeatures * numCoefficientSets if (isIntercept) { 0.0 } else { if (standardizationParam) { regParamL1 } else { - val featureIndex = if ($(fitIntercept)) { - index % numFeaturesPlusIntercept - } else { - index % numFeatures - } + val featureIndex = index / numCoefficientSets // If `standardization` is false, we still standardize the data // to improve the rate of convergence; as a result, we have to // perform this reverse standardization by penalizing each component @@ -466,6 +462,15 @@ class LogisticRegression @Since("1.2.0") ( new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol)) } + /* + The coefficients are laid out in column major order during training. e.g. for + `numClasses = 3` and `numFeatures = 2` and `fitIntercept = true` the layout is: + + Array(beta_11, beta_21, beta_31, beta_12, beta_22, beta_32, intercept_1, intercept_2, + intercept_3) + + where beta_jk corresponds to the coefficient for class `j` and feature `k`. + */ val initialCoefficientsWithIntercept = Vectors.zeros(numCoefficientSets * numFeaturesPlusIntercept) @@ -489,13 +494,14 @@ class LogisticRegression @Since("1.2.0") ( val initialCoefWithInterceptArray = initialCoefficientsWithIntercept.toArray val providedCoef = optInitialModel.get.coefficientMatrix providedCoef.foreachActive { (row, col, value) => - val flatIndex = row * numFeaturesPlusIntercept + col + // convert matrix to column major for training + val flatIndex = col * numCoefficientSets + row // We need to scale the coefficients since they will be trained in the scaled space initialCoefWithInterceptArray(flatIndex) = value * featuresStd(col) } if ($(fitIntercept)) { optInitialModel.get.interceptVector.foreachActive { (index, value) => - val coefIndex = (index + 1) * numFeaturesPlusIntercept - 1 + val coefIndex = numCoefficientSets * numFeatures + index initialCoefWithInterceptArray(coefIndex) = value } } @@ -526,7 +532,7 @@ class LogisticRegression @Since("1.2.0") ( val rawIntercepts = histogram.map(c => math.log(c + 1)) // add 1 for smoothing val rawMean = rawIntercepts.sum / rawIntercepts.length rawIntercepts.indices.foreach { i => - initialCoefficientsWithIntercept.toArray(i * numFeaturesPlusIntercept + numFeatures) = + initialCoefficientsWithIntercept.toArray(numClasses * numFeatures + i) = rawIntercepts(i) - rawMean } } else if ($(fitIntercept)) { @@ -572,16 +578,20 @@ class LogisticRegression @Since("1.2.0") ( /* The coefficients are trained in the scaled space; we're converting them back to the original space. + + Additionally, since the coefficients were laid out in column major order during training + to avoid extra computation, we convert them back to row major before passing them to the + model. + Note that the intercept in scaled space and original space is the same; as a result, no scaling is needed. */ val rawCoefficients = state.x.toArray.clone() val coefficientArray = Array.tabulate(numCoefficientSets * numFeatures) { i => - // flatIndex will loop though rawCoefficients, and skip the intercept terms. - val flatIndex = if ($(fitIntercept)) i + i / numFeatures else i + val colMajorIndex = (i % numFeatures) * numCoefficientSets + i / numFeatures val featureIndex = i % numFeatures if (featuresStd(featureIndex) != 0.0) { - rawCoefficients(flatIndex) / featuresStd(featureIndex) + rawCoefficients(colMajorIndex) / featuresStd(featureIndex) } else { 0.0 } @@ -618,7 +628,7 @@ class LogisticRegression @Since("1.2.0") ( val interceptsArray: Array[Double] = if ($(fitIntercept)) { Array.tabulate(numCoefficientSets) { i => - val coefIndex = (i + 1) * numFeaturesPlusIntercept - 1 + val coefIndex = numFeatures * numCoefficientSets + i rawCoefficients(coefIndex) } } else { @@ -697,6 +707,7 @@ class LogisticRegressionModel private[spark] ( /** * A vector of model coefficients for "binomial" logistic regression. If this model was trained * using the "multinomial" family then an exception is thrown. + * * @return Vector */ @Since("2.0.0") @@ -720,6 +731,7 @@ class LogisticRegressionModel private[spark] ( /** * The model intercept for "binomial" logistic regression. If this model was fit with the * "multinomial" family then an exception is thrown. + * * @return Double */ @Since("1.3.0") @@ -1389,6 +1401,12 @@ class BinaryLogisticRegressionSummary private[classification] ( * $$ *

* + * @note In order to avoid unnecessary computation during calculation of the gradient updates + * we lay out the coefficients in column major order during training. This allows us to + * perform feature standardization once, while still retaining sequential memory access + * for speed. We convert back to row major order when we create the model, + * since this form is optimal for the matrix operations used for prediction. + * * @param bcCoefficients The broadcast coefficients corresponding to the features. * @param bcFeaturesStd The broadcast standard deviation values of the features. * @param numClasses the number of possible outcomes for k classes classification problem in @@ -1486,23 +1504,25 @@ private class LogisticAggregator( var marginOfLabel = 0.0 var maxMargin = Double.NegativeInfinity - val margins = Array.tabulate(numClasses) { i => - var margin = 0.0 - features.foreachActive { (index, value) => - if (localFeaturesStd(index) != 0.0 && value != 0.0) { - margin += localCoefficients(i * numFeaturesPlusIntercept + index) * - value / localFeaturesStd(index) - } + val margins = new Array[Double](numClasses) + features.foreachActive { (index, value) => + val stdValue = value / localFeaturesStd(index) + var j = 0 + while (j < numClasses) { + margins(j) += localCoefficients(index * numClasses + j) * stdValue + j += 1 } - + } + var i = 0 + while (i < numClasses) { if (fitIntercept) { - margin += localCoefficients(i * numFeaturesPlusIntercept + numFeatures) + margins(i) += localCoefficients(numClasses * numFeatures + i) } - if (i == label.toInt) marginOfLabel = margin - if (margin > maxMargin) { - maxMargin = margin + if (i == label.toInt) marginOfLabel = margins(i) + if (margins(i) > maxMargin) { + maxMargin = margins(i) } - margin + i += 1 } /** @@ -1510,33 +1530,39 @@ private class LogisticAggregator( * We address this by subtracting maxMargin from all the margins, so it's guaranteed * that all of the new margins will be smaller than zero to prevent arithmetic overflow. */ + val multipliers = new Array[Double](numClasses) val sum = { var temp = 0.0 - if (maxMargin > 0) { - for (i <- 0 until numClasses) { - margins(i) -= maxMargin - temp += math.exp(margins(i)) - } - } else { - for (i <- 0 until numClasses) { - temp += math.exp(margins(i)) - } + var i = 0 + while (i < numClasses) { + if (maxMargin > 0) margins(i) -= maxMargin + val exp = math.exp(margins(i)) + temp += exp + multipliers(i) = exp + i += 1 } temp } - for (i <- 0 until numClasses) { - val multiplier = math.exp(margins(i)) / sum - { - if (label == i) 1.0 else 0.0 - } - features.foreachActive { (index, value) => - if (localFeaturesStd(index) != 0.0 && value != 0.0) { - localGradientArray(i * numFeaturesPlusIntercept + index) += - weight * multiplier * value / localFeaturesStd(index) + margins.indices.foreach { i => + multipliers(i) = multipliers(i) / sum - (if (label == i) 1.0 else 0.0) + } + features.foreachActive { (index, value) => + if (localFeaturesStd(index) != 0.0 && value != 0.0) { + val stdValue = value / localFeaturesStd(index) + var j = 0 + while (j < numClasses) { + localGradientArray(index * numClasses + j) += + weight * multipliers(j) * stdValue + j += 1 } } - if (fitIntercept) { - localGradientArray(i * numFeaturesPlusIntercept + numFeatures) += weight * multiplier + } + if (fitIntercept) { + var i = 0 + while (i < numClasses) { + localGradientArray(numFeatures * numClasses + i) += weight * multipliers(i) + i += 1 } } @@ -1637,6 +1663,7 @@ private class LogisticCostFun( val bcCoeffs = instances.context.broadcast(coeffs) val featuresStd = bcFeaturesStd.value val numFeatures = featuresStd.length + val numCoefficientSets = if (multinomial) numClasses else 1 val logisticAggregator = { val seqOp = (c: LogisticAggregator, instance: Instance) => c.add(instance) @@ -1656,7 +1683,7 @@ private class LogisticCostFun( var sum = 0.0 coeffs.foreachActive { case (index, value) => // We do not apply regularization to the intercepts - val isIntercept = fitIntercept && ((index + 1) % (numFeatures + 1) == 0) + val isIntercept = fitIntercept && index >= numCoefficientSets * numFeatures if (!isIntercept) { // The following code will compute the loss of the regularization; also // the gradient of the regularization, and add back to totalGradientArray. @@ -1665,11 +1692,7 @@ private class LogisticCostFun( totalGradientArray(index) += regParamL2 * value value * value } else { - val featureIndex = if (fitIntercept) { - index % (numFeatures + 1) - } else { - index % numFeatures - } + val featureIndex = index / numCoefficientSets if (featuresStd(featureIndex) != 0.0) { // If `standardization` is false, we still standardize the data // to improve the rate of convergence; as a result, we have to From 893355143a177f1fea1d2fb6f6e617574e5c5e52 Mon Sep 17 00:00:00 2001 From: Guoqiang Li Date: Sat, 12 Nov 2016 09:49:14 +0000 Subject: [PATCH 093/534] [SPARK-18375][SPARK-18383][BUILD][CORE] Upgrade netty to 4.0.42.Final ## What changes were proposed in this pull request? One of the important changes for 4.0.42.Final is "Support any FileRegion implementation when using epoll transport netty/netty#5825". In 4.0.42.Final, `MessageWithHeader` can work properly when `spark.[shuffle|rpc].io.mode` is set to epoll ## How was this patch tested? Existing tests Author: Guoqiang Li Closes #15830 from witgo/SPARK-18375_netty-4.0.42. (cherry picked from commit bc41d997ea287080f549219722b6d9049adef4e2) Signed-off-by: Sean Owen --- core/src/main/scala/org/apache/spark/util/Utils.scala | 4 ++++ dev/deps/spark-deps-hadoop-2.2 | 2 +- dev/deps/spark-deps-hadoop-2.3 | 2 +- dev/deps/spark-deps-hadoop-2.4 | 2 +- dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- pom.xml | 2 +- 7 files changed, 10 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 1de66af632a8a..892e112e18f85 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -39,6 +39,7 @@ import scala.reflect.ClassTag import scala.util.Try import scala.util.control.{ControlThrowable, NonFatal} +import _root_.io.netty.channel.unix.Errors.NativeIoException import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} import com.google.common.io.{ByteStreams, Files => GFiles} import com.google.common.net.InetAddresses @@ -2222,6 +2223,9 @@ private[spark] object Utils extends Logging { isBindCollision(e.getCause) case e: MultiException => e.getThrowables.asScala.exists(isBindCollision) + case e: NativeIoException => + (e.getMessage != null && e.getMessage.startsWith("bind() failed: ")) || + isBindCollision(e.getCause) case e: Exception => isBindCollision(e.getCause) case _ => false } diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 6e749ac16cac0..bbdea069f9496 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -123,7 +123,7 @@ metrics-json-3.1.2.jar metrics-jvm-3.1.2.jar minlog-1.3.0.jar netty-3.8.0.Final.jar -netty-all-4.0.41.Final.jar +netty-all-4.0.42.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index 515995a0a46bd..a2dec41d64519 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -130,7 +130,7 @@ metrics-jvm-3.1.2.jar minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.8.0.Final.jar -netty-all-4.0.41.Final.jar +netty-all-4.0.42.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index d2139fd952406..c1f02b93d751c 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -130,7 +130,7 @@ metrics-jvm-3.1.2.jar minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.8.0.Final.jar -netty-all-4.0.41.Final.jar +netty-all-4.0.42.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index b5cecf72ec35f..4f04636be712b 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -138,7 +138,7 @@ metrics-jvm-3.1.2.jar minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.8.0.Final.jar -netty-all-4.0.41.Final.jar +netty-all-4.0.42.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index a5e03a78e7ea8..da3af9ffa155b 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -139,7 +139,7 @@ metrics-jvm-3.1.2.jar minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.8.0.Final.jar -netty-all-4.0.41.Final.jar +netty-all-4.0.42.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar diff --git a/pom.xml b/pom.xml index 8aa0a6c3caab9..650b4cd965b66 100644 --- a/pom.xml +++ b/pom.xml @@ -552,7 +552,7 @@ io.netty netty-all - 4.0.41.Final + 4.0.42.Final io.netty From b2ba83d10ac06614c0126f4b0d913f6979051682 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sat, 12 Nov 2016 06:13:22 -0800 Subject: [PATCH 094/534] [SPARK-14077][ML][FOLLOW-UP] Minor refactor and cleanup for NaiveBayes ## What changes were proposed in this pull request? * Refactor out ```trainWithLabelCheck``` and make ```mllib.NaiveBayes``` call into it. * Avoid capturing the outer object for ```modelType```. * Move ```requireNonnegativeValues``` and ```requireZeroOneBernoulliValues``` to companion object. ## How was this patch tested? Existing tests. Author: Yanbo Liang Closes #15826 from yanboliang/spark-14077-2. (cherry picked from commit 22cb3a060a440205281b71686637679645454ca6) Signed-off-by: Yanbo Liang --- .../spark/ml/classification/NaiveBayes.scala | 72 +++++++++---------- .../mllib/classification/NaiveBayes.scala | 6 +- 2 files changed, 39 insertions(+), 39 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index b03a07a6bc1e7..f1a7676c74b0e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -76,7 +76,7 @@ class NaiveBayes @Since("1.5.0") ( extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] with NaiveBayesParams with DefaultParamsWritable { - import NaiveBayes.{Bernoulli, Multinomial} + import NaiveBayes._ @Since("1.5.0") def this() = this(Identifiable.randomUID("nb")) @@ -110,21 +110,20 @@ class NaiveBayes @Since("1.5.0") ( @Since("2.1.0") def setWeightCol(value: String): this.type = set(weightCol, value) + override protected def train(dataset: Dataset[_]): NaiveBayesModel = { + trainWithLabelCheck(dataset, positiveLabel = true) + } + /** * ml assumes input labels in range [0, numClasses). But this implementation * is also called by mllib NaiveBayes which allows other kinds of input labels - * such as {-1, +1}. Here we use this parameter to switch between different processing logic. - * It should be removed when we remove mllib NaiveBayes. + * such as {-1, +1}. `positiveLabel` is used to determine whether the label + * should be checked and it should be removed when we remove mllib NaiveBayes. */ - private[spark] var isML: Boolean = true - - private[spark] def setIsML(isML: Boolean): this.type = { - this.isML = isML - this - } - - override protected def train(dataset: Dataset[_]): NaiveBayesModel = { - if (isML) { + private[spark] def trainWithLabelCheck( + dataset: Dataset[_], + positiveLabel: Boolean): NaiveBayesModel = { + if (positiveLabel) { val numClasses = getNumClasses(dataset) if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + @@ -133,28 +132,9 @@ class NaiveBayes @Since("1.5.0") ( } } - val requireNonnegativeValues: Vector => Unit = (v: Vector) => { - val values = v match { - case sv: SparseVector => sv.values - case dv: DenseVector => dv.values - } - - require(values.forall(_ >= 0.0), - s"Naive Bayes requires nonnegative feature values but found $v.") - } - - val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => { - val values = v match { - case sv: SparseVector => sv.values - case dv: DenseVector => dv.values - } - - require(values.forall(v => v == 0.0 || v == 1.0), - s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.") - } - + val modelTypeValue = $(modelType) val requireValues: Vector => Unit = { - $(modelType) match { + modelTypeValue match { case Multinomial => requireNonnegativeValues case Bernoulli => @@ -226,13 +206,33 @@ class NaiveBayes @Since("1.5.0") ( @Since("1.6.0") object NaiveBayes extends DefaultParamsReadable[NaiveBayes] { /** String name for multinomial model type. */ - private[spark] val Multinomial: String = "multinomial" + private[classification] val Multinomial: String = "multinomial" /** String name for Bernoulli model type. */ - private[spark] val Bernoulli: String = "bernoulli" + private[classification] val Bernoulli: String = "bernoulli" /* Set of modelTypes that NaiveBayes supports */ - private[spark] val supportedModelTypes = Set(Multinomial, Bernoulli) + private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli) + + private[NaiveBayes] def requireNonnegativeValues(v: Vector): Unit = { + val values = v match { + case sv: SparseVector => sv.values + case dv: DenseVector => dv.values + } + + require(values.forall(_ >= 0.0), + s"Naive Bayes requires nonnegative feature values but found $v.") + } + + private[NaiveBayes] def requireZeroOneBernoulliValues(v: Vector): Unit = { + val values = v match { + case sv: SparseVector => sv.values + case dv: DenseVector => dv.values + } + + require(values.forall(v => v == 0.0 || v == 1.0), + s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.") + } @Since("1.6.0") override def load(path: String): NaiveBayes = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 33561be4b5bc1..767d056861a8b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -364,12 +364,12 @@ class NaiveBayes private ( val nb = new NewNaiveBayes() .setModelType(modelType) .setSmoothing(lambda) - .setIsML(false) val dataset = data.map { case LabeledPoint(label, features) => (label, features.asML) } .toDF("label", "features") - val newModel = nb.fit(dataset) + // mllib NaiveBayes allows input labels like {-1, +1}, so set `positiveLabel` as false. + val newModel = nb.trainWithLabelCheck(dataset, positiveLabel = false) val pi = newModel.pi.toArray val theta = Array.fill[Double](newModel.numClasses, newModel.numFeatures)(0.0) @@ -378,7 +378,7 @@ class NaiveBayes private ( theta(i)(j) = v } - require(newModel.oldLabels != null, + assert(newModel.oldLabels != null, "The underlying ML NaiveBayes training does not produce labels.") new NaiveBayesModel(newModel.oldLabels, pi, theta, modelType) } From 6fae4241f281638d52071102c7f0ee6c2c73a8c7 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sat, 12 Nov 2016 14:50:37 -0800 Subject: [PATCH 095/534] [SPARK-18418] Fix flags for make_binary_release for hadoop profile ## What changes were proposed in this pull request? Fix the flags used to specify the hadoop version ## How was this patch tested? Manually tested as part of https://github.com/apache/spark/pull/15659 by having the build succeed. cc joshrosen Author: Holden Karau Closes #15860 from holdenk/minor-fix-release-build-script. (cherry picked from commit 1386fd28daf798bf152606f4da30a36223d75d18) Signed-off-by: Josh Rosen --- dev/create-release/release-build.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 96f9b5714ebb8..81f0d63054e29 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -187,10 +187,10 @@ if [[ "$1" == "package" ]]; then # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds # share the same Zinc server. FLAGS="-Psparkr -Phive -Phive-thriftserver -Pyarn -Pmesos" - make_binary_release "hadoop2.3" "-Phadoop2.3 $FLAGS" "3033" & - make_binary_release "hadoop2.4" "-Phadoop2.4 $FLAGS" "3034" & - make_binary_release "hadoop2.6" "-Phadoop2.6 $FLAGS" "3035" & - make_binary_release "hadoop2.7" "-Phadoop2.7 $FLAGS" "3036" & + make_binary_release "hadoop2.3" "-Phadoop-2.3 $FLAGS" "3033" & + make_binary_release "hadoop2.4" "-Phadoop-2.4 $FLAGS" "3034" & + make_binary_release "hadoop2.6" "-Phadoop-2.6 $FLAGS" "3035" & + make_binary_release "hadoop2.7" "-Phadoop-2.7 $FLAGS" "3036" & make_binary_release "hadoop2.4-without-hive" "-Psparkr -Phadoop-2.4 -Pyarn -Pmesos" "3037" & make_binary_release "without-hadoop" "-Psparkr -Phadoop-provided -Pyarn -Pmesos" "3038" & wait From 0c69224ed752c25be1545cfe8ba0db8487a70bf2 Mon Sep 17 00:00:00 2001 From: Denny Lee Date: Sun, 13 Nov 2016 18:10:06 -0800 Subject: [PATCH 096/534] [SPARK-18426][STRUCTURED STREAMING] Python Documentation Fix for Structured Streaming Programming Guide ## What changes were proposed in this pull request? Update the python section of the Structured Streaming Guide from .builder() to .builder ## How was this patch tested? Validated documentation and successfully running the test example. Please review https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark before opening a pull request. 'Builder' object is not callable object hence changed .builder() to .builder Author: Denny Lee Closes #15872 from dennyglee/master. (cherry picked from commit b91a51bb231af321860415075a7f404bc46e0a74) Signed-off-by: Reynold Xin --- docs/structured-streaming-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index d838ed35a14fd..d2545584ae3b0 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -58,7 +58,7 @@ from pyspark.sql.functions import explode from pyspark.sql.functions import split spark = SparkSession \ - .builder() \ + .builder \ .appName("StructuredNetworkWordCount") \ .getOrCreate() {% endhighlight %} From 8fc6455c0b77f81be79908bb65e6264bf61c90e7 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sun, 13 Nov 2016 20:25:12 -0800 Subject: [PATCH 097/534] [SPARK-18412][SPARKR][ML] Fix exception for some SparkR ML algorithms training on libsvm data ## What changes were proposed in this pull request? * Fix the following exceptions which throws when ```spark.randomForest```(classification), ```spark.gbt```(classification), ```spark.naiveBayes``` and ```spark.glm```(binomial family) were fitted on libsvm data. ``` java.lang.IllegalArgumentException: requirement failed: If label column already exists, forceIndexLabel can not be set with true. ``` See [SPARK-18412](https://issues.apache.org/jira/browse/SPARK-18412) for more detail about how to reproduce this bug. * Refactor out ```getFeaturesAndLabels``` to RWrapperUtils, since lots of ML algorithm wrappers use this function. * Drop some unwanted columns when making prediction. ## How was this patch tested? Add unit test. Author: Yanbo Liang Closes #15851 from yanboliang/spark-18412. (cherry picked from commit 07be232ea12dfc8dc3701ca948814be7dbebf4ee) Signed-off-by: Yanbo Liang --- R/pkg/inst/tests/testthat/test_mllib.R | 18 ++++++++-- .../spark/ml/r/GBTClassificationWrapper.scala | 18 ++++------ .../GeneralizedLinearRegressionWrapper.scala | 5 ++- .../apache/spark/ml/r/NaiveBayesWrapper.scala | 14 +++----- .../org/apache/spark/ml/r/RWrapperUtils.scala | 36 ++++++++++++++++--- .../r/RandomForestClassificationWrapper.scala | 18 ++++------ 6 files changed, 68 insertions(+), 41 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 33e85b78de4fe..4831ce27bec8a 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -881,7 +881,8 @@ test_that("spark.kstest", { expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:") }) -test_that("spark.randomForest Regression", { +test_that("spark.randomForest", { + # regression data <- suppressWarnings(createDataFrame(longley)) model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, numTrees = 1) @@ -923,9 +924,8 @@ test_that("spark.randomForest Regression", { expect_equal(stats$treeWeights, stats2$treeWeights) unlink(modelPath) -}) -test_that("spark.randomForest Classification", { + # classification data <- suppressWarnings(createDataFrame(iris)) model <- spark.randomForest(data, Species ~ Petal_Length + Petal_Width, "classification", maxDepth = 5, maxBins = 16) @@ -971,6 +971,12 @@ test_that("spark.randomForest Classification", { predictions <- collect(predict(model, data))$prediction expect_equal(length(grep("1.0", predictions)), 50) expect_equal(length(grep("2.0", predictions)), 50) + + # spark.randomForest classification can work on libsvm data + data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), + source = "libsvm") + model <- spark.randomForest(data, label ~ features, "classification") + expect_equal(summary(model)$numFeatures, 4) }) test_that("spark.gbt", { @@ -1039,6 +1045,12 @@ test_that("spark.gbt", { expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction)) expect_equal(s$numFeatures, 5) expect_equal(s$numTrees, 20) + + # spark.gbt classification can work on libsvm data + data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"), + source = "libsvm") + model <- spark.gbt(data, label ~ features, "classification") + expect_equal(summary(model)$numFeatures, 692) }) sparkR.session.stop() diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala index 8946025032200..aacb41ee2659b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala @@ -23,10 +23,10 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} -import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier} import org.apache.spark.ml.feature.{IndexToString, RFormula} import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.r.RWrapperUtils._ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -51,6 +51,7 @@ private[r] class GBTClassifierWrapper private ( pipeline.transform(dataset) .drop(PREDICTED_LABEL_INDEX_COL) .drop(gbtcModel.getFeaturesCol) + .drop(gbtcModel.getLabelCol) } override def write: MLWriter = new @@ -81,19 +82,11 @@ private[r] object GBTClassifierWrapper extends MLReadable[GBTClassifierWrapper] val rFormula = new RFormula() .setFormula(formula) .setForceIndexLabel(true) - RWrapperUtils.checkDataColumns(rFormula, data) + checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) - // get feature names from output schema - val schema = rFormulaModel.transform(data).schema - val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) - .attributes.get - val features = featureAttrs.map(_.name.get) - - // get label names from output schema - val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol)) - .asInstanceOf[NominalAttribute] - val labels = labelAttr.values.get + // get labels and feature names from output schema + val (features, labels) = getFeaturesAndLabels(rFormulaModel, data) // assemble and fit the pipeline val rfc = new GBTClassifier() @@ -109,6 +102,7 @@ private[r] object GBTClassifierWrapper extends MLReadable[GBTClassifierWrapper] .setMaxMemoryInMB(maxMemoryInMB) .setCacheNodeIds(cacheNodeIds) .setFeaturesCol(rFormula.getFeaturesCol) + .setLabelCol(rFormula.getLabelCol) .setPredictionCol(PREDICTED_LABEL_INDEX_COL) if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index 995b1ef03bcec..add4d49110d16 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -29,6 +29,7 @@ import org.apache.spark.ml.regression._ import org.apache.spark.ml.Transformer import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.r.RWrapperUtils._ import org.apache.spark.ml.util._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ @@ -64,6 +65,7 @@ private[r] class GeneralizedLinearRegressionWrapper private ( .drop(PREDICTED_LABEL_PROB_COL) .drop(PREDICTED_LABEL_INDEX_COL) .drop(glm.getFeaturesCol) + .drop(glm.getLabelCol) } else { pipeline.transform(dataset) .drop(glm.getFeaturesCol) @@ -92,7 +94,7 @@ private[r] object GeneralizedLinearRegressionWrapper regParam: Double): GeneralizedLinearRegressionWrapper = { val rFormula = new RFormula().setFormula(formula) if (family == "binomial") rFormula.setForceIndexLabel(true) - RWrapperUtils.checkDataColumns(rFormula, data) + checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) // get labels and feature names from output schema val schema = rFormulaModel.transform(data).schema @@ -109,6 +111,7 @@ private[r] object GeneralizedLinearRegressionWrapper .setWeightCol(weightCol) .setRegParam(regParam) .setFeaturesCol(rFormula.getFeaturesCol) + .setLabelCol(rFormula.getLabelCol) val pipeline = if (family == "binomial") { // Convert prediction from probability to label index. val probToPred = new ProbabilityToPrediction() diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala index 4fdab2dd94655..0afea4be3d1dd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala @@ -23,9 +23,9 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} -import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel} import org.apache.spark.ml.feature.{IndexToString, RFormula} +import org.apache.spark.ml.r.RWrapperUtils._ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -46,6 +46,7 @@ private[r] class NaiveBayesWrapper private ( pipeline.transform(dataset) .drop(PREDICTED_LABEL_INDEX_COL) .drop(naiveBayesModel.getFeaturesCol) + .drop(naiveBayesModel.getLabelCol) } override def write: MLWriter = new NaiveBayesWrapper.NaiveBayesWrapperWriter(this) @@ -60,21 +61,16 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] { val rFormula = new RFormula() .setFormula(formula) .setForceIndexLabel(true) - RWrapperUtils.checkDataColumns(rFormula, data) + checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) // get labels and feature names from output schema - val schema = rFormulaModel.transform(data).schema - val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol)) - .asInstanceOf[NominalAttribute] - val labels = labelAttr.values.get - val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) - .attributes.get - val features = featureAttrs.map(_.name.get) + val (features, labels) = getFeaturesAndLabels(rFormulaModel, data) // assemble and fit the pipeline val naiveBayes = new NaiveBayes() .setSmoothing(smoothing) .setModelType("bernoulli") .setFeaturesCol(rFormula.getFeaturesCol) + .setLabelCol(rFormula.getLabelCol) .setPredictionCol(PREDICTED_LABEL_INDEX_COL) val idxToStr = new IndexToString() .setInputCol(PREDICTED_LABEL_INDEX_COL) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala index 379007c4d948d..665e50af67d46 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala @@ -18,11 +18,12 @@ package org.apache.spark.ml.r import org.apache.spark.internal.Logging -import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} +import org.apache.spark.ml.feature.{RFormula, RFormulaModel} import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.Dataset -object RWrapperUtils extends Logging { +private[r] object RWrapperUtils extends Logging { /** * DataFrame column check. @@ -32,14 +33,41 @@ object RWrapperUtils extends Logging { * * @param rFormula RFormula instance * @param data Input dataset - * @return Unit */ def checkDataColumns(rFormula: RFormula, data: Dataset[_]): Unit = { if (data.schema.fieldNames.contains(rFormula.getFeaturesCol)) { val newFeaturesName = s"${Identifiable.randomUID(rFormula.getFeaturesCol)}" - logWarning(s"data containing ${rFormula.getFeaturesCol} column, " + + logInfo(s"data containing ${rFormula.getFeaturesCol} column, " + s"using new name $newFeaturesName instead") rFormula.setFeaturesCol(newFeaturesName) } + + if (rFormula.getForceIndexLabel && data.schema.fieldNames.contains(rFormula.getLabelCol)) { + val newLabelName = s"${Identifiable.randomUID(rFormula.getLabelCol)}" + logInfo(s"data containing ${rFormula.getLabelCol} column and we force to index label, " + + s"using new name $newLabelName instead") + rFormula.setLabelCol(newLabelName) + } + } + + /** + * Get the feature names and original labels from the schema + * of DataFrame transformed by RFormulaModel. + * + * @param rFormulaModel The RFormulaModel instance. + * @param data Input dataset. + * @return The feature names and original labels. + */ + def getFeaturesAndLabels( + rFormulaModel: RFormulaModel, + data: Dataset[_]): (Array[String], Array[String]) = { + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol)) + .asInstanceOf[NominalAttribute] + val labels = labelAttr.values.get + (features, labels) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala index 31f846dc6cfec..0b860e5af96e3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala @@ -23,10 +23,10 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} -import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} import org.apache.spark.ml.feature.{IndexToString, RFormula} import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.r.RWrapperUtils._ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -51,6 +51,7 @@ private[r] class RandomForestClassifierWrapper private ( pipeline.transform(dataset) .drop(PREDICTED_LABEL_INDEX_COL) .drop(rfcModel.getFeaturesCol) + .drop(rfcModel.getLabelCol) } override def write: MLWriter = new @@ -82,19 +83,11 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC val rFormula = new RFormula() .setFormula(formula) .setForceIndexLabel(true) - RWrapperUtils.checkDataColumns(rFormula, data) + checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) - // get feature names from output schema - val schema = rFormulaModel.transform(data).schema - val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) - .attributes.get - val features = featureAttrs.map(_.name.get) - - // get label names from output schema - val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol)) - .asInstanceOf[NominalAttribute] - val labels = labelAttr.values.get + // get labels and feature names from output schema + val (features, labels) = getFeaturesAndLabels(rFormulaModel, data) // assemble and fit the pipeline val rfc = new RandomForestClassifier() @@ -111,6 +104,7 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC .setCacheNodeIds(cacheNodeIds) .setProbabilityCol(probabilityCol) .setFeaturesCol(rFormula.getFeaturesCol) + .setLabelCol(rFormula.getLabelCol) .setPredictionCol(PREDICTED_LABEL_INDEX_COL) if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong) From 12bde11ca0613dbd7d917c81a8b480d5a9355da5 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 14 Nov 2016 16:52:07 +0900 Subject: [PATCH 098/534] [SPARK-18382][WEBUI] "run at null:-1" in UI when no file/line info in call site info ## What changes were proposed in this pull request? Avoid reporting null/-1 file / line number in call sites if encountering StackTraceElement without this info ## How was this patch tested? Existing tests Author: Sean Owen Closes #15862 from srowen/SPARK-18382. (cherry picked from commit f95b124c68ccc2e318f6ac30685aa47770eea8f3) Signed-off-by: Kousuke Saruta --- core/src/main/scala/org/apache/spark/util/Utils.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 892e112e18f85..a2386d6b9e12f 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1419,8 +1419,12 @@ private[spark] object Utils extends Logging { } callStack(0) = ste.toString // Put last Spark method on top of the stack trace. } else { - firstUserLine = ste.getLineNumber - firstUserFile = ste.getFileName + if (ste.getFileName != null) { + firstUserFile = ste.getFileName + if (ste.getLineNumber >= 0) { + firstUserLine = ste.getLineNumber + } + } callStack += ste.toString insideSpark = false } From d554c02f4f50d3d58661d5f87aacf34152545c24 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Mon, 14 Nov 2016 12:08:06 +0100 Subject: [PATCH 099/534] [SPARK-18166][MLLIB] Fix Poisson GLM bug due to wrong requirement of response values ## What changes were proposed in this pull request? The current implementation of Poisson GLM seems to allow only positive values. This is incorrect since the support of Poisson includes the origin. The bug is easily fixed by changing the test of the Poisson variable from 'require(y **>** 0.0' to 'require(y **>=** 0.0'. mengxr srowen Author: actuaryzhang Author: actuaryzhang Closes #15683 from actuaryzhang/master. (cherry picked from commit ae6cddb78742be94aa0851ce719f293e0a64ce4f) Signed-off-by: Sean Owen --- .../GeneralizedLinearRegression.scala | 4 +- .../GeneralizedLinearRegressionSuite.scala | 45 +++++++++++++++++++ 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 1938e8ecc513d..1d2961e0277f5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -501,8 +501,8 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine val defaultLink: Link = Log override def initialize(y: Double, weight: Double): Double = { - require(y > 0.0, "The response variable of Poisson family " + - s"should be positive, but got $y") + require(y >= 0.0, "The response variable of Poisson family " + + s"should be non-negative, but got $y") y } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 111bc974642d9..6a4ac1735b2cb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -44,6 +44,7 @@ class GeneralizedLinearRegressionSuite @transient var datasetGaussianInverse: DataFrame = _ @transient var datasetBinomial: DataFrame = _ @transient var datasetPoissonLog: DataFrame = _ + @transient var datasetPoissonLogWithZero: DataFrame = _ @transient var datasetPoissonIdentity: DataFrame = _ @transient var datasetPoissonSqrt: DataFrame = _ @transient var datasetGammaInverse: DataFrame = _ @@ -88,6 +89,12 @@ class GeneralizedLinearRegressionSuite xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, family = "poisson", link = "log").toDF() + datasetPoissonLogWithZero = generateGeneralizedLinearRegressionInput( + intercept = -1.5, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 100, seed, noiseLevel = 0.01, + family = "poisson", link = "log") + .map{x => LabeledPoint(if (x.label < 0.7) 0.0 else x.label, x.features)}.toDF() + datasetPoissonIdentity = generateGeneralizedLinearRegressionInput( intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, @@ -139,6 +146,10 @@ class GeneralizedLinearRegressionSuite label + "," + features.toArray.mkString(",") }.repartition(1).saveAsTextFile( "target/tmp/GeneralizedLinearRegressionSuite/datasetPoissonLog") + datasetPoissonLogWithZero.rdd.map { case Row(label: Double, features: Vector) => + label + "," + features.toArray.mkString(",") + }.repartition(1).saveAsTextFile( + "target/tmp/GeneralizedLinearRegressionSuite/datasetPoissonLogWithZero") datasetPoissonIdentity.rdd.map { case Row(label: Double, features: Vector) => label + "," + features.toArray.mkString(",") }.repartition(1).saveAsTextFile( @@ -456,6 +467,40 @@ class GeneralizedLinearRegressionSuite } } + test("generalized linear regression: poisson family against glm (with zero values)") { + /* + R code: + f1 <- data$V1 ~ data$V2 + data$V3 - 1 + f2 <- data$V1 ~ data$V2 + data$V3 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family="poisson", data=data) + print(as.vector(coef(model))) + } + [1] 0.4272661 -0.1565423 + [1] -3.6911354 0.6214301 0.1295814 + */ + val expected = Seq( + Vectors.dense(0.0, 0.4272661, -0.1565423), + Vectors.dense(-3.6911354, 0.6214301, 0.1295814)) + + import GeneralizedLinearRegression._ + + var idx = 0 + val link = "log" + val dataset = datasetPoissonLogWithZero + for (fitIntercept <- Seq(false, true)) { + val trainer = new GeneralizedLinearRegression().setFamily("poisson").setLink(link) + .setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction") + val model = trainer.fit(dataset) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " + + s"$link link and fitIntercept = $fitIntercept (with zero values).") + idx += 1 + } + } + test("generalized linear regression: gamma family against glm") { /* R code: From 518dc1e1e63a8955b16a3f2ca7592264fd637ae6 Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Mon, 14 Nov 2016 12:22:36 +0100 Subject: [PATCH 100/534] [SPARK-18396][HISTORYSERVER] Duration" column makes search result confused, maybe we should make it unsearchable MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? When we search data in History Server, it will check if any columns contains the search string. Duration is represented as long value in table, so if we search simple string like "003", "111", the duration containing "003", ‘111“ will be showed, which make not much sense to users. We cannot simply transfer the long value to meaning format like "1 h", "3.2 min" because they are also used for sorting. Better way to handle it is ban "Duration" columns from searching. ## How was this patch tested manually tests. Before("local-1478225166651" pass the filter because its duration in long value, which is "257244245" contains search string "244"): ![before](https://cloud.githubusercontent.com/assets/5276001/20203166/f851ffc6-a7ff-11e6-8fe6-91a90ca92b23.jpg) After: ![after](https://cloud.githubusercontent.com/assets/5276001/20178646/2129fbb0-a78d-11e6-9edb-39f885ce3ed0.jpg) Author: WangTaoTheTonic Closes #15838 from WangTaoTheTonic/duration. (cherry picked from commit 637a0bb88f74712001f32a53ff66fd0b8cb67e4a) Signed-off-by: Sean Owen --- .../main/resources/org/apache/spark/ui/static/historypage.js | 3 +++ 1 file changed, 3 insertions(+) diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index 6c0ec8d5fce54..8fd91865b0429 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -139,6 +139,9 @@ $(document).ready(function() { {name: 'eighth'}, {name: 'ninth'}, ], + "columnDefs": [ + {"searchable": false, "targets": [5]} + ], "autoWidth": false, "order": [[ 4, "desc" ]] }; From c07fe1c5924e167fb569427e5e6b78adcfde648e Mon Sep 17 00:00:00 2001 From: Noritaka Sekiyama Date: Mon, 14 Nov 2016 21:07:59 +0900 Subject: [PATCH 101/534] [SPARK-18432][DOC] Changed HDFS default block size from 64MB to 128MB Changed HDFS default block size from 64MB to 128MB. https://issues.apache.org/jira/browse/SPARK-18432 Author: Noritaka Sekiyama Closes #15879 from moomindani/SPARK-18432. (cherry picked from commit 9d07ceee7860921eafb55b47852f1b51089c98da) Signed-off-by: Kousuke Saruta --- docs/programming-guide.md | 6 +++--- docs/tuning.md | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/programming-guide.md b/docs/programming-guide.md index b9a2110b602a0..58bf17b4a84ef 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -343,7 +343,7 @@ Some notes on reading files with Spark: * All of Spark's file-based input methods, including `textFile`, support running on directories, compressed files, and wildcards as well. For example, you can use `textFile("/my/directory")`, `textFile("/my/directory/*.txt")`, and `textFile("/my/directory/*.gz")`. -* The `textFile` method also takes an optional second argument for controlling the number of partitions of the file. By default, Spark creates one partition for each block of the file (blocks being 64MB by default in HDFS), but you can also ask for a higher number of partitions by passing a larger value. Note that you cannot have fewer partitions than blocks. +* The `textFile` method also takes an optional second argument for controlling the number of partitions of the file. By default, Spark creates one partition for each block of the file (blocks being 128MB by default in HDFS), but you can also ask for a higher number of partitions by passing a larger value. Note that you cannot have fewer partitions than blocks. Apart from text files, Spark's Scala API also supports several other data formats: @@ -375,7 +375,7 @@ Some notes on reading files with Spark: * All of Spark's file-based input methods, including `textFile`, support running on directories, compressed files, and wildcards as well. For example, you can use `textFile("/my/directory")`, `textFile("/my/directory/*.txt")`, and `textFile("/my/directory/*.gz")`. -* The `textFile` method also takes an optional second argument for controlling the number of partitions of the file. By default, Spark creates one partition for each block of the file (blocks being 64MB by default in HDFS), but you can also ask for a higher number of partitions by passing a larger value. Note that you cannot have fewer partitions than blocks. +* The `textFile` method also takes an optional second argument for controlling the number of partitions of the file. By default, Spark creates one partition for each block of the file (blocks being 128MB by default in HDFS), but you can also ask for a higher number of partitions by passing a larger value. Note that you cannot have fewer partitions than blocks. Apart from text files, Spark's Java API also supports several other data formats: @@ -407,7 +407,7 @@ Some notes on reading files with Spark: * All of Spark's file-based input methods, including `textFile`, support running on directories, compressed files, and wildcards as well. For example, you can use `textFile("/my/directory")`, `textFile("/my/directory/*.txt")`, and `textFile("/my/directory/*.gz")`. -* The `textFile` method also takes an optional second argument for controlling the number of partitions of the file. By default, Spark creates one partition for each block of the file (blocks being 64MB by default in HDFS), but you can also ask for a higher number of partitions by passing a larger value. Note that you cannot have fewer partitions than blocks. +* The `textFile` method also takes an optional second argument for controlling the number of partitions of the file. By default, Spark creates one partition for each block of the file (blocks being 128MB by default in HDFS), but you can also ask for a higher number of partitions by passing a larger value. Note that you cannot have fewer partitions than blocks. Apart from text files, Spark's Python API also supports several other data formats: diff --git a/docs/tuning.md b/docs/tuning.md index 9c43b315bbb9e..0de303a3bd9bf 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -224,8 +224,8 @@ temporary objects created during task execution. Some steps which may be useful * As an example, if your task is reading data from HDFS, the amount of memory used by the task can be estimated using the size of the data block read from HDFS. Note that the size of a decompressed block is often 2 or 3 times the - size of the block. So if we wish to have 3 or 4 tasks' worth of working space, and the HDFS block size is 64 MB, - we can estimate size of Eden to be `4*3*64MB`. + size of the block. So if we wish to have 3 or 4 tasks' worth of working space, and the HDFS block size is 128 MB, + we can estimate size of Eden to be `4*3*128MB`. * Monitor how the frequency and time taken by garbage collection changes with the new settings. From 3c623d226a0c495c36c86d199879b9e922d1ece2 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 14 Nov 2016 10:03:01 -0800 Subject: [PATCH 102/534] [SPARK-18416][STRUCTURED STREAMING] Fixed temp file leak in state store ## What changes were proposed in this pull request? StateStore.get() causes temporary files to be created immediately, even if the store is not used to make updates for new version. The temp file is not closed as store.commit() is not called in those cases, thus keeping the output stream to temp file open forever. This PR fixes it by opening the temp file only when there are updates being made. ## How was this patch tested? New unit test Author: Tathagata Das Closes #15859 from tdas/SPARK-18416. (cherry picked from commit bdfe60ac921172be0fb77de2f075cc7904a3b238) Signed-off-by: Shixiong Zhu --- .../state/HDFSBackedStateStoreProvider.scala | 10 +-- .../streaming/state/StateStoreSuite.scala | 63 +++++++++++++++++++ 2 files changed, 68 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 808713161c316..f07feaad5dc71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -87,8 +87,7 @@ private[state] class HDFSBackedStateStoreProvider( private val newVersion = version + 1 private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") - private val tempDeltaFileStream = compressStream(fs.create(tempDeltaFile, true)) - + private lazy val tempDeltaFileStream = compressStream(fs.create(tempDeltaFile, true)) private val allUpdates = new java.util.HashMap[UnsafeRow, StoreUpdate]() @volatile private var state: STATE = UPDATING @@ -101,7 +100,7 @@ private[state] class HDFSBackedStateStoreProvider( } override def put(key: UnsafeRow, value: UnsafeRow): Unit = { - verify(state == UPDATING, "Cannot remove after already committed or aborted") + verify(state == UPDATING, "Cannot put after already committed or aborted") val isNewKey = !mapToUpdate.containsKey(key) mapToUpdate.put(key, value) @@ -125,6 +124,7 @@ private[state] class HDFSBackedStateStoreProvider( /** Remove keys that match the following condition */ override def remove(condition: UnsafeRow => Boolean): Unit = { verify(state == UPDATING, "Cannot remove after already committed or aborted") + val keyIter = mapToUpdate.keySet().iterator() while (keyIter.hasNext) { val key = keyIter.next @@ -154,7 +154,7 @@ private[state] class HDFSBackedStateStoreProvider( finalizeDeltaFile(tempDeltaFileStream) finalDeltaFile = commitUpdates(newVersion, mapToUpdate, tempDeltaFile) state = COMMITTED - logInfo(s"Committed version $newVersion for $this") + logInfo(s"Committed version $newVersion for $this to file $finalDeltaFile") newVersion } catch { case NonFatal(e) => @@ -174,7 +174,7 @@ private[state] class HDFSBackedStateStoreProvider( if (tempDeltaFile != null) { fs.delete(tempDeltaFile, true) } - logInfo("Aborted") + logInfo(s"Aborted version $newVersion for $this") } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 504a26516107f..533cd0cd2a2ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -468,6 +468,69 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(e.getCause.getMessage.contains("Failed to rename")) } + test("SPARK-18416: do not create temp delta file until the store is updated") { + val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val storeId = StateStoreId(dir, 0, 0) + val storeConf = StateStoreConf.empty + val hadoopConf = new Configuration() + val deltaFileDir = new File(s"$dir/0/0/") + + def numTempFiles: Int = { + if (deltaFileDir.exists) { + deltaFileDir.listFiles.map(_.getName).count(n => n.contains("temp") && !n.startsWith(".")) + } else 0 + } + + def numDeltaFiles: Int = { + if (deltaFileDir.exists) { + deltaFileDir.listFiles.map(_.getName).count(n => n.contains(".delta") && !n.startsWith(".")) + } else 0 + } + + def shouldNotCreateTempFile[T](body: => T): T = { + val before = numTempFiles + val result = body + assert(numTempFiles === before) + result + } + + // Getting the store should not create temp file + val store0 = shouldNotCreateTempFile { + StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf) + } + + // Put should create a temp file + put(store0, "a", 1) + assert(numTempFiles === 1) + assert(numDeltaFiles === 0) + + // Commit should remove temp file and create a delta file + store0.commit() + assert(numTempFiles === 0) + assert(numDeltaFiles === 1) + + // Remove should create a temp file + val store1 = shouldNotCreateTempFile { + StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) + } + remove(store1, _ == "a") + assert(numTempFiles === 1) + assert(numDeltaFiles === 1) + + // Commit should remove temp file and create a delta file + store1.commit() + assert(numTempFiles === 0) + assert(numDeltaFiles === 2) + + // Commit without any updates should create a delta file + val store2 = shouldNotCreateTempFile { + StateStore.get(storeId, keySchema, valueSchema, 2, storeConf, hadoopConf) + } + store2.commit() + assert(numTempFiles === 0) + assert(numDeltaFiles === 3) + } + def getDataFromFiles( provider: HDFSBackedStateStoreProvider, version: Int = -1): Set[(String, Int)] = { From db691f05cec9e03f507c5ed544bcc6edefb3842d Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Mon, 14 Nov 2016 11:10:37 -0800 Subject: [PATCH 103/534] [SPARK-17510][STREAMING][KAFKA] config max rate on a per-partition basis ## What changes were proposed in this pull request? Allow configuration of max rate on a per-topicpartition basis. ## How was this patch tested? Unit tests. The reporter (Jeff Nadler) said he could test on his workload, so let's wait on that report. Author: cody koeninger Closes #15132 from koeninger/SPARK-17510. (cherry picked from commit 89d1fa58dbe88560b1f2b0362fcc3035ccc888be) Signed-off-by: Shixiong Zhu --- .../kafka010/DirectKafkaInputDStream.scala | 11 ++-- .../spark/streaming/kafka010/KafkaUtils.scala | 53 ++++++++++++++++++- .../kafka010/PerPartitionConfig.scala | 47 ++++++++++++++++ .../kafka010/DirectKafkaStreamSuite.scala | 34 ++++++++---- .../kafka/DirectKafkaInputDStream.scala | 4 +- 5 files changed, 131 insertions(+), 18 deletions(-) create mode 100644 external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/PerPartitionConfig.scala diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index 7e57bb18cbd50..794f53c5abfd0 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -57,7 +57,8 @@ import org.apache.spark.streaming.scheduler.rate.RateEstimator private[spark] class DirectKafkaInputDStream[K, V]( _ssc: StreamingContext, locationStrategy: LocationStrategy, - consumerStrategy: ConsumerStrategy[K, V] + consumerStrategy: ConsumerStrategy[K, V], + ppc: PerPartitionConfig ) extends InputDStream[ConsumerRecord[K, V]](_ssc) with Logging with CanCommitOffsets { val executorKafkaParams = { @@ -128,12 +129,9 @@ private[spark] class DirectKafkaInputDStream[K, V]( } } - private val maxRateLimitPerPartition: Int = context.sparkContext.getConf.getInt( - "spark.streaming.kafka.maxRatePerPartition", 0) - protected[streaming] def maxMessagesPerPartition( offsets: Map[TopicPartition, Long]): Option[Map[TopicPartition, Long]] = { - val estimatedRateLimit = rateController.map(_.getLatestRate().toInt) + val estimatedRateLimit = rateController.map(_.getLatestRate()) // calculate a per-partition rate limit based on current lag val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match { @@ -144,11 +142,12 @@ private[spark] class DirectKafkaInputDStream[K, V]( val totalLag = lagPerPartition.values.sum lagPerPartition.map { case (tp, lag) => + val maxRateLimitPerPartition = ppc.maxRatePerPartition(tp) val backpressureRate = Math.round(lag / totalLag.toFloat * rate) tp -> (if (maxRateLimitPerPartition > 0) { Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate) } - case None => offsets.map { case (tp, offset) => tp -> maxRateLimitPerPartition } + case None => offsets.map { case (tp, offset) => tp -> ppc.maxRatePerPartition(tp) } } if (effectiveRateLimitPerPartition.values.sum > 0) { diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala index b2190bfa05a3a..c11917f59d5b8 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala @@ -123,7 +123,31 @@ object KafkaUtils extends Logging { locationStrategy: LocationStrategy, consumerStrategy: ConsumerStrategy[K, V] ): InputDStream[ConsumerRecord[K, V]] = { - new DirectKafkaInputDStream[K, V](ssc, locationStrategy, consumerStrategy) + val ppc = new DefaultPerPartitionConfig(ssc.sparkContext.getConf) + createDirectStream[K, V](ssc, locationStrategy, consumerStrategy, ppc) + } + + /** + * :: Experimental :: + * Scala constructor for a DStream where + * each given Kafka topic/partition corresponds to an RDD partition. + * @param locationStrategy In most cases, pass in LocationStrategies.preferConsistent, + * see [[LocationStrategies]] for more details. + * @param consumerStrategy In most cases, pass in ConsumerStrategies.subscribe, + * see [[ConsumerStrategies]] for more details. + * @param perPartitionConfig configuration of settings such as max rate on a per-partition basis. + * see [[PerPartitionConfig]] for more details. + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + */ + @Experimental + def createDirectStream[K, V]( + ssc: StreamingContext, + locationStrategy: LocationStrategy, + consumerStrategy: ConsumerStrategy[K, V], + perPartitionConfig: PerPartitionConfig + ): InputDStream[ConsumerRecord[K, V]] = { + new DirectKafkaInputDStream[K, V](ssc, locationStrategy, consumerStrategy, perPartitionConfig) } /** @@ -150,6 +174,33 @@ object KafkaUtils extends Logging { jssc.ssc, locationStrategy, consumerStrategy)) } + /** + * :: Experimental :: + * Java constructor for a DStream where + * each given Kafka topic/partition corresponds to an RDD partition. + * @param keyClass Class of the keys in the Kafka records + * @param valueClass Class of the values in the Kafka records + * @param locationStrategy In most cases, pass in LocationStrategies.preferConsistent, + * see [[LocationStrategies]] for more details. + * @param consumerStrategy In most cases, pass in ConsumerStrategies.subscribe, + * see [[ConsumerStrategies]] for more details + * @param perPartitionConfig configuration of settings such as max rate on a per-partition basis. + * see [[PerPartitionConfig]] for more details. + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + */ + @Experimental + def createDirectStream[K, V]( + jssc: JavaStreamingContext, + locationStrategy: LocationStrategy, + consumerStrategy: ConsumerStrategy[K, V], + perPartitionConfig: PerPartitionConfig + ): JavaInputDStream[ConsumerRecord[K, V]] = { + new JavaInputDStream( + createDirectStream[K, V]( + jssc.ssc, locationStrategy, consumerStrategy, perPartitionConfig)) + } + /** * Tweak kafka params to prevent issues on executors */ diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/PerPartitionConfig.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/PerPartitionConfig.scala new file mode 100644 index 0000000000000..4792f2a955110 --- /dev/null +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/PerPartitionConfig.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010 + +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.SparkConf +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Interface for user-supplied configurations that can't otherwise be set via Spark properties, + * because they need tweaking on a per-partition basis, + */ +@Experimental +abstract class PerPartitionConfig extends Serializable { + /** + * Maximum rate (number of records per second) at which data will be read + * from each Kafka partition. + */ + def maxRatePerPartition(topicPartition: TopicPartition): Long +} + +/** + * Default per-partition configuration + */ +private class DefaultPerPartitionConfig(conf: SparkConf) + extends PerPartitionConfig { + val maxRate = conf.getLong("spark.streaming.kafka.maxRatePerPartition", 0) + + def maxRatePerPartition(topicPartition: TopicPartition): Long = maxRate +} diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala index 02aec43c3b34f..f36e0a901f7b0 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala @@ -252,7 +252,8 @@ class DirectKafkaStreamSuite val s = new DirectKafkaInputDStream[String, String]( ssc, preferredHosts, - ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala)) + ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala), + new DefaultPerPartitionConfig(sparkConf)) s.consumer.poll(0) assert( s.consumer.position(topicPartition) >= offsetBeforeStart, @@ -306,7 +307,8 @@ class DirectKafkaStreamSuite ConsumerStrategies.Assign[String, String]( List(topicPartition), kafkaParams.asScala, - Map(topicPartition -> 11L))) + Map(topicPartition -> 11L)), + new DefaultPerPartitionConfig(sparkConf)) s.consumer.poll(0) assert( s.consumer.position(topicPartition) >= offsetBeforeStart, @@ -518,7 +520,7 @@ class DirectKafkaStreamSuite test("maxMessagesPerPartition with backpressure disabled") { val topic = "maxMessagesPerPartition" - val kafkaStream = getDirectKafkaStream(topic, None) + val kafkaStream = getDirectKafkaStream(topic, None, None) val input = Map(new TopicPartition(topic, 0) -> 50L, new TopicPartition(topic, 1) -> 50L) assert(kafkaStream.maxMessagesPerPartition(input).get == @@ -528,7 +530,7 @@ class DirectKafkaStreamSuite test("maxMessagesPerPartition with no lag") { val topic = "maxMessagesPerPartition" val rateController = Some(new ConstantRateController(0, new ConstantEstimator(100), 100)) - val kafkaStream = getDirectKafkaStream(topic, rateController) + val kafkaStream = getDirectKafkaStream(topic, rateController, None) val input = Map(new TopicPartition(topic, 0) -> 0L, new TopicPartition(topic, 1) -> 0L) assert(kafkaStream.maxMessagesPerPartition(input).isEmpty) @@ -537,11 +539,19 @@ class DirectKafkaStreamSuite test("maxMessagesPerPartition respects max rate") { val topic = "maxMessagesPerPartition" val rateController = Some(new ConstantRateController(0, new ConstantEstimator(100), 1000)) - val kafkaStream = getDirectKafkaStream(topic, rateController) + val ppc = Some(new PerPartitionConfig { + def maxRatePerPartition(tp: TopicPartition) = + if (tp.topic == topic && tp.partition == 0) { + 50 + } else { + 100 + } + }) + val kafkaStream = getDirectKafkaStream(topic, rateController, ppc) val input = Map(new TopicPartition(topic, 0) -> 1000L, new TopicPartition(topic, 1) -> 1000L) assert(kafkaStream.maxMessagesPerPartition(input).get == - Map(new TopicPartition(topic, 0) -> 10L, new TopicPartition(topic, 1) -> 10L)) + Map(new TopicPartition(topic, 0) -> 5L, new TopicPartition(topic, 1) -> 10L)) } test("using rate controller") { @@ -570,7 +580,9 @@ class DirectKafkaStreamSuite new DirectKafkaInputDStream[String, String]( ssc, preferredHosts, - ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala)) { + ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala), + new DefaultPerPartitionConfig(sparkConf) + ) { override protected[streaming] val rateController = Some(new DirectKafkaRateController(id, estimator)) }.map(r => (r.key, r.value)) @@ -616,7 +628,10 @@ class DirectKafkaStreamSuite }.toSeq.sortBy { _._1 } } - private def getDirectKafkaStream(topic: String, mockRateController: Option[RateController]) = { + private def getDirectKafkaStream( + topic: String, + mockRateController: Option[RateController], + ppc: Option[PerPartitionConfig]) = { val batchIntervalMilliseconds = 100 val sparkConf = new SparkConf() @@ -643,7 +658,8 @@ class DirectKafkaStreamSuite tps.foreach(tp => consumer.seek(tp, 0)) consumer } - } + }, + ppc.getOrElse(new DefaultPerPartitionConfig(sparkConf)) ) { override protected[streaming] val rateController = mockRateController } diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index c3c799375bbeb..d52c230eb7849 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -88,12 +88,12 @@ class DirectKafkaInputDStream[ protected val kc = new KafkaCluster(kafkaParams) - private val maxRateLimitPerPartition: Int = context.sparkContext.getConf.getInt( + private val maxRateLimitPerPartition: Long = context.sparkContext.getConf.getLong( "spark.streaming.kafka.maxRatePerPartition", 0) protected[streaming] def maxMessagesPerPartition( offsets: Map[TopicAndPartition, Long]): Option[Map[TopicAndPartition, Long]] = { - val estimatedRateLimit = rateController.map(_.getLatestRate().toInt) + val estimatedRateLimit = rateController.map(_.getLatestRate()) // calculate a per-partition rate limit based on current lag val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match { From cff7a70b59c3ac2cb1fab2216e9e6dcf2a6ac89a Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Mon, 14 Nov 2016 19:42:00 +0000 Subject: [PATCH 104/534] [SPARK-11496][GRAPHX][FOLLOWUP] Add param checking for runParallelPersonalizedPageRank ## What changes were proposed in this pull request? add the param checking to keep in line with other algos ## How was this patch tested? existing tests Author: Zheng RuiFeng Closes #15876 from zhengruifeng/param_check_runParallelPersonalizedPageRank. (cherry picked from commit 75934457d75996be71ffd0d4b448497d656c0d40) Signed-off-by: DB Tsai --- .../main/scala/org/apache/spark/graphx/lib/PageRank.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index f4b00757a8b54..c0c3c73463aab 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -185,6 +185,13 @@ object PageRank extends Logging { def runParallelPersonalizedPageRank[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], numIter: Int, resetProb: Double = 0.15, sources: Array[VertexId]): Graph[Vector, Double] = { + require(numIter > 0, s"Number of iterations must be greater than 0," + + s" but got ${numIter}") + require(resetProb >= 0 && resetProb <= 1, s"Random reset probability must belong" + + s" to [0, 1], but got ${resetProb}") + require(sources.nonEmpty, s"The list of sources must be non-empty," + + s" but got ${sources.mkString("[", ",", "]")}") + // TODO if one sources vertex id is outside of the int range // we won't be able to store its activations in a sparse vector val zero = Vectors.sparse(sources.size, List()).asBreeze From ae66799feec895751f49418885da58f35fc2aaa6 Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Mon, 14 Nov 2016 20:59:15 +0100 Subject: [PATCH 105/534] [SPARK-17348][SQL] Incorrect results from subquery transformation ## What changes were proposed in this pull request? Return an Analysis exception when there is a correlated non-equality predicate in a subquery and the correlated column from the outer reference is not from the immediate parent operator of the subquery. This PR prevents incorrect results from subquery transformation in such case. Test cases, both positive and negative tests, are added. ## How was this patch tested? sql/test, catalyst/test, hive/test, and scenarios that will produce incorrect results without this PR and product correct results when subquery transformation does happen. Author: Nattavut Sutyanyong Closes #15763 from nsyca/spark-17348. (cherry picked from commit bd85603ba5f9e61e1aa8326d3e4d5703b5977a4c) Signed-off-by: Herman van Hovell --- .../sql/catalyst/analysis/Analyzer.scala | 44 +++++++++ .../sql/catalyst/analysis/CheckAnalysis.scala | 7 -- .../org/apache/spark/sql/SubquerySuite.scala | 95 ++++++++++++++++++- 3 files changed, 137 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8dbec408002f1..dcee2e4b1fe73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -972,6 +972,37 @@ class Analyzer( } } + // SPARK-17348: A potential incorrect result case. + // When a correlated predicate is a non-equality predicate, + // certain operators are not permitted from the operator + // hosting the correlated predicate up to the operator on the outer table. + // Otherwise, the pull up of the correlated predicate + // will generate a plan with a different semantics + // which could return incorrect result. + // Currently we check for Aggregate and Window operators + // + // Below shows an example of a Logical Plan during Analyzer phase that + // show this problem. Pulling the correlated predicate [outer(c2#77) >= ..] + // through the Aggregate (or Window) operator could alter the result of + // the Aggregate. + // + // Project [c1#76] + // +- Project [c1#87, c2#88] + // : (Aggregate or Window operator) + // : +- Filter [outer(c2#77) >= c2#88)] + // : +- SubqueryAlias t2, `t2` + // : +- Project [_1#84 AS c1#87, _2#85 AS c2#88] + // : +- LocalRelation [_1#84, _2#85] + // +- SubqueryAlias t1, `t1` + // +- Project [_1#73 AS c1#76, _2#74 AS c2#77] + // +- LocalRelation [_1#73, _2#74] + def failOnNonEqualCorrelatedPredicate(found: Boolean, p: LogicalPlan): Unit = { + if (found) { + // Report a non-supported case as an exception + failAnalysis(s"Correlated column is not allowed in a non-equality predicate:\n$p") + } + } + /** Determine which correlated predicate references are missing from this plan. */ def missingReferences(p: LogicalPlan): AttributeSet = { val localPredicateReferences = p.collect(predicateMap) @@ -982,12 +1013,20 @@ class Analyzer( localPredicateReferences -- p.outputSet } + var foundNonEqualCorrelatedPred : Boolean = false + // Simplify the predicates before pulling them out. val transformed = BooleanSimplification(sub) transformUp { case f @ Filter(cond, child) => // Find all predicates with an outer reference. val (correlated, local) = splitConjunctivePredicates(cond).partition(containsOuter) + // Find any non-equality correlated predicates + foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists { + case _: EqualTo | _: EqualNullSafe => false + case _ => true + } + // Rewrite the filter without the correlated predicates if any. correlated match { case Nil => f @@ -1009,12 +1048,17 @@ class Analyzer( } case a @ Aggregate(grouping, expressions, child) => failOnOuterReference(a) + failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) + val referencesToAdd = missingReferences(a) if (referencesToAdd.nonEmpty) { Aggregate(grouping ++ referencesToAdd, expressions ++ referencesToAdd, child) } else { a } + case w : Window => + failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, w) + w case j @ Join(left, _, RightOuter, _) => failOnOuterReference(j) failOnOuterReferenceInSubTree(left, "a RIGHT OUTER JOIN") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 3455a567b7786..7b75c1f70974b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -119,13 +119,6 @@ trait CheckAnalysis extends PredicateHelper { } case s @ ScalarSubquery(query, conditions, _) if conditions.nonEmpty => - // Make sure we are using equi-joins. - conditions.foreach { - case _: EqualTo | _: EqualNullSafe => // ok - case e => failAnalysis( - s"The correlated scalar subquery can only contain equality predicates: $e") - } - // Make sure correlated scalar subqueries contain one row for every outer row by // enforcing that they are aggregates which contain exactly one aggregate expressions. // The analyzer has already checked that subquery contained only one output column, and diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 89348668340be..c84a6f161893c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -498,10 +498,10 @@ class SubquerySuite extends QueryTest with SharedSQLContext { test("non-equal correlated scalar subquery") { val msg1 = intercept[AnalysisException] { - sql("select a, (select b from l l2 where l2.a < l1.a) sum_b from l l1") + sql("select a, (select sum(b) from l l2 where l2.a < l1.a) sum_b from l l1") } assert(msg1.getMessage.contains( - "The correlated scalar subquery can only contain equality predicates")) + "Correlated column is not allowed in a non-equality predicate:")) } test("disjunctive correlated scalar subquery") { @@ -639,6 +639,97 @@ class SubquerySuite extends QueryTest with SharedSQLContext { | from t1 left join t2 on t1.c1=t2.c2) t3 | where c3 not in (select c2 from t2)""".stripMargin), Row(2) :: Nil) + } + } + + test("SPARK-17348: Correlated subqueries with non-equality predicate (good case)") { + withTempView("t1", "t2") { + Seq((1, 1)).toDF("c1", "c2").createOrReplaceTempView("t1") + Seq((1, 1), (2, 0)).toDF("c1", "c2").createOrReplaceTempView("t2") + + // Simple case + checkAnswer( + sql( + """ + | select c1 + | from t1 + | where c1 in (select t2.c1 + | from t2 + | where t1.c2 >= t2.c2)""".stripMargin), + Row(1) :: Nil) + + // More complex case with OR predicate + checkAnswer( + sql( + """ + | select t1.c1 + | from t1, t1 as t3 + | where t1.c1 = t3.c1 + | and (t1.c1 in (select t2.c1 + | from t2 + | where t1.c2 >= t2.c2 + | or t3.c2 < t2.c2) + | or t1.c2 >= 0)""".stripMargin), + Row(1) :: Nil) + } + } + + test("SPARK-17348: Correlated subqueries with non-equality predicate (error case)") { + withTempView("t1", "t2", "t3", "t4") { + Seq((1, 1)).toDF("c1", "c2").createOrReplaceTempView("t1") + Seq((1, 1), (2, 0)).toDF("c1", "c2").createOrReplaceTempView("t2") + Seq((2, 1)).toDF("c1", "c2").createOrReplaceTempView("t3") + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t4") + + // Simplest case + intercept[AnalysisException] { + sql( + """ + | select t1.c1 + | from t1 + | where t1.c1 in (select max(t2.c1) + | from t2 + | where t1.c2 >= t2.c2)""".stripMargin).collect() + } + + // Add a HAVING on top and augmented within an OR predicate + intercept[AnalysisException] { + sql( + """ + | select t1.c1 + | from t1 + | where t1.c1 in (select max(t2.c1) + | from t2 + | where t1.c2 >= t2.c2 + | having count(*) > 0 ) + | or t1.c2 >= 0""".stripMargin).collect() + } + + // Add a HAVING on top and augmented within an OR predicate + intercept[AnalysisException] { + sql( + """ + | select t1.c1 + | from t1, t1 as t3 + | where t1.c1 = t3.c1 + | and (t1.c1 in (select max(t2.c1) + | from t2 + | where t1.c2 = t2.c2 + | or t3.c2 = t2.c2) + | )""".stripMargin).collect() + } + + // In Window expression: changing the data set to + // demonstrate if this query ran, it would return incorrect result. + intercept[AnalysisException] { + sql( + """ + | select c1 + | from t3 + | where c1 in (select max(t4.c1) over () + | from t4 + | where t3.c2 >= t4.c2)""".stripMargin).collect() + } } } } From 27999b3661481c0232135dbe021787afe963d812 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 14 Nov 2016 16:46:26 -0800 Subject: [PATCH 106/534] [SPARK-18124] Observed delay based Event Time Watermarks This PR adds a new method `withWatermark` to the `Dataset` API, which can be used specify an _event time watermark_. An event time watermark allows the streaming engine to reason about the point in time after which we no longer expect to see late data. This PR also has augmented `StreamExecution` to use this watermark for several purposes: - To know when a given time window aggregation is finalized and thus results can be emitted when using output modes that do not allow updates (e.g. `Append` mode). - To minimize the amount of state that we need to keep for on-going aggregations, by evicting state for groups that are no longer expected to change. Although, we do still maintain all state if the query requires (i.e. if the event time is not present in the `groupBy` or when running in `Complete` mode). An example that emits windowed counts of records, waiting up to 5 minutes for late data to arrive. ```scala df.withWatermark("eventTime", "5 minutes") .groupBy(window($"eventTime", "1 minute") as 'window) .count() .writeStream .format("console") .mode("append") // In append mode, we only output finalized aggregations. .start() ``` ### Calculating the watermark. The current event time is computed by looking at the `MAX(eventTime)` seen this epoch across all of the partitions in the query minus some user defined _delayThreshold_. An additional constraint is that the watermark must increase monotonically. Note that since we must coordinate this value across partitions occasionally, the actual watermark used is only guaranteed to be at least `delay` behind the actual event time. In some cases we may still process records that arrive more than delay late. This mechanism was chosen for the initial implementation over processing time for two reasons: - it is robust to downtime that could affect processing delay - it does not require syncing of time or timezones between the producer and the processing engine. ### Other notable implementation details - A new trigger metric `eventTimeWatermark` outputs the current value of the watermark. - We mark the event time column in the `Attribute` metadata using the key `spark.watermarkDelay`. This allows downstream operations to know which column holds the event time. Operations like `window` propagate this metadata. - `explain()` marks the watermark with a suffix of `-T${delayMs}` to ease debugging of how this information is propagated. - Currently, we don't filter out late records, but instead rely on the state store to avoid emitting records that are both added and filtered in the same epoch. ### Remaining in this PR - [ ] The test for recovery is currently failing as we don't record the watermark used in the offset log. We will need to do so to ensure determinism, but this is deferred until #15626 is merged. ### Other follow-ups There are some natural additional features that we should consider for future work: - Ability to write records that arrive too late to some external store in case any out-of-band remediation is required. - `Update` mode so you can get partial results before a group is evicted. - Other mechanisms for calculating the watermark. In particular a watermark based on quantiles would be more robust to outliers. Author: Michael Armbrust Closes #15702 from marmbrus/watermarks. (cherry picked from commit c07187823a98f0d1a0f58c06e28a27e1abed157a) Signed-off-by: Tathagata Das --- .../spark/unsafe/types/CalendarInterval.java | 4 + .../apache/spark/sql/AnalysisException.scala | 3 +- .../sql/catalyst/analysis/Analyzer.scala | 8 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 10 + .../UnsupportedOperationChecker.scala | 18 +- .../sql/catalyst/analysis/unresolved.scala | 3 +- .../expressions/namedExpressions.scala | 17 +- .../plans/logical/EventTimeWatermark.scala | 51 +++++ .../scala/org/apache/spark/sql/Dataset.scala | 40 +++- .../spark/sql/execution/SparkStrategies.scala | 12 +- .../sql/execution/aggregate/AggUtils.scala | 9 +- .../sql/execution/command/commands.scala | 2 +- .../streaming/EventTimeWatermarkExec.scala | 93 +++++++++ .../sql/execution/streaming/ForeachSink.scala | 3 +- .../streaming/IncrementalExecution.scala | 12 +- .../streaming/StatefulAggregate.scala | 170 +++++++++------- .../execution/streaming/StreamExecution.scala | 25 ++- .../execution/streaming/StreamMetrics.scala | 1 + .../state/HDFSBackedStateStoreProvider.scala | 23 ++- .../streaming/state/StateStore.scala | 7 +- .../streaming/state/StateStoreSuite.scala | 6 +- .../spark/sql/streaming/WatermarkSuite.scala | 191 ++++++++++++++++++ 22 files changed, 597 insertions(+), 111 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/WatermarkSuite.scala diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java index 518ed6470a753..a7b0e6f80c2b6 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java @@ -252,6 +252,10 @@ public static long parseSecondNano(String secondNano) throws IllegalArgumentExce public final int months; public final long microseconds; + public final long milliseconds() { + return this.microseconds / MICROS_PER_MILLI; + } + public CalendarInterval(int months, long microseconds) { this.months = months; this.microseconds = microseconds; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala index 7defb9df862c0..ff8576157305b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -31,7 +31,8 @@ class AnalysisException protected[sql] ( val message: String, val line: Option[Int] = None, val startPosition: Option[Int] = None, - val plan: Option[LogicalPlan] = None, + // Some plans fail to serialize due to bugs in scala collections. + @transient val plan: Option[LogicalPlan] = None, val cause: Option[Throwable] = None) extends Exception(message, cause.orNull) with Serializable { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index dcee2e4b1fe73..b7e167557c559 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2213,7 +2213,13 @@ object TimeWindowing extends Rule[LogicalPlan] { windowExpressions.head.timeColumn.resolved && windowExpressions.head.checkInputDataTypes().isSuccess) { val window = windowExpressions.head - val windowAttr = AttributeReference("window", window.dataType)() + + val metadata = window.timeColumn match { + case a: Attribute => a.metadata + case _ => Metadata.empty + } + val windowAttr = + AttributeReference("window", window.dataType, metadata = metadata)() val maxNumOverlapping = math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt val windows = Seq.tabulate(maxNumOverlapping + 1) { i => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 7b75c1f70974b..98e50d0d3c674 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -148,6 +148,16 @@ trait CheckAnalysis extends PredicateHelper { } operator match { + case etw: EventTimeWatermark => + etw.eventTime.dataType match { + case s: StructType + if s.find(_.name == "end").map(_.dataType) == Some(TimestampType) => + case _: TimestampType => + case _ => + failAnalysis( + s"Event time must be defined on a window or a timestamp, but " + + s"${etw.eventTime.name} is of type ${etw.eventTime.dataType.simpleString}") + } case f: Filter if f.condition.dataType != BooleanType => failAnalysis( s"filter expression '${f.condition.sql}' " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index e81370c504abb..c054fcbef36f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.{AnalysisException, InternalOutputModes} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.streaming.OutputMode @@ -55,9 +56,20 @@ object UnsupportedOperationChecker { // Disallow some output mode outputMode match { case InternalOutputModes.Append if aggregates.nonEmpty => - throwError( - s"$outputMode output mode not supported when there are streaming aggregations on " + - s"streaming DataFrames/DataSets")(plan) + val aggregate = aggregates.head + + // Find any attributes that are associated with an eventTime watermark. + val watermarkAttributes = aggregate.groupingExpressions.collect { + case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => a + } + + // We can append rows to the sink once the group is under the watermark. Without this + // watermark a group is never "finished" so we would never output anything. + if (watermarkAttributes.isEmpty) { + throwError( + s"$outputMode output mode not supported when there are streaming aggregations on " + + s"streaming DataFrames/DataSets")(plan) + } case InternalOutputModes.Complete | InternalOutputModes.Update if aggregates.isEmpty => throwError( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 235ae04782455..36ed9ba50372b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, Codege import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan} import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.quoteIdentifier -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.{DataType, Metadata, StructType} /** * Thrown when an invalid attempt is made to access a property of a tree that has yet to be fully @@ -98,6 +98,7 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un override def withNullability(newNullability: Boolean): UnresolvedAttribute = this override def withQualifier(newQualifier: Option[String]): UnresolvedAttribute = this override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName) + override def withMetadata(newMetadata: Metadata): Attribute = this override def toString: String = s"'$name" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 306a99d5a37bf..1274757136051 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -22,6 +22,7 @@ import java.util.{Objects, UUID} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.sql.types._ @@ -104,6 +105,7 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn def withNullability(newNullability: Boolean): Attribute def withQualifier(newQualifier: Option[String]): Attribute def withName(newName: String): Attribute + def withMetadata(newMetadata: Metadata): Attribute override def toAttribute: Attribute = this def newInstance(): Attribute @@ -292,11 +294,22 @@ case class AttributeReference( } } + override def withMetadata(newMetadata: Metadata): Attribute = { + AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier, isGenerated) + } + override protected final def otherCopyArgs: Seq[AnyRef] = { exprId :: qualifier :: isGenerated :: Nil } - override def toString: String = s"$name#${exprId.id}$typeSuffix" + /** Used to signal the column used to calculate an eventTime watermark (e.g. a#1-T{delayMs}) */ + private def delaySuffix = if (metadata.contains(EventTimeWatermark.delayKey)) { + s"-T${metadata.getLong(EventTimeWatermark.delayKey)}ms" + } else { + "" + } + + override def toString: String = s"$name#${exprId.id}$typeSuffix$delaySuffix" // Since the expression id is not in the first constructor it is missing from the default // tree string. @@ -332,6 +345,8 @@ case class PrettyAttribute( override def withQualifier(newQualifier: Option[String]): Attribute = throw new UnsupportedOperationException override def withName(newName: String): Attribute = throw new UnsupportedOperationException + override def withMetadata(newMetadata: Metadata): Attribute = + throw new UnsupportedOperationException override def qualifier: Option[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala new file mode 100644 index 0000000000000..4224a7997c410 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.types.MetadataBuilder +import org.apache.spark.unsafe.types.CalendarInterval + +object EventTimeWatermark { + /** The [[org.apache.spark.sql.types.Metadata]] key used to hold the eventTime watermark delay. */ + val delayKey = "spark.watermarkDelayMs" +} + +/** + * Used to mark a user specified column as holding the event time for a row. + */ +case class EventTimeWatermark( + eventTime: Attribute, + delay: CalendarInterval, + child: LogicalPlan) extends LogicalPlan { + + // Update the metadata on the eventTime column to include the desired delay. + override val output: Seq[Attribute] = child.output.map { a => + if (a semanticEquals eventTime) { + val updatedMetadata = new MetadataBuilder() + .withMetadata(a.metadata) + .putLong(EventTimeWatermark.delayKey, delay.milliseconds) + .build() + a.withMetadata(updatedMetadata) + } else { + a + } + } + + override val children: Seq[LogicalPlan] = child :: Nil +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index eb2b20afc37cf..af30683cc01c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -50,6 +50,7 @@ import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel +import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.Utils private[sql] object Dataset { @@ -476,7 +477,7 @@ class Dataset[T] private[sql]( * `collect()`, will throw an [[AnalysisException]] when there is a streaming * source present. * - * @group basic + * @group streaming * @since 2.0.0 */ @Experimental @@ -496,8 +497,6 @@ class Dataset[T] private[sql]( /** * Returns a checkpointed version of this Dataset. * - * @param eager When true, materializes the underlying checkpointed RDD eagerly. - * * @group basic * @since 2.1.0 */ @@ -535,6 +534,41 @@ class Dataset[T] private[sql]( )(sparkSession)).as[T] } + /** + * :: Experimental :: + * Defines an event time watermark for this [[Dataset]]. A watermark tracks a point in time + * before which we assume no more late data is going to arrive. + * + * Spark will use this watermark for several purposes: + * - To know when a given time window aggregation can be finalized and thus can be emitted when + * using output modes that do not allow updates. + * - To minimize the amount of state that we need to keep for on-going aggregations. + * + * The current watermark is computed by looking at the `MAX(eventTime)` seen across + * all of the partitions in the query minus a user specified `delayThreshold`. Due to the cost + * of coordinating this value across partitions, the actual watermark used is only guaranteed + * to be at least `delayThreshold` behind the actual event time. In some cases we may still + * process records that arrive more than `delayThreshold` late. + * + * @param eventTime the name of the column that contains the event time of the row. + * @param delayThreshold the minimum delay to wait to data to arrive late, relative to the latest + * record that has been processed in the form of an interval + * (e.g. "1 minute" or "5 hours"). + * + * @group streaming + * @since 2.1.0 + */ + @Experimental + @InterfaceStability.Evolving + // We only accept an existing column name, not a derived column here as a watermark that is + // defined on a derived column cannot referenced elsewhere in the plan. + def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = withTypedPlan { + val parsedDelay = + Option(CalendarInterval.fromString("interval " + delayThreshold)) + .getOrElse(throw new AnalysisException(s"Unable to parse time delay '$delayThreshold'")) + EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan) + } + /** * Displays the Dataset in a tabular form. Strings more than 20 characters will be truncated, * and all cells will be aligned right. For example: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 190fdd84343ee..2308ae8a6c611 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -18,20 +18,23 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{execution, SaveMode, Strategy} +import org.apache.spark.sql.{SaveMode, Strategy} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} -import org.apache.spark.sql.execution.streaming.{MemoryPlan, StreamingExecutionRelation, StreamingRelation, StreamingRelationExec} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.StreamingQuery /** * Converts a logical plan into zero or more SparkPlans. This API is exposed for experimenting @@ -224,6 +227,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ object StatefulAggregationStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case EventTimeWatermark(columnName, delay, child) => + EventTimeWatermarkExec(columnName, delay, planLater(child)) :: Nil + case PhysicalAggregation( namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 4fbb9d554c9bf..f7ea8970edf90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -313,8 +313,13 @@ object AggUtils { } // Note: stateId and returnAllStates are filled in later with preparation rules // in IncrementalExecution. - val saved = StateStoreSaveExec( - groupingAttributes, stateId = None, returnAllStates = None, partialMerged2) + val saved = + StateStoreSaveExec( + groupingAttributes, + stateId = None, + outputMode = None, + eventTimeWatermark = None, + partialMerged2) val finalAndCompleteAggregate: SparkPlan = { val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index d82e54e57564c..52d8dc22a2d4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -104,7 +104,7 @@ case class ExplainCommand( if (logicalPlan.isStreaming) { // This is used only by explaining `Dataset/DataFrame` created by `spark.readStream`, so the // output mode does not matter since there is no `Sink`. - new IncrementalExecution(sparkSession, logicalPlan, OutputMode.Append(), "", 0) + new IncrementalExecution(sparkSession, logicalPlan, OutputMode.Append(), "", 0, 0) } else { sparkSession.sessionState.executePlan(logicalPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala new file mode 100644 index 0000000000000..4c8cb069d23a0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.math.max + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} +import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types.MetadataBuilder +import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.util.AccumulatorV2 + +/** Tracks the maximum positive long seen. */ +class MaxLong(protected var currentValue: Long = 0) + extends AccumulatorV2[Long, Long] { + + override def isZero: Boolean = value == 0 + override def value: Long = currentValue + override def copy(): AccumulatorV2[Long, Long] = new MaxLong(currentValue) + + override def reset(): Unit = { + currentValue = 0 + } + + override def add(v: Long): Unit = { + currentValue = max(v, value) + } + + override def merge(other: AccumulatorV2[Long, Long]): Unit = { + currentValue = max(value, other.value) + } +} + +/** + * Used to mark a column as the containing the event time for a given record. In addition to + * adding appropriate metadata to this column, this operator also tracks the maximum observed event + * time. Based on the maximum observed time and a user specified delay, we can calculate the + * `watermark` after which we assume we will no longer see late records for a particular time + * period. + */ +case class EventTimeWatermarkExec( + eventTime: Attribute, + delay: CalendarInterval, + child: SparkPlan) extends SparkPlan { + + // TODO: Use Spark SQL Metrics? + val maxEventTime = new MaxLong + sparkContext.register(maxEventTime) + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { iter => + val getEventTime = UnsafeProjection.create(eventTime :: Nil, child.output) + iter.map { row => + maxEventTime.add(getEventTime(row).getLong(0)) + row + } + } + } + + // Update the metadata on the eventTime column to include the desired delay. + override val output: Seq[Attribute] = child.output.map { a => + if (a semanticEquals eventTime) { + val updatedMetadata = new MetadataBuilder() + .withMetadata(a.metadata) + .putLong(EventTimeWatermark.delayKey, delay.milliseconds) + .build() + + a.withMetadata(updatedMetadata) + } else { + a + } + } + + override def children: Seq[SparkPlan] = child :: Nil +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala index 24f98b9211f12..f5c550dd6ac3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala @@ -60,7 +60,8 @@ class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Seria deserialized, data.queryExecution.asInstanceOf[IncrementalExecution].outputMode, data.queryExecution.asInstanceOf[IncrementalExecution].checkpointLocation, - data.queryExecution.asInstanceOf[IncrementalExecution].currentBatchId) + data.queryExecution.asInstanceOf[IncrementalExecution].currentBatchId, + data.queryExecution.asInstanceOf[IncrementalExecution].currentEventTimeWatermark) incrementalExecution.toRdd.mapPartitions { rows => rows.map(_.get(0, objectType)) }.asInstanceOf[RDD[T]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 05294df2673dc..e9d072f8a98b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -32,11 +32,13 @@ class IncrementalExecution( logicalPlan: LogicalPlan, val outputMode: OutputMode, val checkpointLocation: String, - val currentBatchId: Long) + val currentBatchId: Long, + val currentEventTimeWatermark: Long) extends QueryExecution(sparkSession, logicalPlan) { // TODO: make this always part of planning. - val stateStrategy = sparkSession.sessionState.planner.StatefulAggregationStrategy +: + val stateStrategy = + sparkSession.sessionState.planner.StatefulAggregationStrategy +: sparkSession.sessionState.planner.StreamingRelationStrategy +: sparkSession.sessionState.experimentalMethods.extraStrategies @@ -57,17 +59,17 @@ class IncrementalExecution( val state = new Rule[SparkPlan] { override def apply(plan: SparkPlan): SparkPlan = plan transform { - case StateStoreSaveExec(keys, None, None, + case StateStoreSaveExec(keys, None, None, None, UnaryExecNode(agg, StateStoreRestoreExec(keys2, None, child))) => val stateId = OperatorStateId(checkpointLocation, operatorId, currentBatchId) - val returnAllStates = if (outputMode == InternalOutputModes.Complete) true else false operatorId += 1 StateStoreSaveExec( keys, Some(stateId), - Some(returnAllStates), + Some(outputMode), + Some(currentEventTimeWatermark), agg.withNewChildren( StateStoreRestoreExec( keys, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala index ad8238f189c64..7af978a9c4aa2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala @@ -21,12 +21,17 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratePredicate, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution +import org.apache.spark.sql.InternalOutputModes._ +import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.StructType + /** Used to identify the state store for a given operator. */ case class OperatorStateId( @@ -92,8 +97,9 @@ case class StateStoreRestoreExec( */ case class StateStoreSaveExec( keyExpressions: Seq[Attribute], - stateId: Option[OperatorStateId], - returnAllStates: Option[Boolean], + stateId: Option[OperatorStateId] = None, + outputMode: Option[OutputMode] = None, + eventTimeWatermark: Option[Long] = None, child: SparkPlan) extends execution.UnaryExecNode with StatefulOperator { @@ -104,9 +110,9 @@ case class StateStoreSaveExec( override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver - assert(returnAllStates.nonEmpty, - "Incorrect planning in IncrementalExecution, returnAllStates have not been set") - val saveAndReturnFunc = if (returnAllStates.get) saveAndReturnAll _ else saveAndReturnUpdated _ + assert(outputMode.nonEmpty, + "Incorrect planning in IncrementalExecution, outputMode has not been set") + child.execute().mapPartitionsWithStateStore( getStateId.checkpointLocation, operatorId = getStateId.operatorId, @@ -114,75 +120,95 @@ case class StateStoreSaveExec( keyExpressions.toStructType, child.output.toStructType, sqlContext.sessionState, - Some(sqlContext.streams.stateStoreCoordinator) - )(saveAndReturnFunc) + Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => + val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) + val numOutputRows = longMetric("numOutputRows") + val numTotalStateRows = longMetric("numTotalStateRows") + val numUpdatedStateRows = longMetric("numUpdatedStateRows") + + outputMode match { + // Update and output all rows in the StateStore. + case Some(Complete) => + while (iter.hasNext) { + val row = iter.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key.copy(), row.copy()) + numUpdatedStateRows += 1 + } + store.commit() + numTotalStateRows += store.numKeys() + store.iterator().map { case (k, v) => + numOutputRows += 1 + v.asInstanceOf[InternalRow] + } + + // Update and output only rows being evicted from the StateStore + case Some(Append) => + while (iter.hasNext) { + val row = iter.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key.copy(), row.copy()) + numUpdatedStateRows += 1 + } + + val watermarkAttribute = + keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey)).get + // If we are evicting based on a window, use the end of the window. Otherwise just + // use the attribute itself. + val evictionExpression = + if (watermarkAttribute.dataType.isInstanceOf[StructType]) { + LessThanOrEqual( + GetStructField(watermarkAttribute, 1), + Literal(eventTimeWatermark.get * 1000)) + } else { + LessThanOrEqual( + watermarkAttribute, + Literal(eventTimeWatermark.get * 1000)) + } + + logInfo(s"Filtering state store on: $evictionExpression") + val predicate = newPredicate(evictionExpression, keyExpressions) + store.remove(predicate.eval) + + store.commit() + + numTotalStateRows += store.numKeys() + store.updates().filter(_.isInstanceOf[ValueRemoved]).map { removed => + numOutputRows += 1 + removed.value.asInstanceOf[InternalRow] + } + + // Update and output modified rows from the StateStore. + case Some(Update) => + new Iterator[InternalRow] { + private[this] val baseIterator = iter + + override def hasNext: Boolean = { + if (!baseIterator.hasNext) { + store.commit() + numTotalStateRows += store.numKeys() + false + } else { + true + } + } + + override def next(): InternalRow = { + val row = baseIterator.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key.copy(), row.copy()) + numOutputRows += 1 + numUpdatedStateRows += 1 + row + } + } + + case _ => throw new UnsupportedOperationException(s"Invalid output mode: $outputMode") + } + } } override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning - - /** - * Save all the rows to the state store, and return all the rows in the state store. - * Note that this returns an iterator that pipelines the saving to store with downstream - * processing. - */ - private def saveAndReturnUpdated( - store: StateStore, - iter: Iterator[InternalRow]): Iterator[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") - val numTotalStateRows = longMetric("numTotalStateRows") - val numUpdatedStateRows = longMetric("numUpdatedStateRows") - - new Iterator[InternalRow] { - private[this] val baseIterator = iter - private[this] val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) - - override def hasNext: Boolean = { - if (!baseIterator.hasNext) { - store.commit() - numTotalStateRows += store.numKeys() - false - } else { - true - } - } - - override def next(): InternalRow = { - val row = baseIterator.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key.copy(), row.copy()) - numOutputRows += 1 - numUpdatedStateRows += 1 - row - } - } - } - - /** - * Save all the rows to the state store, and return all the rows in the state store. - * Note that the saving to store is blocking; only after all the rows have been saved - * is the iterator on the update store data is generated. - */ - private def saveAndReturnAll( - store: StateStore, - iter: Iterator[InternalRow]): Iterator[InternalRow] = { - val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) - val numOutputRows = longMetric("numOutputRows") - val numTotalStateRows = longMetric("numTotalStateRows") - val numUpdatedStateRows = longMetric("numUpdatedStateRows") - - while (iter.hasNext) { - val row = iter.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key.copy(), row.copy()) - numUpdatedStateRows += 1 - } - store.commit() - numTotalStateRows += store.numKeys() - store.iterator().map { case (k, v) => - numOutputRows += 1 - v.asInstanceOf[InternalRow] - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 57e89f85361e4..3ca6feac05cef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -92,6 +92,9 @@ class StreamExecution( /** The current batchId or -1 if execution has not yet been initialized. */ private var currentBatchId: Long = -1 + /** The current eventTime watermark, used to bound the lateness of data that will processed. */ + private var currentEventTimeWatermark: Long = 0 + /** All stream sources present in the query plan. */ private val sources = logicalPlan.collect { case s: StreamingExecutionRelation => s.source } @@ -427,7 +430,8 @@ class StreamExecution( triggerLogicalPlan, outputMode, checkpointFile("state"), - currentBatchId) + currentBatchId, + currentEventTimeWatermark) lastExecution.executedPlan // Force the lazy generation of execution plan } @@ -436,6 +440,25 @@ class StreamExecution( sink.addBatch(currentBatchId, nextBatch) reportNumRows(executedPlan, triggerLogicalPlan, newData) + // Update the eventTime watermark if we find one in the plan. + // TODO: Does this need to be an AttributeMap? + lastExecution.executedPlan.collect { + case e: EventTimeWatermarkExec => + logTrace(s"Maximum observed eventTime: ${e.maxEventTime.value}") + (e.maxEventTime.value / 1000) - e.delay.milliseconds() + }.headOption.foreach { newWatermark => + if (newWatermark > currentEventTimeWatermark) { + logInfo(s"Updating eventTime watermark to: $newWatermark ms") + currentEventTimeWatermark = newWatermark + } else { + logTrace(s"Event time didn't move: $newWatermark < $currentEventTimeWatermark") + } + + if (newWatermark != 0) { + streamMetrics.reportTriggerDetail(EVENT_TIME_WATERMARK, newWatermark) + } + } + awaitBatchLock.lock() try { // Wake up any threads that are waiting for the stream to progress. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala index e98d1883e4596..5645554a58f6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala @@ -221,6 +221,7 @@ object StreamMetrics extends Logging { val IS_TRIGGER_ACTIVE = "isTriggerActive" val IS_DATA_PRESENT_IN_TRIGGER = "isDataPresentInTrigger" val STATUS_MESSAGE = "statusMessage" + val EVENT_TIME_WATERMARK = "eventTimeWatermark" val START_TIMESTAMP = "timestamp.triggerStart" val GET_OFFSET_TIMESTAMP = "timestamp.afterGetOffset" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index f07feaad5dc71..493fdaaec5069 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -109,7 +109,7 @@ private[state] class HDFSBackedStateStoreProvider( case Some(ValueAdded(_, _)) => // Value did not exist in previous version and was added already, keep it marked as added allUpdates.put(key, ValueAdded(key, value)) - case Some(ValueUpdated(_, _)) | Some(KeyRemoved(_)) => + case Some(ValueUpdated(_, _)) | Some(ValueRemoved(_, _)) => // Value existed in previous version and updated/removed, mark it as updated allUpdates.put(key, ValueUpdated(key, value)) case None => @@ -124,24 +124,25 @@ private[state] class HDFSBackedStateStoreProvider( /** Remove keys that match the following condition */ override def remove(condition: UnsafeRow => Boolean): Unit = { verify(state == UPDATING, "Cannot remove after already committed or aborted") - - val keyIter = mapToUpdate.keySet().iterator() - while (keyIter.hasNext) { - val key = keyIter.next - if (condition(key)) { - keyIter.remove() + val entryIter = mapToUpdate.entrySet().iterator() + while (entryIter.hasNext) { + val entry = entryIter.next + if (condition(entry.getKey)) { + val value = entry.getValue + val key = entry.getKey + entryIter.remove() Option(allUpdates.get(key)) match { case Some(ValueUpdated(_, _)) | None => // Value existed in previous version and maybe was updated, mark removed - allUpdates.put(key, KeyRemoved(key)) + allUpdates.put(key, ValueRemoved(key, value)) case Some(ValueAdded(_, _)) => // Value did not exist in previous version and was added, should not appear in updates allUpdates.remove(key) - case Some(KeyRemoved(_)) => + case Some(ValueRemoved(_, _)) => // Remove already in update map, no need to change } - writeToDeltaFile(tempDeltaFileStream, KeyRemoved(key)) + writeToDeltaFile(tempDeltaFileStream, ValueRemoved(key, value)) } } } @@ -334,7 +335,7 @@ private[state] class HDFSBackedStateStoreProvider( writeUpdate(key, value) case ValueUpdated(key, value) => writeUpdate(key, value) - case KeyRemoved(key) => + case ValueRemoved(key, value) => writeRemove(key) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 7132e284c28f4..9bc6c0e2b9334 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -99,13 +99,16 @@ trait StateStoreProvider { /** Trait representing updates made to a [[StateStore]]. */ -sealed trait StoreUpdate +sealed trait StoreUpdate { + def key: UnsafeRow + def value: UnsafeRow +} case class ValueAdded(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate case class ValueUpdated(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate -case class KeyRemoved(key: UnsafeRow) extends StoreUpdate +case class ValueRemoved(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 533cd0cd2a2ea..05fc7345a7daf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -668,11 +668,11 @@ private[state] object StateStoreSuite { } def updatesToSet(iterator: Iterator[StoreUpdate]): Set[TestUpdate] = { - iterator.map { _ match { + iterator.map { case ValueAdded(key, value) => Added(rowToString(key), rowToInt(value)) case ValueUpdated(key, value) => Updated(rowToString(key), rowToInt(value)) - case KeyRemoved(key) => Removed(rowToString(key)) - }}.toSet + case ValueRemoved(key, _) => Removed(rowToString(key)) + }.toSet } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/WatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/WatermarkSuite.scala new file mode 100644 index 0000000000000..3617ec0f564c1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/WatermarkSuite.scala @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.functions.{count, window} + +class WatermarkSuite extends StreamTest with BeforeAndAfter with Logging { + + import testImplicits._ + + after { + sqlContext.streams.active.foreach(_.stop()) + } + + test("error on bad column") { + val inputData = MemoryStream[Int].toDF() + val e = intercept[AnalysisException] { + inputData.withWatermark("badColumn", "1 minute") + } + assert(e.getMessage contains "badColumn") + } + + test("error on wrong type") { + val inputData = MemoryStream[Int].toDF() + val e = intercept[AnalysisException] { + inputData.withWatermark("value", "1 minute") + } + assert(e.getMessage contains "value") + assert(e.getMessage contains "int") + } + + + test("watermark metric") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(windowedAggregation)( + AddData(inputData, 15), + AssertOnLastQueryStatus { status => + status.triggerDetails.get(StreamMetrics.EVENT_TIME_WATERMARK) === "5000" + }, + AddData(inputData, 15), + AssertOnLastQueryStatus { status => + status.triggerDetails.get(StreamMetrics.EVENT_TIME_WATERMARK) === "5000" + }, + AddData(inputData, 25), + AssertOnLastQueryStatus { status => + status.triggerDetails.get(StreamMetrics.EVENT_TIME_WATERMARK) === "15000" + } + ) + } + + test("append-mode watermark aggregation") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(windowedAggregation)( + AddData(inputData, 10, 11, 12, 13, 14, 15), + CheckAnswer(), + AddData(inputData, 25), // Advance watermark to 15 seconds + CheckAnswer(), + AddData(inputData, 25), // Evict items less than previous watermark. + CheckAnswer((10, 5)) + ) + } + + ignore("recovery") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(windowedAggregation)( + AddData(inputData, 10, 11, 12, 13, 14, 15), + CheckAnswer(), + AddData(inputData, 25), // Advance watermark to 15 seconds + StopStream, + StartStream(), + CheckAnswer(), + AddData(inputData, 25), // Evict items less than previous watermark. + StopStream, + StartStream(), + CheckAnswer((10, 5)) + ) + } + + test("dropping old data") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(windowedAggregation)( + AddData(inputData, 10, 11, 12), + CheckAnswer(), + AddData(inputData, 25), // Advance watermark to 15 seconds + CheckAnswer(), + AddData(inputData, 25), // Evict items less than previous watermark. + CheckAnswer((10, 3)), + AddData(inputData, 10), // 10 is later than 15 second watermark + CheckAnswer((10, 3)), + AddData(inputData, 25), + CheckAnswer((10, 3)) // Should not emit an incorrect partial result. + ) + } + + test("complete mode") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + // No eviction when asked to compute complete results. + testStream(windowedAggregation, OutputMode.Complete)( + AddData(inputData, 10, 11, 12), + CheckAnswer((10, 3)), + AddData(inputData, 25), + CheckAnswer((10, 3), (25, 1)), + AddData(inputData, 25), + CheckAnswer((10, 3), (25, 2)), + AddData(inputData, 10), + CheckAnswer((10, 4), (25, 2)), + AddData(inputData, 25), + CheckAnswer((10, 4), (25, 3)) + ) + } + + test("group by on raw timestamp") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy($"eventTime") + .agg(count("*") as 'count) + .select($"eventTime".cast("long").as[Long], $"count".as[Long]) + + testStream(windowedAggregation)( + AddData(inputData, 10), + CheckAnswer(), + AddData(inputData, 25), // Advance watermark to 15 seconds + CheckAnswer(), + AddData(inputData, 25), // Evict items less than previous watermark. + CheckAnswer((10, 1)) + ) + } +} From 649c15fae423a415cb6165aa0ef6d97ab4949afb Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Mon, 14 Nov 2016 21:15:39 -0800 Subject: [PATCH 107/534] [SPARK-18428][DOC] Update docs for GraphX ## What changes were proposed in this pull request? 1, Add link of `VertexRDD` and `EdgeRDD` 2, Notify in `Vertex and Edge RDDs` that not all methods are listed 3, `VertexID` -> `VertexId` ## How was this patch tested? No tests, only docs is modified Author: Zheng RuiFeng Closes #15875 from zhengruifeng/update_graphop_doc. (cherry picked from commit c31def1ddcbed340bfc071d54fb3dc7945cb525a) Signed-off-by: Reynold Xin --- docs/graphx-programming-guide.md | 68 ++++++++++++++++---------------- 1 file changed, 35 insertions(+), 33 deletions(-) diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 58671e6f146d8..1097cf1211c1f 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -11,6 +11,7 @@ description: GraphX graph processing library guide for Spark SPARK_VERSION_SHORT [EdgeRDD]: api/scala/index.html#org.apache.spark.graphx.EdgeRDD +[VertexRDD]: api/scala/index.html#org.apache.spark.graphx.VertexRDD [Edge]: api/scala/index.html#org.apache.spark.graphx.Edge [EdgeTriplet]: api/scala/index.html#org.apache.spark.graphx.EdgeTriplet [Graph]: api/scala/index.html#org.apache.spark.graphx.Graph @@ -89,7 +90,7 @@ with user defined objects attached to each vertex and edge. A directed multigra graph with potentially multiple parallel edges sharing the same source and destination vertex. The ability to support parallel edges simplifies modeling scenarios where there can be multiple relationships (e.g., co-worker and friend) between the same vertices. Each vertex is keyed by a -*unique* 64-bit long identifier (`VertexID`). GraphX does not impose any ordering constraints on +*unique* 64-bit long identifier (`VertexId`). GraphX does not impose any ordering constraints on the vertex identifiers. Similarly, edges have corresponding source and destination vertex identifiers. @@ -130,12 +131,12 @@ class Graph[VD, ED] { } {% endhighlight %} -The classes `VertexRDD[VD]` and `EdgeRDD[ED]` extend and are optimized versions of `RDD[(VertexID, +The classes `VertexRDD[VD]` and `EdgeRDD[ED]` extend and are optimized versions of `RDD[(VertexId, VD)]` and `RDD[Edge[ED]]` respectively. Both `VertexRDD[VD]` and `EdgeRDD[ED]` provide additional functionality built around graph computation and leverage internal optimizations. We discuss the -`VertexRDD` and `EdgeRDD` API in greater detail in the section on [vertex and edge +`VertexRDD`[VertexRDD] and `EdgeRDD`[EdgeRDD] API in greater detail in the section on [vertex and edge RDDs](#vertex_and_edge_rdds) but for now they can be thought of as simply RDDs of the form: -`RDD[(VertexID, VD)]` and `RDD[Edge[ED]]`. +`RDD[(VertexId, VD)]` and `RDD[Edge[ED]]`. ### Example Property Graph @@ -197,7 +198,7 @@ graph.edges.filter(e => e.srcId > e.dstId).count {% endhighlight %} > Note that `graph.vertices` returns an `VertexRDD[(String, String)]` which extends -> `RDD[(VertexID, (String, String))]` and so we use the scala `case` expression to deconstruct the +> `RDD[(VertexId, (String, String))]` and so we use the scala `case` expression to deconstruct the > tuple. On the other hand, `graph.edges` returns an `EdgeRDD` containing `Edge[String]` objects. > We could have also used the case class type constructor as in the following: > {% highlight scala %} @@ -287,7 +288,7 @@ class Graph[VD, ED] { // Change the partitioning heuristic ============================================================ def partitionBy(partitionStrategy: PartitionStrategy): Graph[VD, ED] // Transform vertex and edge attributes ========================================================== - def mapVertices[VD2](map: (VertexID, VD) => VD2): Graph[VD2, ED] + def mapVertices[VD2](map: (VertexId, VD) => VD2): Graph[VD2, ED] def mapEdges[ED2](map: Edge[ED] => ED2): Graph[VD, ED2] def mapEdges[ED2](map: (PartitionID, Iterator[Edge[ED]]) => Iterator[ED2]): Graph[VD, ED2] def mapTriplets[ED2](map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] @@ -297,18 +298,18 @@ class Graph[VD, ED] { def reverse: Graph[VD, ED] def subgraph( epred: EdgeTriplet[VD,ED] => Boolean = (x => true), - vpred: (VertexID, VD) => Boolean = ((v, d) => true)) + vpred: (VertexId, VD) => Boolean = ((v, d) => true)) : Graph[VD, ED] def mask[VD2, ED2](other: Graph[VD2, ED2]): Graph[VD, ED] def groupEdges(merge: (ED, ED) => ED): Graph[VD, ED] // Join RDDs with the graph ====================================================================== - def joinVertices[U](table: RDD[(VertexID, U)])(mapFunc: (VertexID, VD, U) => VD): Graph[VD, ED] - def outerJoinVertices[U, VD2](other: RDD[(VertexID, U)]) - (mapFunc: (VertexID, VD, Option[U]) => VD2) + def joinVertices[U](table: RDD[(VertexId, U)])(mapFunc: (VertexId, VD, U) => VD): Graph[VD, ED] + def outerJoinVertices[U, VD2](other: RDD[(VertexId, U)]) + (mapFunc: (VertexId, VD, Option[U]) => VD2) : Graph[VD2, ED] // Aggregate information about adjacent triplets ================================================= - def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexID]] - def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexID, VD)]] + def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexId]] + def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexId, VD)]] def aggregateMessages[Msg: ClassTag]( sendMsg: EdgeContext[VD, ED, Msg] => Unit, mergeMsg: (Msg, Msg) => Msg, @@ -316,15 +317,15 @@ class Graph[VD, ED] { : VertexRDD[A] // Iterative graph-parallel computation ========================================================== def pregel[A](initialMsg: A, maxIterations: Int, activeDirection: EdgeDirection)( - vprog: (VertexID, VD, A) => VD, - sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexID,A)], + vprog: (VertexId, VD, A) => VD, + sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId,A)], mergeMsg: (A, A) => A) : Graph[VD, ED] // Basic graph algorithms ======================================================================== def pageRank(tol: Double, resetProb: Double = 0.15): Graph[Double, Double] - def connectedComponents(): Graph[VertexID, ED] + def connectedComponents(): Graph[VertexId, ED] def triangleCount(): Graph[Int, ED] - def stronglyConnectedComponents(numIter: Int): Graph[VertexID, ED] + def stronglyConnectedComponents(numIter: Int): Graph[VertexId, ED] } {% endhighlight %} @@ -481,7 +482,7 @@ original value. > is therefore recommended that the input RDD be made unique using the following which will > also *pre-index* the resulting values to substantially accelerate the subsequent join. > {% highlight scala %} -val nonUniqueCosts: RDD[(VertexID, Double)] +val nonUniqueCosts: RDD[(VertexId, Double)] val uniqueCosts: VertexRDD[Double] = graph.vertices.aggregateUsingIndex(nonUnique, (a,b) => a + b) val joinedGraph = graph.joinVertices(uniqueCosts)( @@ -511,7 +512,7 @@ val degreeGraph = graph.outerJoinVertices(outDegrees) { (id, oldAttr, outDegOpt) > provide type annotation for the user defined function: > {% highlight scala %} val joinedGraph = graph.joinVertices(uniqueCosts, - (id: VertexID, oldCost: Double, extraCost: Double) => oldCost + extraCost) + (id: VertexId, oldCost: Double, extraCost: Double) => oldCost + extraCost) {% endhighlight %} > @@ -558,7 +559,7 @@ The user defined `mergeMsg` function takes two messages destined to the same ver yields a single message. Think of `mergeMsg` as the reduce function in map-reduce. The [`aggregateMessages`][Graph.aggregateMessages] operator returns a `VertexRDD[Msg]` containing the aggregate message (of type `Msg`) destined to each vertex. Vertices that did not -receive a message are not included in the returned `VertexRDD`. +receive a message are not included in the returned `VertexRDD`[VertexRDD]. + +More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.regression.LinearRegression). + {% include_example python/ml/linear_regression_with_elastic_net.py %}
@@ -519,18 +546,21 @@ function and extracting model summary statistics.
+ Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.GeneralizedLinearRegression) for more details. {% include_example scala/org/apache/spark/examples/ml/GeneralizedLinearRegressionExample.scala %}
+ Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/GeneralizedLinearRegression.html) for more details. {% include_example java/org/apache/spark/examples/ml/JavaGeneralizedLinearRegressionExample.java %}
+ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.GeneralizedLinearRegression) for more details. {% include_example python/ml/generalized_linear_regression_example.py %} @@ -705,14 +735,23 @@ The implementation matches the result from R's survival function
+ +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.AFTSurvivalRegression) for more details. + {% include_example scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala %}
+ +Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/AFTSurvivalRegression.html) for more details. + {% include_example java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java %}
+ +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.AFTSurvivalRegression) for more details. + {% include_example python/ml/aft_survival_regression.py %}
diff --git a/docs/ml-pipeline.md b/docs/ml-pipeline.md index adb057ba7e250..b4d6be94f5eb0 100644 --- a/docs/ml-pipeline.md +++ b/docs/ml-pipeline.md @@ -207,14 +207,29 @@ This example covers the concepts of `Estimator`, `Transformer`, and `Param`.
+ +Refer to the [`Estimator` Scala docs](api/scala/index.html#org.apache.spark.ml.Estimator), +the [`Transformer` Scala docs](api/scala/index.html#org.apache.spark.ml.Transformer) and +the [`Params` Scala docs](api/scala/index.html#org.apache.spark.ml.param.Params) for details on the API. + {% include_example scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala %}
+ +Refer to the [`Estimator` Java docs](api/java/org/apache/spark/ml/Estimator.html), +the [`Transformer` Java docs](api/java/org/apache/spark/ml/Transformer.html) and +the [`Params` Java docs](api/java/org/apache/spark/ml/param/Params.html) for details on the API. + {% include_example java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java %}
+ +Refer to the [`Estimator` Python docs](api/python/pyspark.ml.html#pyspark.ml.Estimator), +the [`Transformer` Python docs](api/python/pyspark.ml.html#pyspark.ml.Transformer) and +the [`Params` Python docs](api/python/pyspark.ml.html#pyspark.ml.param.Params) for more details on the API. + {% include_example python/ml/estimator_transformer_param_example.py %}
@@ -227,14 +242,24 @@ This example follows the simple text document `Pipeline` illustrated in the figu
+ +Refer to the [`Pipeline` Scala docs](api/scala/index.html#org.apache.spark.ml.Pipeline) for details on the API. + {% include_example scala/org/apache/spark/examples/ml/PipelineExample.scala %}
+ + +Refer to the [`Pipeline` Java docs](api/java/org/apache/spark/ml/Pipeline.html) for details on the API. + {% include_example java/org/apache/spark/examples/ml/JavaPipelineExample.java %}
+ +Refer to the [`Pipeline` Python docs](api/python/pyspark.ml.html#pyspark.ml.Pipeline) for more details on the API. + {% include_example python/ml/pipeline_example.py %}
diff --git a/docs/ml-tuning.md b/docs/ml-tuning.md index 2ca90c7092fd3..15748720b7ae2 100644 --- a/docs/ml-tuning.md +++ b/docs/ml-tuning.md @@ -75,15 +75,23 @@ However, it is also a well-established method for choosing parameters which is m
+ +Refer to the [`CrossValidator` Scala docs](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) for details on the API. + {% include_example scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala %}
+ +Refer to the [`CrossValidator` Java docs](api/java/org/apache/spark/ml/tuning/CrossValidator.html) for details on the API. + {% include_example java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java %}
+Refer to the [`CrossValidator` Python docs](api/python/pyspark.ml.html#pyspark.ml.tuning.CrossValidator) for more details on the API. + {% include_example python/ml/cross_validator.py %}
@@ -107,14 +115,23 @@ Like `CrossValidator`, `TrainValidationSplit` finally fits the `Estimator` using
+ +Refer to the [`TrainValidationSplit` Scala docs](api/scala/index.html#org.apache.spark.ml.tuning.TrainValidationSplit) for details on the API. + {% include_example scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala %}
+ +Refer to the [`TrainValidationSplit` Java docs](api/java/org/apache/spark/ml/tuning/TrainValidationSplit.html) for details on the API. + {% include_example java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java %}
+ +Refer to the [`TrainValidationSplit` Python docs](api/python/pyspark.ml.html#pyspark.ml.tuning.TrainValidationSplit) for more details on the API. + {% include_example python/ml/train_validation_split.py %}
From b0ae8712358fc8c07aa5efe4d0bd337e7e452078 Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Wed, 16 Nov 2016 11:59:00 +0000 Subject: [PATCH 126/534] [SPARK-18420][BUILD] Fix the errors caused by lint check in Java Small fix, fix the errors caused by lint check in Java - Clear unused objects and `UnusedImports`. - Add comments around the method `finalize` of `NioBufferedFileInputStream`to turn off checkstyle. - Cut the line which is longer than 100 characters into two lines. Travis CI. ``` $ build/mvn -T 4 -q -DskipTests -Pyarn -Phadoop-2.3 -Pkinesis-asl -Phive -Phive-thriftserver install $ dev/lint-java ``` Before: ``` Checkstyle checks failed at following occurrences: [ERROR] src/main/java/org/apache/spark/network/util/TransportConf.java:[21,8] (imports) UnusedImports: Unused import - org.apache.commons.crypto.cipher.CryptoCipherFactory. [ERROR] src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java:[516,5] (modifier) RedundantModifier: Redundant 'public' modifier. [ERROR] src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java:[133] (coding) NoFinalizer: Avoid using finalizer method. [ERROR] src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java:[71] (sizes) LineLength: Line is longer than 100 characters (found 113). [ERROR] src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java:[112] (sizes) LineLength: Line is longer than 100 characters (found 110). [ERROR] src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java:[31,17] (modifier) ModifierOrder: 'static' modifier out of order with the JLS suggestions. [ERROR]src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java:[64] (sizes) LineLength: Line is longer than 100 characters (found 103). [ERROR] src/main/java/org/apache/spark/examples/ml/JavaInteractionExample.java:[22,8] (imports) UnusedImports: Unused import - org.apache.spark.ml.linalg.Vectors. [ERROR] src/main/java/org/apache/spark/examples/ml/JavaInteractionExample.java:[51] (regexp) RegexpSingleline: No trailing whitespace allowed. ``` After: ``` $ build/mvn -T 4 -q -DskipTests -Pyarn -Phadoop-2.3 -Pkinesis-asl -Phive -Phive-thriftserver install $ dev/lint-java Using `mvn` from path: /home/travis/build/ConeyLiu/spark/build/apache-maven-3.3.9/bin/mvn Checkstyle checks passed. ``` Author: Xianyang Liu Closes #15865 from ConeyLiu/master. (cherry picked from commit 7569cf6cb85bda7d0e76d3e75e286d4796e77e08) Signed-off-by: Sean Owen --- .../spark/io/NioBufferedFileInputStream.java | 2 ++ dev/checkstyle.xml | 15 +++++++++++++++ .../spark/examples/ml/JavaInteractionExample.java | 3 +-- ...vaLogisticRegressionWithElasticNetExample.java | 4 ++-- .../sql/catalyst/expressions/UnsafeArrayData.java | 3 ++- .../sql/catalyst/expressions/UnsafeMapData.java | 3 ++- .../sql/catalyst/expressions/HiveHasherSuite.java | 1 - 7 files changed, 24 insertions(+), 7 deletions(-) diff --git a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java index f6d1288cb263d..ea5f1a9abf69b 100644 --- a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java +++ b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java @@ -130,8 +130,10 @@ public synchronized void close() throws IOException { StorageUtils.dispose(byteBuffer); } + //checkstyle.off: NoFinalizer @Override protected void finalize() throws IOException { close(); } + //checkstyle.on: NoFinalizer } diff --git a/dev/checkstyle.xml b/dev/checkstyle.xml index 3de6aa91dcd51..92c5251c85037 100644 --- a/dev/checkstyle.xml +++ b/dev/checkstyle.xml @@ -52,6 +52,20 @@ + + + + + + + @@ -168,5 +182,6 @@ + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaInteractionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaInteractionExample.java index 4213c05703cc6..3684a87e22e7b 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaInteractionExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaInteractionExample.java @@ -19,7 +19,6 @@ import org.apache.spark.ml.feature.Interaction; import org.apache.spark.ml.feature.VectorAssembler; -import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.*; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; @@ -48,7 +47,7 @@ public static void main(String[] args) { RowFactory.create(5, 9, 2, 7, 10, 7, 3), RowFactory.create(6, 1, 1, 4, 2, 8, 4) ); - + StructType schema = new StructType(new StructField[]{ new StructField("id1", DataTypes.IntegerType, false, Metadata.empty()), new StructField("id2", DataTypes.IntegerType, false, Metadata.empty()), diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java index b8fb5972ea418..4cdec21d23023 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java @@ -60,8 +60,8 @@ public static void main(String[] args) { LogisticRegressionModel mlrModel = mlr.fit(training); // Print the coefficients and intercepts for logistic regression with multinomial family - System.out.println("Multinomial coefficients: " - + lrModel.coefficientMatrix() + "\nMultinomial intercepts: " + mlrModel.interceptVector()); + System.out.println("Multinomial coefficients: " + lrModel.coefficientMatrix() + + "\nMultinomial intercepts: " + mlrModel.interceptVector()); // $example off$ spark.stop(); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 86523c1474015..e8c33871f97bc 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -109,7 +109,8 @@ public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) { // Read the number of elements from the first 8 bytes. final long numElements = Platform.getLong(baseObject, baseOffset); assert numElements >= 0 : "numElements (" + numElements + ") should >= 0"; - assert numElements <= Integer.MAX_VALUE : "numElements (" + numElements + ") should <= Integer.MAX_VALUE"; + assert numElements <= Integer.MAX_VALUE : + "numElements (" + numElements + ") should <= Integer.MAX_VALUE"; this.numElements = (int)numElements; this.baseObject = baseObject; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java index 35029f5a50e3e..f17441dfccb6d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java @@ -68,7 +68,8 @@ public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) { // Read the numBytes of key array from the first 8 bytes. final long keyArraySize = Platform.getLong(baseObject, baseOffset); assert keyArraySize >= 0 : "keyArraySize (" + keyArraySize + ") should >= 0"; - assert keyArraySize <= Integer.MAX_VALUE : "keyArraySize (" + keyArraySize + ") should <= Integer.MAX_VALUE"; + assert keyArraySize <= Integer.MAX_VALUE : + "keyArraySize (" + keyArraySize + ") should <= Integer.MAX_VALUE"; final int valueArraySize = sizeInBytes - (int)keyArraySize - 8; assert valueArraySize >= 0 : "valueArraySize (" + valueArraySize + ") should >= 0"; diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java index 67a5eb0c7fe8f..b67c6f3e6e85e 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java @@ -28,7 +28,6 @@ import java.util.Set; public class HiveHasherSuite { - private final static HiveHasher hasher = new HiveHasher(); @Test public void testKnownIntegerInputs() { From c0dbe08d604dea543eb17ccb802a8a20d6c21a69 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 16 Nov 2016 08:25:15 -0800 Subject: [PATCH 127/534] [SPARK-18415][SQL] Weird Plan Output when CTE used in RunnableCommand ### What changes were proposed in this pull request? Currently, when CTE is used in RunnableCommand, the Analyzer does not replace the logical node `With`. The child plan of RunnableCommand is not resolved. Thus, the output of the `With` plan node looks very confusing. For example, ``` sql( """ |CREATE VIEW cte_view AS |WITH w AS (SELECT 1 AS n), cte1 (select 2), cte2 as (select 3) |SELECT n FROM w """.stripMargin).explain() ``` The output is like ``` ExecutedCommand +- CreateViewCommand `cte_view`, WITH w AS (SELECT 1 AS n), cte1 (select 2), cte2 as (select 3) SELECT n FROM w, false, false, PersistedView +- 'With [(w,SubqueryAlias w +- Project [1 AS n#16] +- OneRowRelation$ ), (cte1,'SubqueryAlias cte1 +- 'Project [unresolvedalias(2, None)] +- OneRowRelation$ ), (cte2,'SubqueryAlias cte2 +- 'Project [unresolvedalias(3, None)] +- OneRowRelation$ )] +- 'Project ['n] +- 'UnresolvedRelation `w` ``` After the fix, the output is as shown below. ``` ExecutedCommand +- CreateViewCommand `cte_view`, WITH w AS (SELECT 1 AS n), cte1 (select 2), cte2 as (select 3) SELECT n FROM w, false, false, PersistedView +- CTE [w, cte1, cte2] : :- SubqueryAlias w : : +- Project [1 AS n#16] : : +- OneRowRelation$ : :- 'SubqueryAlias cte1 : : +- 'Project [unresolvedalias(2, None)] : : +- OneRowRelation$ : +- 'SubqueryAlias cte2 : +- 'Project [unresolvedalias(3, None)] : +- OneRowRelation$ +- 'Project ['n] +- 'UnresolvedRelation `w` ``` BTW, this PR also fixes the output of the view type. ### How was this patch tested? Manual Author: gatorsmile Closes #15854 from gatorsmile/cteName. (cherry picked from commit 608ecc512b759514c75a1b475582f237ed569f10) Signed-off-by: Herman van Hovell --- .../catalyst/plans/logical/basicLogicalOperators.scala | 8 ++++++++ .../org/apache/spark/sql/execution/command/views.scala | 4 +++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 574caf039d3d2..dd6c8fd1dcf3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** * When planning take() or collect() operations, this special node that is inserted at the top of @@ -405,6 +406,13 @@ case class InsertIntoTable( */ case class With(child: LogicalPlan, cteRelations: Seq[(String, SubqueryAlias)]) extends UnaryNode { override def output: Seq[Attribute] = child.output + + override def simpleString: String = { + val cteAliases = Utils.truncatedString(cteRelations.map(_._1), "[", ", ", "]") + s"CTE $cteAliases" + } + + override def innerChildren: Seq[QueryPlan[_]] = cteRelations.map(_._2) } case class WithWindowDefinition( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 30472ec45ce44..154141bf83c7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -33,7 +33,9 @@ import org.apache.spark.sql.types.MetadataBuilder * ViewType is used to specify the expected view type when we want to create or replace a view in * [[CreateViewCommand]]. */ -sealed trait ViewType +sealed trait ViewType { + override def toString: String = getClass.getSimpleName.stripSuffix("$") +} /** * LocalTempView means session-scoped local temporary views. Its lifetime is the lifetime of the From b86e962c90c4322cd98b5bf3b19e251da2d32442 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 16 Nov 2016 10:00:59 -0800 Subject: [PATCH 128/534] [SPARK-18459][SPARK-18460][STRUCTUREDSTREAMING] Rename triggerId to batchId and add triggerDetails to json in StreamingQueryStatus ## What changes were proposed in this pull request? SPARK-18459: triggerId seems like a number that should be increasing with each trigger, whether or not there is data in it. However, actually, triggerId increases only where there is a batch of data in a trigger. So its better to rename it to batchId. SPARK-18460: triggerDetails was missing from json representation. Fixed it. ## How was this patch tested? Updated existing unit tests. Author: Tathagata Das Closes #15895 from tdas/SPARK-18459. (cherry picked from commit 0048ce7ce64b02cbb6a1c4a2963a0b1b9541047e) Signed-off-by: Shixiong Zhu --- python/pyspark/sql/streaming.py | 6 ++--- .../execution/streaming/StreamMetrics.scala | 8 +++---- .../sql/streaming/StreamingQueryStatus.scala | 4 ++-- .../streaming/StreamMetricsSuite.scala | 8 +++---- .../StreamingQueryListenerSuite.scala | 4 ++-- .../streaming/StreamingQueryStatusSuite.scala | 22 +++++++++++++++++-- 6 files changed, 35 insertions(+), 17 deletions(-) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index f326f16232690..0e4589be976ea 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -212,12 +212,12 @@ def __str__(self): Processing rate 23.5 rows/sec Latency: 345.0 ms Trigger details: + batchId: 5 isDataPresentInTrigger: true isTriggerActive: true latency.getBatch.total: 20 latency.getOffset.total: 10 numRows.input.total: 100 - triggerId: 5 Source statuses [1 source]: Source 1 - MySource1 Available offset: 0 @@ -341,8 +341,8 @@ def triggerDetails(self): If no trigger is currently active, then it will have details of the last completed trigger. >>> sqs.triggerDetails - {u'triggerId': u'5', u'latency.getBatch.total': u'20', u'numRows.input.total': u'100', - u'isTriggerActive': u'true', u'latency.getOffset.total': u'10', + {u'latency.getBatch.total': u'20', u'numRows.input.total': u'100', + u'isTriggerActive': u'true', u'batchId': u'5', u'latency.getOffset.total': u'10', u'isDataPresentInTrigger': u'true'} """ return self._jsqs.triggerDetails() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala index 5645554a58f6e..942e6ed8944be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala @@ -78,13 +78,13 @@ class StreamMetrics(sources: Set[Source], triggerClock: Clock, codahaleSourceNam // =========== Setter methods =========== - def reportTriggerStarted(triggerId: Long): Unit = synchronized { + def reportTriggerStarted(batchId: Long): Unit = synchronized { numInputRows.clear() triggerDetails.clear() sourceTriggerDetails.values.foreach(_.clear()) - reportTriggerDetail(TRIGGER_ID, triggerId) - sources.foreach(s => reportSourceTriggerDetail(s, TRIGGER_ID, triggerId)) + reportTriggerDetail(BATCH_ID, batchId) + sources.foreach(s => reportSourceTriggerDetail(s, BATCH_ID, batchId)) reportTriggerDetail(IS_TRIGGER_ACTIVE, true) currentTriggerStartTimestamp = triggerClock.getTimeMillis() reportTriggerDetail(START_TIMESTAMP, currentTriggerStartTimestamp) @@ -217,7 +217,7 @@ object StreamMetrics extends Logging { } - val TRIGGER_ID = "triggerId" + val BATCH_ID = "batchId" val IS_TRIGGER_ACTIVE = "isTriggerActive" val IS_DATA_PRESENT_IN_TRIGGER = "isDataPresentInTrigger" val STATUS_MESSAGE = "statusMessage" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala index 99c7729d02351..ba732ff7fc2ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala @@ -102,7 +102,7 @@ class StreamingQueryStatus private( ("inputRate" -> JDouble(inputRate)) ~ ("processingRate" -> JDouble(processingRate)) ~ ("latency" -> latency.map(JDouble).getOrElse(JNothing)) ~ - ("triggerDetails" -> JsonProtocol.mapToJson(triggerDetails.asScala)) + ("triggerDetails" -> JsonProtocol.mapToJson(triggerDetails.asScala)) ~ ("sourceStatuses" -> JArray(sourceStatuses.map(_.jsonValue).toList)) ~ ("sinkStatus" -> sinkStatus.jsonValue) } @@ -151,7 +151,7 @@ private[sql] object StreamingQueryStatus { desc = "MySink", offsetDesc = OffsetSeq(Some(LongOffset(1)) :: None :: Nil).toString), triggerDetails = Map( - TRIGGER_ID -> "5", + BATCH_ID -> "5", IS_TRIGGER_ACTIVE -> "true", IS_DATA_PRESENT_IN_TRIGGER -> "true", GET_OFFSET_LATENCY -> "10", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/StreamMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/StreamMetricsSuite.scala index 938423db64745..38c4ece439770 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/StreamMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/StreamMetricsSuite.scala @@ -50,10 +50,10 @@ class StreamMetricsSuite extends SparkFunSuite { assert(sm.currentSourceProcessingRate(source) === 0.0) assert(sm.currentLatency() === None) assert(sm.currentTriggerDetails() === - Map(TRIGGER_ID -> "1", IS_TRIGGER_ACTIVE -> "true", + Map(BATCH_ID -> "1", IS_TRIGGER_ACTIVE -> "true", START_TIMESTAMP -> "0", "key" -> "value")) assert(sm.currentSourceTriggerDetails(source) === - Map(TRIGGER_ID -> "1", "key2" -> "value2")) + Map(BATCH_ID -> "1", "key2" -> "value2")) // Finishing the trigger should calculate the rates, except input rate which needs // to have another trigger interval @@ -66,11 +66,11 @@ class StreamMetricsSuite extends SparkFunSuite { assert(sm.currentSourceProcessingRate(source) === 100.0) assert(sm.currentLatency() === None) assert(sm.currentTriggerDetails() === - Map(TRIGGER_ID -> "1", IS_TRIGGER_ACTIVE -> "false", + Map(BATCH_ID -> "1", IS_TRIGGER_ACTIVE -> "false", START_TIMESTAMP -> "0", FINISH_TIMESTAMP -> "1000", NUM_INPUT_ROWS -> "100", "key" -> "value")) assert(sm.currentSourceTriggerDetails(source) === - Map(TRIGGER_ID -> "1", NUM_SOURCE_INPUT_ROWS -> "100", "key2" -> "value2")) + Map(BATCH_ID -> "1", NUM_SOURCE_INPUT_ROWS -> "100", "key2" -> "value2")) // After another trigger starts, the rates and latencies should not change until // new rows are reported diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index cebb32a0a56cc..98f3bec7080af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -84,7 +84,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { AssertOnLastQueryStatus { status: StreamingQueryStatus => // Check the correctness of the trigger info of the last completed batch reported by // onQueryProgress - assert(status.triggerDetails.containsKey("triggerId")) + assert(status.triggerDetails.containsKey("batchId")) assert(status.triggerDetails.get("isTriggerActive") === "false") assert(status.triggerDetails.get("isDataPresentInTrigger") === "true") @@ -104,7 +104,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { assert(status.triggerDetails.get("numRows.state.aggregation1.updated") === "1") assert(status.sourceStatuses.length === 1) - assert(status.sourceStatuses(0).triggerDetails.containsKey("triggerId")) + assert(status.sourceStatuses(0).triggerDetails.containsKey("batchId")) assert(status.sourceStatuses(0).triggerDetails.get("latency.getOffset.source") === "100") assert(status.sourceStatuses(0).triggerDetails.get("latency.getBatch.source") === "200") assert(status.sourceStatuses(0).triggerDetails.get("numRows.input.source") === "2") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusSuite.scala index 6af19fb0c2327..50a7d92ede9a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusSuite.scala @@ -48,12 +48,12 @@ class StreamingQueryStatusSuite extends SparkFunSuite { | Processing rate 23.5 rows/sec | Latency: 345.0 ms | Trigger details: + | batchId: 5 | isDataPresentInTrigger: true | isTriggerActive: true | latency.getBatch.total: 20 | latency.getOffset.total: 10 | numRows.input.total: 100 - | triggerId: 5 | Source statuses [1 source]: | Source 1 - MySource1 | Available offset: 0 @@ -72,7 +72,11 @@ class StreamingQueryStatusSuite extends SparkFunSuite { test("json") { assert(StreamingQueryStatus.testStatus.json === """ - |{"sourceStatuses":[{"description":"MySource1","offsetDesc":"0","inputRate":15.5, + |{"name":"query","id":1,"timestamp":123,"inputRate":15.5,"processingRate":23.5, + |"latency":345.0,"triggerDetails":{"latency.getBatch.total":"20", + |"numRows.input.total":"100","isTriggerActive":"true","batchId":"5", + |"latency.getOffset.total":"10","isDataPresentInTrigger":"true"}, + |"sourceStatuses":[{"description":"MySource1","offsetDesc":"0","inputRate":15.5, |"processingRate":23.5,"triggerDetails":{"numRows.input.source":"100", |"latency.getOffset.source":"10","latency.getBatch.source":"20"}}], |"sinkStatus":{"description":"MySink","offsetDesc":"[1, -]"}} @@ -84,6 +88,20 @@ class StreamingQueryStatusSuite extends SparkFunSuite { StreamingQueryStatus.testStatus.prettyJson === """ |{ + | "name" : "query", + | "id" : 1, + | "timestamp" : 123, + | "inputRate" : 15.5, + | "processingRate" : 23.5, + | "latency" : 345.0, + | "triggerDetails" : { + | "latency.getBatch.total" : "20", + | "numRows.input.total" : "100", + | "isTriggerActive" : "true", + | "batchId" : "5", + | "latency.getOffset.total" : "10", + | "isDataPresentInTrigger" : "true" + | }, | "sourceStatuses" : [ { | "description" : "MySource1", | "offsetDesc" : "0", From 3d4756d56b852dcf4e1bebe621d4a30570873c3c Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 16 Nov 2016 11:03:10 -0800 Subject: [PATCH 129/534] [SPARK-18461][DOCS][STRUCTUREDSTREAMING] Added more information about monitoring streaming queries ## What changes were proposed in this pull request? screen shot 2016-11-15 at 6 27 32 pm screen shot 2016-11-15 at 6 27 45 pm Author: Tathagata Das Closes #15897 from tdas/SPARK-18461. (cherry picked from commit bb6cdfd9a6a6b6c91aada7c3174436146045ed1e) Signed-off-by: Michael Armbrust --- .../structured-streaming-programming-guide.md | 182 +++++++++++++++++- 1 file changed, 179 insertions(+), 3 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index d2545584ae3b0..77b66b3b3a497 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -1087,9 +1087,185 @@ spark.streams().awaitAnyTermination() # block until any one of them terminates
-Finally, for asynchronous monitoring of streaming queries, you can create and attach a `StreamingQueryListener` -([Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryListener)/[Java](api/java/org/apache/spark/sql/streaming/StreamingQueryListener.html) docs), -which will give you regular callback-based updates when queries are started and terminated. + +## Monitoring Streaming Queries +There are two ways you can monitor queries. You can directly get the current status +of an active query using `streamingQuery.status`, which will return a `StreamingQueryStatus` object +([Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryStatus)/[Java](api/java/org/apache/spark/sql/streaming/StreamingQueryStatus.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.StreamingQueryStatus) docs) +that has all the details like current ingestion rates, processing rates, average latency, +details of the currently active trigger, etc. + +
+
+ +{% highlight scala %} +val query: StreamingQuery = ... + +println(query.status) + +/* Will print the current status of the query + +Status of query 'queryName' + Query id: 1 + Status timestamp: 123 + Input rate: 15.5 rows/sec + Processing rate 23.5 rows/sec + Latency: 345.0 ms + Trigger details: + batchId: 5 + isDataPresentInTrigger: true + isTriggerActive: true + latency.getBatch.total: 20 + latency.getOffset.total: 10 + numRows.input.total: 100 + Source statuses [1 source]: + Source 1 - MySource1 + Available offset: 0 + Input rate: 15.5 rows/sec + Processing rate: 23.5 rows/sec + Trigger details: + numRows.input.source: 100 + latency.getOffset.source: 10 + latency.getBatch.source: 20 + Sink status - MySink + Committed offsets: [1, -] +*/ +{% endhighlight %} + +
+
+ +{% highlight java %} +StreamingQuery query = ... + +System.out.println(query.status); + +/* Will print the current status of the query + +Status of query 'queryName' + Query id: 1 + Status timestamp: 123 + Input rate: 15.5 rows/sec + Processing rate 23.5 rows/sec + Latency: 345.0 ms + Trigger details: + batchId: 5 + isDataPresentInTrigger: true + isTriggerActive: true + latency.getBatch.total: 20 + latency.getOffset.total: 10 + numRows.input.total: 100 + Source statuses [1 source]: + Source 1 - MySource1 + Available offset: 0 + Input rate: 15.5 rows/sec + Processing rate: 23.5 rows/sec + Trigger details: + numRows.input.source: 100 + latency.getOffset.source: 10 + latency.getBatch.source: 20 + Sink status - MySink + Committed offsets: [1, -] +*/ +{% endhighlight %} + +
+
+ +{% highlight python %} +query = ... // a StreamingQuery + +print(query.status) + +''' +Will print the current status of the query + +Status of query 'queryName' + Query id: 1 + Status timestamp: 123 + Input rate: 15.5 rows/sec + Processing rate 23.5 rows/sec + Latency: 345.0 ms + Trigger details: + batchId: 5 + isDataPresentInTrigger: true + isTriggerActive: true + latency.getBatch.total: 20 + latency.getOffset.total: 10 + numRows.input.total: 100 + Source statuses [1 source]: + Source 1 - MySource1 + Available offset: 0 + Input rate: 15.5 rows/sec + Processing rate: 23.5 rows/sec + Trigger details: + numRows.input.source: 100 + latency.getOffset.source: 10 + latency.getBatch.source: 20 + Sink status - MySink + Committed offsets: [1, -] +''' +{% endhighlight %} + +
+
+ + +You can also asynchronously monitor all queries associated with a +`SparkSession` by attaching a `StreamingQueryListener` +([Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryListener)/[Java](api/java/org/apache/spark/sql/streaming/StreamingQueryListener.html) docs). +Once you attach your custom `StreamingQueryListener` object with +`sparkSession.streams.attachListener()`, you will get callbacks when a query is started and +stopped and when there is progress made in an active query. Here is an example, + +
+
+ +{% highlight scala %} +val spark: SparkSession = ... + +spark.streams.addListener(new StreamingQueryListener() { + + override def onQueryStarted(queryStarted: QueryStartedEvent): Unit = { + println("Query started: " + queryTerminated.queryStatus.name) + } + override def onQueryTerminated(queryTerminated: QueryTerminatedEvent): Unit = { + println("Query terminated: " + queryTerminated.queryStatus.name) + } + override def onQueryProgress(queryProgress: QueryProgressEvent): Unit = { + println("Query made progress: " + queryProgress.queryStatus) + } +}) +{% endhighlight %} + +
+
+ +{% highlight java %} +SparkSession spark = ... + +spark.streams.addListener(new StreamingQueryListener() { + + @Overrides void onQueryStarted(QueryStartedEvent queryStarted) { + System.out.println("Query started: " + queryTerminated.queryStatus.name); + } + @Overrides void onQueryTerminated(QueryTerminatedEvent queryTerminated) { + System.out.println("Query terminated: " + queryTerminated.queryStatus.name); + } + @Overrides void onQueryProgress(QueryProgressEvent queryProgress) { + System.out.println("Query made progress: " + queryProgress.queryStatus); + } +}); +{% endhighlight %} + +
+
+{% highlight bash %} +Not available in Python. +{% endhighlight %} + +
+
## Recovering from Failures with Checkpointing In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger) and the running aggregates (e.g. word counts in the [quick example](#quick-example)) to the checkpoint location. As of Spark 2.0, this checkpoint location has to be a path in an HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries). From 523abfe19caa11747133877b0c8319c68ac66e56 Mon Sep 17 00:00:00 2001 From: Artur Sukhenko Date: Wed, 16 Nov 2016 15:08:01 -0800 Subject: [PATCH 130/534] [YARN][DOC] Increasing NodeManager's heap size with External Shuffle Service ## What changes were proposed in this pull request? Suggest users to increase `NodeManager's` heap size if `External Shuffle Service` is enabled as `NM` can spend a lot of time doing GC resulting in shuffle operations being a bottleneck due to `Shuffle Read blocked time` bumped up. Also because of GC `NodeManager` can use an enormous amount of CPU and cluster performance will suffer. I have seen NodeManager using 5-13G RAM and up to 2700% CPU with `spark_shuffle` service on. ## How was this patch tested? #### Added step 5: ![shuffle_service](https://cloud.githubusercontent.com/assets/15244468/20355499/2fec0fde-ac2a-11e6-8f8b-1c80daf71be1.png) Author: Artur Sukhenko Closes #15906 from Devian-ua/nmHeapSize. (cherry picked from commit 55589987be89ff78dadf44498352fbbd811a206e) Signed-off-by: Reynold Xin --- docs/running-on-yarn.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index cd18808681ece..fe0221ce7c5b6 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -559,6 +559,8 @@ pre-packaged distribution. 1. In the `yarn-site.xml` on each node, add `spark_shuffle` to `yarn.nodemanager.aux-services`, then set `yarn.nodemanager.aux-services.spark_shuffle.class` to `org.apache.spark.network.yarn.YarnShuffleService`. +1. Increase `NodeManager's` heap size by setting `YARN_HEAPSIZE` (1000 by default) in `etc/hadoop/yarn-env.sh` +to avoid garbage collection issues during shuffle. 1. Restart all `NodeManager`s in your cluster. The following extra configuration options are available when the shuffle service is running on YARN: From 9515793820c7954d82116238a67e632ea3e783b5 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 17 Nov 2016 11:21:08 +0800 Subject: [PATCH 131/534] [SPARK-18442][SQL] Fix nullability of WrapOption. ## What changes were proposed in this pull request? The nullability of `WrapOption` should be `false`. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #15887 from ueshin/issues/SPARK-18442. (cherry picked from commit 170eeb345f951de89a39fe565697b3e913011768) Signed-off-by: Wenchen Fan --- .../apache/spark/sql/catalyst/expressions/objects/objects.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 50e2ac3c36d93..0e3d99127ed56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -341,7 +341,7 @@ case class WrapOption(child: Expression, optType: DataType) override def dataType: DataType = ObjectType(classOf[Option[_]]) - override def nullable: Boolean = true + override def nullable: Boolean = false override def inputTypes: Seq[AbstractDataType] = optType :: Nil From 6a3cbbc037fe631e1b89c46000373dc2ba86a5eb Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 16 Nov 2016 14:22:15 -0800 Subject: [PATCH 132/534] [SPARK-1267][SPARK-18129] Allow PySpark to be pip installed ## What changes were proposed in this pull request? This PR aims to provide a pip installable PySpark package. This does a bunch of work to copy the jars over and package them with the Python code (to prevent challenges from trying to use different versions of the Python code with different versions of the JAR). It does not currently publish to PyPI but that is the natural follow up (SPARK-18129). Done: - pip installable on conda [manual tested] - setup.py installed on a non-pip managed system (RHEL) with YARN [manual tested] - Automated testing of this (virtualenv) - packaging and signing with release-build* Possible follow up work: - release-build update to publish to PyPI (SPARK-18128) - figure out who owns the pyspark package name on prod PyPI (is it someone with in the project or should we ask PyPI or should we choose a different name to publish with like ApachePySpark?) - Windows support and or testing ( SPARK-18136 ) - investigate details of wheel caching and see if we can avoid cleaning the wheel cache during our test - consider how we want to number our dev/snapshot versions Explicitly out of scope: - Using pip installed PySpark to start a standalone cluster - Using pip installed PySpark for non-Python Spark programs *I've done some work to test release-build locally but as a non-committer I've just done local testing. ## How was this patch tested? Automated testing with virtualenv, manual testing with conda, a system wide install, and YARN integration. release-build changes tested locally as a non-committer (no testing of upload artifacts to Apache staging websites) Author: Holden Karau Author: Juliet Hougland Author: Juliet Hougland Closes #15659 from holdenk/SPARK-1267-pip-install-pyspark. --- .gitignore | 2 + bin/beeline | 2 +- bin/find-spark-home | 41 ++++ bin/load-spark-env.sh | 2 +- bin/pyspark | 6 +- bin/run-example | 2 +- bin/spark-class | 6 +- bin/spark-shell | 4 +- bin/spark-sql | 2 +- bin/spark-submit | 2 +- bin/sparkR | 2 +- dev/create-release/release-build.sh | 26 ++- dev/create-release/release-tag.sh | 11 +- dev/lint-python | 4 +- dev/make-distribution.sh | 16 +- dev/pip-sanity-check.py | 36 +++ dev/run-pip-tests | 115 ++++++++++ dev/run-tests-jenkins.py | 1 + dev/run-tests.py | 7 + dev/sparktestsupport/__init__.py | 1 + docs/building-spark.md | 8 + docs/index.md | 4 +- .../spark/launcher/CommandBuilderUtils.java | 2 +- python/MANIFEST.in | 22 ++ python/README.md | 32 +++ python/pyspark/__init__.py | 1 + python/pyspark/find_spark_home.py | 74 +++++++ python/pyspark/java_gateway.py | 3 +- python/pyspark/version.py | 19 ++ python/setup.cfg | 22 ++ python/setup.py | 209 ++++++++++++++++++ 31 files changed, 660 insertions(+), 24 deletions(-) create mode 100755 bin/find-spark-home create mode 100644 dev/pip-sanity-check.py create mode 100755 dev/run-pip-tests create mode 100644 python/MANIFEST.in create mode 100644 python/README.md create mode 100755 python/pyspark/find_spark_home.py create mode 100644 python/pyspark/version.py create mode 100644 python/setup.cfg create mode 100644 python/setup.py diff --git a/.gitignore b/.gitignore index 39d17e1793f77..5634a434db0c0 100644 --- a/.gitignore +++ b/.gitignore @@ -57,6 +57,8 @@ project/plugins/project/build.properties project/plugins/src_managed/ project/plugins/target/ python/lib/pyspark.zip +python/deps +python/pyspark/python reports/ scalastyle-on-compile.generated.xml scalastyle-output.xml diff --git a/bin/beeline b/bin/beeline index 1627626941a73..058534699e44b 100755 --- a/bin/beeline +++ b/bin/beeline @@ -25,7 +25,7 @@ set -o posix # Figure out if SPARK_HOME is set if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi CLASS="org.apache.hive.beeline.BeeLine" diff --git a/bin/find-spark-home b/bin/find-spark-home new file mode 100755 index 0000000000000..fa78407d4175a --- /dev/null +++ b/bin/find-spark-home @@ -0,0 +1,41 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Attempts to find a proper value for SPARK_HOME. Should be included using "source" directive. + +FIND_SPARK_HOME_PYTHON_SCRIPT="$(cd "$(dirname "$0")"; pwd)/find_spark_home.py" + +# Short cirtuit if the user already has this set. +if [ ! -z "${SPARK_HOME}" ]; then + exit 0 +elif [ ! -f "$FIND_SPARK_HOME_PYTHON_SCRIPT" ]; then + # If we are not in the same directory as find_spark_home.py we are not pip installed so we don't + # need to search the different Python directories for a Spark installation. + # Note only that, if the user has pip installed PySpark but is directly calling pyspark-shell or + # spark-submit in another directory we want to use that version of PySpark rather than the + # pip installed version of PySpark. + export SPARK_HOME="$(cd "$(dirname "$0")"/..; pwd)" +else + # We are pip installed, use the Python script to resolve a reasonable SPARK_HOME + # Default to standard python interpreter unless told otherwise + if [[ -z "$PYSPARK_DRIVER_PYTHON" ]]; then + PYSPARK_DRIVER_PYTHON="${PYSPARK_PYTHON:-"python"}" + fi + export SPARK_HOME=$($PYSPARK_DRIVER_PYTHON "$FIND_SPARK_HOME_PYTHON_SCRIPT") +fi diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index eaea964ed5b3d..8a2f709960a25 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -23,7 +23,7 @@ # Figure out where Spark is installed if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi if [ -z "$SPARK_ENV_LOADED" ]; then diff --git a/bin/pyspark b/bin/pyspark index d6b3ab0a44321..98387c2ec5b8a 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi source "${SPARK_HOME}"/bin/load-spark-env.sh @@ -46,7 +46,7 @@ WORKS_WITH_IPYTHON=$(python -c 'import sys; print(sys.version_info >= (2, 7, 0)) # Determine the Python executable to use for the executors: if [[ -z "$PYSPARK_PYTHON" ]]; then - if [[ $PYSPARK_DRIVER_PYTHON == *ipython* && ! WORKS_WITH_IPYTHON ]]; then + if [[ $PYSPARK_DRIVER_PYTHON == *ipython* && ! $WORKS_WITH_IPYTHON ]]; then echo "IPython requires Python 2.7+; please install python2.7 or set PYSPARK_PYTHON" 1>&2 exit 1 else @@ -68,7 +68,7 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR export PYTHONHASHSEED=0 - exec "$PYSPARK_DRIVER_PYTHON" -m $1 + exec "$PYSPARK_DRIVER_PYTHON" -m "$1" exit fi diff --git a/bin/run-example b/bin/run-example index dd0e3c4120260..4ba5399311d33 100755 --- a/bin/run-example +++ b/bin/run-example @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi export _SPARK_CMD_USAGE="Usage: ./bin/run-example [options] example-class [example args]" diff --git a/bin/spark-class b/bin/spark-class index 377c8d1add3f6..77ea40cc37946 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi . "${SPARK_HOME}"/bin/load-spark-env.sh @@ -27,7 +27,7 @@ fi if [ -n "${JAVA_HOME}" ]; then RUNNER="${JAVA_HOME}/bin/java" else - if [ `command -v java` ]; then + if [ "$(command -v java)" ]; then RUNNER="java" else echo "JAVA_HOME is not set" >&2 @@ -36,7 +36,7 @@ else fi # Find Spark jars. -if [ -f "${SPARK_HOME}/RELEASE" ]; then +if [ -d "${SPARK_HOME}/jars" ]; then SPARK_JARS_DIR="${SPARK_HOME}/jars" else SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars" diff --git a/bin/spark-shell b/bin/spark-shell index 6583b5bd880ee..421f36cac3d47 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -21,7 +21,7 @@ # Shell script for starting the Spark Shell REPL cygwin=false -case "`uname`" in +case "$(uname)" in CYGWIN*) cygwin=true;; esac @@ -29,7 +29,7 @@ esac set -o posix if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi export _SPARK_CMD_USAGE="Usage: ./bin/spark-shell [options]" diff --git a/bin/spark-sql b/bin/spark-sql index 970d12cbf51dd..b08b944ebd319 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi export _SPARK_CMD_USAGE="Usage: ./bin/spark-sql [options] [cli option]" diff --git a/bin/spark-submit b/bin/spark-submit index 023f9c162f4b8..4e9d3614e6370 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi # disable randomized hash for string in Python 3.3+ diff --git a/bin/sparkR b/bin/sparkR index 2c07a82e2173b..29ab10df8ab6d 100755 --- a/bin/sparkR +++ b/bin/sparkR @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi source "${SPARK_HOME}"/bin/load-spark-env.sh diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 81f0d63054e29..1dbfa3b6e361b 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -162,14 +162,35 @@ if [[ "$1" == "package" ]]; then export ZINC_PORT=$ZINC_PORT echo "Creating distribution: $NAME ($FLAGS)" + # Write out the NAME and VERSION to PySpark version info we rewrite the - into a . and SNAPSHOT + # to dev0 to be closer to PEP440. We use the NAME as a "local version". + PYSPARK_VERSION=`echo "$SPARK_VERSION+$NAME" | sed -r "s/-/./" | sed -r "s/SNAPSHOT/dev0/"` + echo "__version__='$PYSPARK_VERSION'" > python/pyspark/version.py + # Get maven home set by MVN MVN_HOME=`$MVN -version 2>&1 | grep 'Maven home' | awk '{print $NF}'` - ./dev/make-distribution.sh --name $NAME --mvn $MVN_HOME/bin/mvn --tgz $FLAGS \ + echo "Creating distribution" + ./dev/make-distribution.sh --name $NAME --mvn $MVN_HOME/bin/mvn --tgz --pip $FLAGS \ -DzincPort=$ZINC_PORT 2>&1 > ../binary-release-$NAME.log cd .. - cp spark-$SPARK_VERSION-bin-$NAME/spark-$SPARK_VERSION-bin-$NAME.tgz . + echo "Copying and signing python distribution" + PYTHON_DIST_NAME=pyspark-$PYSPARK_VERSION.tar.gz + cp spark-$SPARK_VERSION-bin-$NAME/python/dist/$PYTHON_DIST_NAME . + + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ + --output $PYTHON_DIST_NAME.asc \ + --detach-sig $PYTHON_DIST_NAME + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ + MD5 $PYTHON_DIST_NAME > \ + $PYTHON_DIST_NAME.md5 + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ + SHA512 $PYTHON_DIST_NAME > \ + $PYTHON_DIST_NAME.sha + + echo "Copying and signing regular binary distribution" + cp spark-$SPARK_VERSION-bin-$NAME/spark-$SPARK_VERSION-bin-$NAME.tgz . echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ --output spark-$SPARK_VERSION-bin-$NAME.tgz.asc \ --detach-sig spark-$SPARK_VERSION-bin-$NAME.tgz @@ -208,6 +229,7 @@ if [[ "$1" == "package" ]]; then # Re-upload a second time and leave the files in the timestamped upload directory: LFTP mkdir -p $dest_dir LFTP mput -O $dest_dir 'spark-*' + LFTP mput -O $dest_dir 'pyspark-*' exit 0 fi diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh index b7e5100ca7408..370a62ce15bc4 100755 --- a/dev/create-release/release-tag.sh +++ b/dev/create-release/release-tag.sh @@ -65,6 +65,7 @@ sed -i".tmp1" 's/Version.*$/Version: '"$RELEASE_VERSION"'/g' R/pkg/DESCRIPTION # Set the release version in docs sed -i".tmp1" 's/SPARK_VERSION:.*$/SPARK_VERSION: '"$RELEASE_VERSION"'/g' docs/_config.yml sed -i".tmp2" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$RELEASE_VERSION"'/g' docs/_config.yml +sed -i".tmp3" 's/__version__ = .*$/__version__ = "'"$RELEASE_VERSION"'"/' python/pyspark/version.py git commit -a -m "Preparing Spark release $RELEASE_TAG" echo "Creating tag $RELEASE_TAG at the head of $GIT_BRANCH" @@ -74,12 +75,16 @@ git tag $RELEASE_TAG $MVN versions:set -DnewVersion=$NEXT_VERSION | grep -v "no value" # silence logs # Remove -SNAPSHOT before setting the R version as R expects version strings to only have numbers R_NEXT_VERSION=`echo $NEXT_VERSION | sed 's/-SNAPSHOT//g'` -sed -i".tmp2" 's/Version.*$/Version: '"$R_NEXT_VERSION"'/g' R/pkg/DESCRIPTION +sed -i".tmp4" 's/Version.*$/Version: '"$R_NEXT_VERSION"'/g' R/pkg/DESCRIPTION +# Write out the R_NEXT_VERSION to PySpark version info we use dev0 instead of SNAPSHOT to be closer +# to PEP440. +sed -i".tmp5" 's/__version__ = .*$/__version__ = "'"$R_NEXT_VERSION.dev0"'"/' python/pyspark/version.py + # Update docs with next version -sed -i".tmp3" 's/SPARK_VERSION:.*$/SPARK_VERSION: '"$NEXT_VERSION"'/g' docs/_config.yml +sed -i".tmp6" 's/SPARK_VERSION:.*$/SPARK_VERSION: '"$NEXT_VERSION"'/g' docs/_config.yml # Use R version for short version -sed -i".tmp4" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$R_NEXT_VERSION"'/g' docs/_config.yml +sed -i".tmp7" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$R_NEXT_VERSION"'/g' docs/_config.yml git commit -a -m "Preparing development version $NEXT_VERSION" diff --git a/dev/lint-python b/dev/lint-python index 63487043a50b6..3f878c2dad6b1 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -20,7 +20,9 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" PATHS_TO_CHECK="./python/pyspark/ ./examples/src/main/python/ ./dev/sparktestsupport" -PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/run-tests.py ./python/run-tests.py ./dev/run-tests-jenkins.py" +# TODO: fix pep8 errors with the rest of the Python scripts under dev +PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/run-tests.py ./python/*.py ./dev/run-tests-jenkins.py" +PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/pip-sanity-check.py" PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt" PYLINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/pylint-report.txt" PYLINT_INSTALL_INFO="$SPARK_ROOT_DIR/dev/pylint-info.txt" diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index 9be4fdfa51c93..49b46fbc3fb27 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -33,6 +33,7 @@ SPARK_HOME="$(cd "`dirname "$0"`/.."; pwd)" DISTDIR="$SPARK_HOME/dist" MAKE_TGZ=false +MAKE_PIP=false NAME=none MVN="$SPARK_HOME/build/mvn" @@ -40,7 +41,7 @@ function exit_with_usage { echo "make-distribution.sh - tool for making binary distributions of Spark" echo "" echo "usage:" - cl_options="[--name] [--tgz] [--mvn ]" + cl_options="[--name] [--tgz] [--pip] [--mvn ]" echo "make-distribution.sh $cl_options " echo "See Spark's \"Building Spark\" doc for correct Maven options." echo "" @@ -67,6 +68,9 @@ while (( "$#" )); do --tgz) MAKE_TGZ=true ;; + --pip) + MAKE_PIP=true + ;; --mvn) MVN="$2" shift @@ -201,6 +205,16 @@ fi # Copy data files cp -r "$SPARK_HOME/data" "$DISTDIR" +# Make pip package +if [ "$MAKE_PIP" == "true" ]; then + echo "Building python distribution package" + cd $SPARK_HOME/python + python setup.py sdist + cd .. +else + echo "Skipping creating pip installable PySpark" +fi + # Copy other things mkdir "$DISTDIR"/conf cp "$SPARK_HOME"/conf/*.template "$DISTDIR"/conf diff --git a/dev/pip-sanity-check.py b/dev/pip-sanity-check.py new file mode 100644 index 0000000000000..430c2ab52766a --- /dev/null +++ b/dev/pip-sanity-check.py @@ -0,0 +1,36 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark.sql import SparkSession +import sys + +if __name__ == "__main__": + spark = SparkSession\ + .builder\ + .appName("PipSanityCheck")\ + .getOrCreate() + sc = spark.sparkContext + rdd = sc.parallelize(range(100), 10) + value = rdd.reduce(lambda x, y: x + y) + if (value != 4950): + print("Value {0} did not match expected value.".format(value), file=sys.stderr) + sys.exit(-1) + print("Successfully ran pip sanity check") + + spark.stop() diff --git a/dev/run-pip-tests b/dev/run-pip-tests new file mode 100755 index 0000000000000..e1da18e60bb3d --- /dev/null +++ b/dev/run-pip-tests @@ -0,0 +1,115 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Stop on error +set -e +# Set nullglob for when we are checking existence based on globs +shopt -s nullglob + +FWDIR="$(cd "$(dirname "$0")"/..; pwd)" +cd "$FWDIR" + +echo "Constucting virtual env for testing" +VIRTUALENV_BASE=$(mktemp -d) + +# Clean up the virtual env enviroment used if we created one. +function delete_virtualenv() { + echo "Cleaning up temporary directory - $VIRTUALENV_BASE" + rm -rf "$VIRTUALENV_BASE" +} +trap delete_virtualenv EXIT + +# Some systems don't have pip or virtualenv - in those cases our tests won't work. +if ! hash virtualenv 2>/dev/null; then + echo "Missing virtualenv skipping pip installability tests." + exit 0 +fi +if ! hash pip 2>/dev/null; then + echo "Missing pip, skipping pip installability tests." + exit 0 +fi + +# Figure out which Python execs we should test pip installation with +PYTHON_EXECS=() +if hash python2 2>/dev/null; then + # We do this since we are testing with virtualenv and the default virtual env python + # is in /usr/bin/python + PYTHON_EXECS+=('python2') +elif hash python 2>/dev/null; then + # If python2 isn't installed fallback to python if available + PYTHON_EXECS+=('python') +fi +if hash python3 2>/dev/null; then + PYTHON_EXECS+=('python3') +fi + +# Determine which version of PySpark we are building for archive name +PYSPARK_VERSION=$(python -c "exec(open('python/pyspark/version.py').read());print __version__") +PYSPARK_DIST="$FWDIR/python/dist/pyspark-$PYSPARK_VERSION.tar.gz" +# The pip install options we use for all the pip commands +PIP_OPTIONS="--upgrade --no-cache-dir --force-reinstall " +# Test both regular user and edit/dev install modes. +PIP_COMMANDS=("pip install $PIP_OPTIONS $PYSPARK_DIST" + "pip install $PIP_OPTIONS -e python/") + +for python in "${PYTHON_EXECS[@]}"; do + for install_command in "${PIP_COMMANDS[@]}"; do + echo "Testing pip installation with python $python" + # Create a temp directory for us to work in and save its name to a file for cleanup + echo "Using $VIRTUALENV_BASE for virtualenv" + VIRTUALENV_PATH="$VIRTUALENV_BASE"/$python + rm -rf "$VIRTUALENV_PATH" + mkdir -p "$VIRTUALENV_PATH" + virtualenv --python=$python "$VIRTUALENV_PATH" + source "$VIRTUALENV_PATH"/bin/activate + # Upgrade pip + pip install --upgrade pip + + echo "Creating pip installable source dist" + cd "$FWDIR"/python + $python setup.py sdist + + + echo "Installing dist into virtual env" + cd dist + # Verify that the dist directory only contains one thing to install + sdists=(*.tar.gz) + if [ ${#sdists[@]} -ne 1 ]; then + echo "Unexpected number of targets found in dist directory - please cleanup existing sdists first." + exit -1 + fi + # Do the actual installation + cd "$FWDIR" + $install_command + + cd / + + echo "Run basic sanity check on pip installed version with spark-submit" + spark-submit "$FWDIR"/dev/pip-sanity-check.py + echo "Run basic sanity check with import based" + python "$FWDIR"/dev/pip-sanity-check.py + echo "Run the tests for context.py" + python "$FWDIR"/python/pyspark/context.py + + cd "$FWDIR" + + done +done + +exit 0 diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index a48d918f9dc1f..1d1e72faccf2a 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -128,6 +128,7 @@ def run_tests(tests_timeout): ERROR_CODES["BLOCK_MIMA"]: 'MiMa tests', ERROR_CODES["BLOCK_SPARK_UNIT_TESTS"]: 'Spark unit tests', ERROR_CODES["BLOCK_PYSPARK_UNIT_TESTS"]: 'PySpark unit tests', + ERROR_CODES["BLOCK_PYSPARK_PIP_TESTS"]: 'PySpark pip packaging tests', ERROR_CODES["BLOCK_SPARKR_UNIT_TESTS"]: 'SparkR unit tests', ERROR_CODES["BLOCK_TIMEOUT"]: 'from timeout after a configured wait of \`%s\`' % ( tests_timeout) diff --git a/dev/run-tests.py b/dev/run-tests.py index 5d661f5f1a1c5..ab285ac96af7e 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -432,6 +432,12 @@ def run_python_tests(test_modules, parallelism): run_cmd(command) +def run_python_packaging_tests(): + set_title_and_block("Running PySpark packaging tests", "BLOCK_PYSPARK_PIP_TESTS") + command = [os.path.join(SPARK_HOME, "dev", "run-pip-tests")] + run_cmd(command) + + def run_build_tests(): set_title_and_block("Running build tests", "BLOCK_BUILD_TESTS") run_cmd([os.path.join(SPARK_HOME, "dev", "test-dependencies.sh")]) @@ -583,6 +589,7 @@ def main(): modules_with_python_tests = [m for m in test_modules if m.python_test_goals] if modules_with_python_tests: run_python_tests(modules_with_python_tests, opts.parallelism) + run_python_packaging_tests() if any(m.should_run_r_tests for m in test_modules): run_sparkr_tests() diff --git a/dev/sparktestsupport/__init__.py b/dev/sparktestsupport/__init__.py index 89015f8c4fb9c..38f25da41f775 100644 --- a/dev/sparktestsupport/__init__.py +++ b/dev/sparktestsupport/__init__.py @@ -33,5 +33,6 @@ "BLOCK_SPARKR_UNIT_TESTS": 20, "BLOCK_JAVA_STYLE": 21, "BLOCK_BUILD_TESTS": 22, + "BLOCK_PYSPARK_PIP_TESTS": 23, "BLOCK_TIMEOUT": 124 } diff --git a/docs/building-spark.md b/docs/building-spark.md index 2b404bd3e116c..88da0cc9c3bbf 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -265,6 +265,14 @@ or Java 8 tests are automatically enabled when a Java 8 JDK is detected. If you have JDK 8 installed but it is not the system default, you can set JAVA_HOME to point to JDK 8 before running the tests. +## PySpark pip installable + +If you are building Spark for use in a Python environment and you wish to pip install it, you will first need to build the Spark JARs as described above. Then you can construct an sdist package suitable for setup.py and pip installable package. + + cd python; python setup.py sdist + +**Note:** Due to packaging requirements you can not directly pip install from the Python directory, rather you must first build the sdist package as described above. + ## PySpark Tests with Maven If you are building PySpark and wish to run the PySpark tests you will need to build Spark with Hive support. diff --git a/docs/index.md b/docs/index.md index fe51439ae08d7..39de11de854a7 100644 --- a/docs/index.md +++ b/docs/index.md @@ -14,7 +14,9 @@ It also supports a rich set of higher-level tools including [Spark SQL](sql-prog Get Spark from the [downloads page](http://spark.apache.org/downloads.html) of the project website. This documentation is for Spark version {{site.SPARK_VERSION}}. Spark uses Hadoop's client libraries for HDFS and YARN. Downloads are pre-packaged for a handful of popular Hadoop versions. Users can also download a "Hadoop free" binary and run Spark with any Hadoop version -[by augmenting Spark's classpath](hadoop-provided.html). +[by augmenting Spark's classpath](hadoop-provided.html). +Scala and Java users can include Spark in their projects using its maven cooridnates and in the future Python users can also install Spark from PyPI. + If you'd like to build Spark from source, visit [Building Spark](building-spark.html). diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index 62a22008d0d5d..250b2a882feb5 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -357,7 +357,7 @@ static int javaMajorVersion(String javaVersion) { static String findJarsDir(String sparkHome, String scalaVersion, boolean failIfNotFound) { // TODO: change to the correct directory once the assembly build is changed. File libdir; - if (new File(sparkHome, "RELEASE").isFile()) { + if (new File(sparkHome, "jars").isDirectory()) { libdir = new File(sparkHome, "jars"); checkState(!failIfNotFound || libdir.isDirectory(), "Library directory '%s' does not exist.", diff --git a/python/MANIFEST.in b/python/MANIFEST.in new file mode 100644 index 0000000000000..bbcce1baa439d --- /dev/null +++ b/python/MANIFEST.in @@ -0,0 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +global-exclude *.py[cod] __pycache__ .DS_Store +recursive-include deps/jars *.jar +graft deps/bin +recursive-include deps/examples *.py +recursive-include lib *.zip +include README.md diff --git a/python/README.md b/python/README.md new file mode 100644 index 0000000000000..0a5c8010b8486 --- /dev/null +++ b/python/README.md @@ -0,0 +1,32 @@ +# Apache Spark + +Spark is a fast and general cluster computing system for Big Data. It provides +high-level APIs in Scala, Java, Python, and R, and an optimized engine that +supports general computation graphs for data analysis. It also supports a +rich set of higher-level tools including Spark SQL for SQL and DataFrames, +MLlib for machine learning, GraphX for graph processing, +and Spark Streaming for stream processing. + + + +## Online Documentation + +You can find the latest Spark documentation, including a programming +guide, on the [project web page](http://spark.apache.org/documentation.html) + + +## Python Packaging + +This README file only contains basic information related to pip installed PySpark. +This packaging is currently experimental and may change in future versions (although we will do our best to keep compatibility). +Using PySpark requires the Spark JARs, and if you are building this from source please see the builder instructions at +["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html). + +The Python packaging for Spark is not intended to replace all of the other use cases. This Python packaged version of Spark is suitable for interacting with an existing cluster (be it Spark standalone, YARN, or Mesos) - but does not contain the tools required to setup your own standalone Spark cluster. You can download the full version of Spark from the [Apache Spark downloads page](http://spark.apache.org/downloads.html). + + +**NOTE:** If you are using this with a Spark standalone cluster you must ensure that the version (including minor version) matches or you may experience odd errors. + +## Python Requirements + +At its core PySpark depends on Py4J (currently version 0.10.4), but additional sub-packages have their own requirements (including numpy and pandas). \ No newline at end of file diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index ec1687415a7f6..5f93586a48a5a 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -50,6 +50,7 @@ from pyspark.serializers import MarshalSerializer, PickleSerializer from pyspark.status import * from pyspark.profiler import Profiler, BasicProfiler +from pyspark.version import __version__ def since(version): diff --git a/python/pyspark/find_spark_home.py b/python/pyspark/find_spark_home.py new file mode 100755 index 0000000000000..212a618b767ab --- /dev/null +++ b/python/pyspark/find_spark_home.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This script attempt to determine the correct setting for SPARK_HOME given +# that Spark may have been installed on the system with pip. + +from __future__ import print_function +import os +import sys + + +def _find_spark_home(): + """Find the SPARK_HOME.""" + # If the enviroment has SPARK_HOME set trust it. + if "SPARK_HOME" in os.environ: + return os.environ["SPARK_HOME"] + + def is_spark_home(path): + """Takes a path and returns true if the provided path could be a reasonable SPARK_HOME""" + return (os.path.isfile(os.path.join(path, "bin/spark-submit")) and + (os.path.isdir(os.path.join(path, "jars")) or + os.path.isdir(os.path.join(path, "assembly")))) + + paths = ["../", os.path.dirname(os.path.realpath(__file__))] + + # Add the path of the PySpark module if it exists + if sys.version < "3": + import imp + try: + module_home = imp.find_module("pyspark")[1] + paths.append(module_home) + # If we are installed in edit mode also look two dirs up + paths.append(os.path.join(module_home, "../../")) + except ImportError: + # Not pip installed no worries + pass + else: + from importlib.util import find_spec + try: + module_home = os.path.dirname(find_spec("pyspark").origin) + paths.append(module_home) + # If we are installed in edit mode also look two dirs up + paths.append(os.path.join(module_home, "../../")) + except ImportError: + # Not pip installed no worries + pass + + # Normalize the paths + paths = [os.path.abspath(p) for p in paths] + + try: + return next(path for path in paths if is_spark_home(path)) + except StopIteration: + print("Could not find valid SPARK_HOME while searching {0}".format(paths), file=sys.stderr) + exit(-1) + +if __name__ == "__main__": + print(_find_spark_home()) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index c1cf843d84388..3c783ae541a1f 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -29,6 +29,7 @@ xrange = range from py4j.java_gateway import java_import, JavaGateway, GatewayClient +from pyspark.find_spark_home import _find_spark_home from pyspark.serializers import read_int @@ -41,7 +42,7 @@ def launch_gateway(conf=None): if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) else: - SPARK_HOME = os.environ["SPARK_HOME"] + SPARK_HOME = _find_spark_home() # Launch the Py4j gateway using Spark's run command so that we pick up the # proper classpath and settings from spark-env.sh on_windows = platform.system() == "Windows" diff --git a/python/pyspark/version.py b/python/pyspark/version.py new file mode 100644 index 0000000000000..08a301695fda7 --- /dev/null +++ b/python/pyspark/version.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__version__ = "2.1.0.dev0" diff --git a/python/setup.cfg b/python/setup.cfg new file mode 100644 index 0000000000000..d100b932bbafc --- /dev/null +++ b/python/setup.cfg @@ -0,0 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +[bdist_wheel] +universal = 1 + +[metadata] +description-file = README.md diff --git a/python/setup.py b/python/setup.py new file mode 100644 index 0000000000000..625aea04073f5 --- /dev/null +++ b/python/setup.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import glob +import os +import sys +from setuptools import setup, find_packages +from shutil import copyfile, copytree, rmtree + +if sys.version_info < (2, 7): + print("Python versions prior to 2.7 are not supported for pip installed PySpark.", + file=sys.stderr) + exit(-1) + +try: + exec(open('pyspark/version.py').read()) +except IOError: + print("Failed to load PySpark version file for packaging. You must be in Spark's python dir.", + file=sys.stderr) + sys.exit(-1) +VERSION = __version__ +# A temporary path so we can access above the Python project root and fetch scripts and jars we need +TEMP_PATH = "deps" +SPARK_HOME = os.path.abspath("../") + +# Provide guidance about how to use setup.py +incorrect_invocation_message = """ +If you are installing pyspark from spark source, you must first build Spark and +run sdist. + + To build Spark with maven you can run: + ./build/mvn -DskipTests clean package + Building the source dist is done in the Python directory: + cd python + python setup.py sdist + pip install dist/*.tar.gz""" + +# Figure out where the jars are we need to package with PySpark. +JARS_PATH = glob.glob(os.path.join(SPARK_HOME, "assembly/target/scala-*/jars/")) + +if len(JARS_PATH) == 1: + JARS_PATH = JARS_PATH[0] +elif (os.path.isfile("../RELEASE") and len(glob.glob("../jars/spark*core*.jar")) == 1): + # Release mode puts the jars in a jars directory + JARS_PATH = os.path.join(SPARK_HOME, "jars") +elif len(JARS_PATH) > 1: + print("Assembly jars exist for multiple scalas ({0}), please cleanup assembly/target".format( + JARS_PATH), file=sys.stderr) + sys.exit(-1) +elif len(JARS_PATH) == 0 and not os.path.exists(TEMP_PATH): + print(incorrect_invocation_message, file=sys.stderr) + sys.exit(-1) + +EXAMPLES_PATH = os.path.join(SPARK_HOME, "examples/src/main/python") +SCRIPTS_PATH = os.path.join(SPARK_HOME, "bin") +SCRIPTS_TARGET = os.path.join(TEMP_PATH, "bin") +JARS_TARGET = os.path.join(TEMP_PATH, "jars") +EXAMPLES_TARGET = os.path.join(TEMP_PATH, "examples") + + +# Check and see if we are under the spark path in which case we need to build the symlink farm. +# This is important because we only want to build the symlink farm while under Spark otherwise we +# want to use the symlink farm. And if the symlink farm exists under while under Spark (e.g. a +# partially built sdist) we should error and have the user sort it out. +in_spark = (os.path.isfile("../core/src/main/scala/org/apache/spark/SparkContext.scala") or + (os.path.isfile("../RELEASE") and len(glob.glob("../jars/spark*core*.jar")) == 1)) + + +def _supports_symlinks(): + """Check if the system supports symlinks (e.g. *nix) or not.""" + return getattr(os, "symlink", None) is not None + + +if (in_spark): + # Construct links for setup + try: + os.mkdir(TEMP_PATH) + except: + print("Temp path for symlink to parent already exists {0}".format(TEMP_PATH), + file=sys.stderr) + exit(-1) + +try: + # We copy the shell script to be under pyspark/python/pyspark so that the launcher scripts + # find it where expected. The rest of the files aren't copied because they are accessed + # using Python imports instead which will be resolved correctly. + try: + os.makedirs("pyspark/python/pyspark") + except OSError: + # Don't worry if the directory already exists. + pass + copyfile("pyspark/shell.py", "pyspark/python/pyspark/shell.py") + + if (in_spark): + # Construct the symlink farm - this is necessary since we can't refer to the path above the + # package root and we need to copy the jars and scripts which are up above the python root. + if _supports_symlinks(): + os.symlink(JARS_PATH, JARS_TARGET) + os.symlink(SCRIPTS_PATH, SCRIPTS_TARGET) + os.symlink(EXAMPLES_PATH, EXAMPLES_TARGET) + else: + # For windows fall back to the slower copytree + copytree(JARS_PATH, JARS_TARGET) + copytree(SCRIPTS_PATH, SCRIPTS_TARGET) + copytree(EXAMPLES_PATH, EXAMPLES_TARGET) + else: + # If we are not inside of SPARK_HOME verify we have the required symlink farm + if not os.path.exists(JARS_TARGET): + print("To build packaging must be in the python directory under the SPARK_HOME.", + file=sys.stderr) + + if not os.path.isdir(SCRIPTS_TARGET): + print(incorrect_invocation_message, file=sys.stderr) + exit(-1) + + # Scripts directive requires a list of each script path and does not take wild cards. + script_names = os.listdir(SCRIPTS_TARGET) + scripts = list(map(lambda script: os.path.join(SCRIPTS_TARGET, script), script_names)) + # We add find_spark_home.py to the bin directory we install so that pip installed PySpark + # will search for SPARK_HOME with Python. + scripts.append("pyspark/find_spark_home.py") + + # Parse the README markdown file into rst for PyPI + long_description = "!!!!! missing pandoc do not upload to PyPI !!!!" + try: + import pypandoc + long_description = pypandoc.convert('README.md', 'rst') + except ImportError: + print("Could not import pypandoc - required to package PySpark", file=sys.stderr) + + setup( + name='pyspark', + version=VERSION, + description='Apache Spark Python API', + long_description=long_description, + author='Spark Developers', + author_email='dev@spark.apache.org', + url='https://github.com/apache/spark/tree/master/python', + packages=['pyspark', + 'pyspark.mllib', + 'pyspark.ml', + 'pyspark.sql', + 'pyspark.streaming', + 'pyspark.bin', + 'pyspark.jars', + 'pyspark.python.pyspark', + 'pyspark.python.lib', + 'pyspark.examples.src.main.python'], + include_package_data=True, + package_dir={ + 'pyspark.jars': 'deps/jars', + 'pyspark.bin': 'deps/bin', + 'pyspark.python.lib': 'lib', + 'pyspark.examples.src.main.python': 'deps/examples', + }, + package_data={ + 'pyspark.jars': ['*.jar'], + 'pyspark.bin': ['*'], + 'pyspark.python.lib': ['*.zip'], + 'pyspark.examples.src.main.python': ['*.py', '*/*.py']}, + scripts=scripts, + license='http://www.apache.org/licenses/LICENSE-2.0', + install_requires=['py4j==0.10.4'], + setup_requires=['pypandoc'], + extras_require={ + 'ml': ['numpy>=1.7'], + 'mllib': ['numpy>=1.7'], + 'sql': ['pandas'] + }, + classifiers=[ + 'Development Status :: 5 - Production/Stable', + 'License :: OSI Approved :: Apache Software License', + 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.4', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: Implementation :: CPython', + 'Programming Language :: Python :: Implementation :: PyPy'] + ) +finally: + # We only cleanup the symlink farm if we were in Spark, otherwise we are installing rather than + # packaging. + if (in_spark): + # Depending on cleaning up the symlink farm or copied version + if _supports_symlinks(): + os.remove(os.path.join(TEMP_PATH, "jars")) + os.remove(os.path.join(TEMP_PATH, "bin")) + os.remove(os.path.join(TEMP_PATH, "examples")) + else: + rmtree(os.path.join(TEMP_PATH, "jars")) + rmtree(os.path.join(TEMP_PATH, "bin")) + rmtree(os.path.join(TEMP_PATH, "examples")) + os.rmdir(TEMP_PATH) From 014fceee04c69d7944c74b3794e821e4d1003dd0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 17 Nov 2016 00:00:38 -0800 Subject: [PATCH 133/534] [SPARK-18464][SQL] support old table which doesn't store schema in metastore ## What changes were proposed in this pull request? Before Spark 2.1, users can create an external data source table without schema, and we will infer the table schema at runtime. In Spark 2.1, we decided to infer the schema when the table was created, so that we don't need to infer it again and again at runtime. This is a good improvement, but we should still respect and support old tables which doesn't store table schema in metastore. ## How was this patch tested? regression test. Author: Wenchen Fan Closes #15900 from cloud-fan/hive-catalog. (cherry picked from commit 07b3f045cd6f79b92bc86b3b1b51d3d5e6bd37ce) Signed-off-by: Reynold Xin --- .../spark/sql/execution/command/tables.scala | 8 ++++++- .../spark/sql/hive/HiveExternalCatalog.scala | 5 +++++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 4 +++- .../sql/hive/MetastoreDataSourcesSuite.scala | 22 +++++++++++++++++++ 4 files changed, 37 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 119e732d0202c..7049e53a78684 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -431,7 +431,13 @@ case class DescribeTableCommand( describeSchema(catalog.lookupRelation(table).schema, result) } else { val metadata = catalog.getTableMetadata(table) - describeSchema(metadata.schema, result) + if (metadata.schema.isEmpty) { + // In older version(prior to 2.1) of Spark, the table schema can be empty and should be + // inferred at runtime. We should still support it. + describeSchema(catalog.lookupRelation(metadata.identifier).schema, result) + } else { + describeSchema(metadata.schema, result) + } describePartitionInfo(metadata, result) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index cbd00da81cfcd..843305883abc8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -1023,6 +1023,11 @@ object HiveExternalCatalog { // After SPARK-6024, we removed this flag. // Although we are not using `spark.sql.sources.schema` any more, we need to still support. DataType.fromJson(schema.get).asInstanceOf[StructType] + } else if (props.filterKeys(_.startsWith(DATASOURCE_SCHEMA_PREFIX)).isEmpty) { + // If there is no schema information in table properties, it means the schema of this table + // was empty when saving into metastore, which is possible in older version(prior to 2.1) of + // Spark. We should respect it. + new StructType() } else { val numSchemaParts = props.get(DATASOURCE_SCHEMA_NUMPARTS) if (numSchemaParts.isDefined) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 8e5fc88aad448..edbde5d10b47c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -64,7 +64,9 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log val dataSource = DataSource( sparkSession, - userSpecifiedSchema = Some(table.schema), + // In older version(prior to 2.1) of Spark, the table schema can be empty and should be + // inferred at runtime. We should still support it. + userSpecifiedSchema = if (table.schema.isEmpty) None else Some(table.schema), partitionColumns = table.partitionColumnNames, bucketSpec = table.bucketSpec, className = table.provider.get, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index c50f92e783c88..4ab1a54edc46d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -1371,4 +1371,26 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } } } + + test("SPARK-18464: support old table which doesn't store schema in table properties") { + withTable("old") { + withTempPath { path => + Seq(1 -> "a").toDF("i", "j").write.parquet(path.getAbsolutePath) + val tableDesc = CatalogTable( + identifier = TableIdentifier("old", Some("default")), + tableType = CatalogTableType.EXTERNAL, + storage = CatalogStorageFormat.empty.copy( + properties = Map("path" -> path.getAbsolutePath) + ), + schema = new StructType(), + properties = Map( + HiveExternalCatalog.DATASOURCE_PROVIDER -> "parquet")) + hiveClient.createTable(tableDesc, ignoreIfExists = false) + + checkAnswer(spark.table("old"), Row(1, "a")) + + checkAnswer(sql("DESC old"), Row("i", "int", null) :: Row("j", "string", null) :: Nil) + } + } + } } From 2ee4fc8891be53b2fae43faa5cd09ade32173bba Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Thu, 17 Nov 2016 11:13:22 +0000 Subject: [PATCH 134/534] [YARN][DOC] Remove non-Yarn specific configurations from running-on-yarn.md ## What changes were proposed in this pull request? Remove `spark.driver.memory`, `spark.executor.memory`, `spark.driver.cores`, and `spark.executor.cores` from `running-on-yarn.md` as they are not Yarn-specific, and they are also defined in`configuration.md`. ## How was this patch tested? Build passed & Manually check. Author: Weiqing Yang Closes #15869 from weiqingy/yarnDoc. (cherry picked from commit a3cac7bd86a6fe8e9b42da1bf580aaeb59378304) Signed-off-by: Sean Owen --- docs/running-on-yarn.md | 36 ------------------------------------ 1 file changed, 36 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index fe0221ce7c5b6..4d1fafc07b8fc 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -117,28 +117,6 @@ To use a custom metrics.properties for the application master and executors, upd Use lower-case suffixes, e.g. k, m, g, t, and p, for kibi-, mebi-, gibi-, tebi-, and pebibytes, respectively. - - spark.driver.memory - 1g - - Amount of memory to use for the driver process, i.e. where SparkContext is initialized. - (e.g. 1g, 2g). - -
Note: In client mode, this config must not be set through the SparkConf - directly in your application, because the driver JVM has already started at that point. - Instead, please set this through the --driver-memory command line option - or in your default properties file. - - - - spark.driver.cores - 1 - - Number of cores used by the driver in YARN cluster mode. - Since the driver is run in the same JVM as the YARN Application Master in cluster mode, this also controls the cores used by the YARN Application Master. - In client mode, use spark.yarn.am.cores to control the number of cores used by the YARN Application Master instead. - - spark.yarn.am.cores 1 @@ -233,13 +211,6 @@ To use a custom metrics.properties for the application master and executors, upd Comma-separated list of jars to be placed in the working directory of each executor. - - spark.executor.cores - 1 in YARN mode, all the available cores on the worker in standalone mode. - - The number of cores to use on each executor. For YARN and standalone mode only. - - spark.executor.instances 2 @@ -247,13 +218,6 @@ To use a custom metrics.properties for the application master and executors, upd The number of executors for static allocation. With spark.dynamicAllocation.enabled, the initial set of executors will be at least this large. - - spark.executor.memory - 1g - - Amount of memory to use per executor process (e.g. 2g, 8g). - - spark.yarn.executor.memoryOverhead executorMemory * 0.10, with minimum of 384 From 4fcecb4cf081fba0345f1939420ca1d9f6de720c Mon Sep 17 00:00:00 2001 From: anabranch Date: Thu, 17 Nov 2016 11:34:55 +0000 Subject: [PATCH 135/534] [SPARK-18365][DOCS] Improve Sample Method Documentation ## What changes were proposed in this pull request? I found the documentation for the sample method to be confusing, this adds more clarification across all languages. - [x] Scala - [x] Python - [x] R - [x] RDD Scala - [ ] RDD Python with SEED - [X] RDD Java - [x] RDD Java with SEED - [x] RDD Python ## How was this patch tested? NA Please review https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark before opening a pull request. Author: anabranch Author: Bill Chambers Closes #15815 from anabranch/SPARK-18365. (cherry picked from commit 49b6f456aca350e9e2c170782aa5cc75e7822680) Signed-off-by: Sean Owen --- R/pkg/R/DataFrame.R | 4 +++- .../main/scala/org/apache/spark/api/java/JavaRDD.scala | 8 ++++++-- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 3 +++ python/pyspark/rdd.py | 5 +++++ python/pyspark/sql/dataframe.py | 5 +++++ .../src/main/scala/org/apache/spark/sql/Dataset.scala | 10 ++++++++-- 6 files changed, 30 insertions(+), 5 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 1cf9b38ea6483..4e3d97bb3ad07 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -936,7 +936,9 @@ setMethod("unique", #' Sample #' -#' Return a sampled subset of this SparkDataFrame using a random seed. +#' Return a sampled subset of this SparkDataFrame using a random seed. +#' Note: this is not guaranteed to provide exactly the fraction specified +#' of the total count of of the given SparkDataFrame. #' #' @param x A SparkDataFrame #' @param withReplacement Sampling with replacement or not diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 20d6c9341bf7a..d67cff64e6e46 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -98,7 +98,9 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) def repartition(numPartitions: Int): JavaRDD[T] = rdd.repartition(numPartitions) /** - * Return a sampled subset of this RDD. + * Return a sampled subset of this RDD with a random seed. + * Note: this is NOT guaranteed to provide exactly the fraction of the count + * of the given [[RDD]]. * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size @@ -109,7 +111,9 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) sample(withReplacement, fraction, Utils.random.nextLong) /** - * Return a sampled subset of this RDD. + * Return a sampled subset of this RDD, with a user-supplied seed. + * Note: this is NOT guaranteed to provide exactly the fraction of the count + * of the given [[RDD]]. * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e018af35cb18d..cded899db1f5c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -466,6 +466,9 @@ abstract class RDD[T: ClassTag]( /** * Return a sampled subset of this RDD. * + * Note: this is NOT guaranteed to provide exactly the fraction of the count + * of the given [[RDD]]. + * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 2de2c2fd1a60b..a163ceafe9d3b 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -386,6 +386,11 @@ def sample(self, withReplacement, fraction, seed=None): with replacement: expected number of times each element is chosen; fraction must be >= 0 :param seed: seed for the random number generator + .. note:: + + This is not guaranteed to provide exactly the fraction specified of the total count + of the given :class:`DataFrame`. + >>> rdd = sc.parallelize(range(100), 4) >>> 6 <= rdd.sample(False, 0.1, 81).count() <= 14 True diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 29710acf54c4f..38998900837cf 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -549,6 +549,11 @@ def distinct(self): def sample(self, withReplacement, fraction, seed=None): """Returns a sampled subset of this :class:`DataFrame`. + .. note:: + + This is not guaranteed to provide exactly the fraction specified of the total count + of the given :class:`DataFrame`. + >>> df.sample(False, 0.5, 42).count() 2 """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index af30683cc01c4..3761773698df3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1646,7 +1646,10 @@ class Dataset[T] private[sql]( } /** - * Returns a new Dataset by sampling a fraction of rows. + * Returns a new [[Dataset]] by sampling a fraction of rows, using a user-supplied seed. + * + * Note: this is NOT guaranteed to provide exactly the fraction of the count + * of the given [[Dataset]]. * * @param withReplacement Sample with replacement or not. * @param fraction Fraction of rows to generate. @@ -1665,7 +1668,10 @@ class Dataset[T] private[sql]( } /** - * Returns a new Dataset by sampling a fraction of rows, using a random seed. + * Returns a new [[Dataset]] by sampling a fraction of rows, using a random seed. + * + * Note: this is NOT guaranteed to provide exactly the fraction of the total count + * of the given [[Dataset]]. * * @param withReplacement Sample with replacement or not. * @param fraction Fraction of rows to generate. From 42777b1b3c10d3945494e27f1dedd43f2f836361 Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Thu, 17 Nov 2016 13:37:42 +0000 Subject: [PATCH 136/534] [SPARK-17462][MLLIB]use VersionUtils to parse Spark version strings ## What changes were proposed in this pull request? Several places in MLlib use custom regexes or other approaches to parse Spark versions. Those should be fixed to use the VersionUtils. This PR replaces custom regexes with VersionUtils to get Spark version numbers. ## How was this patch tested? Existing tests. Signed-off-by: VinceShieh vincent.xieintel.com Author: VinceShieh Closes #15055 from VinceShieh/SPARK-17462. (cherry picked from commit de77c67750dc868d75d6af173c3820b75a9fe4b7) Signed-off-by: Sean Owen --- .../main/scala/org/apache/spark/ml/clustering/KMeans.scala | 6 ++---- mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index a0d481b294ac7..26505b4cc1501 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -33,6 +33,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.util.VersionUtils.majorVersion /** * Common params for KMeans and KMeansModel @@ -232,10 +233,7 @@ object KMeansModel extends MLReadable[KMeansModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val versionRegex = "([0-9]+)\\.(.+)".r - val versionRegex(major, _) = metadata.sparkVersion - - val clusterCenters = if (major.toInt >= 2) { + val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) { val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data] data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML) } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 444006fe1edb6..1e49352b8517e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.util.VersionUtils.majorVersion /** * Params for [[PCA]] and [[PCAModel]]. @@ -204,11 +205,8 @@ object PCAModel extends MLReadable[PCAModel] { override def load(path: String): PCAModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) - val versionRegex = "([0-9]+)\\.(.+)".r - val versionRegex(major, _) = metadata.sparkVersion - val dataPath = new Path(path, "data").toString - val model = if (major.toInt >= 2) { + val model = if (majorVersion(metadata.sparkVersion) >= 2) { val Row(pc: DenseMatrix, explainedVariance: DenseVector) = sparkSession.read.parquet(dataPath) .select("pc", "explainedVariance") From 536a2159393c82d414cc46797c8bfd958f453d33 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Thu, 17 Nov 2016 13:40:16 +0000 Subject: [PATCH 137/534] [SPARK-18480][DOCS] Fix wrong links for ML guide docs ## What changes were proposed in this pull request? 1, There are two `[Graph.partitionBy]` in `graphx-programming-guide.md`, the first one had no effert. 2, `DataFrame`, `Transformer`, `Pipeline` and `Parameter` in `ml-pipeline.md` were linked to `ml-guide.html` by mistake. 3, `PythonMLLibAPI` in `mllib-linear-methods.md` was not accessable, because class `PythonMLLibAPI` is private. 4, Other link updates. ## How was this patch tested? manual tests Author: Zheng RuiFeng Closes #15912 from zhengruifeng/md_fix. (cherry picked from commit cdaf4ce9fe58c4606be8aa2a5c3756d30545c850) Signed-off-by: Sean Owen --- docs/graphx-programming-guide.md | 1 - docs/ml-classification-regression.md | 4 ++-- docs/ml-features.md | 2 +- docs/ml-pipeline.md | 12 ++++++------ docs/mllib-linear-methods.md | 4 +--- .../main/scala/org/apache/spark/ml/feature/LSH.scala | 2 +- .../spark/ml/tree/impl/GradientBoostedTrees.scala | 8 ++++---- .../org/apache/spark/ml/tree/impl/RandomForest.scala | 8 ++++---- 8 files changed, 19 insertions(+), 22 deletions(-) diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 1097cf1211c1f..e271b28fb4f28 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -36,7 +36,6 @@ description: GraphX graph processing library guide for Spark SPARK_VERSION_SHORT [Graph.fromEdgeTuples]: api/scala/index.html#org.apache.spark.graphx.Graph$@fromEdgeTuples[VD](RDD[(VertexId,VertexId)],VD,Option[PartitionStrategy])(ClassTag[VD]):Graph[VD,Int] [Graph.fromEdges]: api/scala/index.html#org.apache.spark.graphx.Graph$@fromEdges[VD,ED](RDD[Edge[ED]],VD)(ClassTag[VD],ClassTag[ED]):Graph[VD,ED] [PartitionStrategy]: api/scala/index.html#org.apache.spark.graphx.PartitionStrategy -[Graph.partitionBy]: api/scala/index.html#org.apache.spark.graphx.Graph$@partitionBy(partitionStrategy:org.apache.spark.graphx.PartitionStrategy):org.apache.spark.graphx.Graph[VD,ED] [PageRank]: api/scala/index.html#org.apache.spark.graphx.lib.PageRank$ [ConnectedComponents]: api/scala/index.html#org.apache.spark.graphx.lib.ConnectedComponents$ [TriangleCount]: api/scala/index.html#org.apache.spark.graphx.lib.TriangleCount$ diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index cb2ccbf4fe157..c72c01fcff830 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -984,7 +984,7 @@ Random forests combine many decision trees in order to reduce the risk of overfi The `spark.ml` implementation supports random forests for binary and multiclass classification and for regression, using both continuous and categorical features. -For more information on the algorithm itself, please see the [`spark.mllib` documentation on random forests](mllib-ensembles.html). +For more information on the algorithm itself, please see the [`spark.mllib` documentation on random forests](mllib-ensembles.html#random-forests). ### Inputs and Outputs @@ -1065,7 +1065,7 @@ GBTs iteratively train decision trees in order to minimize a loss function. The `spark.ml` implementation supports GBTs for binary classification and for regression, using both continuous and categorical features. -For more information on the algorithm itself, please see the [`spark.mllib` documentation on GBTs](mllib-ensembles.html). +For more information on the algorithm itself, please see the [`spark.mllib` documentation on GBTs](mllib-ensembles.html#gradient-boosted-trees-gbts). ### Inputs and Outputs diff --git a/docs/ml-features.md b/docs/ml-features.md index 903177210d820..45724a3716e74 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -694,7 +694,7 @@ for more details on the API. `VectorIndexer` helps index categorical features in datasets of `Vector`s. It can both automatically decide which features are categorical and convert original values to category indices. Specifically, it does the following: -1. Take an input column of type [Vector](api/scala/index.html#org.apache.spark.mllib.linalg.Vector) and a parameter `maxCategories`. +1. Take an input column of type [Vector](api/scala/index.html#org.apache.spark.ml.linalg.Vector) and a parameter `maxCategories`. 2. Decide which features should be categorical based on the number of distinct values, where features with at most `maxCategories` are declared categorical. 3. Compute 0-based category indices for each categorical feature. 4. Index categorical features and transform original feature values to indices. diff --git a/docs/ml-pipeline.md b/docs/ml-pipeline.md index b4d6be94f5eb0..0384513ab7014 100644 --- a/docs/ml-pipeline.md +++ b/docs/ml-pipeline.md @@ -38,26 +38,26 @@ algorithms into a single pipeline, or workflow. This section covers the key concepts introduced by the Pipelines API, where the pipeline concept is mostly inspired by the [scikit-learn](http://scikit-learn.org/) project. -* **[`DataFrame`](ml-guide.html#dataframe)**: This ML API uses `DataFrame` from Spark SQL as an ML +* **[`DataFrame`](ml-pipeline.html#dataframe)**: This ML API uses `DataFrame` from Spark SQL as an ML dataset, which can hold a variety of data types. E.g., a `DataFrame` could have different columns storing text, feature vectors, true labels, and predictions. -* **[`Transformer`](ml-guide.html#transformers)**: A `Transformer` is an algorithm which can transform one `DataFrame` into another `DataFrame`. +* **[`Transformer`](ml-pipeline.html#transformers)**: A `Transformer` is an algorithm which can transform one `DataFrame` into another `DataFrame`. E.g., an ML model is a `Transformer` which transforms a `DataFrame` with features into a `DataFrame` with predictions. -* **[`Estimator`](ml-guide.html#estimators)**: An `Estimator` is an algorithm which can be fit on a `DataFrame` to produce a `Transformer`. +* **[`Estimator`](ml-pipeline.html#estimators)**: An `Estimator` is an algorithm which can be fit on a `DataFrame` to produce a `Transformer`. E.g., a learning algorithm is an `Estimator` which trains on a `DataFrame` and produces a model. -* **[`Pipeline`](ml-guide.html#pipeline)**: A `Pipeline` chains multiple `Transformer`s and `Estimator`s together to specify an ML workflow. +* **[`Pipeline`](ml-pipeline.html#pipeline)**: A `Pipeline` chains multiple `Transformer`s and `Estimator`s together to specify an ML workflow. -* **[`Parameter`](ml-guide.html#parameters)**: All `Transformer`s and `Estimator`s now share a common API for specifying parameters. +* **[`Parameter`](ml-pipeline.html#parameters)**: All `Transformer`s and `Estimator`s now share a common API for specifying parameters. ## DataFrame Machine learning can be applied to a wide variety of data types, such as vectors, text, images, and structured data. This API adopts the `DataFrame` from Spark SQL in order to support a variety of data types. -`DataFrame` supports many basic and structured types; see the [Spark SQL datatype reference](sql-programming-guide.html#spark-sql-datatype-reference) for a list of supported types. +`DataFrame` supports many basic and structured types; see the [Spark SQL datatype reference](sql-programming-guide.html#data-types) for a list of supported types. In addition to the types listed in the Spark SQL guide, `DataFrame` can use ML [`Vector`](mllib-data-types.html#local-vector) types. A `DataFrame` can be created either implicitly or explicitly from a regular `RDD`. See the code examples below and the [Spark SQL programming guide](sql-programming-guide.html) for examples. diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 816bdf1317000..3085539b40e61 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -139,7 +139,7 @@ and logistic regression. Linear SVMs supports only binary classification, while logistic regression supports both binary and multiclass classification problems. For both methods, `spark.mllib` supports L1 and L2 regularized variants. -The training data set is represented by an RDD of [LabeledPoint](mllib-data-types.html) in MLlib, +The training data set is represented by an RDD of [LabeledPoint](mllib-data-types.html#labeled-point) in MLlib, where labels are class indices starting from zero: $0, 1, 2, \ldots$. ### Linear Support Vector Machines (SVMs) @@ -491,5 +491,3 @@ Algorithms are all implemented in Scala: * [RidgeRegressionWithSGD](api/scala/index.html#org.apache.spark.mllib.regression.RidgeRegressionWithSGD) * [LassoWithSGD](api/scala/index.html#org.apache.spark.mllib.regression.LassoWithSGD) -Python calls the Scala implementation via -[PythonMLLibAPI](api/scala/index.html#org.apache.spark.mllib.api.python.PythonMLLibAPI). diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala index 333a8c364a884..eb117c40eea3a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala @@ -40,7 +40,7 @@ private[ml] trait LSHParams extends HasInputCol with HasOutputCol { * @group param */ final val outputDim: IntParam = new IntParam(this, "outputDim", "output dimension, where" + - "increasing dimensionality lowers the false negative rate, and decreasing dimensionality" + + " increasing dimensionality lowers the false negative rate, and decreasing dimensionality" + " improves the running performance", ParamValidators.gt(0)) /** @group getParam */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index 7bef899a633d9..ede0a060eef95 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -34,7 +34,7 @@ private[spark] object GradientBoostedTrees extends Logging { /** * Method to train a gradient boosting model - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param input Training dataset: RDD of [[LabeledPoint]]. * @param seed Random seed. * @return tuple of ensemble models and weights: * (array of decision tree models, array of model weights) @@ -59,7 +59,7 @@ private[spark] object GradientBoostedTrees extends Logging { /** * Method to validate a gradient boosting model - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param input Training dataset: RDD of [[LabeledPoint]]. * @param validationInput Validation dataset. * This dataset should be different from the training dataset, * but it should follow the same distribution. @@ -162,7 +162,7 @@ private[spark] object GradientBoostedTrees extends Logging { * Method to calculate error of the base learner for the gradient boosting calculation. * Note: This method is not used by the gradient boosting algorithm but is useful for debugging * purposes. - * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param data Training dataset: RDD of [[LabeledPoint]]. * @param trees Boosted Decision Tree models * @param treeWeights Learning rates at each boosting iteration. * @param loss evaluation metric. @@ -184,7 +184,7 @@ private[spark] object GradientBoostedTrees extends Logging { /** * Method to compute error or loss for every iteration of gradient boosting. * - * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param data RDD of [[LabeledPoint]] * @param trees Boosted Decision Tree models * @param treeWeights Learning rates at each boosting iteration. * @param loss evaluation metric. diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index b504f411d256d..8ae5ca3c84b0e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -82,7 +82,7 @@ private[spark] object RandomForest extends Logging { /** * Train a random forest. * - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param input Training data: RDD of [[LabeledPoint]] * @return an unweighted set of trees */ def run( @@ -343,7 +343,7 @@ private[spark] object RandomForest extends Logging { /** * Given a group of nodes, this finds the best split for each node. * - * @param input Training data: RDD of [[org.apache.spark.ml.tree.impl.TreePoint]] + * @param input Training data: RDD of [[TreePoint]] * @param metadata Learning and dataset metadata * @param topNodesForGroup For each tree in group, tree index -> root node. * Used for matching instances with nodes. @@ -854,10 +854,10 @@ private[spark] object RandomForest extends Logging { * and for multiclass classification with a high-arity feature, * there is one bin per category. * - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param input Training data: RDD of [[LabeledPoint]] * @param metadata Learning and dataset metadata * @param seed random seed - * @return Splits, an Array of [[org.apache.spark.mllib.tree.model.Split]] + * @return Splits, an Array of [[Split]] * of size (numFeatures, numSplits) */ protected[tree] def findSplits( From 978798880c0b1e6a15e8a342847e1ff4d83a5ac0 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 17 Nov 2016 17:04:19 +0000 Subject: [PATCH 138/534] [SPARK-18490][SQL] duplication nodename extrainfo for ShuffleExchange ## What changes were proposed in this pull request? In ShuffleExchange, the nodename's extraInfo are the same when exchangeCoordinator.isEstimated is true or false. Merge the two situation in the PR. Author: root Closes #15920 from windpiger/DupNodeNameShuffleExchange. (cherry picked from commit b0aa1aa1af6c513a6a881eaea96abdd2b480ef98) Signed-off-by: Sean Owen --- .../apache/spark/sql/execution/exchange/ShuffleExchange.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala index 7a4a251370706..125a4930c6528 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala @@ -45,9 +45,7 @@ case class ShuffleExchange( override def nodeName: String = { val extraInfo = coordinator match { - case Some(exchangeCoordinator) if exchangeCoordinator.isEstimated => - s"(coordinator id: ${System.identityHashCode(coordinator)})" - case Some(exchangeCoordinator) if !exchangeCoordinator.isEstimated => + case Some(exchangeCoordinator) => s"(coordinator id: ${System.identityHashCode(coordinator)})" case None => "" } From fc466be4fd8def06880f59d50e5567c22cc53d6a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 17 Nov 2016 17:31:12 -0800 Subject: [PATCH 139/534] [SPARK-18360][SQL] default table path of tables in default database should depend on the location of default database ## What changes were proposed in this pull request? The current semantic of the warehouse config: 1. it's a static config, which means you can't change it once your spark application is launched. 2. Once a database is created, its location won't change even the warehouse path config is changed. 3. default database is a special case, although its location is fixed, but the locations of tables created in it are not. If a Spark app starts with warehouse path B(while the location of default database is A), then users create a table `tbl` in default database, its location will be `B/tbl` instead of `A/tbl`. If uses change the warehouse path config to C, and create another table `tbl2`, its location will still be `B/tbl2` instead of `C/tbl2`. rule 3 doesn't make sense and I think we made it by mistake, not intentionally. Data source tables don't follow rule 3 and treat default database like normal ones. This PR fixes hive serde tables to make it consistent with data source tables. ## How was this patch tested? HiveSparkSubmitSuite Author: Wenchen Fan Closes #15812 from cloud-fan/default-db. (cherry picked from commit ce13c2672318242748f7520ed4ce6bcfad4fb428) Signed-off-by: Yin Huai --- .../spark/sql/hive/HiveExternalCatalog.scala | 237 ++++++++++-------- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 76 +++++- 2 files changed, 190 insertions(+), 123 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 843305883abc8..cacffcf33c263 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -197,136 +197,151 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat if (tableDefinition.tableType == VIEW) { client.createTable(tableDefinition, ignoreIfExists) - } else if (tableDefinition.provider.get == DDLUtils.HIVE_PROVIDER) { - // Here we follow data source tables and put table metadata like provider, schema, etc. in - // table properties, so that we can work around the Hive metastore issue about not case - // preserving and make Hive serde table support mixed-case column names. - val tableWithDataSourceProps = tableDefinition.copy( - properties = tableDefinition.properties ++ tableMetaToTableProps(tableDefinition)) - client.createTable(tableWithDataSourceProps, ignoreIfExists) } else { - // To work around some hive metastore issues, e.g. not case-preserving, bad decimal type - // support, no column nullability, etc., we should do some extra works before saving table - // metadata into Hive metastore: - // 1. Put table metadata like provider, schema, etc. in table properties. - // 2. Check if this table is hive compatible. - // 2.1 If it's not hive compatible, set location URI, schema, partition columns and bucket - // spec to empty and save table metadata to Hive. - // 2.2 If it's hive compatible, set serde information in table metadata and try to save - // it to Hive. If it fails, treat it as not hive compatible and go back to 2.1 - val tableProperties = tableMetaToTableProps(tableDefinition) - // Ideally we should not create a managed table with location, but Hive serde table can // specify location for managed table. And in [[CreateDataSourceTableAsSelectCommand]] we have // to create the table directory and write out data before we create this table, to avoid // exposing a partial written table. val needDefaultTableLocation = tableDefinition.tableType == MANAGED && tableDefinition.storage.locationUri.isEmpty + val tableLocation = if (needDefaultTableLocation) { Some(defaultTablePath(tableDefinition.identifier)) } else { tableDefinition.storage.locationUri } - // Ideally we should also put `locationUri` in table properties like provider, schema, etc. - // However, in older version of Spark we already store table location in storage properties - // with key "path". Here we keep this behaviour for backward compatibility. - val storagePropsWithLocation = tableDefinition.storage.properties ++ - tableLocation.map("path" -> _) - - // converts the table metadata to Spark SQL specific format, i.e. set data schema, names and - // bucket specification to empty. Note that partition columns are retained, so that we can - // call partition-related Hive API later. - def newSparkSQLSpecificMetastoreTable(): CatalogTable = { - tableDefinition.copy( - // Hive only allows directory paths as location URIs while Spark SQL data source tables - // also allow file paths. For non-hive-compatible format, we should not set location URI - // to avoid hive metastore to throw exception. - storage = tableDefinition.storage.copy( - locationUri = None, - properties = storagePropsWithLocation), - schema = tableDefinition.partitionSchema, - bucketSpec = None, - properties = tableDefinition.properties ++ tableProperties) + + if (tableDefinition.provider.get == DDLUtils.HIVE_PROVIDER) { + val tableWithDataSourceProps = tableDefinition.copy( + // We can't leave `locationUri` empty and count on Hive metastore to set a default table + // location, because Hive metastore uses hive.metastore.warehouse.dir to generate default + // table location for tables in default database, while we expect to use the location of + // default database. + storage = tableDefinition.storage.copy(locationUri = tableLocation), + // Here we follow data source tables and put table metadata like provider, schema, etc. in + // table properties, so that we can work around the Hive metastore issue about not case + // preserving and make Hive serde table support mixed-case column names. + properties = tableDefinition.properties ++ tableMetaToTableProps(tableDefinition)) + client.createTable(tableWithDataSourceProps, ignoreIfExists) + } else { + createDataSourceTable( + tableDefinition.withNewStorage(locationUri = tableLocation), + ignoreIfExists) } + } + } - // converts the table metadata to Hive compatible format, i.e. set the serde information. - def newHiveCompatibleMetastoreTable(serde: HiveSerDe): CatalogTable = { - val location = if (tableDefinition.tableType == EXTERNAL) { - // When we hit this branch, we are saving an external data source table with hive - // compatible format, which means the data source is file-based and must have a `path`. - require(tableDefinition.storage.locationUri.isDefined, - "External file-based data source table must have a `path` entry in storage properties.") - Some(new Path(tableDefinition.location).toUri.toString) - } else { - None - } + private def createDataSourceTable(table: CatalogTable, ignoreIfExists: Boolean): Unit = { + // To work around some hive metastore issues, e.g. not case-preserving, bad decimal type + // support, no column nullability, etc., we should do some extra works before saving table + // metadata into Hive metastore: + // 1. Put table metadata like provider, schema, etc. in table properties. + // 2. Check if this table is hive compatible. + // 2.1 If it's not hive compatible, set location URI, schema, partition columns and bucket + // spec to empty and save table metadata to Hive. + // 2.2 If it's hive compatible, set serde information in table metadata and try to save + // it to Hive. If it fails, treat it as not hive compatible and go back to 2.1 + val tableProperties = tableMetaToTableProps(table) + + // Ideally we should also put `locationUri` in table properties like provider, schema, etc. + // However, in older version of Spark we already store table location in storage properties + // with key "path". Here we keep this behaviour for backward compatibility. + val storagePropsWithLocation = table.storage.properties ++ + table.storage.locationUri.map("path" -> _) + + // converts the table metadata to Spark SQL specific format, i.e. set data schema, names and + // bucket specification to empty. Note that partition columns are retained, so that we can + // call partition-related Hive API later. + def newSparkSQLSpecificMetastoreTable(): CatalogTable = { + table.copy( + // Hive only allows directory paths as location URIs while Spark SQL data source tables + // also allow file paths. For non-hive-compatible format, we should not set location URI + // to avoid hive metastore to throw exception. + storage = table.storage.copy( + locationUri = None, + properties = storagePropsWithLocation), + schema = table.partitionSchema, + bucketSpec = None, + properties = table.properties ++ tableProperties) + } - tableDefinition.copy( - storage = tableDefinition.storage.copy( - locationUri = location, - inputFormat = serde.inputFormat, - outputFormat = serde.outputFormat, - serde = serde.serde, - properties = storagePropsWithLocation - ), - properties = tableDefinition.properties ++ tableProperties) + // converts the table metadata to Hive compatible format, i.e. set the serde information. + def newHiveCompatibleMetastoreTable(serde: HiveSerDe): CatalogTable = { + val location = if (table.tableType == EXTERNAL) { + // When we hit this branch, we are saving an external data source table with hive + // compatible format, which means the data source is file-based and must have a `path`. + require(table.storage.locationUri.isDefined, + "External file-based data source table must have a `path` entry in storage properties.") + Some(new Path(table.location).toUri.toString) + } else { + None } - val qualifiedTableName = tableDefinition.identifier.quotedString - val maybeSerde = HiveSerDe.sourceToSerDe(tableDefinition.provider.get) - val skipHiveMetadata = tableDefinition.storage.properties - .getOrElse("skipHiveMetadata", "false").toBoolean - - val (hiveCompatibleTable, logMessage) = maybeSerde match { - case _ if skipHiveMetadata => - val message = - s"Persisting data source table $qualifiedTableName into Hive metastore in" + - "Spark SQL specific format, which is NOT compatible with Hive." - (None, message) - - // our bucketing is un-compatible with hive(different hash function) - case _ if tableDefinition.bucketSpec.nonEmpty => - val message = - s"Persisting bucketed data source table $qualifiedTableName into " + - "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive. " - (None, message) - - case Some(serde) => - val message = - s"Persisting file based data source table $qualifiedTableName into " + - s"Hive metastore in Hive compatible format." - (Some(newHiveCompatibleMetastoreTable(serde)), message) - - case _ => - val provider = tableDefinition.provider.get - val message = - s"Couldn't find corresponding Hive SerDe for data source provider $provider. " + - s"Persisting data source table $qualifiedTableName into Hive metastore in " + - s"Spark SQL specific format, which is NOT compatible with Hive." - (None, message) - } + table.copy( + storage = table.storage.copy( + locationUri = location, + inputFormat = serde.inputFormat, + outputFormat = serde.outputFormat, + serde = serde.serde, + properties = storagePropsWithLocation + ), + properties = table.properties ++ tableProperties) + } - (hiveCompatibleTable, logMessage) match { - case (Some(table), message) => - // We first try to save the metadata of the table in a Hive compatible way. - // If Hive throws an error, we fall back to save its metadata in the Spark SQL - // specific way. - try { - logInfo(message) - saveTableIntoHive(table, ignoreIfExists) - } catch { - case NonFatal(e) => - val warningMessage = - s"Could not persist ${tableDefinition.identifier.quotedString} in a Hive " + - "compatible way. Persisting it into Hive metastore in Spark SQL specific format." - logWarning(warningMessage, e) - saveTableIntoHive(newSparkSQLSpecificMetastoreTable(), ignoreIfExists) - } + val qualifiedTableName = table.identifier.quotedString + val maybeSerde = HiveSerDe.sourceToSerDe(table.provider.get) + val skipHiveMetadata = table.storage.properties + .getOrElse("skipHiveMetadata", "false").toBoolean + + val (hiveCompatibleTable, logMessage) = maybeSerde match { + case _ if skipHiveMetadata => + val message = + s"Persisting data source table $qualifiedTableName into Hive metastore in" + + "Spark SQL specific format, which is NOT compatible with Hive." + (None, message) + + // our bucketing is un-compatible with hive(different hash function) + case _ if table.bucketSpec.nonEmpty => + val message = + s"Persisting bucketed data source table $qualifiedTableName into " + + "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive. " + (None, message) + + case Some(serde) => + val message = + s"Persisting file based data source table $qualifiedTableName into " + + s"Hive metastore in Hive compatible format." + (Some(newHiveCompatibleMetastoreTable(serde)), message) + + case _ => + val provider = table.provider.get + val message = + s"Couldn't find corresponding Hive SerDe for data source provider $provider. " + + s"Persisting data source table $qualifiedTableName into Hive metastore in " + + s"Spark SQL specific format, which is NOT compatible with Hive." + (None, message) + } - case (None, message) => - logWarning(message) - saveTableIntoHive(newSparkSQLSpecificMetastoreTable(), ignoreIfExists) - } + (hiveCompatibleTable, logMessage) match { + case (Some(table), message) => + // We first try to save the metadata of the table in a Hive compatible way. + // If Hive throws an error, we fall back to save its metadata in the Spark SQL + // specific way. + try { + logInfo(message) + saveTableIntoHive(table, ignoreIfExists) + } catch { + case NonFatal(e) => + val warningMessage = + s"Could not persist ${table.identifier.quotedString} in a Hive " + + "compatible way. Persisting it into Hive metastore in Spark SQL specific format." + logWarning(warningMessage, e) + saveTableIntoHive(newSparkSQLSpecificMetastoreTable(), ignoreIfExists) + } + + case (None, message) => + logWarning(message) + saveTableIntoHive(newSparkSQLSpecificMetastoreTable(), ignoreIfExists) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index fbd705172cae6..a670560c5969d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -24,6 +24,7 @@ import java.util.Date import scala.collection.mutable.ArrayBuffer import scala.tools.nsc.Properties +import org.apache.hadoop.fs.Path import org.scalatest.{BeforeAndAfterEach, Matchers} import org.scalatest.concurrent.Timeouts import org.scalatest.exceptions.TestFailedDueToTimeoutException @@ -33,11 +34,12 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.sql.{QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, FunctionResource, JarResource} +import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer -import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types.{DecimalType, StructType} import org.apache.spark.util.{ResetSystemProperties, Utils} /** @@ -295,6 +297,20 @@ class HiveSparkSubmitSuite runSparkSubmit(args) } + test("SPARK-18360: default table path of tables in default database should depend on the " + + "location of default database") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SPARK_18360.getClass.getName.stripSuffix("$"), + "--name", "SPARK-18360", + "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", + unusedJar.toString) + runSparkSubmit(args) + } + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. // This is copied from org.apache.spark.deploy.SparkSubmitSuite private def runSparkSubmit(args: Seq[String]): Unit = { @@ -397,11 +413,7 @@ object SetWarehouseLocationTest extends Logging { def main(args: Array[String]): Unit = { Utils.configTestLog4j("INFO") - val sparkConf = new SparkConf(loadDefaults = true) - val builder = SparkSession.builder() - .config(sparkConf) - .config("spark.ui.enabled", "false") - .enableHiveSupport() + val sparkConf = new SparkConf(loadDefaults = true).set("spark.ui.enabled", "false") val providedExpectedWarehouseLocation = sparkConf.getOption("spark.sql.test.expectedWarehouseDir") @@ -410,7 +422,7 @@ object SetWarehouseLocationTest extends Logging { // If spark.sql.test.expectedWarehouseDir is set, the warehouse dir is set // through spark-summit. So, neither spark.sql.warehouse.dir nor // hive.metastore.warehouse.dir is set at here. - (builder.getOrCreate(), warehouseDir) + (new TestHiveContext(new SparkContext(sparkConf)).sparkSession, warehouseDir) case None => val warehouseLocation = Utils.createTempDir() warehouseLocation.delete() @@ -420,10 +432,10 @@ object SetWarehouseLocationTest extends Logging { // spark.sql.warehouse.dir and hive.metastore.warehouse.dir. // We are expecting that the value of spark.sql.warehouse.dir will override the // value of hive.metastore.warehouse.dir. - val session = builder - .config("spark.sql.warehouse.dir", warehouseLocation.toString) - .config("hive.metastore.warehouse.dir", hiveWarehouseLocation.toString) - .getOrCreate() + val session = new TestHiveContext(new SparkContext(sparkConf + .set("spark.sql.warehouse.dir", warehouseLocation.toString) + .set("hive.metastore.warehouse.dir", hiveWarehouseLocation.toString))) + .sparkSession (session, warehouseLocation.toString) } @@ -801,3 +813,43 @@ object SPARK_14244 extends QueryTest { } } } + +object SPARK_18360 { + def main(args: Array[String]): Unit = { + val spark = SparkSession.builder() + .config("spark.ui.enabled", "false") + .enableHiveSupport().getOrCreate() + + val defaultDbLocation = spark.catalog.getDatabase("default").locationUri + assert(new Path(defaultDbLocation) == new Path(spark.sharedState.warehousePath)) + + val hiveClient = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + + try { + val tableMeta = CatalogTable( + identifier = TableIdentifier("test_tbl", Some("default")), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty, + schema = new StructType().add("i", "int"), + provider = Some(DDLUtils.HIVE_PROVIDER)) + + val newWarehousePath = Utils.createTempDir().getAbsolutePath + hiveClient.runSqlHive(s"SET hive.metastore.warehouse.dir=$newWarehousePath") + hiveClient.createTable(tableMeta, ignoreIfExists = false) + val rawTable = hiveClient.getTable("default", "test_tbl") + // Hive will use the value of `hive.metastore.warehouse.dir` to generate default table + // location for tables in default database. + assert(rawTable.storage.locationUri.get.contains(newWarehousePath)) + hiveClient.dropTable("default", "test_tbl", ignoreIfNotExists = false, purge = false) + + spark.sharedState.externalCatalog.createTable(tableMeta, ignoreIfExists = false) + val readBack = spark.sharedState.externalCatalog.getTable("default", "test_tbl") + // Spark SQL will use the location of default database to generate default table + // location for tables in default database. + assert(readBack.storage.locationUri.get.contains(defaultDbLocation)) + } finally { + hiveClient.dropTable("default", "test_tbl", ignoreIfNotExists = true, purge = false) + hiveClient.runSqlHive(s"SET hive.metastore.warehouse.dir=$defaultDbLocation") + } + } +} From e8b1955e20a966da9a95f75320680cbab1096540 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 17 Nov 2016 18:45:15 -0800 Subject: [PATCH 140/534] [SPARK-18462] Fix ClassCastException in SparkListenerDriverAccumUpdates event ## What changes were proposed in this pull request? This patch fixes a `ClassCastException: java.lang.Integer cannot be cast to java.lang.Long` error which could occur in the HistoryServer while trying to process a deserialized `SparkListenerDriverAccumUpdates` event. The problem stems from how `jackson-module-scala` handles primitive type parameters (see https://github.com/FasterXML/jackson-module-scala/wiki/FAQ#deserializing-optionint-and-other-primitive-challenges for more details). This was causing a problem where our code expected a field to be deserialized as a `(Long, Long)` tuple but we got an `(Int, Int)` tuple instead. This patch hacks around this issue by registering a custom `Converter` with Jackson in order to deserialize the tuples as `(Object, Object)` and perform the appropriate casting. ## How was this patch tested? New regression tests in `SQLListenerSuite`. Author: Josh Rosen Closes #15922 from JoshRosen/SPARK-18462. (cherry picked from commit d9dd979d170f44383a9a87f892f2486ddb3cca7d) Signed-off-by: Reynold Xin --- .../spark/sql/execution/ui/SQLListener.scala | 39 +++++++++++++++- .../sql/execution/ui/SQLListenerSuite.scala | 44 ++++++++++++++++++- 2 files changed, 80 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 60f13432d78d2..5daf21595d8a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -19,6 +19,11 @@ package org.apache.spark.sql.execution.ui import scala.collection.mutable +import com.fasterxml.jackson.databind.JavaType +import com.fasterxml.jackson.databind.`type`.TypeFactory +import com.fasterxml.jackson.databind.annotation.JsonDeserialize +import com.fasterxml.jackson.databind.util.Converter + import org.apache.spark.{JobExecutionStatus, SparkConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging @@ -43,9 +48,41 @@ case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long) extends SparkListenerEvent @DeveloperApi -case class SparkListenerDriverAccumUpdates(executionId: Long, accumUpdates: Seq[(Long, Long)]) +case class SparkListenerDriverAccumUpdates( + executionId: Long, + @JsonDeserialize(contentConverter = classOf[LongLongTupleConverter]) + accumUpdates: Seq[(Long, Long)]) extends SparkListenerEvent +/** + * Jackson [[Converter]] for converting an (Int, Int) tuple into a (Long, Long) tuple. + * + * This is necessary due to limitations in how Jackson's scala module deserializes primitives; + * see the "Deserializing Option[Int] and other primitive challenges" section in + * https://github.com/FasterXML/jackson-module-scala/wiki/FAQ for a discussion of this issue and + * SPARK-18462 for the specific problem that motivated this conversion. + */ +private class LongLongTupleConverter extends Converter[(Object, Object), (Long, Long)] { + + override def convert(in: (Object, Object)): (Long, Long) = { + def toLong(a: Object): Long = a match { + case i: java.lang.Integer => i.intValue() + case l: java.lang.Long => l.longValue() + } + (toLong(in._1), toLong(in._2)) + } + + override def getInputType(typeFactory: TypeFactory): JavaType = { + val objectType = typeFactory.uncheckedSimpleType(classOf[Object]) + typeFactory.constructSimpleType(classOf[(_, _)], classOf[(_, _)], Array(objectType, objectType)) + } + + override def getOutputType(typeFactory: TypeFactory): JavaType = { + val longType = typeFactory.uncheckedSimpleType(classOf[Long]) + typeFactory.constructSimpleType(classOf[(_, _)], classOf[(_, _)], Array(longType, longType)) + } +} + class SQLHistoryListenerFactory extends SparkHistoryListenerFactory { override def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 19b6d2603129c..7b4ff675fba72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.ui import java.util.Properties +import org.json4s.jackson.JsonMethods._ import org.mockito.Mockito.mock import org.apache.spark._ @@ -35,10 +36,10 @@ import org.apache.spark.sql.execution.{LeafExecNode, QueryExecution, SparkPlanIn import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{AccumulatorMetadata, LongAccumulator} +import org.apache.spark.util.{AccumulatorMetadata, JsonProtocol, LongAccumulator} -class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { +class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTestUtils { import testImplicits._ import org.apache.spark.AccumulatorSuite.makeInfo @@ -416,6 +417,45 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { assert(driverUpdates(physicalPlan.longMetric("dummy").id) == expectedAccumValue) } + test("roundtripping SparkListenerDriverAccumUpdates through JsonProtocol (SPARK-18462)") { + val event = SparkListenerDriverAccumUpdates(1L, Seq((2L, 3L))) + val json = JsonProtocol.sparkEventToJson(event) + assertValidDataInJson(json, + parse(""" + |{ + | "Event": "org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates", + | "executionId": 1, + | "accumUpdates": [[2,3]] + |} + """.stripMargin)) + JsonProtocol.sparkEventFromJson(json) match { + case SparkListenerDriverAccumUpdates(executionId, accums) => + assert(executionId == 1L) + accums.foreach { case (a, b) => + assert(a == 2L) + assert(b == 3L) + } + } + + // Test a case where the numbers in the JSON can only fit in longs: + val longJson = parse( + """ + |{ + | "Event": "org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates", + | "executionId": 4294967294, + | "accumUpdates": [[4294967294,3]] + |} + """.stripMargin) + JsonProtocol.sparkEventFromJson(longJson) match { + case SparkListenerDriverAccumUpdates(executionId, accums) => + assert(executionId == 4294967294L) + accums.foreach { case (a, b) => + assert(a == 4294967294L) + assert(b == 3L) + } + } + } + } From 5912c19e76719a1c388a7a151af03ebf71b8f0db Mon Sep 17 00:00:00 2001 From: Tyson Condie Date: Fri, 18 Nov 2016 11:11:24 -0800 Subject: [PATCH 141/534] [SPARK-18187][SQL] CompactibleFileStreamLog should not use "compactInterval" direcly with user setting. ## What changes were proposed in this pull request? CompactibleFileStreamLog relys on "compactInterval" to detect a compaction batch. If the "compactInterval" is reset by user, CompactibleFileStreamLog will return wrong answer, resulting data loss. This PR procides a way to check the validity of 'compactInterval', and calculate an appropriate value. ## How was this patch tested? When restart a stream, we change the 'spark.sql.streaming.fileSource.log.compactInterval' different with the former one. The primary solution to this issue was given by uncleGen Added extensions include an additional metadata field in OffsetSeq and CompactibleFileStreamLog APIs. zsxwing Author: Tyson Condie Author: genmao.ygm Closes #15852 from tcondie/spark-18187. (cherry picked from commit 51baca2219fda8692b88fc8552548544aec73a1e) Signed-off-by: Shixiong Zhu --- .../streaming/CompactibleFileStreamLog.scala | 61 ++++++++++++++++++- .../streaming/FileStreamSinkLog.scala | 8 ++- .../streaming/FileStreamSourceLog.scala | 9 +-- .../execution/streaming/HDFSMetadataLog.scala | 2 +- .../sql/execution/streaming/OffsetSeq.scala | 12 +++- .../execution/streaming/OffsetSeqLog.scala | 31 +++++++--- .../CompactibleFileStreamLogSuite.scala | 33 ++++++++++ .../sql/streaming/FileStreamSourceSuite.scala | 41 ++++++++----- .../spark/sql/streaming/StreamTest.scala | 20 +++++- 9 files changed, 178 insertions(+), 39 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala index 8af3db1968882..8529ceac30f1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala @@ -63,7 +63,46 @@ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag]( protected def isDeletingExpiredLog: Boolean - protected def compactInterval: Int + protected def defaultCompactInterval: Int + + protected final lazy val compactInterval: Int = { + // SPARK-18187: "compactInterval" can be set by user via defaultCompactInterval. + // If there are existing log entries, then we should ensure a compatible compactInterval + // is used, irrespective of the defaultCompactInterval. There are three cases: + // + // 1. If there is no '.compact' file, we can use the default setting directly. + // 2. If there are two or more '.compact' files, we use the interval of patch id suffix with + // '.compact' as compactInterval. This case could arise if isDeletingExpiredLog == false. + // 3. If there is only one '.compact' file, then we must find a compact interval + // that is compatible with (i.e., a divisor of) the previous compact file, and that + // faithfully tries to represent the revised default compact interval i.e., is at least + // is large if possible. + // e.g., if defaultCompactInterval is 5 (and previous compact interval could have + // been any 2,3,4,6,12), then a log could be: 11.compact, 12, 13, in which case + // will ensure that the new compactInterval = 6 > 5 and (11 + 1) % 6 == 0 + val compactibleBatchIds = fileManager.list(metadataPath, batchFilesFilter) + .filter(f => f.getPath.toString.endsWith(CompactibleFileStreamLog.COMPACT_FILE_SUFFIX)) + .map(f => pathToBatchId(f.getPath)) + .sorted + .reverse + + // Case 1 + var interval = defaultCompactInterval + if (compactibleBatchIds.length >= 2) { + // Case 2 + val latestCompactBatchId = compactibleBatchIds(0) + val previousCompactBatchId = compactibleBatchIds(1) + interval = (latestCompactBatchId - previousCompactBatchId).toInt + } else if (compactibleBatchIds.length == 1) { + // Case 3 + interval = CompactibleFileStreamLog.deriveCompactInterval( + defaultCompactInterval, compactibleBatchIds(0).toInt) + } + assert(interval > 0, s"intervalValue = $interval not positive value.") + logInfo(s"Set the compact interval to $interval " + + s"[defaultCompactInterval: $defaultCompactInterval]") + interval + } /** * Filter out the obsolete logs. @@ -245,4 +284,24 @@ object CompactibleFileStreamLog { def nextCompactionBatchId(batchId: Long, compactInterval: Long): Long = { (batchId + compactInterval + 1) / compactInterval * compactInterval - 1 } + + /** + * Derives a compact interval from the latest compact batch id and + * a default compact interval. + */ + def deriveCompactInterval(defaultInterval: Int, latestCompactBatchId: Int) : Int = { + if (latestCompactBatchId + 1 <= defaultInterval) { + latestCompactBatchId + 1 + } else if (defaultInterval < (latestCompactBatchId + 1) / 2) { + // Find the first divisor >= default compact interval + def properDivisors(min: Int, n: Int) = + (min to n/2).view.filter(i => n % i == 0) :+ n + + properDivisors(defaultInterval, latestCompactBatchId + 1).head + } else { + // default compact interval > than any divisor other than latest compact id + latestCompactBatchId + 1 + } + } } + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala index b4f14151f1ef2..eb6eed87eca7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala @@ -88,9 +88,11 @@ class FileStreamSinkLog( protected override val isDeletingExpiredLog = sparkSession.sessionState.conf.fileSinkLogDeletion - protected override val compactInterval = sparkSession.sessionState.conf.fileSinkLogCompactInterval - require(compactInterval > 0, - s"Please set ${SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key} (was $compactInterval) " + + protected override val defaultCompactInterval = + sparkSession.sessionState.conf.fileSinkLogCompactInterval + + require(defaultCompactInterval > 0, + s"Please set ${SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key} (was $defaultCompactInterval) " + "to a positive value.") override def compactLogs(logs: Seq[SinkFileStatus]): Seq[SinkFileStatus] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala index fe81b15607068..327b3ac267766 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala @@ -38,11 +38,12 @@ class FileStreamSourceLog( import CompactibleFileStreamLog._ // Configurations about metadata compaction - protected override val compactInterval = + protected override val defaultCompactInterval: Int = sparkSession.sessionState.conf.fileSourceLogCompactInterval - require(compactInterval > 0, - s"Please set ${SQLConf.FILE_SOURCE_LOG_COMPACT_INTERVAL.key} (was $compactInterval) to a " + - s"positive value.") + + require(defaultCompactInterval > 0, + s"Please set ${SQLConf.FILE_SOURCE_LOG_COMPACT_INTERVAL.key} " + + s"(was $defaultCompactInterval) to a positive value.") protected override val fileCleanupDelayMs = sparkSession.sessionState.conf.fileSourceLogCleanupDelay diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index db7057d7da70c..080729b2ca8d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -70,7 +70,7 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: /** * A `PathFilter` to filter only batch files */ - private val batchFilesFilter = new PathFilter { + protected val batchFilesFilter = new PathFilter { override def accept(path: Path): Boolean = isBatchFile(path) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index a4e1fe6797097..7469caeee3be5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -23,7 +23,7 @@ package org.apache.spark.sql.execution.streaming * [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance * vector clock that must progress linearly forward. */ -case class OffsetSeq(offsets: Seq[Option[Offset]]) { +case class OffsetSeq(offsets: Seq[Option[Offset]], metadata: Option[String] = None) { /** * Unpacks an offset into [[StreamProgress]] by associating each offset with the order list of @@ -47,7 +47,13 @@ object OffsetSeq { * Returns a [[OffsetSeq]] with a variable sequence of offsets. * `nulls` in the sequence are converted to `None`s. */ - def fill(offsets: Offset*): OffsetSeq = { - OffsetSeq(offsets.map(Option(_))) + def fill(offsets: Offset*): OffsetSeq = OffsetSeq.fill(None, offsets: _*) + + /** + * Returns a [[OffsetSeq]] with metadata and a variable sequence of offsets. + * `nulls` in the sequence are converted to `None`s. + */ + def fill(metadata: Option[String], offsets: Offset*): OffsetSeq = { + OffsetSeq(offsets.map(Option(_)), metadata) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala index d1c9d95be9fdb..cc25b4474ba2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala @@ -33,12 +33,13 @@ import org.apache.spark.sql.SparkSession * by a newline character. If a source offset is missing, then * that line will contain a string value defined in the * SERIALIZED_VOID_OFFSET variable in [[OffsetSeqLog]] companion object. - * For instance, when dealine wiht [[LongOffset]] types: - * v1 // version 1 - * {0} // LongOffset 0 - * {3} // LongOffset 3 - * - // No offset for this source i.e., an invalid JSON string - * {2} // LongOffset 2 + * For instance, when dealing with [[LongOffset]] types: + * v1 // version 1 + * metadata + * {0} // LongOffset 0 + * {3} // LongOffset 3 + * - // No offset for this source i.e., an invalid JSON string + * {2} // LongOffset 2 * ... */ class OffsetSeqLog(sparkSession: SparkSession, path: String) @@ -58,13 +59,25 @@ class OffsetSeqLog(sparkSession: SparkSession, path: String) if (version != OffsetSeqLog.VERSION) { throw new IllegalStateException(s"Unknown log version: ${version}") } - OffsetSeq.fill(lines.map(parseOffset).toArray: _*) + + // read metadata + val metadata = lines.next().trim match { + case "" => None + case md => Some(md) + } + OffsetSeq.fill(metadata, lines.map(parseOffset).toArray: _*) } - override protected def serialize(metadata: OffsetSeq, out: OutputStream): Unit = { + override protected def serialize(offsetSeq: OffsetSeq, out: OutputStream): Unit = { // called inside a try-finally where the underlying stream is closed in the caller out.write(OffsetSeqLog.VERSION.getBytes(UTF_8)) - metadata.offsets.map(_.map(_.json)).foreach { offset => + + // write metadata + out.write('\n') + out.write(offsetSeq.metadata.getOrElse("").getBytes(UTF_8)) + + // write offsets, one per line + offsetSeq.offsets.map(_.map(_.json)).foreach { offset => out.write('\n') offset match { case Some(json: String) => out.write(json.getBytes(UTF_8)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala new file mode 100644 index 0000000000000..2cd2157b293cb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.SparkFunSuite + +class CompactibleFileStreamLogSuite extends SparkFunSuite { + + import CompactibleFileStreamLog._ + + test("deriveCompactInterval") { + // latestCompactBatchId(4) + 1 <= default(5) + // then use latestestCompactBatchId + 1 === 5 + assert(5 === deriveCompactInterval(5, 4)) + // First divisor of 10 greater than 4 === 5 + assert(5 === deriveCompactInterval(4, 9)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index b365af76c3795..a099153d2e58e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.streaming import java.io.File +import scala.collection.mutable + import org.scalatest.PrivateMethodTester import org.scalatest.time.SpanSugar._ @@ -896,32 +898,38 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } - test("compacat metadata log") { + test("compact interval metadata log") { val _sources = PrivateMethod[Seq[Source]]('sources) val _metadataLog = PrivateMethod[FileStreamSourceLog]('metadataLog) - def verify(execution: StreamExecution) - (batchId: Long, expectedBatches: Int): Boolean = { + def verify( + execution: StreamExecution, + batchId: Long, + expectedBatches: Int, + expectedCompactInterval: Int): Boolean = { import CompactibleFileStreamLog._ val fileSource = (execution invokePrivate _sources()).head.asInstanceOf[FileStreamSource] val metadataLog = fileSource invokePrivate _metadataLog() - if (isCompactionBatch(batchId, 2)) { + if (isCompactionBatch(batchId, expectedCompactInterval)) { val path = metadataLog.batchIdToPath(batchId) // Assert path name should be ended with compact suffix. - assert(path.getName.endsWith(COMPACT_FILE_SUFFIX)) + assert(path.getName.endsWith(COMPACT_FILE_SUFFIX), + "path does not end with compact file suffix") // Compacted batch should include all entries from start. val entries = metadataLog.get(batchId) - assert(entries.isDefined) - assert(entries.get.length === metadataLog.allFiles().length) - assert(metadataLog.get(None, Some(batchId)).flatMap(_._2).length === entries.get.length) + assert(entries.isDefined, "Entries not defined") + assert(entries.get.length === metadataLog.allFiles().length, "clean up check") + assert(metadataLog.get(None, Some(batchId)).flatMap(_._2).length === + entries.get.length, "Length check") } assert(metadataLog.allFiles().sortBy(_.batchId) === - metadataLog.get(None, Some(batchId)).flatMap(_._2).sortBy(_.batchId)) + metadataLog.get(None, Some(batchId)).flatMap(_._2).sortBy(_.batchId), + "Batch id mismatch") metadataLog.get(None, Some(batchId)).flatMap(_._2).length === expectedBatches } @@ -932,26 +940,27 @@ class FileStreamSourceSuite extends FileStreamSourceTest { ) { val fileStream = createFileStream("text", src.getCanonicalPath) val filtered = fileStream.filter($"value" contains "keep") + val updateConf = Map(SQLConf.FILE_SOURCE_LOG_COMPACT_INTERVAL.key -> "5") testStream(filtered)( AddTextFileData("drop1\nkeep2\nkeep3", src, tmp), CheckAnswer("keep2", "keep3"), - AssertOnQuery(verify(_)(0L, 1)), + AssertOnQuery(verify(_, 0L, 1, 2)), AddTextFileData("drop4\nkeep5\nkeep6", src, tmp), CheckAnswer("keep2", "keep3", "keep5", "keep6"), - AssertOnQuery(verify(_)(1L, 2)), + AssertOnQuery(verify(_, 1L, 2, 2)), AddTextFileData("drop7\nkeep8\nkeep9", src, tmp), CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9"), - AssertOnQuery(verify(_)(2L, 3)), + AssertOnQuery(verify(_, 2L, 3, 2)), StopStream, - StartStream(), - AssertOnQuery(verify(_)(2L, 3)), + StartStream(additionalConfs = updateConf), + AssertOnQuery(verify(_, 2L, 3, 2)), AddTextFileData("drop10\nkeep11", src, tmp), CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9", "keep11"), - AssertOnQuery(verify(_)(3L, 4)), + AssertOnQuery(verify(_, 3L, 4, 2)), AddTextFileData("drop12\nkeep13", src, tmp), CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9", "keep11", "keep13"), - AssertOnQuery(verify(_)(4L, 5)) + AssertOnQuery(verify(_, 4L, 5, 2)) ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 742833065144d..a6b2d4b9ab4c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -161,7 +161,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { /** Starts the stream, resuming if data has already been processed. It must not be running. */ case class StartStream( trigger: Trigger = ProcessingTime(0), - triggerClock: Clock = new SystemClock) + triggerClock: Clock = new SystemClock, + additionalConfs: Map[String, String] = Map.empty) extends StreamAction /** Advance the trigger clock's time manually. */ @@ -240,6 +241,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { var lastStream: StreamExecution = null val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for val sink = new MemorySink(stream.schema, outputMode) + val resetConfValues = mutable.Map[String, Option[String]]() @volatile var streamDeathCause: Throwable = null @@ -330,7 +332,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { startedTest.foreach { action => logInfo(s"Processing test stream action: $action") action match { - case StartStream(trigger, triggerClock) => + case StartStream(trigger, triggerClock, additionalConfs) => verify(currentStream == null, "stream already running") verify(triggerClock.isInstanceOf[SystemClock] || triggerClock.isInstanceOf[StreamManualClock], @@ -338,6 +340,14 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { if (triggerClock.isInstanceOf[StreamManualClock]) { manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis() } + + additionalConfs.foreach(pair => { + val value = + if (spark.conf.contains(pair._1)) Some(spark.conf.get(pair._1)) else None + resetConfValues(pair._1) = value + spark.conf.set(pair._1, pair._2) + }) + lastStream = currentStream currentStream = spark @@ -519,6 +529,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { currentStream.stop() } spark.streams.removeListener(statusCollector) + + // Rollback prev configuration values + resetConfValues.foreach { + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) + } } } From ec622eb7e1ffd0775c9ca4683d1032ca8d41654a Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Fri, 18 Nov 2016 11:19:49 -0800 Subject: [PATCH 142/534] [SPARK-18457][SQL] ORC and other columnar formats using HiveShim read all columns when doing a simple count ## What changes were proposed in this pull request? When reading zero columns (e.g., count(*)) from ORC or any other format that uses HiveShim, actually set the read column list to empty for Hive to use. ## How was this patch tested? Query correctness is handled by existing unit tests. I'm happy to add more if anyone can point out some case that is not covered. Reduction in data read can be verified in the UI when built with a recent version of Hadoop say: ``` build/mvn -Pyarn -Phadoop-2.7 -Dhadoop.version=2.7.0 -Phive -DskipTests clean package ``` However the default Hadoop 2.2 that is used for unit tests does not report actual bytes read and instead just full file sizes (see FileScanRDD.scala line 80). Therefore I don't think there is a good way to add a unit test for this. I tested with the following setup using above build options ``` case class OrcData(intField: Long, stringField: String) spark.range(1,1000000).map(i => OrcData(i, s"part-$i")).toDF().write.format("orc").save("orc_test") sql( s"""CREATE EXTERNAL TABLE orc_test( | intField LONG, | stringField STRING |) |STORED AS ORC |LOCATION '${System.getProperty("user.dir") + "/orc_test"}' """.stripMargin) ``` ## Results query | Spark 2.0.2 | this PR ---|---|--- `sql("select count(*) from orc_test").collect`|4.4 MB|199.4 KB `sql("select intField from orc_test").collect`|743.4 KB|743.4 KB `sql("select * from orc_test").collect`|4.4 MB|4.4 MB Author: Andrew Ray Closes #15898 from aray/sql-orc-no-col. (cherry picked from commit 795e9fc9213cb9941ae131aadcafddb94bde5f74) Signed-off-by: Reynold Xin --- .../org/apache/spark/sql/hive/HiveShim.scala | 6 ++--- .../spark/sql/hive/orc/OrcQuerySuite.scala | 25 ++++++++++++++++++- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala index 0d2a765a388aa..9e9894803ce25 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -69,13 +69,13 @@ private[hive] object HiveShim { } /* - * Cannot use ColumnProjectionUtils.appendReadColumns directly, if ids is null or empty + * Cannot use ColumnProjectionUtils.appendReadColumns directly, if ids is null */ def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { - if (ids != null && ids.nonEmpty) { + if (ids != null) { ColumnProjectionUtils.appendReadColumns(conf, ids.asJava) } - if (names != null && names.nonEmpty) { + if (names != null) { appendReadColumnNames(conf, names) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index ecb5972984523..a628977af2f4e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -20,11 +20,13 @@ package org.apache.spark.sql.hive.orc import java.nio.charset.StandardCharsets import java.sql.Timestamp +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.ql.io.orc.{OrcStruct, SparkOrcNewRecordReader} import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.datasources.{LogicalRelation, RecordReaderIterator} import org.apache.spark.sql.hive.{HiveUtils, MetastoreRelation} import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ @@ -577,4 +579,25 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { assert(spark.table(tableName).schema == schema.copy(fields = expectedFields)) } } + + test("Empty schema does not read data from ORC file") { + val data = Seq((1, 1), (2, 2)) + withOrcFile(data) { path => + val requestedSchema = StructType(Nil) + val conf = new Configuration() + val physicalSchema = OrcFileOperator.readSchema(Seq(path), Some(conf)).get + OrcRelation.setRequiredColumns(conf, physicalSchema, requestedSchema) + val maybeOrcReader = OrcFileOperator.getFileReader(path, Some(conf)) + assert(maybeOrcReader.isDefined) + val orcRecordReader = new SparkOrcNewRecordReader( + maybeOrcReader.get, conf, 0, maybeOrcReader.get.getContentLength) + + val recordsIterator = new RecordReaderIterator[OrcStruct](orcRecordReader) + try { + assert(recordsIterator.next().toString == "{null, null}") + } finally { + recordsIterator.close() + } + } + } } From 6717981e4d76f0794a75c60586de4677c49659ad Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 18 Nov 2016 21:45:18 +0000 Subject: [PATCH 143/534] [SPARK-18422][CORE] Fix wholeTextFiles test to pass on Windows in JavaAPISuite ## What changes were proposed in this pull request? This PR fixes the test `wholeTextFiles` in `JavaAPISuite.java`. This is failed due to the different path format on Windows. For example, the path in `container` was ``` C:\projects\spark\target\tmp\1478967560189-0/part-00000 ``` whereas `new URI(res._1()).getPath()` was as below: ``` /C:/projects/spark/target/tmp/1478967560189-0/part-00000 ``` ## How was this patch tested? Tests in `JavaAPISuite.java`. Tested via AppVeyor. **Before** Build: https://ci.appveyor.com/project/spark-test/spark/build/63-JavaAPISuite-1 Diff: https://github.com/apache/spark/compare/master...spark-test:JavaAPISuite-1 ``` [info] Test org.apache.spark.JavaAPISuite.wholeTextFiles started [error] Test org.apache.spark.JavaAPISuite.wholeTextFiles failed: java.lang.AssertionError: expected: but was:, took 0.578 sec [error] at org.apache.spark.JavaAPISuite.wholeTextFiles(JavaAPISuite.java:1089) ... ``` **After** Build started: [CORE] `org.apache.spark.JavaAPISuite` [![PR-15866](https://ci.appveyor.com/api/projects/status/github/spark-test/spark?branch=198DDA52-F201-4D2B-BE2F-244E0C1725B2&svg=true)](https://ci.appveyor.com/project/spark-test/spark/branch/198DDA52-F201-4D2B-BE2F-244E0C1725B2) Diff: https://github.com/apache/spark/compare/master...spark-test:198DDA52-F201-4D2B-BE2F-244E0C1725B2 ``` [info] Test org.apache.spark.JavaAPISuite.wholeTextFiles started ... ``` Author: hyukjinkwon Closes #15866 from HyukjinKwon/SPARK-18422. (cherry picked from commit 40d59ff5eaac6df237fe3d50186695c3806b268c) Signed-off-by: Sean Owen --- .../java/org/apache/spark/JavaAPISuite.java | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 533025ba83e72..7bebe0612f9a8 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -20,7 +20,6 @@ import java.io.*; import java.nio.channels.FileChannel; import java.nio.ByteBuffer; -import java.net.URI; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; @@ -46,6 +45,7 @@ import com.google.common.collect.Lists; import com.google.common.base.Throwables; import com.google.common.io.Files; +import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.io.compress.DefaultCodec; @@ -1075,18 +1075,23 @@ public void wholeTextFiles() throws Exception { byte[] content2 = "spark is also easy to use.\n".getBytes(StandardCharsets.UTF_8); String tempDirName = tempDir.getAbsolutePath(); - Files.write(content1, new File(tempDirName + "/part-00000")); - Files.write(content2, new File(tempDirName + "/part-00001")); + String path1 = new Path(tempDirName, "part-00000").toUri().getPath(); + String path2 = new Path(tempDirName, "part-00001").toUri().getPath(); + + Files.write(content1, new File(path1)); + Files.write(content2, new File(path2)); Map container = new HashMap<>(); - container.put(tempDirName+"/part-00000", new Text(content1).toString()); - container.put(tempDirName+"/part-00001", new Text(content2).toString()); + container.put(path1, new Text(content1).toString()); + container.put(path2, new Text(content2).toString()); JavaPairRDD readRDD = sc.wholeTextFiles(tempDirName, 3); List> result = readRDD.collect(); for (Tuple2 res : result) { - assertEquals(res._2(), container.get(new URI(res._1()).getPath())); + // Note that the paths from `wholeTextFiles` are in URI format on Windows, + // for example, file:/C:/a/b/c. + assertEquals(res._2(), container.get(new Path(res._1()).toUri().getPath())); } } From 136f687c6282c328c2ae121fc3d45207550d184b Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 18 Nov 2016 16:13:02 -0800 Subject: [PATCH 144/534] [SPARK-18477][SS] Enable interrupts for HDFS in HDFSMetadataLog ## What changes were proposed in this pull request? HDFS `write` may just hang until timeout if some network error happens. It's better to enable interrupts to allow stopping the query fast on HDFS. This PR just changes the logic to only disable interrupts for local file system, as HADOOP-10622 only happens for local file system. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #15911 from zsxwing/interrupt-on-dfs. (cherry picked from commit e5f5c29e021d504284fe5ad1a77dcd5a992ac10a) Signed-off-by: Tathagata Das --- .../execution/streaming/HDFSMetadataLog.scala | 56 ++++++++++++++----- 1 file changed, 41 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 080729b2ca8d6..d95ec7f67feb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -105,25 +105,34 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: /** * Store the metadata for the specified batchId and return `true` if successful. If the batchId's * metadata has already been stored, this method will return `false`. - * - * Note that this method must be called on a [[org.apache.spark.util.UninterruptibleThread]] - * so that interrupts can be disabled while writing the batch file. This is because there is a - * potential dead-lock in Hadoop "Shell.runCommand" before 2.5.0 (HADOOP-10622). If the thread - * running "Shell.runCommand" is interrupted, then the thread can get deadlocked. In our - * case, `writeBatch` creates a file using HDFS API and calls "Shell.runCommand" to set the - * file permissions, and can get deadlocked if the stream execution thread is stopped by - * interrupt. Hence, we make sure that this method is called on [[UninterruptibleThread]] which - * allows us to disable interrupts here. Also see SPARK-14131. */ override def add(batchId: Long, metadata: T): Boolean = { get(batchId).map(_ => false).getOrElse { // Only write metadata when the batch has not yet been written - Thread.currentThread match { - case ut: UninterruptibleThread => - ut.runUninterruptibly { writeBatch(batchId, metadata, serialize) } - case _ => - throw new IllegalStateException( - "HDFSMetadataLog.add() must be executed on a o.a.spark.util.UninterruptibleThread") + if (fileManager.isLocalFileSystem) { + Thread.currentThread match { + case ut: UninterruptibleThread => + // When using a local file system, "writeBatch" must be called on a + // [[org.apache.spark.util.UninterruptibleThread]] so that interrupts can be disabled + // while writing the batch file. This is because there is a potential dead-lock in + // Hadoop "Shell.runCommand" before 2.5.0 (HADOOP-10622). If the thread running + // "Shell.runCommand" is interrupted, then the thread can get deadlocked. In our case, + // `writeBatch` creates a file using HDFS API and will call "Shell.runCommand" to set + // the file permission if using the local file system, and can get deadlocked if the + // stream execution thread is stopped by interrupt. Hence, we make sure that + // "writeBatch" is called on [[UninterruptibleThread]] which allows us to disable + // interrupts here. Also see SPARK-14131. + ut.runUninterruptibly { writeBatch(batchId, metadata, serialize) } + case _ => + throw new IllegalStateException( + "HDFSMetadataLog.add() on a local file system must be executed on " + + "a o.a.spark.util.UninterruptibleThread") + } + } else { + // For a distributed file system, such as HDFS or S3, if the network is broken, write + // operations may just hang until timeout. We should enable interrupts to allow stopping + // the query fast. + writeBatch(batchId, metadata, serialize) } true } @@ -298,6 +307,9 @@ object HDFSMetadataLog { /** Recursively delete a path if it exists. Should not throw exception if file doesn't exist. */ def delete(path: Path): Unit + + /** Whether the file systme is a local FS. */ + def isLocalFileSystem: Boolean } /** @@ -342,6 +354,13 @@ object HDFSMetadataLog { // ignore if file has already been deleted } } + + override def isLocalFileSystem: Boolean = fc.getDefaultFileSystem match { + case _: local.LocalFs | _: local.RawLocalFs => + // LocalFs = RawLocalFs + ChecksumFs + true + case _ => false + } } /** @@ -398,5 +417,12 @@ object HDFSMetadataLog { // ignore if file has already been deleted } } + + override def isLocalFileSystem: Boolean = fs match { + case _: LocalFileSystem | _: RawLocalFileSystem => + // LocalFileSystem = RawLocalFileSystem + ChecksumFileSystem + true + case _ => false + } } } From 4b1df0e89badd9bb175673aefc96d3f9358e976d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 18 Nov 2016 16:34:11 -0800 Subject: [PATCH 145/534] [SPARK-18505][SQL] Simplify AnalyzeColumnCommand ## What changes were proposed in this pull request? I'm spending more time at the design & code level for cost-based optimizer now, and have found a number of issues related to maintainability and compatibility that I will like to address. This is a small pull request to clean up AnalyzeColumnCommand: 1. Removed warning on duplicated columns. Warnings in log messages are useless since most users that run SQL don't see them. 2. Removed the nested updateStats function, by just inlining the function. 3. Renamed a few functions to better reflect what they do. 4. Removed the factory apply method for ColumnStatStruct. It is a bad pattern to use a apply method that returns an instantiation of a class that is not of the same type (ColumnStatStruct.apply used to return CreateNamedStruct). 5. Renamed ColumnStatStruct to just AnalyzeColumnCommand. 6. Added more documentation explaining some of the non-obvious return types and code blocks. In follow-up pull requests, I'd like to address the following: 1. Get rid of the Map[String, ColumnStat] map, since internally we should be using Attribute to reference columns, rather than strings. 2. Decouple the fields exposed by ColumnStat and internals of Spark SQL's execution path. Currently the two are coupled because ColumnStat takes in an InternalRow. 3. Correctness: Remove code path that stores statistics in the catalog using the base64 encoding of the UnsafeRow format, which is not stable across Spark versions. 4. Clearly document the data representation stored in the catalog for statistics. ## How was this patch tested? Affected test cases have been updated. Author: Reynold Xin Closes #15933 from rxin/SPARK-18505. (cherry picked from commit 6f7ff75091154fed7649ea6d79e887aad9fbde6a) Signed-off-by: Reynold Xin --- .../command/AnalyzeColumnCommand.scala | 115 ++++++++++-------- .../spark/sql/StatisticsColumnSuite.scala | 2 +- .../org/apache/spark/sql/StatisticsTest.scala | 7 +- .../spark/sql/hive/HiveExternalCatalog.scala | 4 +- .../sql/hive/client/HiveClientImpl.scala | 2 +- 5 files changed, 74 insertions(+), 56 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 6141fab4aff0d..7fc57d09e9243 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.execution.command -import scala.collection.mutable - +import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases @@ -44,13 +43,16 @@ case class AnalyzeColumnCommand( val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentWithDB)) - relation match { + // Compute total size + val (catalogTable: CatalogTable, sizeInBytes: Long) = relation match { case catalogRel: CatalogRelation => - updateStats(catalogRel.catalogTable, + // This is a Hive serde format table + (catalogRel.catalogTable, AnalyzeTableCommand.calculateTotalSize(sessionState, catalogRel.catalogTable)) case logicalRel: LogicalRelation if logicalRel.catalogTable.isDefined => - updateStats(logicalRel.catalogTable.get, + // This is a data source format table + (logicalRel.catalogTable.get, AnalyzeTableCommand.calculateTotalSize(sessionState, logicalRel.catalogTable.get)) case otherRelation => @@ -58,45 +60,45 @@ case class AnalyzeColumnCommand( s"${otherRelation.nodeName}.") } - def updateStats(catalogTable: CatalogTable, newTotalSize: Long): Unit = { - val (rowCount, columnStats) = computeColStats(sparkSession, relation) - // We also update table-level stats in order to keep them consistent with column-level stats. - val statistics = Statistics( - sizeInBytes = newTotalSize, - rowCount = Some(rowCount), - // Newly computed column stats should override the existing ones. - colStats = catalogTable.stats.map(_.colStats).getOrElse(Map()) ++ columnStats) - sessionState.catalog.alterTable(catalogTable.copy(stats = Some(statistics))) - // Refresh the cached data source table in the catalog. - sessionState.catalog.refreshTable(tableIdentWithDB) - } + // Compute stats for each column + val (rowCount, newColStats) = + AnalyzeColumnCommand.computeColStats(sparkSession, relation, columnNames) + + // We also update table-level stats in order to keep them consistent with column-level stats. + val statistics = Statistics( + sizeInBytes = sizeInBytes, + rowCount = Some(rowCount), + // Newly computed column stats should override the existing ones. + colStats = catalogTable.stats.map(_.colStats).getOrElse(Map.empty) ++ newColStats) + + sessionState.catalog.alterTable(catalogTable.copy(stats = Some(statistics))) + + // Refresh the cached data source table in the catalog. + sessionState.catalog.refreshTable(tableIdentWithDB) Seq.empty[Row] } +} +object AnalyzeColumnCommand extends Logging { + + /** + * Compute stats for the given columns. + * @return (row count, map from column name to ColumnStats) + * + * This is visible for testing. + */ def computeColStats( sparkSession: SparkSession, - relation: LogicalPlan): (Long, Map[String, ColumnStat]) = { + relation: LogicalPlan, + columnNames: Seq[String]): (Long, Map[String, ColumnStat]) = { - // check correctness of column names - val attributesToAnalyze = mutable.MutableList[Attribute]() - val duplicatedColumns = mutable.MutableList[String]() + // Resolve the column names and dedup using AttributeSet val resolver = sparkSession.sessionState.conf.resolver - columnNames.foreach { col => + val attributesToAnalyze = AttributeSet(columnNames.map { col => val exprOption = relation.output.find(attr => resolver(attr.name, col)) - val expr = exprOption.getOrElse(throw new AnalysisException(s"Invalid column name: $col.")) - // do deduplication - if (!attributesToAnalyze.contains(expr)) { - attributesToAnalyze += expr - } else { - duplicatedColumns += col - } - } - if (duplicatedColumns.nonEmpty) { - logWarning("Duplicate column names were deduplicated in `ANALYZE TABLE` statement. " + - s"Input columns: ${columnNames.mkString("(", ", ", ")")}. " + - s"Duplicate columns: ${duplicatedColumns.mkString("(", ", ", ")")}.") - } + exprOption.getOrElse(throw new AnalysisException(s"Invalid column name: $col.")) + }).toSeq // Collect statistics per column. // The first element in the result will be the overall row count, the following elements @@ -104,22 +106,21 @@ case class AnalyzeColumnCommand( // The layout of each struct follows the layout of the ColumnStats. val ndvMaxErr = sparkSession.sessionState.conf.ndvMaxError val expressions = Count(Literal(1)).toAggregateExpression() +: - attributesToAnalyze.map(ColumnStatStruct(_, ndvMaxErr)) + attributesToAnalyze.map(AnalyzeColumnCommand.createColumnStatStruct(_, ndvMaxErr)) val namedExpressions = expressions.map(e => Alias(e, e.toString)()) val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)) .queryExecution.toRdd.collect().head // unwrap the result + // TODO: Get rid of numFields by using the public Dataset API. val rowCount = statsRow.getLong(0) val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) => - val numFields = ColumnStatStruct.numStatFields(expr.dataType) + val numFields = AnalyzeColumnCommand.numStatFields(expr.dataType) (expr.name, ColumnStat(statsRow.getStruct(i + 1, numFields))) }.toMap (rowCount, columnStats) } -} -object ColumnStatStruct { private val zero = Literal(0, LongType) private val one = Literal(1, LongType) @@ -137,7 +138,11 @@ object ColumnStatStruct { private def numTrues(e: Expression): Expression = Sum(If(e, one, zero)) private def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero)) - private def getStruct(exprs: Seq[Expression]): CreateNamedStruct = { + /** + * Creates a struct that groups the sequence of expressions together. This is used to create + * one top level struct per column. + */ + private def createStruct(exprs: Seq[Expression]): CreateNamedStruct = { CreateStruct(exprs.map { expr: Expression => expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() @@ -161,6 +166,7 @@ object ColumnStatStruct { Seq(numNulls(e), numTrues(e), numFalses(e)) } + // TODO(rxin): Get rid of this function. def numStatFields(dataType: DataType): Int = { dataType match { case BinaryType | BooleanType => 3 @@ -168,14 +174,25 @@ object ColumnStatStruct { } } - def apply(attr: Attribute, relativeSD: Double): CreateNamedStruct = attr.dataType match { - // Use aggregate functions to compute statistics we need. - case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(attr, relativeSD)) - case StringType => getStruct(stringColumnStat(attr, relativeSD)) - case BinaryType => getStruct(binaryColumnStat(attr)) - case BooleanType => getStruct(booleanColumnStat(attr)) - case otherType => - throw new AnalysisException("Analyzing columns is not supported for column " + - s"${attr.name} of data type: ${attr.dataType}.") + /** + * Creates a struct expression that contains the statistics to collect for a column. + * + * @param attr column to collect statistics + * @param relativeSD relative error for approximate number of distinct values. + */ + def createColumnStatStruct(attr: Attribute, relativeSD: Double): CreateNamedStruct = { + attr.dataType match { + case _: NumericType | TimestampType | DateType => + createStruct(numericColumnStat(attr, relativeSD)) + case StringType => + createStruct(stringColumnStat(attr, relativeSD)) + case BinaryType => + createStruct(binaryColumnStat(attr)) + case BooleanType => + createStruct(booleanColumnStat(attr)) + case otherType => + throw new AnalysisException("Analyzing columns is not supported for column " + + s"${attr.name} of data type: ${attr.dataType}.") + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala index f1a201abd8da6..e866ac2cb3b34 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala @@ -79,7 +79,7 @@ class StatisticsColumnSuite extends StatisticsTest { val tableIdent = TableIdentifier(table, Some("default")) val relation = spark.sessionState.catalog.lookupRelation(tableIdent) val (_, columnStats) = - AnalyzeColumnCommand(tableIdent, columnsToAnalyze).computeColStats(spark, relation) + AnalyzeColumnCommand.computeColStats(spark, relation, columnsToAnalyze) assert(columnStats.contains(colName1)) assert(columnStats.contains(colName2)) // check deduplication diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala index 5134ac0e7e5b3..915ee0d31bca2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} -import org.apache.spark.sql.execution.command.{AnalyzeColumnCommand, ColumnStatStruct} +import org.apache.spark.sql.execution.command.AnalyzeColumnCommand import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ + trait StatisticsTest extends QueryTest with SharedSQLContext { def checkColStats( @@ -36,7 +37,7 @@ trait StatisticsTest extends QueryTest with SharedSQLContext { val tableIdent = TableIdentifier(table, Some("default")) val relation = spark.sessionState.catalog.lookupRelation(tableIdent) val (_, columnStats) = - AnalyzeColumnCommand(tableIdent, columns.map(_.name)).computeColStats(spark, relation) + AnalyzeColumnCommand.computeColStats(spark, relation, columns.map(_.name)) expectedColStatsSeq.foreach { case (field, expectedColStat) => assert(columnStats.contains(field.name)) val colStat = columnStats(field.name) @@ -48,7 +49,7 @@ trait StatisticsTest extends QueryTest with SharedSQLContext { // check if we get the same colStat after encoding and decoding val encodedCS = colStat.toString - val numFields = ColumnStatStruct.numStatFields(field.dataType) + val numFields = AnalyzeColumnCommand.numStatFields(field.dataType) val decodedCS = ColumnStat(numFields, encodedCS) StatisticsTest.checkColStat( dataType = field.dataType, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index cacffcf33c263..5dbb4024bbee0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.execution.command.{ColumnStatStruct, DDLUtils} +import org.apache.spark.sql.execution.command.{AnalyzeColumnCommand, DDLUtils} import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.sql.internal.StaticSQLConf._ @@ -634,7 +634,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat .map { case (k, v) => (k.drop(STATISTICS_COL_STATS_PREFIX.length), v) } val colStats: Map[String, ColumnStat] = tableWithSchema.schema.collect { case f if colStatsProps.contains(f.name) => - val numFields = ColumnStatStruct.numStatFields(f.dataType) + val numFields = AnalyzeColumnCommand.numStatFields(f.dataType) (f.name, ColumnStat(numFields, colStatsProps(f.name))) }.toMap tableWithSchema.copy( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 2bf9a26b0b7fc..daae8523c6366 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -97,7 +97,7 @@ private[hive] class HiveClientImpl( } // Create an internal session state for this HiveClientImpl. - val state = { + val state: SessionState = { val original = Thread.currentThread().getContextClassLoader // Switch to the initClassLoader. Thread.currentThread().setContextClassLoader(initClassLoader) From b4bad04c5e20b06992100c1d44ece9d3a5b4f817 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 18 Nov 2016 16:34:38 -0800 Subject: [PATCH 146/534] [SPARK-18497][SS] Make ForeachSink support watermark ## What changes were proposed in this pull request? The issue in ForeachSink is the new created DataSet still uses the old QueryExecution. When `foreachPartition` is called, `QueryExecution.toString` will be called and then fail because it doesn't know how to plan EventTimeWatermark. This PR just replaces the QueryExecution with IncrementalExecution to fix the issue. ## How was this patch tested? `test("foreach with watermark")`. Author: Shixiong Zhu Closes #15934 from zsxwing/SPARK-18497. (cherry picked from commit 2a40de408b5eb47edba92f9fe92a42ed1e78bf98) Signed-off-by: Tathagata Das --- .../sql/execution/streaming/ForeachSink.scala | 16 ++++----- .../streaming/ForeachSinkSuite.scala | 35 +++++++++++++++++++ 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala index f5c550dd6ac3a..c93fcfb77cc93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala @@ -47,22 +47,22 @@ class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Seria // method supporting incremental planning. But in the long run, we should generally make newly // created Datasets use `IncrementalExecution` where necessary (which is SPARK-16264 tries to // resolve). - + val incrementalExecution = data.queryExecution.asInstanceOf[IncrementalExecution] val datasetWithIncrementalExecution = - new Dataset(data.sparkSession, data.logicalPlan, implicitly[Encoder[T]]) { + new Dataset(data.sparkSession, incrementalExecution, implicitly[Encoder[T]]) { override lazy val rdd: RDD[T] = { val objectType = exprEnc.deserializer.dataType val deserialized = CatalystSerde.deserialize[T](logicalPlan) // was originally: sparkSession.sessionState.executePlan(deserialized) ... - val incrementalExecution = new IncrementalExecution( + val newIncrementalExecution = new IncrementalExecution( this.sparkSession, deserialized, - data.queryExecution.asInstanceOf[IncrementalExecution].outputMode, - data.queryExecution.asInstanceOf[IncrementalExecution].checkpointLocation, - data.queryExecution.asInstanceOf[IncrementalExecution].currentBatchId, - data.queryExecution.asInstanceOf[IncrementalExecution].currentEventTimeWatermark) - incrementalExecution.toRdd.mapPartitions { rows => + incrementalExecution.outputMode, + incrementalExecution.checkpointLocation, + incrementalExecution.currentBatchId, + incrementalExecution.currentEventTimeWatermark) + newIncrementalExecution.toRdd.mapPartitions { rows => rows.map(_.get(0, objectType)) }.asInstanceOf[RDD[T]] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala index 9e059216110f2..ee6261036fdd0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkException import org.apache.spark.sql.ForeachWriter +import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest} import org.apache.spark.sql.test.SharedSQLContext @@ -169,6 +170,40 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf assert(errorEvent.error.get.getMessage === "error") } } + + test("foreach with watermark") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"count".as[Long]) + .map(_.toInt) + .repartition(1) + + val query = windowedAggregation + .writeStream + .outputMode(OutputMode.Complete) + .foreach(new TestForeachWriter()) + .start() + try { + inputData.addData(10, 11, 12) + query.processAllAvailable() + + val allEvents = ForeachSinkSuite.allEvents() + assert(allEvents.size === 1) + val expectedEvents = Seq( + ForeachSinkSuite.Open(partition = 0, version = 0), + ForeachSinkSuite.Process(value = 3), + ForeachSinkSuite.Close(None) + ) + assert(allEvents === Seq(expectedEvents)) + } finally { + query.stop() + } + } } /** A global object to collect events in the executor */ From 693401be24bfefe5305038b87888cdeb641d7642 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 19 Nov 2016 09:00:11 +0000 Subject: [PATCH 147/534] [SPARK-18448][CORE] SparkSession should implement java.lang.AutoCloseable like JavaSparkContext ## What changes were proposed in this pull request? Just adds `close()` + `Closeable` as a synonym for `stop()`. This makes it usable in Java in try-with-resources, as suggested by ash211 (`Closeable` extends `AutoCloseable` BTW) ## How was this patch tested? Existing tests Author: Sean Owen Closes #15932 from srowen/SPARK-18448. (cherry picked from commit db9fb9baacbf8640dd37a507b7450db727c7e6ea) Signed-off-by: Sean Owen --- .../main/scala/org/apache/spark/sql/SparkSession.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 3045eb69f427f..58b2ab3957173 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.beans.Introspector +import java.io.Closeable import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConverters._ @@ -72,7 +73,7 @@ import org.apache.spark.util.Utils class SparkSession private( @transient val sparkContext: SparkContext, @transient private val existingSharedState: Option[SharedState]) - extends Serializable with Logging { self => + extends Serializable with Closeable with Logging { self => private[sql] def this(sc: SparkContext) { this(sc, None) @@ -647,6 +648,13 @@ class SparkSession private( sparkContext.stop() } + /** + * Synonym for `stop()`. + * + * @since 2.2.0 + */ + override def close(): Unit = stop() + /** * Parses the data type in our internal string representation. The data type string should * have the same format as the one generated by `toString` in scala. From 4b396a6545ec0f1e31b0e211228f04bdc5660300 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 19 Nov 2016 11:24:15 +0000 Subject: [PATCH 148/534] [SPARK-18445][BUILD][DOCS] Fix the markdown for `Note:`/`NOTE:`/`Note that`/`'''Note:'''` across Scala/Java API documentation It seems in Scala/Java, - `Note:` - `NOTE:` - `Note that` - `'''Note:'''` - `note` This PR proposes to fix those to `note` to be consistent. **Before** - Scala ![2016-11-17 6 16 39](https://cloud.githubusercontent.com/assets/6477701/20383180/1a7aed8c-acf2-11e6-9611-5eaf6d52c2e0.png) - Java ![2016-11-17 6 14 41](https://cloud.githubusercontent.com/assets/6477701/20383096/c8ffc680-acf1-11e6-914a-33460bf1401d.png) **After** - Scala ![2016-11-17 6 16 44](https://cloud.githubusercontent.com/assets/6477701/20383167/09940490-acf2-11e6-937a-0d5e1dc2cadf.png) - Java ![2016-11-17 6 13 39](https://cloud.githubusercontent.com/assets/6477701/20383132/e7c2a57e-acf1-11e6-9c47-b849674d4d88.png) The notes were found via ```bash grep -r "NOTE: " . | \ # Note:|NOTE:|Note that|'''Note:''' grep -v "// NOTE: " | \ # starting with // does not appear in API documentation. grep -E '.scala|.java' | \ # java/scala files grep -v Suite | \ # exclude tests grep -v Test | \ # exclude tests grep -e 'org.apache.spark.api.java' \ # packages appear in API documenation -e 'org.apache.spark.api.java.function' \ # note that this is a regular expression. So actual matches were mostly `org/apache/spark/api/java/functions ...` -e 'org.apache.spark.api.r' \ ... ``` ```bash grep -r "Note that " . | \ # Note:|NOTE:|Note that|'''Note:''' grep -v "// Note that " | \ # starting with // does not appear in API documentation. grep -E '.scala|.java' | \ # java/scala files grep -v Suite | \ # exclude tests grep -v Test | \ # exclude tests grep -e 'org.apache.spark.api.java' \ # packages appear in API documenation -e 'org.apache.spark.api.java.function' \ -e 'org.apache.spark.api.r' \ ... ``` ```bash grep -r "Note: " . | \ # Note:|NOTE:|Note that|'''Note:''' grep -v "// Note: " | \ # starting with // does not appear in API documentation. grep -E '.scala|.java' | \ # java/scala files grep -v Suite | \ # exclude tests grep -v Test | \ # exclude tests grep -e 'org.apache.spark.api.java' \ # packages appear in API documenation -e 'org.apache.spark.api.java.function' \ -e 'org.apache.spark.api.r' \ ... ``` ```bash grep -r "'''Note:'''" . | \ # Note:|NOTE:|Note that|'''Note:''' grep -v "// '''Note:''' " | \ # starting with // does not appear in API documentation. grep -E '.scala|.java' | \ # java/scala files grep -v Suite | \ # exclude tests grep -v Test | \ # exclude tests grep -e 'org.apache.spark.api.java' \ # packages appear in API documenation -e 'org.apache.spark.api.java.function' \ -e 'org.apache.spark.api.r' \ ... ``` And then fixed one by one comparing with API documentation/access modifiers. After that, manually tested via `jekyll build`. Author: hyukjinkwon Closes #15889 from HyukjinKwon/SPARK-18437. (cherry picked from commit d5b1d5fc80153571c308130833d0c0774de62c92) Signed-off-by: Sean Owen --- .../org/apache/spark/ContextCleaner.scala | 2 +- .../scala/org/apache/spark/Partitioner.scala | 2 +- .../scala/org/apache/spark/SparkConf.scala | 6 +- .../scala/org/apache/spark/SparkContext.scala | 47 ++++++++------- .../apache/spark/api/java/JavaDoubleRDD.scala | 4 +- .../apache/spark/api/java/JavaPairRDD.scala | 26 ++++---- .../org/apache/spark/api/java/JavaRDD.scala | 12 ++-- .../apache/spark/api/java/JavaRDDLike.scala | 3 +- .../spark/api/java/JavaSparkContext.scala | 21 +++---- .../api/java/JavaSparkStatusTracker.scala | 2 +- .../apache/spark/io/CompressionCodec.scala | 23 ++++--- .../apache/spark/partial/BoundedDouble.scala | 2 +- .../org/apache/spark/rdd/CoGroupedRDD.scala | 8 +-- .../apache/spark/rdd/DoubleRDDFunctions.scala | 2 +- .../org/apache/spark/rdd/HadoopRDD.scala | 6 +- .../org/apache/spark/rdd/NewHadoopRDD.scala | 6 +- .../apache/spark/rdd/PairRDDFunctions.scala | 23 +++---- .../spark/rdd/PartitionPruningRDD.scala | 2 +- .../spark/rdd/PartitionwiseSampledRDD.scala | 2 +- .../main/scala/org/apache/spark/rdd/RDD.scala | 46 +++++++------- .../apache/spark/rdd/RDDCheckpointData.scala | 2 +- .../spark/rdd/ReliableCheckpointRDD.scala | 2 +- .../spark/rdd/SequenceFileRDDFunctions.scala | 5 +- .../apache/spark/rdd/ZippedWithIndexRDD.scala | 2 +- .../spark/scheduler/AccumulableInfo.scala | 10 ++-- .../spark/serializer/JavaSerializer.scala | 2 +- .../spark/serializer/KryoSerializer.scala | 2 +- .../apache/spark/serializer/Serializer.scala | 2 +- .../apache/spark/storage/StorageUtils.scala | 19 +++--- .../org/apache/spark/util/AccumulatorV2.scala | 5 +- .../spark/scheduler/DAGSchedulerSuite.scala | 2 +- docs/mllib-isotonic-regression.md | 2 +- docs/streaming-programming-guide.md | 2 +- .../spark/sql/kafka010/KafkaSource.scala | 2 +- .../spark/streaming/kafka/KafkaUtils.scala | 8 +-- .../streaming/kinesis/KinesisUtils.scala | 60 +++++++++---------- .../kinesis/KinesisBackedBlockRDDSuite.scala | 2 +- .../apache/spark/graphx/impl/GraphImpl.scala | 2 +- .../apache/spark/graphx/lib/PageRank.scala | 2 +- .../org/apache/spark/ml/linalg/Vectors.scala | 2 +- .../scala/org/apache/spark/ml/Model.scala | 2 +- .../DecisionTreeClassifier.scala | 6 +- .../ml/classification/GBTClassifier.scala | 6 +- .../classification/LogisticRegression.scala | 36 +++++------ .../spark/ml/clustering/GaussianMixture.scala | 6 +- .../spark/ml/feature/MinMaxScaler.scala | 3 +- .../spark/ml/feature/OneHotEncoder.scala | 3 +- .../org/apache/spark/ml/feature/PCA.scala | 5 +- .../spark/ml/feature/StopWordsRemover.scala | 5 +- .../spark/ml/feature/StringIndexer.scala | 6 +- .../org/apache/spark/ml/param/params.scala | 2 +- .../ml/regression/DecisionTreeRegressor.scala | 6 +- .../GeneralizedLinearRegression.scala | 4 +- .../ml/regression/LinearRegression.scala | 28 +++++---- .../ml/source/libsvm/LibSVMDataSource.scala | 2 +- .../ml/tree/impl/GradientBoostedTrees.scala | 4 +- .../org/apache/spark/ml/util/ReadWrite.scala | 2 +- .../classification/LogisticRegression.scala | 28 +++++---- .../spark/mllib/classification/SVM.scala | 20 ++++--- .../mllib/clustering/GaussianMixture.scala | 8 +-- .../spark/mllib/clustering/KMeans.scala | 8 ++- .../apache/spark/mllib/clustering/LDA.scala | 4 +- .../spark/mllib/clustering/LDAModel.scala | 2 +- .../spark/mllib/clustering/LDAOptimizer.scala | 6 +- .../mllib/evaluation/AreaUnderCurve.scala | 2 +- .../apache/spark/mllib/linalg/Vectors.scala | 6 +- .../linalg/distributed/BlockMatrix.scala | 2 +- .../linalg/distributed/IndexedRowMatrix.scala | 5 +- .../mllib/linalg/distributed/RowMatrix.scala | 21 ++++--- .../spark/mllib/optimization/Gradient.scala | 3 +- .../apache/spark/mllib/rdd/RDDFunctions.scala | 2 +- .../MatrixFactorizationModel.scala | 6 +- .../apache/spark/mllib/stat/Statistics.scala | 34 +++++------ .../spark/mllib/tree/DecisionTree.scala | 32 +++++----- .../apache/spark/mllib/tree/loss/Loss.scala | 12 ++-- .../mllib/tree/model/treeEnsembleModels.scala | 4 +- pom.xml | 7 +++ project/SparkBuild.scala | 3 +- python/pyspark/mllib/stat/KernelDensity.py | 2 +- python/pyspark/mllib/util.py | 2 +- python/pyspark/rdd.py | 4 +- python/pyspark/streaming/kafka.py | 4 +- .../scala/org/apache/spark/sql/Encoders.scala | 8 +-- .../sql/types/CalendarIntervalType.scala | 4 +- .../scala/org/apache/spark/sql/Column.scala | 2 +- .../spark/sql/DataFrameStatFunctions.scala | 3 +- .../apache/spark/sql/DataFrameWriter.scala | 2 +- .../scala/org/apache/spark/sql/Dataset.scala | 56 ++++++++--------- .../org/apache/spark/sql/SQLContext.scala | 7 ++- .../org/apache/spark/sql/SparkSession.scala | 9 +-- .../apache/spark/sql/UDFRegistration.scala | 3 +- .../execution/streaming/state/package.scala | 4 +- .../sql/expressions/UserDefinedFunction.scala | 8 ++- .../org/apache/spark/sql/functions.scala | 22 +++---- .../apache/spark/sql/jdbc/JdbcDialects.scala | 2 +- .../apache/spark/sql/sources/interfaces.scala | 10 ++-- .../sql/util/QueryExecutionListener.scala | 8 ++- .../columnar/InMemoryColumnarQuerySuite.scala | 2 +- .../spark/streaming/StreamingContext.scala | 18 +++--- .../streaming/api/java/JavaPairDStream.scala | 2 +- .../api/java/JavaStreamingContext.scala | 40 +++++++------ .../spark/streaming/dstream/DStream.scala | 4 +- .../dstream/MapWithStateDStream.scala | 2 +- .../WriteAheadLogBackedBlockRDDSuite.scala | 2 +- 104 files changed, 516 insertions(+), 435 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 5678d790e9e76..af913454fce69 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -139,7 +139,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { periodicGCService.shutdown() } - /** Register a RDD for cleanup when it is garbage collected. */ + /** Register an RDD for cleanup when it is garbage collected. */ def registerRDDForCleanup(rdd: RDD[_]): Unit = { registerForCleanup(rdd, CleanRDD(rdd.id)) } diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 93dfbc0e6ed65..f83f5278e8b8f 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -101,7 +101,7 @@ class HashPartitioner(partitions: Int) extends Partitioner { * A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly * equal ranges. The ranges are determined by sampling the content of the RDD passed in. * - * Note that the actual number of partitions created by the RangePartitioner might not be the same + * @note The actual number of partitions created by the RangePartitioner might not be the same * as the `partitions` parameter, in the case where the number of sampled records is less than * the value of `partitions`. */ diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index c9c342df82c97..04d657c09afd0 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -42,10 +42,10 @@ import org.apache.spark.util.Utils * All setter methods in this class support chaining. For example, you can write * `new SparkConf().setMaster("local").setAppName("My app")`. * - * Note that once a SparkConf object is passed to Spark, it is cloned and can no longer be modified - * by the user. Spark does not support modifying the configuration at runtime. - * * @param loadDefaults whether to also load values from Java system properties + * + * @note Once a SparkConf object is passed to Spark, it is cloned and can no longer be modified + * by the user. Spark does not support modifying the configuration at runtime. */ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Serializable { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 25a3d609a6b09..1261e3e735761 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -281,7 +281,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. * - * '''Note:''' As it will be reused in all Hadoop RDDs, it's better not to modify it unless you + * @note As it will be reused in all Hadoop RDDs, it's better not to modify it unless you * plan to set some global configurations for all Hadoop RDDs. */ def hadoopConfiguration: Configuration = _hadoopConfiguration @@ -700,7 +700,7 @@ class SparkContext(config: SparkConf) extends Logging { * Execute a block of code in a scope such that all new RDDs created in this body will * be part of the same scope. For more detail, see {{org.apache.spark.rdd.RDDOperationScope}}. * - * Note: Return statements are NOT allowed in the given body. + * @note Return statements are NOT allowed in the given body. */ private[spark] def withScope[U](body: => U): U = RDDOperationScope.withScope[U](this)(body) @@ -927,7 +927,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * Load data from a flat binary file, assuming the length of each record is constant. * - * '''Note:''' We ensure that the byte array for each record in the resulting RDD + * @note We ensure that the byte array for each record in the resulting RDD * has the provided record length. * * @param path Directory to the input data files, the path can be comma separated paths as the @@ -970,7 +970,7 @@ class SparkContext(config: SparkConf) extends Logging { * @param valueClass Class of the values * @param minPartitions Minimum number of Hadoop Splits to generate. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -995,7 +995,7 @@ class SparkContext(config: SparkConf) extends Logging { /** Get an RDD for a Hadoop file with an arbitrary InputFormat * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1034,7 +1034,7 @@ class SparkContext(config: SparkConf) extends Logging { * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path, minPartitions) * }}} * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1058,7 +1058,7 @@ class SparkContext(config: SparkConf) extends Logging { * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path) * }}} * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1084,7 +1084,7 @@ class SparkContext(config: SparkConf) extends Logging { * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat * and extra configuration options to pass to the input format. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1124,7 +1124,7 @@ class SparkContext(config: SparkConf) extends Logging { * @param kClass Class of the keys * @param vClass Class of the values * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1150,7 +1150,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * Get an RDD for a Hadoop SequenceFile with given key and value types. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1169,7 +1169,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * Get an RDD for a Hadoop SequenceFile with given key and value types. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1199,7 +1199,7 @@ class SparkContext(config: SparkConf) extends Logging { * for the appropriate type. In addition, we pass the converter a ClassTag of its type to * allow it to figure out the Writable class to use in the subclass case. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1330,16 +1330,18 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Register the given accumulator. Note that accumulators must be registered before use, or it - * will throw exception. + * Register the given accumulator. + * + * @note Accumulators must be registered before use, or it will throw exception. */ def register(acc: AccumulatorV2[_, _]): Unit = { acc.register(this) } /** - * Register the given accumulator with given name. Note that accumulators must be registered - * before use, or it will throw exception. + * Register the given accumulator with given name. + * + * @note Accumulators must be registered before use, or it will throw exception. */ def register(acc: AccumulatorV2[_, _], name: String): Unit = { acc.register(this, name = Some(name)) @@ -1550,7 +1552,7 @@ class SparkContext(config: SparkConf) extends Logging { * :: DeveloperApi :: * Request that the cluster manager kill the specified executors. * - * Note: This is an indication to the cluster manager that the application wishes to adjust + * @note This is an indication to the cluster manager that the application wishes to adjust * its resource usage downwards. If the application wishes to replace the executors it kills * through this method with new ones, it should follow up explicitly with a call to * {{SparkContext#requestExecutors}}. @@ -1572,7 +1574,7 @@ class SparkContext(config: SparkConf) extends Logging { * :: DeveloperApi :: * Request that the cluster manager kill the specified executor. * - * Note: This is an indication to the cluster manager that the application wishes to adjust + * @note This is an indication to the cluster manager that the application wishes to adjust * its resource usage downwards. If the application wishes to replace the executor it kills * through this method with a new one, it should follow up explicitly with a call to * {{SparkContext#requestExecutors}}. @@ -1590,7 +1592,7 @@ class SparkContext(config: SparkConf) extends Logging { * this request. This assumes the cluster manager will automatically and eventually * fulfill all missing application resource requests. * - * Note: The replace is by no means guaranteed; another application on the same cluster + * @note The replace is by no means guaranteed; another application on the same cluster * can steal the window of opportunity and acquire this application's resources in the * mean time. * @@ -1639,7 +1641,8 @@ class SparkContext(config: SparkConf) extends Logging { /** * Returns an immutable map of RDDs that have marked themselves as persistent via cache() call. - * Note that this does not necessarily mean the caching or computation was successful. + * + * @note This does not necessarily mean the caching or computation was successful. */ def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap @@ -2298,7 +2301,7 @@ object SparkContext extends Logging { * singleton object. Because we can only have one active SparkContext per JVM, * this is useful when applications may wish to share a SparkContext. * - * Note: This function cannot be used to create multiple SparkContext instances + * @note This function cannot be used to create multiple SparkContext instances * even if multiple contexts are allowed. */ def getOrCreate(config: SparkConf): SparkContext = { @@ -2323,7 +2326,7 @@ object SparkContext extends Logging { * * This method allows not passing a SparkConf (useful if just retrieving). * - * Note: This function cannot be used to create multiple SparkContext instances + * @note This function cannot be used to create multiple SparkContext instances * even if multiple contexts are allowed. */ def getOrCreate(): SparkContext = { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 0026fc9dad517..a32a4b28c1731 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -153,7 +153,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. */ def intersection(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.intersection(other.srdd)) @@ -256,7 +256,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) * e.g 1<=x<10 , 10<=x<20, 20<=x<50 * And on the input of 1 and 50 we would have a histogram of 1,0,0 * - * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched + * @note If your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched * from an O(log n) insertion to O(1) per element. (where n = # buckets) if you set evenBuckets * to true. * buckets must be sorted and not contain any duplicates. diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 1c95bc4bfcaaf..bff5a29bb60f1 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -206,7 +206,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. */ def intersection(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.intersection(other.rdd)) @@ -223,9 +223,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a - * "combined type" C. Note that V and C can be different -- for example, one might group an - * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three - * functions: + * "combined type" C. + * + * Users provide three functions: * * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) @@ -234,6 +234,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * In addition, users can control the partitioning of the output RDD, the serializer that is use * for the shuffle, and whether to perform map-side aggregation (if a mapper can produce multiple * items with the same key). + * + * @note V and C can be different -- for example, one might group an RDD of type (Int, Int) into + * an RDD of type (Int, List[Int]). */ def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], @@ -255,9 +258,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a - * "combined type" C. Note that V and C can be different -- for example, one might group an - * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three - * functions: + * "combined type" C. + * + * Users provide three functions: * * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) @@ -265,6 +268,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * * In addition, users can control the partitioning of the output RDD. This method automatically * uses map-side aggregation in shuffling the RDD. + * + * @note V and C can be different -- for example, one might group an RDD of type (Int, Int) into + * an RDD of type (Int, List[Int]). */ def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], @@ -398,7 +404,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Allows controlling the * partitioning of the resulting key-value pair RDD by passing a Partitioner. * - * Note: If you are grouping in order to perform an aggregation (such as a sum or average) over + * @note If you are grouping in order to perform an aggregation (such as a sum or average) over * each key, using [[JavaPairRDD.reduceByKey]] or [[JavaPairRDD.combineByKey]] * will provide much better performance. */ @@ -409,7 +415,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Hash-partitions the * resulting RDD with into `numPartitions` partitions. * - * Note: If you are grouping in order to perform an aggregation (such as a sum or average) over + * @note If you are grouping in order to perform an aggregation (such as a sum or average) over * each key, using [[JavaPairRDD.reduceByKey]] or [[JavaPairRDD.combineByKey]] * will provide much better performance. */ @@ -539,7 +545,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Hash-partitions the * resulting RDD with the existing partitioner/parallelism level. * - * Note: If you are grouping in order to perform an aggregation (such as a sum or average) over + * @note If you are grouping in order to perform an aggregation (such as a sum or average) over * each key, using [[JavaPairRDD.reduceByKey]] or [[JavaPairRDD.combineByKey]] * will provide much better performance. */ diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index d67cff64e6e46..ccd94f876e0b8 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -99,27 +99,29 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) /** * Return a sampled subset of this RDD with a random seed. - * Note: this is NOT guaranteed to provide exactly the fraction of the count - * of the given [[RDD]]. * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] * with replacement: expected number of times each element is chosen; fraction must be >= 0 + * + * @note This is NOT guaranteed to provide exactly the fraction of the count + * of the given [[RDD]]. */ def sample(withReplacement: Boolean, fraction: Double): JavaRDD[T] = sample(withReplacement, fraction, Utils.random.nextLong) /** * Return a sampled subset of this RDD, with a user-supplied seed. - * Note: this is NOT guaranteed to provide exactly the fraction of the count - * of the given [[RDD]]. * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] * with replacement: expected number of times each element is chosen; fraction must be >= 0 * @param seed seed for the random number generator + * + * @note This is NOT guaranteed to provide exactly the fraction of the count + * of the given [[RDD]]. */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaRDD[T] = wrapRDD(rdd.sample(withReplacement, fraction, seed)) @@ -157,7 +159,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. */ def intersection(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.intersection(other.rdd)) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index a37c52cbaf210..eda16d957cc58 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -47,7 +47,8 @@ private[spark] abstract class AbstractJavaRDDLike[T, This <: JavaRDDLike[T, This /** * Defines operations common to several Java RDD implementations. - * Note that this trait is not intended to be implemented by user code. + * + * @note This trait is not intended to be implemented by user code. */ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def wrapRDD(rdd: RDD[T]): This diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 4e50c2686dd53..38d347aeab8c6 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -298,7 +298,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get an RDD for a Hadoop SequenceFile with given key and value types. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -316,7 +316,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get an RDD for a Hadoop SequenceFile. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -366,7 +366,7 @@ class JavaSparkContext(val sc: SparkContext) * @param valueClass Class of the values * @param minPartitions Minimum number of Hadoop Splits to generate. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -396,7 +396,7 @@ class JavaSparkContext(val sc: SparkContext) * @param keyClass Class of the keys * @param valueClass Class of the values * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -416,7 +416,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get an RDD for a Hadoop file with an arbitrary InputFormat. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -437,7 +437,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get an RDD for a Hadoop file with an arbitrary InputFormat * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -458,7 +458,7 @@ class JavaSparkContext(val sc: SparkContext) * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat * and extra configuration options to pass to the input format. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -487,7 +487,7 @@ class JavaSparkContext(val sc: SparkContext) * @param kClass Class of the keys * @param vClass Class of the values * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -694,7 +694,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse. * - * '''Note:''' As it will be reused in all Hadoop RDDs, it's better not to modify it unless you + * @note As it will be reused in all Hadoop RDDs, it's better not to modify it unless you * plan to set some global configurations for all Hadoop RDDs. */ def hadoopConfiguration(): Configuration = { @@ -811,7 +811,8 @@ class JavaSparkContext(val sc: SparkContext) /** * Returns a Java map of JavaRDDs that have marked themselves as persistent via cache() call. - * Note that this does not necessarily mean the caching or computation was successful. + * + * @note This does not necessarily mean the caching or computation was successful. */ def getPersistentRDDs: JMap[java.lang.Integer, JavaRDD[_]] = { sc.getPersistentRDDs.mapValues(s => JavaRDD.fromRDD(s)) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala index 99ca3c77cced0..6aa290ecd7bb5 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala @@ -31,7 +31,7 @@ import org.apache.spark.{SparkContext, SparkJobInfo, SparkStageInfo} * will provide information for the last `spark.ui.retainedStages` stages and * `spark.ui.retainedJobs` jobs. * - * NOTE: this class's constructor should be considered private and may be subject to change. + * @note This class's constructor should be considered private and may be subject to change. */ class JavaSparkStatusTracker private[spark] (sc: SparkContext) { diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index ae014becef755..6ba79e506a648 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -32,9 +32,8 @@ import org.apache.spark.util.Utils * CompressionCodec allows the customization of choosing different compression implementations * to be used in block storage. * - * Note: The wire protocol for a codec is not guaranteed compatible across versions of Spark. - * This is intended for use as an internal compression utility within a single - * Spark application. + * @note The wire protocol for a codec is not guaranteed compatible across versions of Spark. + * This is intended for use as an internal compression utility within a single Spark application. */ @DeveloperApi trait CompressionCodec { @@ -103,9 +102,9 @@ private[spark] object CompressionCodec { * LZ4 implementation of [[org.apache.spark.io.CompressionCodec]]. * Block size can be configured by `spark.io.compression.lz4.blockSize`. * - * Note: The wire protocol for this codec is not guaranteed to be compatible across versions - * of Spark. This is intended for use as an internal compression utility within a single Spark - * application. + * @note The wire protocol for this codec is not guaranteed to be compatible across versions + * of Spark. This is intended for use as an internal compression utility within a single Spark + * application. */ @DeveloperApi class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec { @@ -123,9 +122,9 @@ class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec { * :: DeveloperApi :: * LZF implementation of [[org.apache.spark.io.CompressionCodec]]. * - * Note: The wire protocol for this codec is not guaranteed to be compatible across versions - * of Spark. This is intended for use as an internal compression utility within a single Spark - * application. + * @note The wire protocol for this codec is not guaranteed to be compatible across versions + * of Spark. This is intended for use as an internal compression utility within a single Spark + * application. */ @DeveloperApi class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { @@ -143,9 +142,9 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { * Snappy implementation of [[org.apache.spark.io.CompressionCodec]]. * Block size can be configured by `spark.io.compression.snappy.blockSize`. * - * Note: The wire protocol for this codec is not guaranteed to be compatible across versions - * of Spark. This is intended for use as an internal compression utility within a single Spark - * application. + * @note The wire protocol for this codec is not guaranteed to be compatible across versions + * of Spark. This is intended for use as an internal compression utility within a single Spark + * application. */ @DeveloperApi class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { diff --git a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala index ab6aba6fc7d6a..8f579c5a3033c 100644 --- a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala +++ b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala @@ -28,7 +28,7 @@ class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, v this.mean.hashCode ^ this.confidence.hashCode ^ this.low.hashCode ^ this.high.hashCode /** - * Note that consistent with Double, any NaN value will make equality false + * @note Consistent with Double, any NaN value will make equality false */ override def equals(that: Any): Boolean = that match { diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 2381f54ee3f06..a091f06b4ed7c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -66,14 +66,14 @@ private[spark] class CoGroupPartition( /** * :: DeveloperApi :: - * A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a + * An RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a * tuple with the list of values for that key. * - * Note: This is an internal API. We recommend users use RDD.cogroup(...) instead of - * instantiating this directly. - * * @param rdds parent RDDs. * @param part partitioner used to partition the shuffle output + * + * @note This is an internal API. We recommend users use RDD.cogroup(...) instead of + * instantiating this directly. */ @DeveloperApi class CoGroupedRDD[K: ClassTag]( diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index a05a770b40c57..f3ab324d59119 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -158,7 +158,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { * e.g 1<=x<10 , 10<=x<20, 20<=x<=50 * And on the input of 1 and 50 we would have a histogram of 1, 0, 1 * - * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched + * @note If your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched * from an O(log n) insertion to O(1) per element. (where n = # buckets) if you set evenBuckets * to true. * buckets must be sorted and not contain any duplicates. diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 36a2f5c87e372..86351b8c575e5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -84,9 +84,6 @@ private[spark] class HadoopPartition(rddId: Int, override val index: Int, s: Inp * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, * sources in HBase, or S3), using the older MapReduce API (`org.apache.hadoop.mapred`). * - * Note: Instantiating this class directly is not recommended, please use - * [[org.apache.spark.SparkContext.hadoopRDD()]] - * * @param sc The SparkContext to associate the RDD with. * @param broadcastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed * variable references an instance of JobConf, then that JobConf will be used for the Hadoop job. @@ -97,6 +94,9 @@ private[spark] class HadoopPartition(rddId: Int, override val index: Int, s: Inp * @param keyClass Class of the key associated with the inputFormatClass. * @param valueClass Class of the value associated with the inputFormatClass. * @param minPartitions Minimum number of HadoopRDD partitions (Hadoop Splits) to generate. + * + * @note Instantiating this class directly is not recommended, please use + * [[org.apache.spark.SparkContext.hadoopRDD()]] */ @DeveloperApi class HadoopRDD[K, V]( diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 488e777fea371..a5965f597038d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -57,13 +57,13 @@ private[spark] class NewHadoopPartition( * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, * sources in HBase, or S3), using the new MapReduce API (`org.apache.hadoop.mapreduce`). * - * Note: Instantiating this class directly is not recommended, please use - * [[org.apache.spark.SparkContext.newAPIHadoopRDD()]] - * * @param sc The SparkContext to associate the RDD with. * @param inputFormatClass Storage format of the data to be read. * @param keyClass Class of the key associated with the inputFormatClass. * @param valueClass Class of the value associated with the inputFormatClass. + * + * @note Instantiating this class directly is not recommended, please use + * [[org.apache.spark.SparkContext.newAPIHadoopRDD()]] */ @DeveloperApi class NewHadoopRDD[K, V]( diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 67baad1c51bca..9ed0f3d8086a5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -59,8 +59,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * :: Experimental :: * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined type" C - * Note that V and C can be different -- for example, one might group an RDD of type - * (Int, Int) into an RDD of type (Int, Seq[Int]). Users provide three functions: + * + * Users provide three functions: * * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) @@ -68,6 +68,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * * In addition, users can control the partitioning of the output RDD, and whether to perform * map-side aggregation (if a mapper can produce multiple items with the same key). + * + * @note V and C can be different -- for example, one might group an RDD of type + * (Int, Int) into an RDD of type (Int, Seq[Int]). */ @Experimental def combineByKeyWithClassTag[C]( @@ -363,7 +366,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) /** * Count the number of elements for each key, collecting the results to a local Map. * - * Note that this method should only be used if the resulting map is expected to be small, as + * @note This method should only be used if the resulting map is expected to be small, as * the whole thing is loaded into the driver's memory. * To handle very large results, consider using rdd.mapValues(_ => 1L).reduceByKey(_ + _), which * returns an RDD[T, Long] instead of a map. @@ -490,11 +493,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * The ordering of elements within each group is not guaranteed, and may even differ * each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an + * @note This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. * - * Note: As currently implemented, groupByKey must be able to hold all the key-value pairs for any + * @note As currently implemented, groupByKey must be able to hold all the key-value pairs for any * key in memory. If a key has too many values, it can result in an [[OutOfMemoryError]]. */ def groupByKey(partitioner: Partitioner): RDD[(K, Iterable[V])] = self.withScope { @@ -514,11 +517,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * resulting RDD with into `numPartitions` partitions. The ordering of elements within * each group is not guaranteed, and may even differ each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an + * @note This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. * - * Note: As currently implemented, groupByKey must be able to hold all the key-value pairs for any + * @note As currently implemented, groupByKey must be able to hold all the key-value pairs for any * key in memory. If a key has too many values, it can result in an [[OutOfMemoryError]]. */ def groupByKey(numPartitions: Int): RDD[(K, Iterable[V])] = self.withScope { @@ -635,7 +638,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * within each group is not guaranteed, and may even differ each time the resulting RDD is * evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an + * @note This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ @@ -1016,7 +1019,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class * supporting the key and value types K and V in this RDD. * - * Note that, we should make sure our tasks are idempotent when speculation is enabled, i.e. do + * @note We should make sure our tasks are idempotent when speculation is enabled, i.e. do * not use output committer that writes data directly. * There is an example in https://issues.apache.org/jira/browse/SPARK-10063 to show the bad * result of using direct output committer with speculation enabled. @@ -1070,7 +1073,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * output paths required (e.g. a table name to write to) in the same way as it would be * configured for a Hadoop MapReduce job. * - * Note that, we should make sure our tasks are idempotent when speculation is enabled, i.e. do + * @note We should make sure our tasks are idempotent when speculation is enabled, i.e. do * not use output committer that writes data directly. * There is an example in https://issues.apache.org/jira/browse/SPARK-10063 to show the bad * result of using direct output committer with speculation enabled. diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index 0c6ddda52cee9..ce75a16031a3f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -48,7 +48,7 @@ private[spark] class PruneDependency[T](rdd: RDD[T], partitionFilterFunc: Int => /** * :: DeveloperApi :: - * A RDD used to prune RDD partitions/partitions so we can avoid launching tasks on + * An RDD used to prune RDD partitions/partitions so we can avoid launching tasks on * all partitions. An example use case: If we know the RDD is partitioned by range, * and the execution DAG has a filter on the key, we can avoid launching tasks * on partitions that don't have the range covering the key. diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala index 3b1acacf409b9..6a89ea8786464 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala @@ -32,7 +32,7 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long) } /** - * A RDD sampled from its parent RDD partition-wise. For each partition of the parent RDD, + * An RDD sampled from its parent RDD partition-wise. For each partition of the parent RDD, * a user-specified [[org.apache.spark.util.random.RandomSampler]] instance is used to obtain * a random sample of the records in the partition. The random seeds assigned to the samplers * are guaranteed to have different values. diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index cded899db1f5c..bff2b8f1d06c9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -428,7 +428,7 @@ abstract class RDD[T: ClassTag]( * current upstream partitions will be executed in parallel (per whatever * the current partitioning is). * - * Note: With shuffle = true, you can actually coalesce to a larger number + * @note With shuffle = true, you can actually coalesce to a larger number * of partitions. This is useful if you have a small number of partitions, * say 100, potentially with a few partitions being abnormally large. Calling * coalesce(1000, shuffle = true) will result in 1000 partitions with the @@ -466,14 +466,14 @@ abstract class RDD[T: ClassTag]( /** * Return a sampled subset of this RDD. * - * Note: this is NOT guaranteed to provide exactly the fraction of the count - * of the given [[RDD]]. - * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] * with replacement: expected number of times each element is chosen; fraction must be >= 0 * @param seed seed for the random number generator + * + * @note This is NOT guaranteed to provide exactly the fraction of the count + * of the given [[RDD]]. */ def sample( withReplacement: Boolean, @@ -537,13 +537,13 @@ abstract class RDD[T: ClassTag]( /** * Return a fixed-size sampled subset of this RDD in an array * - * @note this method should only be used if the resulting array is expected to be small, as - * all the data is loaded into the driver's memory. - * * @param withReplacement whether sampling is done with replacement * @param num size of the returned sample * @param seed seed for the random number generator * @return sample of specified size in an array + * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. */ def takeSample( withReplacement: Boolean, @@ -618,7 +618,7 @@ abstract class RDD[T: ClassTag]( * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. */ def intersection(other: RDD[T]): RDD[T] = withScope { this.map(v => (v, null)).cogroup(other.map(v => (v, null))) @@ -630,7 +630,7 @@ abstract class RDD[T: ClassTag]( * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. * * @param partitioner Partitioner to use for the resulting RDD */ @@ -646,7 +646,7 @@ abstract class RDD[T: ClassTag]( * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. Performs a hash partition across the cluster * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. * * @param numPartitions How many partitions to use in the resulting RDD */ @@ -674,7 +674,7 @@ abstract class RDD[T: ClassTag]( * mapping to that key. The ordering of elements within each group is not guaranteed, and * may even differ each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an + * @note This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ @@ -687,7 +687,7 @@ abstract class RDD[T: ClassTag]( * mapping to that key. The ordering of elements within each group is not guaranteed, and * may even differ each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an + * @note This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ @@ -702,7 +702,7 @@ abstract class RDD[T: ClassTag]( * mapping to that key. The ordering of elements within each group is not guaranteed, and * may even differ each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an + * @note This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ @@ -921,7 +921,7 @@ abstract class RDD[T: ClassTag]( /** * Return an array that contains all of the elements in this RDD. * - * @note this method should only be used if the resulting array is expected to be small, as + * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. */ def collect(): Array[T] = withScope { @@ -934,7 +934,7 @@ abstract class RDD[T: ClassTag]( * * The iterator will consume as much memory as the largest partition in this RDD. * - * Note: this results in multiple Spark jobs, and if the input RDD is the result + * @note This results in multiple Spark jobs, and if the input RDD is the result * of a wide transformation (e.g. join with different partitioners), to avoid * recomputing the input RDD should be cached first. */ @@ -1182,7 +1182,7 @@ abstract class RDD[T: ClassTag]( /** * Return the count of each unique value in this RDD as a local map of (value, count) pairs. * - * Note that this method should only be used if the resulting map is expected to be small, as + * @note This method should only be used if the resulting map is expected to be small, as * the whole thing is loaded into the driver's memory. * To handle very large results, consider using rdd.map(x => (x, 1L)).reduceByKey(_ + _), which * returns an RDD[T, Long] instead of a map. @@ -1272,7 +1272,7 @@ abstract class RDD[T: ClassTag]( * This is similar to Scala's zipWithIndex but it uses Long instead of Int as the index type. * This method needs to trigger a spark job when this RDD contains more than one partitions. * - * Note that some RDDs, such as those returned by groupBy(), do not guarantee order of + * @note Some RDDs, such as those returned by groupBy(), do not guarantee order of * elements in a partition. The index assigned to each element is therefore not guaranteed, * and may even change if the RDD is reevaluated. If a fixed ordering is required to guarantee * the same index assignments, you should sort the RDD with sortByKey() or save it to a file. @@ -1286,7 +1286,7 @@ abstract class RDD[T: ClassTag]( * 2*n+k, ..., where n is the number of partitions. So there may exist gaps, but this method * won't trigger a spark job, which is different from [[org.apache.spark.rdd.RDD#zipWithIndex]]. * - * Note that some RDDs, such as those returned by groupBy(), do not guarantee order of + * @note Some RDDs, such as those returned by groupBy(), do not guarantee order of * elements in a partition. The unique ID assigned to each element is therefore not guaranteed, * and may even change if the RDD is reevaluated. If a fixed ordering is required to guarantee * the same index assignments, you should sort the RDD with sortByKey() or save it to a file. @@ -1305,10 +1305,10 @@ abstract class RDD[T: ClassTag]( * results from that partition to estimate the number of additional partitions needed to satisfy * the limit. * - * @note this method should only be used if the resulting array is expected to be small, as + * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. * - * @note due to complications in the internal implementation, this method will raise + * @note Due to complications in the internal implementation, this method will raise * an exception if called on an RDD of `Nothing` or `Null`. */ def take(num: Int): Array[T] = withScope { @@ -1370,7 +1370,7 @@ abstract class RDD[T: ClassTag]( * // returns Array(6, 5) * }}} * - * @note this method should only be used if the resulting array is expected to be small, as + * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. * * @param num k, the number of top elements to return @@ -1393,7 +1393,7 @@ abstract class RDD[T: ClassTag]( * // returns Array(2, 3) * }}} * - * @note this method should only be used if the resulting array is expected to be small, as + * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. * * @param num k, the number of elements to return @@ -1438,7 +1438,7 @@ abstract class RDD[T: ClassTag]( } /** - * @note due to complications in the internal implementation, this method will raise an + * @note Due to complications in the internal implementation, this method will raise an * exception if called on an RDD of `Nothing` or `Null`. This may be come up in practice * because, for example, the type of `parallelize(Seq())` is `RDD[Nothing]`. * (`parallelize(Seq())` should be avoided anyway in favor of `parallelize(Seq[T]())`.) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index 429514b4f6bee..1070bb96b2524 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -32,7 +32,7 @@ private[spark] object CheckpointState extends Enumeration { /** * This class contains all the information related to RDD checkpointing. Each instance of this - * class is associated with a RDD. It manages process of checkpointing of the associated RDD, + * class is associated with an RDD. It manages process of checkpointing of the associated RDD, * as well as, manages the post-checkpoint state by providing the updated partitions, * iterator and preferred locations of the checkpointed RDD. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index eac901d10067c..7f399ecf81a08 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -151,7 +151,7 @@ private[spark] object ReliableCheckpointRDD extends Logging { } /** - * Write a RDD partition's data to a checkpoint file. + * Write an RDD partition's data to a checkpoint file. */ def writePartitionToCheckpointFile[T: ClassTag]( path: String, diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala index 1311b481c7c71..86a332790fb00 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala @@ -27,9 +27,10 @@ import org.apache.spark.internal.Logging /** * Extra functions available on RDDs of (key, value) pairs to create a Hadoop SequenceFile, - * through an implicit conversion. Note that this can't be part of PairRDDFunctions because - * we need more implicit parameters to convert our keys and values to Writable. + * through an implicit conversion. * + * @note This can't be part of PairRDDFunctions because we need more implicit parameters to + * convert our keys and values to Writable. */ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag]( self: RDD[(K, V)], diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala index b0e5ba0865c63..8425b211d6ecf 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -29,7 +29,7 @@ class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long) } /** - * Represents a RDD zipped with its element indices. The ordering is first based on the partition + * Represents an RDD zipped with its element indices. The ordering is first based on the partition * index and then the ordering of items within each partition. So the first item in the first * partition gets index 0, and the last item in the last partition receives the largest index. * diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala index cedacad44afec..0a5fe5a1d3ee1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -24,11 +24,6 @@ import org.apache.spark.annotation.DeveloperApi * :: DeveloperApi :: * Information about an [[org.apache.spark.Accumulable]] modified during a task or stage. * - * Note: once this is JSON serialized the types of `update` and `value` will be lost and be - * cast to strings. This is because the user can define an accumulator of any type and it will - * be difficult to preserve the type in consumers of the event log. This does not apply to - * internal accumulators that represent task level metrics. - * * @param id accumulator ID * @param name accumulator name * @param update partial value from a task, may be None if used on driver to describe a stage @@ -36,6 +31,11 @@ import org.apache.spark.annotation.DeveloperApi * @param internal whether this accumulator was internal * @param countFailedValues whether to count this accumulator's partial value if the task failed * @param metadata internal metadata associated with this accumulator, if any + * + * @note Once this is JSON serialized the types of `update` and `value` will be lost and be + * cast to strings. This is because the user can define an accumulator of any type and it will + * be difficult to preserve the type in consumers of the event log. This does not apply to + * internal accumulators that represent task level metrics. */ @DeveloperApi case class AccumulableInfo private[spark] ( diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 8b72da2ee01b7..f60dcfddfdc20 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -131,7 +131,7 @@ private[spark] class JavaSerializerInstance( * :: DeveloperApi :: * A Spark serializer that uses Java's built-in serialization. * - * Note that this serializer is not guaranteed to be wire-compatible across different versions of + * @note This serializer is not guaranteed to be wire-compatible across different versions of * Spark. It is intended to be used to serialize/de-serialize data within a single * Spark application. */ diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 0d26281fe1076..19e020c968a9a 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -45,7 +45,7 @@ import org.apache.spark.util.collection.CompactBuffer /** * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. * - * Note that this serializer is not guaranteed to be wire-compatible across different versions of + * @note This serializer is not guaranteed to be wire-compatible across different versions of * Spark. It is intended to be used to serialize/de-serialize data within a single * Spark application. */ diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index cb95246d5b0ca..afe6cd86059f0 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -40,7 +40,7 @@ import org.apache.spark.util.NextIterator * * 2. Java serialization interface. * - * Note that serializers are not required to be wire-compatible across different versions of Spark. + * @note Serializers are not required to be wire-compatible across different versions of Spark. * They are intended to be used to serialize/de-serialize data within a single Spark application. */ @DeveloperApi diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index fb9941bbd9e0f..e12f2e6095d5a 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -71,7 +71,7 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return the blocks stored in this block manager. * - * Note that this is somewhat expensive, as it involves cloning the underlying maps and then + * @note This is somewhat expensive, as it involves cloning the underlying maps and then * concatenating them together. Much faster alternatives exist for common operations such as * contains, get, and size. */ @@ -80,7 +80,7 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return the RDD blocks stored in this block manager. * - * Note that this is somewhat expensive, as it involves cloning the underlying maps and then + * @note This is somewhat expensive, as it involves cloning the underlying maps and then * concatenating them together. Much faster alternatives exist for common operations such as * getting the memory, disk, and off-heap memory sizes occupied by this RDD. */ @@ -128,7 +128,8 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return whether the given block is stored in this block manager in O(1) time. - * Note that this is much faster than `this.blocks.contains`, which is O(blocks) time. + * + * @note This is much faster than `this.blocks.contains`, which is O(blocks) time. */ def containsBlock(blockId: BlockId): Boolean = { blockId match { @@ -141,7 +142,8 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return the given block stored in this block manager in O(1) time. - * Note that this is much faster than `this.blocks.get`, which is O(blocks) time. + * + * @note This is much faster than `this.blocks.get`, which is O(blocks) time. */ def getBlock(blockId: BlockId): Option[BlockStatus] = { blockId match { @@ -154,19 +156,22 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return the number of blocks stored in this block manager in O(RDDs) time. - * Note that this is much faster than `this.blocks.size`, which is O(blocks) time. + * + * @note This is much faster than `this.blocks.size`, which is O(blocks) time. */ def numBlocks: Int = _nonRddBlocks.size + numRddBlocks /** * Return the number of RDD blocks stored in this block manager in O(RDDs) time. - * Note that this is much faster than `this.rddBlocks.size`, which is O(RDD blocks) time. + * + * @note This is much faster than `this.rddBlocks.size`, which is O(RDD blocks) time. */ def numRddBlocks: Int = _rddBlocks.values.map(_.size).sum /** * Return the number of blocks that belong to the given RDD in O(1) time. - * Note that this is much faster than `this.rddBlocksById(rddId).size`, which is + * + * @note This is much faster than `this.rddBlocksById(rddId).size`, which is * O(blocks in this RDD) time. */ def numRddBlocksById(rddId: Int): Int = _rddBlocks.get(rddId).map(_.size).getOrElse(0) diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index d3ddd39131326..1326f0977c241 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -59,8 +59,9 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { } /** - * Returns true if this accumulator has been registered. Note that all accumulators must be - * registered before use, or it will throw exception. + * Returns true if this accumulator has been registered. + * + * @note All accumulators must be registered before use, or it will throw exception. */ final def isRegistered: Boolean = metadata != null && AccumulatorContext.get(metadata.id).isDefined diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index bec95d13d193a..5e8a854e46a0f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -2076,7 +2076,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou } /** - * Checks the DAGScheduler's internal logic for traversing a RDD DAG by making sure that + * Checks the DAGScheduler's internal logic for traversing an RDD DAG by making sure that * getShuffleDependencies correctly returns the direct shuffle dependencies of a particular * RDD. The test creates the following RDD graph (where n denotes a narrow dependency and s * denotes a shuffle dependency): diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index d90905a86ade9..ca84551506b2b 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -27,7 +27,7 @@ best fitting the original data points. [pool adjacent violators algorithm](http://doi.org/10.1198/TECH.2010.10111) which uses an approach to [parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). -The training input is a RDD of tuples of three double values that represent +The training input is an RDD of tuples of three double values that represent label, feature and weight in this order. Additionally IsotonicRegression algorithm has one optional parameter called $isotonic$ defaulting to true. This argument specifies if the isotonic regression is diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 0b0315b366501..18fc1cd934826 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -2191,7 +2191,7 @@ consistent batch processing times. Make sure you set the CMS GC on both the driv - When data is received from a stream source, receiver creates blocks of data. A new block of data is generated every blockInterval milliseconds. N blocks of data are created during the batchInterval where N = batchInterval/blockInterval. These blocks are distributed by the BlockManager of the current executor to the block managers of other executors. After that, the Network Input Tracker running on the driver is informed about the block locations for further processing. -- A RDD is created on the driver for the blocks created during the batchInterval. The blocks generated during the batchInterval are partitions of the RDD. Each partition is a task in spark. blockInterval== batchinterval would mean that a single partition is created and probably it is processed locally. +- An RDD is created on the driver for the blocks created during the batchInterval. The blocks generated during the batchInterval are partitions of the RDD. Each partition is a task in spark. blockInterval== batchinterval would mean that a single partition is created and probably it is processed locally. - The map tasks on the blocks are processed in the executors (one that received the block, and another where the block was replicated) that has the blocks irrespective of block interval, unless non-local scheduling kicks in. Having bigger blockinterval means bigger blocks. A high value of `spark.locality.wait` increases the chance of processing a block on the local node. A balance needs to be found out between these two parameters to ensure that the bigger blocks are processed locally. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 5bcc5124b0915..341081a338c0e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -279,7 +279,7 @@ private[kafka010] case class KafkaSource( } }.toArray - // Create a RDD that reads from Kafka and get the (key, value) pair as byte arrays. + // Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays. val rdd = new KafkaSourceRDD( sc, executorKafkaParams, offsetRanges, pollTimeoutMs).map { cr => Row(cr.key, cr.value, cr.topic, cr.partition, cr.offset, cr.timestamp, cr.timestampType.id) diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index b17e198077949..56f0cb0b166a2 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -223,7 +223,7 @@ object KafkaUtils { } /** - * Create a RDD from Kafka using offset ranges for each topic and partition. + * Create an RDD from Kafka using offset ranges for each topic and partition. * * @param sc SparkContext object * @param kafkaParams Kafka @@ -255,7 +255,7 @@ object KafkaUtils { } /** - * Create a RDD from Kafka using offset ranges for each topic and partition. This allows you + * Create an RDD from Kafka using offset ranges for each topic and partition. This allows you * specify the Kafka leader to connect to (to optimize fetching) and access the message as well * as the metadata. * @@ -303,7 +303,7 @@ object KafkaUtils { } /** - * Create a RDD from Kafka using offset ranges for each topic and partition. + * Create an RDD from Kafka using offset ranges for each topic and partition. * * @param jsc JavaSparkContext object * @param kafkaParams Kafka @@ -340,7 +340,7 @@ object KafkaUtils { } /** - * Create a RDD from Kafka using offset ranges for each topic and partition. This allows you + * Create an RDD from Kafka using offset ranges for each topic and partition. This allows you * specify the Kafka leader to connect to (to optimize fetching) and access the message as well * as the metadata. * diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index a0007d33d6257..b2daffa34ccbf 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -33,10 +33,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain - * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain - * gets the AWS credentials. - * * @param ssc StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -57,6 +53,10 @@ object KinesisUtils { * StorageLevel.MEMORY_AND_DISK_2 is recommended. * @param messageHandler A custom message handler that can generate a generic output from a * Kinesis `Record`, which contains both message data, and metadata. + * + * @note The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. */ def createStream[T: ClassTag]( ssc: StreamingContext, @@ -81,10 +81,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: - * The given AWS credentials will get saved in DStream checkpoints if checkpointing - * is enabled. Make sure that your checkpoint directory is secure. - * * @param ssc StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -107,6 +103,9 @@ object KinesisUtils { * Kinesis `Record`, which contains both message data, and metadata. * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + * + * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. */ // scalastyle:off def createStream[T: ClassTag]( @@ -134,10 +133,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain - * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain - * gets the AWS credentials. - * * @param ssc StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -156,6 +151,10 @@ object KinesisUtils { * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects. * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * + * @note The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. */ def createStream( ssc: StreamingContext, @@ -178,10 +177,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: - * The given AWS credentials will get saved in DStream checkpoints if checkpointing - * is enabled. Make sure that your checkpoint directory is secure. - * * @param ssc StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -202,6 +197,9 @@ object KinesisUtils { * StorageLevel.MEMORY_AND_DISK_2 is recommended. * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + * + * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. */ def createStream( ssc: StreamingContext, @@ -225,10 +223,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain - * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain - * gets the AWS credentials. - * * @param jssc Java StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -250,6 +244,10 @@ object KinesisUtils { * @param messageHandler A custom message handler that can generate a generic output from a * Kinesis `Record`, which contains both message data, and metadata. * @param recordClass Class of the records in DStream + * + * @note The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. */ def createStream[T]( jssc: JavaStreamingContext, @@ -272,10 +270,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: - * The given AWS credentials will get saved in DStream checkpoints if checkpointing - * is enabled. Make sure that your checkpoint directory is secure. - * * @param jssc Java StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -299,6 +293,9 @@ object KinesisUtils { * @param recordClass Class of the records in DStream * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + * + * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. */ // scalastyle:off def createStream[T]( @@ -326,10 +323,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain - * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain - * gets the AWS credentials. - * * @param jssc Java StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -348,6 +341,10 @@ object KinesisUtils { * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects. * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * + * @note The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. */ def createStream( jssc: JavaStreamingContext, @@ -367,10 +364,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: - * The given AWS credentials will get saved in DStream checkpoints if checkpointing - * is enabled. Make sure that your checkpoint directory is secure. - * * @param jssc Java StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -391,6 +384,9 @@ object KinesisUtils { * StorageLevel.MEMORY_AND_DISK_2 is recommended. * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + * + * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. */ def createStream( jssc: JavaStreamingContext, diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala index 905c33834df16..a4d81a680979e 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -221,7 +221,7 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) assert(collectedData.toSet === testData.toSet) // Verify that the block fetching is skipped when isBlockValid is set to false. - // This is done by using a RDD whose data is only in memory but is set to skip block fetching + // This is done by using an RDD whose data is only in memory but is set to skip block fetching // Using that RDD will throw exception, as it skips block fetching even if the blocks are in // in BlockManager. if (testIsBlockValid) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index e18831382d4d5..3810110099993 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -42,7 +42,7 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( @transient override val edges: EdgeRDDImpl[ED, VD] = replicatedVertexView.edges - /** Return a RDD that brings edges together with their source and destination vertices. */ + /** Return an RDD that brings edges together with their source and destination vertices. */ @transient override lazy val triplets: RDD[EdgeTriplet[VD, ED]] = { replicatedVertexView.upgrade(vertices, true, true) replicatedVertexView.edges.partitionsRDD.mapPartitions(_.flatMap { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index c0c3c73463aab..f926984aa6335 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -58,7 +58,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors} * `alpha` is the random reset probability (typically 0.15), `inNbrs[i]` is the set of * neighbors which link to `i` and `outDeg[j]` is the out degree of vertex `j`. * - * Note that this is not the "normalized" PageRank and as a consequence pages that have no + * @note This is not the "normalized" PageRank and as a consequence pages that have no * inlinks will have a PageRank of alpha. */ object PageRank extends Logging { diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala index 2e4a58dc6291c..22e4ec693b1f7 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala @@ -30,7 +30,7 @@ import org.apache.spark.annotation.Since /** * Represents a numeric vector, whose index type is Int and value type is Double. * - * Note: Users should not implement this interface. + * @note Users should not implement this interface. */ @Since("2.0.0") sealed trait Vector extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala index 252acc156583f..c581fed177273 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala @@ -30,7 +30,7 @@ import org.apache.spark.ml.param.ParamMap abstract class Model[M <: Model[M]] extends Transformer { /** * The parent estimator that produced this model. - * Note: For ensembles' component Models, this value can be null. + * @note For ensembles' component Models, this value can be null. */ @transient var parent: Estimator[M] = _ diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index bb192ab5f25ab..7424031ed4608 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -207,9 +207,9 @@ class DecisionTreeClassificationModel private[ml] ( * where gain is scaled by the number of instances passing through node * - Normalize importances for tree to sum to 1. * - * Note: Feature importance for single decision trees can have high variance due to - * correlated predictor variables. Consider using a [[RandomForestClassifier]] - * to determine feature importance instead. + * @note Feature importance for single decision trees can have high variance due to + * correlated predictor variables. Consider using a [[RandomForestClassifier]] + * to determine feature importance instead. */ @Since("2.0.0") lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index f8f164e8c14bd..52f93f5a6b345 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -43,7 +43,6 @@ import org.apache.spark.sql.types.DoubleType * Gradient-Boosted Trees (GBTs) (http://en.wikipedia.org/wiki/Gradient_boosting) * learning algorithm for classification. * It supports binary labels, as well as both continuous and categorical features. - * Note: Multiclass labels are not currently supported. * * The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999. * @@ -54,6 +53,8 @@ import org.apache.spark.sql.types.DoubleType * based on the loss function, whereas the original gradient boosting method does not. * - We expect to implement TreeBoost in the future: * [https://issues.apache.org/jira/browse/SPARK-4240] + * + * @note Multiclass labels are not currently supported. */ @Since("1.4.0") class GBTClassifier @Since("1.4.0") ( @@ -169,10 +170,11 @@ object GBTClassifier extends DefaultParamsReadable[GBTClassifier] { * Gradient-Boosted Trees (GBTs) (http://en.wikipedia.org/wiki/Gradient_boosting) * model for classification. * It supports binary labels, as well as both continuous and categorical features. - * Note: Multiclass labels are not currently supported. * * @param _trees Decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble. + * + * @note Multiclass labels are not currently supported. */ @Since("1.6.0") class GBTClassificationModel private[ml]( diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 18b9b3043db8a..71a7fe53c15f8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1191,8 +1191,8 @@ class BinaryLogisticRegressionSummary private[classification] ( * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. * See http://en.wikipedia.org/wiki/Receiver_operating_characteristic * - * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. + * This will change in later Spark versions. */ @Since("1.5.0") @transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR") @@ -1200,8 +1200,8 @@ class BinaryLogisticRegressionSummary private[classification] ( /** * Computes the area under the receiver operating characteristic (ROC) curve. * - * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. + * This will change in later Spark versions. */ @Since("1.5.0") lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC() @@ -1210,8 +1210,8 @@ class BinaryLogisticRegressionSummary private[classification] ( * Returns the precision-recall curve, which is a Dataframe containing * two fields recall, precision with (0.0, 1.0) prepended to it. * - * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. + * This will change in later Spark versions. */ @Since("1.5.0") @transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision") @@ -1219,8 +1219,8 @@ class BinaryLogisticRegressionSummary private[classification] ( /** * Returns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0. * - * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. + * This will change in later Spark versions. */ @Since("1.5.0") @transient lazy val fMeasureByThreshold: DataFrame = { @@ -1232,8 +1232,8 @@ class BinaryLogisticRegressionSummary private[classification] ( * Every possible probability obtained in transforming the dataset are used * as thresholds used in calculating the precision. * - * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. + * This will change in later Spark versions. */ @Since("1.5.0") @transient lazy val precisionByThreshold: DataFrame = { @@ -1245,8 +1245,8 @@ class BinaryLogisticRegressionSummary private[classification] ( * Every possible probability obtained in transforming the dataset are used * as thresholds used in calculating the recall. * - * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. + * This will change in later Spark versions. */ @Since("1.5.0") @transient lazy val recallByThreshold: DataFrame = { @@ -1401,18 +1401,18 @@ class BinaryLogisticRegressionSummary private[classification] ( * $$ *

* - * @note In order to avoid unnecessary computation during calculation of the gradient updates - * we lay out the coefficients in column major order during training. This allows us to - * perform feature standardization once, while still retaining sequential memory access - * for speed. We convert back to row major order when we create the model, - * since this form is optimal for the matrix operations used for prediction. - * * @param bcCoefficients The broadcast coefficients corresponding to the features. * @param bcFeaturesStd The broadcast standard deviation values of the features. * @param numClasses the number of possible outcomes for k classes classification problem in * Multinomial Logistic Regression. * @param fitIntercept Whether to fit an intercept term. * @param multinomial Whether to use multinomial (softmax) or binary loss + * + * @note In order to avoid unnecessary computation during calculation of the gradient updates + * we lay out the coefficients in column major order during training. This allows us to + * perform feature standardization once, while still retaining sequential memory access + * for speed. We convert back to row major order when we create the model, + * since this form is optimal for the matrix operations used for prediction. */ private class LogisticAggregator( bcCoefficients: Broadcast[Vector], diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index a0bd66e731a1d..c6035cc4c9647 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -268,9 +268,9 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] { * While this process is generally guaranteed to converge, it is not guaranteed * to find a global optimum. * - * Note: For high-dimensional data (with many features), this algorithm may perform poorly. - * This is due to high-dimensional data (a) making it difficult to cluster at all (based - * on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions. + * @note For high-dimensional data (with many features), this algorithm may perform poorly. + * This is due to high-dimensional data (a) making it difficult to cluster at all (based + * on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions. */ @Since("2.0.0") @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 28cbe1cb01e9a..ccfb0ce8f85ca 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -85,7 +85,8 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H *

* * For the case $E_{max} == E_{min}$, $Rescaled(e_i) = 0.5 * (max + min)$. - * Note that since zero values will probably be transformed to non-zero values, output of the + * + * @note Since zero values will probably be transformed to non-zero values, output of the * transformer will be DenseVector even for sparse input. */ @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index e8e28ba29c841..ea401216aec7b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -36,7 +36,8 @@ import org.apache.spark.sql.types.{DoubleType, NumericType, StructType} * The last category is not included by default (configurable via [[OneHotEncoder!.dropLast]] * because it makes the vector entries sum up to one, and hence linearly dependent. * So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. - * Note that this is different from scikit-learn's OneHotEncoder, which keeps all categories. + * + * @note This is different from scikit-learn's OneHotEncoder, which keeps all categories. * The output vectors are sparse. * * @see [[StringIndexer]] for converting categorical values into category indices diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 1e49352b8517e..6e08bf059124c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -142,8 +142,9 @@ class PCAModel private[ml] ( /** * Transform a vector by computed Principal Components. - * NOTE: Vectors to be transformed must be the same length - * as the source vectors given to [[PCA.fit()]]. + * + * @note Vectors to be transformed must be the same length as the source vectors given + * to [[PCA.fit()]]. */ @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 666070037cdd8..0ced21365ff6f 100755 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -28,7 +28,10 @@ import org.apache.spark.sql.types.{ArrayType, StringType, StructType} /** * A feature transformer that filters out stop words from input. - * Note: null values from input array are preserved unless adding null to stopWords explicitly. + * + * @note null values from input array are preserved unless adding null to stopWords + * explicitly. + * * @see [[http://en.wikipedia.org/wiki/Stop_words]] */ @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 80fe46796f807..8b155f00017cf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -113,11 +113,11 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] { /** * Model fitted by [[StringIndexer]]. * - * NOTE: During transformation, if the input column does not exist, + * @param labels Ordered list of labels, corresponding to indices to be assigned. + * + * @note During transformation, if the input column does not exist, * [[StringIndexerModel.transform]] would return the input dataset unmodified. * This is a temporary fix for the case when target labels do not exist during prediction. - * - * @param labels Ordered list of labels, corresponding to indices to be assigned. */ @Since("1.4.0") class StringIndexerModel ( diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 9245931b27ca6..96206e0b7ad88 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -533,7 +533,7 @@ trait Params extends Identifiable with Serializable { * Returns all params sorted by their names. The default implementation uses Java reflection to * list all public methods that have no arguments and return [[Param]]. * - * Note: Developer should not use this method in constructor because we cannot guarantee that + * @note Developer should not use this method in constructor because we cannot guarantee that * this variable gets initialized before other params. */ lazy val params: Array[Param[_]] = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index ebc6c12ddcf92..1419da874709f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -207,9 +207,9 @@ class DecisionTreeRegressionModel private[ml] ( * where gain is scaled by the number of instances passing through node * - Normalize importances for tree to sum to 1. * - * Note: Feature importance for single decision trees can have high variance due to - * correlated predictor variables. Consider using a [[RandomForestRegressor]] - * to determine feature importance instead. + * @note Feature importance for single decision trees can have high variance due to + * correlated predictor variables. Consider using a [[RandomForestRegressor]] + * to determine feature importance instead. */ @Since("2.0.0") lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 1d2961e0277f5..736fd3b9e0f64 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -879,8 +879,8 @@ class GeneralizedLinearRegressionSummary private[regression] ( * Private copy of model to ensure Params are not modified outside this class. * Coefficients is not a deep copy, but that is acceptable. * - * NOTE: [[predictionCol]] must be set correctly before the value of [[model]] is set, - * and [[model]] must be set before [[predictions]] is set! + * @note [[predictionCol]] must be set correctly before the value of [[model]] is set, + * and [[model]] must be set before [[predictions]] is set! */ protected val model: GeneralizedLinearRegressionModel = origModel.copy(ParamMap.empty).setPredictionCol(predictionCol) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 71c542adf6f6f..da7ce6b46f2ab 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -103,11 +103,13 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String /** * Whether to standardize the training features before fitting the model. * The coefficients of models will be always returned on the original scale, - * so it will be transparent for users. Note that with/without standardization, - * the models should be always converged to the same solution when no regularization - * is applied. In R's GLMNET package, the default behavior is true as well. + * so it will be transparent for users. * Default is true. * + * @note With/without standardization, the models should be always converged + * to the same solution when no regularization is applied. In R's GLMNET package, + * the default behavior is true as well. + * * @group setParam */ @Since("1.5.0") @@ -624,8 +626,8 @@ class LinearRegressionSummary private[regression] ( * explainedVariance = 1 - variance(y - \hat{y}) / variance(y) * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]] * - * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val explainedVariance: Double = metrics.explainedVariance @@ -634,8 +636,8 @@ class LinearRegressionSummary private[regression] ( * Returns the mean absolute error, which is a risk function corresponding to the * expected value of the absolute error loss or l1-norm loss. * - * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val meanAbsoluteError: Double = metrics.meanAbsoluteError @@ -644,8 +646,8 @@ class LinearRegressionSummary private[regression] ( * Returns the mean squared error, which is a risk function corresponding to the * expected value of the squared error loss or quadratic loss. * - * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val meanSquaredError: Double = metrics.meanSquaredError @@ -654,8 +656,8 @@ class LinearRegressionSummary private[regression] ( * Returns the root mean squared error, which is defined as the square root of * the mean squared error. * - * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val rootMeanSquaredError: Double = metrics.rootMeanSquaredError @@ -664,8 +666,8 @@ class LinearRegressionSummary private[regression] ( * Returns R^2^, the coefficient of determination. * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] * - * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val r2: Double = metrics.r2 diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMDataSource.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMDataSource.scala index 73d813064decb..e1376927030e4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMDataSource.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMDataSource.scala @@ -48,7 +48,7 @@ import org.apache.spark.sql.{DataFrame, DataFrameReader} * inconsistent feature dimensions. * - "vectorType": feature vector type, "sparse" (default) or "dense". * - * Note that this class is public for documentation purpose. Please don't use this class directly. + * @note This class is public for documentation purpose. Please don't use this class directly. * Rather, use the data source API as illustrated above. * * @see [[https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ LIBSVM datasets]] diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index ede0a060eef95..0a0bc4c006389 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -98,7 +98,7 @@ private[spark] object GradientBoostedTrees extends Logging { * @param initTreeWeight: learning rate assigned to the first tree. * @param initTree: first DecisionTreeModel. * @param loss: evaluation metric. - * @return a RDD with each element being a zip of the prediction and error + * @return an RDD with each element being a zip of the prediction and error * corresponding to every sample. */ def computeInitialPredictionAndError( @@ -121,7 +121,7 @@ private[spark] object GradientBoostedTrees extends Logging { * @param treeWeight: Learning rate. * @param tree: Tree using which the prediction and error should be updated. * @param loss: evaluation metric. - * @return a RDD with each element being a zip of the prediction and error + * @return an RDD with each element being a zip of the prediction and error * corresponding to each sample. */ def updatePredictionError( diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index bc4f9e6716ee8..e5fa5d53e3fca 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -221,7 +221,7 @@ trait MLReadable[T] { /** * Reads an ML instance from the input path, a shortcut of `read.load(path)`. * - * Note: Implementing classes should override this to be Java-friendly. + * @note Implementing classes should override this to be Java-friendly. */ @Since("1.6.0") def load(path: String): T = read.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index d851b983349c9..4b650000736e2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -202,9 +202,11 @@ object LogisticRegressionModel extends Loader[LogisticRegressionModel] { * Train a classification model for Binary Logistic Regression * using Stochastic Gradient Descent. By default L2 regularization is used, * which can be changed via `LogisticRegressionWithSGD.optimizer`. - * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1} - * for k classes multi-label classification problem. + * * Using [[LogisticRegressionWithLBFGS]] is recommended over this. + * + * @note Labels used in Logistic Regression should be {0, 1, ..., k - 1} + * for k classes multi-label classification problem. */ @Since("0.8.0") class LogisticRegressionWithSGD private[mllib] ( @@ -239,7 +241,8 @@ class LogisticRegressionWithSGD private[mllib] ( /** * Top-level methods for calling Logistic Regression using Stochastic Gradient Descent. - * NOTE: Labels used in Logistic Regression should be {0, 1} + * + * @note Labels used in Logistic Regression should be {0, 1} */ @Since("0.8.0") @deprecated("Use ml.classification.LogisticRegression or LogisticRegressionWithLBFGS", "2.0.0") @@ -252,7 +255,6 @@ object LogisticRegressionWithSGD { * number of iterations of gradient descent using the specified step size. Each iteration uses * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in * gradient descent are initialized using the initial weights provided. - * NOTE: Labels used in Logistic Regression should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. @@ -260,6 +262,8 @@ object LogisticRegressionWithSGD { * @param miniBatchFraction Fraction of data to be used per iteration. * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. + * + * @note Labels used in Logistic Regression should be {0, 1} */ @Since("1.0.0") def train( @@ -276,13 +280,13 @@ object LogisticRegressionWithSGD { * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed * number of iterations of gradient descent using the specified step size. Each iteration uses * `miniBatchFraction` fraction of the data to calculate the gradient. - * NOTE: Labels used in Logistic Regression should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. * @param stepSize Step size to be used for each iteration of gradient descent. - * @param miniBatchFraction Fraction of data to be used per iteration. + * + * @note Labels used in Logistic Regression should be {0, 1} */ @Since("1.0.0") def train( @@ -298,13 +302,13 @@ object LogisticRegressionWithSGD { * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed * number of iterations of gradient descent using the specified step size. We use the entire data * set to update the gradient in each iteration. - * NOTE: Labels used in Logistic Regression should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param stepSize Step size to be used for each iteration of Gradient Descent. - * @param numIterations Number of iterations of gradient descent to run. * @return a LogisticRegressionModel which has the weights and offset from training. + * + * @note Labels used in Logistic Regression should be {0, 1} */ @Since("1.0.0") def train( @@ -318,11 +322,12 @@ object LogisticRegressionWithSGD { * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed * number of iterations of gradient descent using a step size of 1.0. We use the entire data set * to update the gradient in each iteration. - * NOTE: Labels used in Logistic Regression should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. * @return a LogisticRegressionModel which has the weights and offset from training. + * + * @note Labels used in Logistic Regression should be {0, 1} */ @Since("1.0.0") def train( @@ -335,8 +340,6 @@ object LogisticRegressionWithSGD { /** * Train a classification model for Multinomial/Binary Logistic Regression using * Limited-memory BFGS. Standard feature scaling and L2 regularization are used by default. - * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1} - * for k classes multi-label classification problem. * * Earlier implementations of LogisticRegressionWithLBFGS applies a regularization * penalty to all elements including the intercept. If this is called with one of @@ -344,6 +347,9 @@ object LogisticRegressionWithSGD { * into a call to ml.LogisticRegression, otherwise this will use the existing mllib * GeneralizedLinearAlgorithm trainer, resulting in a regularization penalty to the * intercept. + * + * @note Labels used in Logistic Regression should be {0, 1, ..., k - 1} + * for k classes multi-label classification problem. */ @Since("1.1.0") class LogisticRegressionWithLBFGS diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 7c3ccbb40b812..aec1526b55c49 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -125,7 +125,8 @@ object SVMModel extends Loader[SVMModel] { /** * Train a Support Vector Machine (SVM) using Stochastic Gradient Descent. By default L2 * regularization is used, which can be changed via [[SVMWithSGD.optimizer]]. - * NOTE: Labels used in SVM should be {0, 1}. + * + * @note Labels used in SVM should be {0, 1}. */ @Since("0.8.0") class SVMWithSGD private ( @@ -158,7 +159,9 @@ class SVMWithSGD private ( } /** - * Top-level methods for calling SVM. NOTE: Labels used in SVM should be {0, 1}. + * Top-level methods for calling SVM. + * + * @note Labels used in SVM should be {0, 1}. */ @Since("0.8.0") object SVMWithSGD { @@ -169,8 +172,6 @@ object SVMWithSGD { * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in * gradient descent are initialized using the initial weights provided. * - * NOTE: Labels used in SVM should be {0, 1}. - * * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. * @param stepSize Step size to be used for each iteration of gradient descent. @@ -178,6 +179,8 @@ object SVMWithSGD { * @param miniBatchFraction Fraction of data to be used per iteration. * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. + * + * @note Labels used in SVM should be {0, 1}. */ @Since("0.8.0") def train( @@ -195,7 +198,8 @@ object SVMWithSGD { * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number * of iterations of gradient descent using the specified step size. Each iteration uses * `miniBatchFraction` fraction of the data to calculate the gradient. - * NOTE: Labels used in SVM should be {0, 1} + * + * @note Labels used in SVM should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. @@ -217,13 +221,14 @@ object SVMWithSGD { * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number * of iterations of gradient descent using the specified step size. We use the entire data set to * update the gradient in each iteration. - * NOTE: Labels used in SVM should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param stepSize Step size to be used for each iteration of Gradient Descent. * @param regParam Regularization parameter. * @param numIterations Number of iterations of gradient descent to run. * @return a SVMModel which has the weights and offset from training. + * + * @note Labels used in SVM should be {0, 1} */ @Since("0.8.0") def train( @@ -238,11 +243,12 @@ object SVMWithSGD { * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number * of iterations of gradient descent using a step size of 1.0. We use the entire data set to * update the gradient in each iteration. - * NOTE: Labels used in SVM should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. * @return a SVMModel which has the weights and offset from training. + * + * @note Labels used in SVM should be {0, 1} */ @Since("0.8.0") def train(input: RDD[LabeledPoint], numIterations: Int): SVMModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index 43193adf3e184..56cdeea5f7a3f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -41,14 +41,14 @@ import org.apache.spark.util.Utils * While this process is generally guaranteed to converge, it is not guaranteed * to find a global optimum. * - * Note: For high-dimensional data (with many features), this algorithm may perform poorly. - * This is due to high-dimensional data (a) making it difficult to cluster at all (based - * on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions. - * * @param k Number of independent Gaussians in the mixture model. * @param convergenceTol Maximum change in log-likelihood at which convergence * is considered to have occurred. * @param maxIterations Maximum number of iterations allowed. + * + * @note For high-dimensional data (with many features), this algorithm may perform poorly. + * This is due to high-dimensional data (a) making it difficult to cluster at all (based + * on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions. */ @Since("1.3.0") class GaussianMixture private ( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index ed9c064879d01..fa72b72e2d921 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -56,14 +56,18 @@ class KMeans private ( def this() = this(2, 20, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong()) /** - * Number of clusters to create (k). Note that it is possible for fewer than k clusters to + * Number of clusters to create (k). + * + * @note It is possible for fewer than k clusters to * be returned, for example, if there are fewer than k distinct points to cluster. */ @Since("1.4.0") def getK: Int = k /** - * Set the number of clusters to create (k). Note that it is possible for fewer than k clusters to + * Set the number of clusters to create (k). + * + * @note It is possible for fewer than k clusters to * be returned, for example, if there are fewer than k distinct points to cluster. Default: 2. */ @Since("0.8.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index d999b9be8e8ac..7c52abdeaac22 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -175,7 +175,7 @@ class LDA private ( * * This is the parameter to a symmetric Dirichlet distribution. * - * Note: The topics' distributions over terms are called "beta" in the original LDA paper + * @note The topics' distributions over terms are called "beta" in the original LDA paper * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. */ @Since("1.3.0") @@ -187,7 +187,7 @@ class LDA private ( * * This is the parameter to a symmetric Dirichlet distribution. * - * Note: The topics' distributions over terms are called "beta" in the original LDA paper + * @note The topics' distributions over terms are called "beta" in the original LDA paper * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. * * If set to -1, then topicConcentration is set automatically. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 90d8a558f10d4..b5b0e64a2a6c6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -66,7 +66,7 @@ abstract class LDAModel private[clustering] extends Saveable { * * This is the parameter to a symmetric Dirichlet distribution. * - * Note: The topics' distributions over terms are called "beta" in the original LDA paper + * @note The topics' distributions over terms are called "beta" in the original LDA paper * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. */ @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index ae324f86fe6d1..7365ea1f200da 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -93,9 +93,11 @@ final class EMLDAOptimizer extends LDAOptimizer { /** * If using checkpointing, this indicates whether to keep the last checkpoint (vs clean up). * Deleting the checkpoint can cause failures if a data partition is lost, so set this bit with - * care. Note that checkpoints will be cleaned up via reference counting, regardless. + * care. * * Default: true + * + * @note Checkpoints will be cleaned up via reference counting, regardless. */ @Since("2.0.0") def setKeepLastCheckpoint(keepLastCheckpoint: Boolean): this.type = { @@ -348,7 +350,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { * Mini-batch fraction in (0, 1], which sets the fraction of document sampled and used in * each iteration. * - * Note that this should be adjusted in synch with [[LDA.setMaxIterations()]] + * @note This should be adjusted in synch with [[LDA.setMaxIterations()]] * so the entire corpus is used. Specifically, set both so that * maxIterations * miniBatchFraction >= 1. * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala index f0779491e6374..003d1411a9cf7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala @@ -39,7 +39,7 @@ private[evaluation] object AreaUnderCurve { /** * Returns the area under the given curve. * - * @param curve a RDD of ordered 2D points stored in pairs representing a curve + * @param curve an RDD of ordered 2D points stored in pairs representing a curve */ def of(curve: RDD[(Double, Double)]): Double = { curve.sliding(2).aggregate(0.0)( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index fbd217af74ecb..c94d7890cf557 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.types._ /** * Represents a numeric vector, whose index type is Int and value type is Double. * - * Note: Users should not implement this interface. + * @note Users should not implement this interface. */ @SQLUserDefinedType(udt = classOf[VectorUDT]) @Since("1.0.0") @@ -132,7 +132,9 @@ sealed trait Vector extends Serializable { /** * Number of active entries. An "active entry" is an element which is explicitly stored, - * regardless of its value. Note that inactive entries have value 0. + * regardless of its value. + * + * @note Inactive entries have value 0. */ @Since("1.4.0") def numActives: Int diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala index 377be6bfb9886..03866753b50ee 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -451,7 +451,7 @@ class BlockMatrix @Since("1.3.0") ( * [[BlockMatrix]] will only consist of blocks of [[DenseMatrix]]. This may cause * some performance issues until support for multiplying two sparse matrices is added. * - * Note: The behavior of multiply has changed in 1.6.0. `multiply` used to throw an error when + * @note The behavior of multiply has changed in 1.6.0. `multiply` used to throw an error when * there were blocks with duplicate indices. Now, the blocks with duplicate indices will be added * with each other. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index b03b3ecde94f4..809906a158337 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -188,8 +188,9 @@ class IndexedRowMatrix @Since("1.0.0") ( } /** - * Computes the Gramian matrix `A^T A`. Note that this cannot be - * computed on matrices with more than 65535 columns. + * Computes the Gramian matrix `A^T A`. + * + * @note This cannot be computed on matrices with more than 65535 columns. */ @Since("1.0.0") def computeGramianMatrix(): Matrix = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index ec32e37afb792..4b120332ab8d8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -106,8 +106,9 @@ class RowMatrix @Since("1.0.0") ( } /** - * Computes the Gramian matrix `A^T A`. Note that this cannot be computed on matrices with - * more than 65535 columns. + * Computes the Gramian matrix `A^T A`. + * + * @note This cannot be computed on matrices with more than 65535 columns. */ @Since("1.0.0") def computeGramianMatrix(): Matrix = { @@ -168,9 +169,6 @@ class RowMatrix @Since("1.0.0") ( * ARPACK is set to 300 or k * 3, whichever is larger. The numerical tolerance for ARPACK's * eigen-decomposition is set to 1e-10. * - * @note The conditions that decide which method to use internally and the default parameters are - * subject to change. - * * @param k number of leading singular values to keep (0 < k <= n). * It might return less than k if * there are numerically zero singular values or there are not enough Ritz values @@ -180,6 +178,9 @@ class RowMatrix @Since("1.0.0") ( * @param rCond the reciprocal condition number. All singular values smaller than rCond * sigma(0) * are treated as zero, where sigma(0) is the largest singular value. * @return SingularValueDecomposition(U, s, V). U = null if computeU = false. + * + * @note The conditions that decide which method to use internally and the default parameters are + * subject to change. */ @Since("1.0.0") def computeSVD( @@ -319,9 +320,11 @@ class RowMatrix @Since("1.0.0") ( } /** - * Computes the covariance matrix, treating each row as an observation. Note that this cannot - * be computed on matrices with more than 65535 columns. + * Computes the covariance matrix, treating each row as an observation. + * * @return a local dense matrix of size n x n + * + * @note This cannot be computed on matrices with more than 65535 columns. */ @Since("1.0.0") def computeCovariance(): Matrix = { @@ -369,12 +372,12 @@ class RowMatrix @Since("1.0.0") ( * The row data do not need to be "centered" first; it is not necessary for * the mean of each column to be 0. * - * Note that this cannot be computed on matrices with more than 65535 columns. - * * @param k number of top principal components. * @return a matrix of size n-by-k, whose columns are principal components, and * a vector of values which indicate how much variance each principal component * explains + * + * @note This cannot be computed on matrices with more than 65535 columns. */ @Since("1.6.0") def computePrincipalComponentsAndExplainedVariance(k: Int): (Matrix, Vector) = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index 81e64de4e5b5d..c49e72646bf13 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -305,7 +305,8 @@ class LeastSquaresGradient extends Gradient { * :: DeveloperApi :: * Compute gradient and loss for a Hinge loss function, as used in SVM binary classification. * See also the documentation for the precise formulation. - * NOTE: This assumes that the labels are {0,1} + * + * @note This assumes that the labels are {0,1} */ @DeveloperApi class HingeGradient extends Gradient { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala index 0f7857b8d8627..005119616f063 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -31,7 +31,7 @@ import org.apache.spark.rdd.RDD class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable { /** - * Returns a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding + * Returns an RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding * window over them. The ordering is first based on the partition index and then the ordering of * items within each partition. This is similar to sliding in Scala collections, except that it * becomes an empty RDD if the window size is greater than the total number of items. It needs to diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index c642573ccba6d..24e4dcccc843f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -43,14 +43,14 @@ import org.apache.spark.storage.StorageLevel /** * Model representing the result of matrix factorization. * - * Note: If you create the model directly using constructor, please be aware that fast prediction - * requires cached user/product features and their associated partitioners. - * * @param rank Rank for the features in this model. * @param userFeatures RDD of tuples where each tuple represents the userId and * the features computed for this user. * @param productFeatures RDD of tuples where each tuple represents the productId * and the features computed for this product. + * + * @note If you create the model directly using constructor, please be aware that fast prediction + * requires cached user/product features and their associated partitioners. */ @Since("0.8.0") class MatrixFactorizationModel @Since("0.8.0") ( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index f3159f7e724cc..925fdf4d7e7bc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -60,15 +60,15 @@ object Statistics { * Compute the correlation matrix for the input RDD of Vectors using the specified method. * Methods currently supported: `pearson` (default), `spearman`. * - * Note that for Spearman, a rank correlation, we need to create an RDD[Double] for each column - * and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector], - * which is fairly costly. Cache the input RDD before calling corr with `method = "spearman"` to - * avoid recomputing the common lineage. - * * @param X an RDD[Vector] for which the correlation matrix is to be computed. * @param method String specifying the method to use for computing correlation. * Supported: `pearson` (default), `spearman` * @return Correlation matrix comparing columns in X. + * + * @note For Spearman, a rank correlation, we need to create an RDD[Double] for each column + * and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector], + * which is fairly costly. Cache the input RDD before calling corr with `method = "spearman"` to + * avoid recomputing the common lineage. */ @Since("1.1.0") def corr(X: RDD[Vector], method: String): Matrix = Correlations.corrMatrix(X, method) @@ -77,12 +77,12 @@ object Statistics { * Compute the Pearson correlation for the input RDDs. * Returns NaN if either vector has 0 variance. * - * Note: the two input RDDs need to have the same number of partitions and the same number of - * elements in each partition. - * * @param x RDD[Double] of the same cardinality as y. * @param y RDD[Double] of the same cardinality as x. * @return A Double containing the Pearson correlation between the two input RDD[Double]s + * + * @note The two input RDDs need to have the same number of partitions and the same number of + * elements in each partition. */ @Since("1.1.0") def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y) @@ -98,15 +98,15 @@ object Statistics { * Compute the correlation for the input RDDs using the specified method. * Methods currently supported: `pearson` (default), `spearman`. * - * Note: the two input RDDs need to have the same number of partitions and the same number of - * elements in each partition. - * * @param x RDD[Double] of the same cardinality as y. * @param y RDD[Double] of the same cardinality as x. * @param method String specifying the method to use for computing correlation. * Supported: `pearson` (default), `spearman` * @return A Double containing the correlation between the two input RDD[Double]s using the * specified method. + * + * @note The two input RDDs need to have the same number of partitions and the same number of + * elements in each partition. */ @Since("1.1.0") def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method) @@ -122,15 +122,15 @@ object Statistics { * Conduct Pearson's chi-squared goodness of fit test of the observed data against the * expected distribution. * - * Note: the two input Vectors need to have the same size. - * `observed` cannot contain negative values. - * `expected` cannot contain nonpositive values. - * * @param observed Vector containing the observed categorical counts/relative frequencies. * @param expected Vector containing the expected categorical counts/relative frequencies. * `expected` is rescaled if the `expected` sum differs from the `observed` sum. * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. + * + * @note The two input Vectors need to have the same size. + * `observed` cannot contain negative values. + * `expected` cannot contain nonpositive values. */ @Since("1.1.0") def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = { @@ -141,11 +141,11 @@ object Statistics { * Conduct Pearson's chi-squared goodness of fit test of the observed data against the uniform * distribution, with each category having an expected frequency of `1 / observed.size`. * - * Note: `observed` cannot contain negative values. - * * @param observed Vector containing the observed categorical counts/relative frequencies. * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. + * + * @note `observed` cannot contain negative values. */ @Since("1.1.0") def chiSqTest(observed: Vector): ChiSqTestResult = ChiSqTest.chiSquared(observed) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 36feab7859b43..d846c43cf2913 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -75,10 +75,6 @@ object DecisionTree extends Serializable with Logging { * Method to train a decision tree model. * The method supports binary and multiclass classification and regression. * - * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] - * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] - * is recommended to clearly separate classification and regression. - * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. @@ -86,6 +82,10 @@ object DecisionTree extends Serializable with Logging { * of decision tree (classification or regression), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. * @return DecisionTreeModel that can be used for prediction. + * + * @note Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. */ @Since("1.0.0") def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { @@ -96,10 +96,6 @@ object DecisionTree extends Serializable with Logging { * Method to train a decision tree model. * The method supports binary and multiclass classification and regression. * - * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] - * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] - * is recommended to clearly separate classification and regression. - * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. @@ -108,6 +104,10 @@ object DecisionTree extends Serializable with Logging { * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means * 1 internal node + 2 leaf nodes). * @return DecisionTreeModel that can be used for prediction. + * + * @note Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. */ @Since("1.0.0") def train( @@ -123,10 +123,6 @@ object DecisionTree extends Serializable with Logging { * Method to train a decision tree model. * The method supports binary and multiclass classification and regression. * - * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] - * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] - * is recommended to clearly separate classification and regression. - * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. @@ -136,6 +132,10 @@ object DecisionTree extends Serializable with Logging { * 1 internal node + 2 leaf nodes). * @param numClasses Number of classes for classification. Default value of 2. * @return DecisionTreeModel that can be used for prediction. + * + * @note Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. */ @Since("1.2.0") def train( @@ -152,10 +152,6 @@ object DecisionTree extends Serializable with Logging { * Method to train a decision tree model. * The method supports binary and multiclass classification and regression. * - * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] - * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] - * is recommended to clearly separate classification and regression. - * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. @@ -170,6 +166,10 @@ object DecisionTree extends Serializable with Logging { * indicates that feature n is categorical with k categories * indexed from 0: {0, 1, ..., k-1}. * @return DecisionTreeModel that can be used for prediction. + * + * @note Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. */ @Since("1.0.0") def train( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala index de14ddf024d75..09274a2e1b2ac 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -42,11 +42,13 @@ trait Loss extends Serializable { /** * Method to calculate error of the base learner for the gradient boosting calculation. - * Note: This method is not used by the gradient boosting algorithm but is useful for debugging - * purposes. + * * @param model Model of the weak learner. * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @return Measure of model error on data + * + * @note This method is not used by the gradient boosting algorithm but is useful for debugging + * purposes. */ @Since("1.2.0") def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { @@ -55,11 +57,13 @@ trait Loss extends Serializable { /** * Method to calculate loss when the predictions are already known. - * Note: This method is used in the method evaluateEachIteration to avoid recomputing the - * predicted values from previously fit trees. + * * @param prediction Predicted label. * @param label True label. * @return Measure of model error on datapoint. + * + * @note This method is used in the method evaluateEachIteration to avoid recomputing the + * predicted values from previously fit trees. */ private[spark] def computeError(prediction: Double, label: Double): Double } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 657ed0a8ecda8..299950785e420 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -187,7 +187,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { * @param initTreeWeight: learning rate assigned to the first tree. * @param initTree: first DecisionTreeModel. * @param loss: evaluation metric. - * @return a RDD with each element being a zip of the prediction and error + * @return an RDD with each element being a zip of the prediction and error * corresponding to every sample. */ @Since("1.4.0") @@ -213,7 +213,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { * @param treeWeight: Learning rate. * @param tree: Tree using which the prediction and error should be updated. * @param loss: evaluation metric. - * @return a RDD with each element being a zip of the prediction and error + * @return an RDD with each element being a zip of the prediction and error * corresponding to each sample. */ @Since("1.4.0") diff --git a/pom.xml b/pom.xml index 650b4cd965b66..024b2850d0a3d 100644 --- a/pom.xml +++ b/pom.xml @@ -2476,6 +2476,13 @@ maven-javadoc-plugin -Xdoclint:all -Xdoclint:-missing + + + note + a + Note: + + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 2d3a95b163a76..92b45657210e1 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -741,7 +741,8 @@ object Unidoc { javacOptions in (JavaUnidoc, unidoc) := Seq( "-windowtitle", "Spark " + version.value.replaceAll("-SNAPSHOT", "") + " JavaDoc", "-public", - "-noqualifier", "java.lang" + "-noqualifier", "java.lang", + "-tag", """note:a:Note\:""" ), // Use GitHub repository for Scaladoc source links diff --git a/python/pyspark/mllib/stat/KernelDensity.py b/python/pyspark/mllib/stat/KernelDensity.py index 3b1c5519bd87e..7250eab6705a7 100644 --- a/python/pyspark/mllib/stat/KernelDensity.py +++ b/python/pyspark/mllib/stat/KernelDensity.py @@ -28,7 +28,7 @@ class KernelDensity(object): """ - Estimate probability density at required points given a RDD of samples + Estimate probability density at required points given an RDD of samples from the population. >>> kd = KernelDensity() diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index ed6fd4bca4c54..97755807ef262 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -499,7 +499,7 @@ def generateLinearInput(intercept, weights, xMean, xVariance, def generateLinearRDD(sc, nexamples, nfeatures, eps, nParts=2, intercept=0.0): """ - Generate a RDD of LabeledPoints. + Generate an RDD of LabeledPoints. """ return callMLlibFunc( "generateLinearRDDWrapper", sc, int(nexamples), int(nfeatures), diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index a163ceafe9d3b..641787ee20e0c 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1218,7 +1218,7 @@ def mergeMaps(m1, m2): def top(self, num, key=None): """ - Get the top N elements from a RDD. + Get the top N elements from an RDD. Note that this method should only be used if the resulting array is expected to be small, as all the data is loaded into the driver's memory. @@ -1242,7 +1242,7 @@ def merge(a, b): def takeOrdered(self, num, key=None): """ - Get the N elements from a RDD ordered in ascending order or as + Get the N elements from an RDD ordered in ascending order or as specified by the optional key function. Note that this method should only be used if the resulting array is expected diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index bf27d8047a753..134424add3b62 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -144,7 +144,7 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders=None, """ .. note:: Experimental - Create a RDD from Kafka using offset ranges for each topic and partition. + Create an RDD from Kafka using offset ranges for each topic and partition. :param sc: SparkContext object :param kafkaParams: Additional params for Kafka @@ -155,7 +155,7 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders=None, :param valueDecoder: A function used to decode value (default is utf8_decoder) :param messageHandler: A function used to convert KafkaMessageAndMetadata. You can assess meta using messageHandler (default is None). - :return: A RDD object + :return: An RDD object """ if leaders is None: leaders = dict() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index dc90659a676e0..0b95a8821b05a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -165,10 +165,10 @@ object Encoders { * (Scala-specific) Creates an encoder that serializes objects of type T using generic Java * serialization. This encoder maps T into a single byte array (binary) field. * - * Note that this is extremely inefficient and should only be used as the last resort. - * * T must be publicly accessible. * + * @note This is extremely inefficient and should only be used as the last resort. + * * @since 1.6.0 */ def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false) @@ -177,10 +177,10 @@ object Encoders { * Creates an encoder that serializes objects of type T using generic Java serialization. * This encoder maps T into a single byte array (binary) field. * - * Note that this is extremely inefficient and should only be used as the last resort. - * * T must be publicly accessible. * + * @note This is extremely inefficient and should only be used as the last resort. + * * @since 1.6.0 */ def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala index e121044288e5a..21f3497ba06fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala @@ -23,10 +23,10 @@ import org.apache.spark.annotation.InterfaceStability * The data type representing calendar time intervals. The calendar time interval is stored * internally in two components: number of months the number of microseconds. * - * Note that calendar intervals are not comparable. - * * Please use the singleton [[DataTypes.CalendarIntervalType]]. * + * @note Calendar intervals are not comparable. + * * @since 1.5.0 */ @InterfaceStability.Stable diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 7a131b30eafd7..fa3b2b9de5d5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -118,7 +118,7 @@ class TypedColumn[-T, U]( * $"a" === $"b" * }}} * - * Note that the internal Catalyst expression can be accessed via "expr", but this method is for + * @note The internal Catalyst expression can be accessed via "expr", but this method is for * debugging purposes only and can change in any future Spark releases. * * @groupname java_expr_ops Java-specific expression operators diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index b5bbcee37150f..6335fc4579a28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -51,7 +51,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * The algorithm was first present in [[http://dx.doi.org/10.1145/375663.375670 Space-efficient * Online Computation of Quantile Summaries]] by Greenwald and Khanna. * - * Note that NaN values will be removed from the numerical column before calculation * @param col the name of the numerical column * @param probabilities a list of quantile probabilities * Each number must belong to [0, 1]. @@ -61,6 +60,8 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * Note that values greater than 1 are accepted but give the same result as 1. * @return the approximate quantiles at the given probabilities * + * @note NaN values will be removed from the numerical column before calculation + * * @since 2.0.0 */ def approxQuantile( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index e0c89811ddbfa..15281f24fa628 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -218,7 +218,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * Inserts the content of the [[DataFrame]] to the specified table. It requires that * the schema of the [[DataFrame]] is the same as the schema of the table. * - * Note: Unlike `saveAsTable`, `insertInto` ignores the column names and just uses position-based + * @note Unlike `saveAsTable`, `insertInto` ignores the column names and just uses position-based * resolution. For example: * * {{{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 3761773698df3..3c75a6a45ec86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -377,7 +377,7 @@ class Dataset[T] private[sql]( /** * Converts this strongly typed collection of data to generic `DataFrame` with columns renamed. - * This can be quite convenient in conversion from a RDD of tuples into a [[DataFrame]] with + * This can be quite convenient in conversion from an RDD of tuples into a [[DataFrame]] with * meaningful names. For example: * {{{ * val rdd: RDD[(Int, String)] = ... @@ -703,13 +703,13 @@ class Dataset[T] private[sql]( * df1.join(df2, "user_id") * }}} * - * Note that if you perform a self-join using this function without aliasing the input - * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since - * there is no way to disambiguate which side of the join you would like to reference. - * * @param right Right side of the join operation. * @param usingColumn Name of the column to join on. This column must exist on both sides. * + * @note If you perform a self-join using this function without aliasing the input + * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since + * there is no way to disambiguate which side of the join you would like to reference. + * * @group untypedrel * @since 2.0.0 */ @@ -728,13 +728,13 @@ class Dataset[T] private[sql]( * df1.join(df2, Seq("user_id", "user_name")) * }}} * - * Note that if you perform a self-join using this function without aliasing the input - * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since - * there is no way to disambiguate which side of the join you would like to reference. - * * @param right Right side of the join operation. * @param usingColumns Names of the columns to join on. This columns must exist on both sides. * + * @note If you perform a self-join using this function without aliasing the input + * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since + * there is no way to disambiguate which side of the join you would like to reference. + * * @group untypedrel * @since 2.0.0 */ @@ -748,14 +748,14 @@ class Dataset[T] private[sql]( * Different from other join functions, the join columns will only appear once in the output, * i.e. similar to SQL's `JOIN USING` syntax. * - * Note that if you perform a self-join using this function without aliasing the input - * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since - * there is no way to disambiguate which side of the join you would like to reference. - * * @param right Right side of the join operation. * @param usingColumns Names of the columns to join on. This columns must exist on both sides. * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. * + * @note If you perform a self-join using this function without aliasing the input + * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since + * there is no way to disambiguate which side of the join you would like to reference. + * * @group untypedrel * @since 2.0.0 */ @@ -856,10 +856,10 @@ class Dataset[T] private[sql]( /** * Explicit cartesian join with another [[DataFrame]]. * - * Note that cartesian joins are very expensive without an extra filter that can be pushed down. - * * @param right Right side of the join operation. * + * @note Cartesian joins are very expensive without an extra filter that can be pushed down. + * * @group untypedrel * @since 2.1.0 */ @@ -1044,7 +1044,8 @@ class Dataset[T] private[sql]( /** * Selects column based on the column name and return it as a [[Column]]. - * Note that the column name can also reference to a nested column like `a.b`. + * + * @note The column name can also reference to a nested column like `a.b`. * * @group untypedrel * @since 2.0.0 @@ -1053,7 +1054,8 @@ class Dataset[T] private[sql]( /** * Selects column based on the column name and return it as a [[Column]]. - * Note that the column name can also reference to a nested column like `a.b`. + * + * @note The column name can also reference to a nested column like `a.b`. * * @group untypedrel * @since 2.0.0 @@ -1621,7 +1623,7 @@ class Dataset[T] private[sql]( * Returns a new Dataset containing rows only in both this Dataset and another Dataset. * This is equivalent to `INTERSECT` in SQL. * - * Note that, equality checking is performed directly on the encoded representation of the data + * @note Equality checking is performed directly on the encoded representation of the data * and thus is not affected by a custom `equals` function defined on `T`. * * @group typedrel @@ -1635,7 +1637,7 @@ class Dataset[T] private[sql]( * Returns a new Dataset containing rows in this Dataset but not in another Dataset. * This is equivalent to `EXCEPT` in SQL. * - * Note that, equality checking is performed directly on the encoded representation of the data + * @note Equality checking is performed directly on the encoded representation of the data * and thus is not affected by a custom `equals` function defined on `T`. * * @group typedrel @@ -1648,13 +1650,13 @@ class Dataset[T] private[sql]( /** * Returns a new [[Dataset]] by sampling a fraction of rows, using a user-supplied seed. * - * Note: this is NOT guaranteed to provide exactly the fraction of the count - * of the given [[Dataset]]. - * * @param withReplacement Sample with replacement or not. * @param fraction Fraction of rows to generate. * @param seed Seed for sampling. * + * @note This is NOT guaranteed to provide exactly the fraction of the count + * of the given [[Dataset]]. + * * @group typedrel * @since 1.6.0 */ @@ -1670,12 +1672,12 @@ class Dataset[T] private[sql]( /** * Returns a new [[Dataset]] by sampling a fraction of rows, using a random seed. * - * Note: this is NOT guaranteed to provide exactly the fraction of the total count - * of the given [[Dataset]]. - * * @param withReplacement Sample with replacement or not. * @param fraction Fraction of rows to generate. * + * @note This is NOT guaranteed to provide exactly the fraction of the total count + * of the given [[Dataset]]. + * * @group typedrel * @since 1.6.0 */ @@ -2375,7 +2377,7 @@ class Dataset[T] private[sql]( * * The iterator will consume as much memory as the largest partition in this Dataset. * - * Note: this results in multiple Spark jobs, and if the input Dataset is the result + * @note this results in multiple Spark jobs, and if the input Dataset is the result * of a wide transformation (e.g. join with different partitioners), to avoid * recomputing the input Dataset should be cached first. * @@ -2453,7 +2455,7 @@ class Dataset[T] private[sql]( * Returns a new Dataset that contains only the unique rows from this Dataset. * This is an alias for `dropDuplicates`. * - * Note that, equality checking is performed directly on the encoded representation of the data + * @note Equality checking is performed directly on the encoded representation of the data * and thus is not affected by a custom `equals` function defined on `T`. * * @group typedrel diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 3c5cf037c578d..2fae93651b344 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -181,9 +181,6 @@ class SQLContext private[sql](val sparkSession: SparkSession) /** * A collection of methods for registering user-defined functions (UDF). - * Note that the user-defined functions must be deterministic. Due to optimization, - * duplicate invocations may be eliminated or the function may even be invoked more times than - * it is present in the query. * * The following example registers a Scala closure as UDF: * {{{ @@ -208,6 +205,10 @@ class SQLContext private[sql](val sparkSession: SparkSession) * DataTypes.StringType); * }}} * + * @note The user-defined functions must be deterministic. Due to optimization, + * duplicate invocations may be eliminated or the function may even be invoked more times than + * it is present in the query. + * * @group basic * @since 1.3.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 58b2ab3957173..e09e3caa3c981 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -155,9 +155,6 @@ class SparkSession private( /** * A collection of methods for registering user-defined functions (UDF). - * Note that the user-defined functions must be deterministic. Due to optimization, - * duplicate invocations may be eliminated or the function may even be invoked more times than - * it is present in the query. * * The following example registers a Scala closure as UDF: * {{{ @@ -182,6 +179,10 @@ class SparkSession private( * DataTypes.StringType); * }}} * + * @note The user-defined functions must be deterministic. Due to optimization, + * duplicate invocations may be eliminated or the function may even be invoked more times than + * it is present in the query. + * * @since 2.0.0 */ def udf: UDFRegistration = sessionState.udf @@ -201,7 +202,7 @@ class SparkSession private( * Start a new session with isolated SQL configurations, temporary tables, registered * functions are isolated, but sharing the underlying [[SparkContext]] and cached data. * - * Note: Other than the [[SparkContext]], all shared state is initialized lazily. + * @note Other than the [[SparkContext]], all shared state is initialized lazily. * This method will force the initialization of the shared state to ensure that parent * and child sessions are set up with the same shared state. If the underlying catalog * implementation is Hive, this will initialize the metastore, which may take some time. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 0444ad10d34fb..6043c5ee14b54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -39,7 +39,8 @@ import org.apache.spark.util.Utils /** * Functions for registering user-defined functions. Use [[SQLContext.udf]] to access this. - * Note that the user-defined functions must be deterministic. + * + * @note The user-defined functions must be deterministic. * * @since 1.3.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 4914a9d722a83..1b56c08f729c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -28,7 +28,7 @@ package object state { implicit class StateStoreOps[T: ClassTag](dataRDD: RDD[T]) { - /** Map each partition of a RDD along with data in a [[StateStore]]. */ + /** Map each partition of an RDD along with data in a [[StateStore]]. */ def mapPartitionsWithStateStore[U: ClassTag]( sqlContext: SQLContext, checkpointLocation: String, @@ -49,7 +49,7 @@ package object state { storeUpdateFunction) } - /** Map each partition of a RDD along with data in a [[StateStore]]. */ + /** Map each partition of an RDD along with data in a [[StateStore]]. */ private[streaming] def mapPartitionsWithStateStore[U: ClassTag]( checkpointLocation: String, operatorId: Long, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 28598af781653..36dd5f78ac137 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -25,9 +25,7 @@ import org.apache.spark.sql.types.DataType /** * A user-defined function. To create one, use the `udf` functions in [[functions]]. - * Note that the user-defined functions must be deterministic. Due to optimization, - * duplicate invocations may be eliminated or the function may even be invoked more times than - * it is present in the query. + * * As an example: * {{{ * // Defined a UDF that returns true or false based on some numeric score. @@ -37,6 +35,10 @@ import org.apache.spark.sql.types.DataType * df.select( predict(df("score")) ) * }}} * + * @note The user-defined functions must be deterministic. Due to optimization, + * duplicate invocations may be eliminated or the function may even be invoked more times than + * it is present in the query. + * * @since 1.3.0 */ @InterfaceStability.Stable diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e221c032b82f6..d5940c638acdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -476,7 +476,7 @@ object functions { * * (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn) * - * Note: the list of columns should match with grouping columns exactly, or empty (means all the + * @note The list of columns should match with grouping columns exactly, or empty (means all the * grouping columns). * * @group agg_funcs @@ -489,7 +489,7 @@ object functions { * * (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn) * - * Note: the list of columns should match with grouping columns exactly. + * @note The list of columns should match with grouping columns exactly. * * @group agg_funcs * @since 2.0.0 @@ -1120,7 +1120,7 @@ object functions { * Generate a random column with independent and identically distributed (i.i.d.) samples * from U[0.0, 1.0]. * - * Note that this is indeterministic when data partitions are not fixed. + * @note This is indeterministic when data partitions are not fixed. * * @group normal_funcs * @since 1.4.0 @@ -1140,7 +1140,7 @@ object functions { * Generate a column with independent and identically distributed (i.i.d.) samples from * the standard normal distribution. * - * Note that this is indeterministic when data partitions are not fixed. + * @note This is indeterministic when data partitions are not fixed. * * @group normal_funcs * @since 1.4.0 @@ -1159,7 +1159,7 @@ object functions { /** * Partition ID. * - * Note that this is indeterministic because it depends on data partitioning and task scheduling. + * @note This is indeterministic because it depends on data partitioning and task scheduling. * * @group normal_funcs * @since 1.6.0 @@ -2207,7 +2207,7 @@ object functions { * Locate the position of the first occurrence of substr column in the given string. * Returns null if either of the arguments are null. * - * NOTE: The position is not zero based, but 1 based index. Returns 0 if substr + * @note The position is not zero based, but 1 based index. Returns 0 if substr * could not be found in str. * * @group string_funcs @@ -2242,7 +2242,8 @@ object functions { /** * Locate the position of the first occurrence of substr. - * NOTE: The position is not zero based, but 1 based index. Returns 0 if substr + * + * @note The position is not zero based, but 1 based index. Returns 0 if substr * could not be found in str. * * @group string_funcs @@ -2255,7 +2256,7 @@ object functions { /** * Locate the position of the first occurrence of substr in a string column, after position pos. * - * NOTE: The position is not zero based, but 1 based index. returns 0 if substr + * @note The position is not zero based, but 1 based index. returns 0 if substr * could not be found in str. * * @group string_funcs @@ -2369,7 +2370,8 @@ object functions { /** * Splits str around pattern (pattern is a regular expression). - * NOTE: pattern is a string representation of the regular expression. + * + * @note Pattern is a string representation of the regular expression. * * @group string_funcs * @since 1.5.0 @@ -2468,7 +2470,7 @@ object functions { * A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All * pattern letters of [[java.text.SimpleDateFormat]] can be used. * - * NOTE: Use when ever possible specialized functions like [[year]]. These benefit from a + * @note Use when ever possible specialized functions like [[year]]. These benefit from a * specialized implementation. * * @group datetime_funcs diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index dec316be7aea1..7c64e28d24724 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -140,7 +140,7 @@ abstract class JdbcDialect extends Serializable { * tried in reverse order. A user-added dialect will thus be applied first, * overwriting the defaults. * - * Note that all new dialects are applied to new jdbc DataFrames only. Make + * @note All new dialects are applied to new jdbc DataFrames only. Make * sure to register your dialects first. */ @DeveloperApi diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 15a48072525b2..ff6dd8cb0cf92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -69,7 +69,8 @@ trait DataSourceRegister { trait RelationProvider { /** * Returns a new base relation with the given parameters. - * Note: the parameters' keywords are case insensitive and this insensitivity is enforced + * + * @note The parameters' keywords are case insensitive and this insensitivity is enforced * by the Map that is passed to the function. */ def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation @@ -99,7 +100,8 @@ trait RelationProvider { trait SchemaRelationProvider { /** * Returns a new base relation with the given parameters and user defined schema. - * Note: the parameters' keywords are case insensitive and this insensitivity is enforced + * + * @note The parameters' keywords are case insensitive and this insensitivity is enforced * by the Map that is passed to the function. */ def createRelation( @@ -205,7 +207,7 @@ abstract class BaseRelation { * large to broadcast. This method will be called multiple times during query planning * and thus should not perform expensive operations for each invocation. * - * Note that it is always better to overestimate size than underestimate, because underestimation + * @note It is always better to overestimate size than underestimate, because underestimation * could lead to execution plans that are suboptimal (i.e. broadcasting a very large table). * * @since 1.3.0 @@ -219,7 +221,7 @@ abstract class BaseRelation { * * If `needConversion` is `false`, buildScan() should return an [[RDD]] of [[InternalRow]] * - * Note: The internal representation is not stable across releases and thus data sources outside + * @note The internal representation is not stable across releases and thus data sources outside * of Spark SQL should leave this as true. * * @since 1.4.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index 5e93fc469a41f..4504582187b97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.QueryExecution * :: Experimental :: * The interface of query execution listener that can be used to analyze execution metrics. * - * Note that implementations should guarantee thread-safety as they can be invoked by + * @note Implementations should guarantee thread-safety as they can be invoked by * multiple different threads. */ @Experimental @@ -39,24 +39,26 @@ trait QueryExecutionListener { /** * A callback function that will be called when a query executed successfully. - * Note that this can be invoked by multiple different threads. * * @param funcName name of the action that triggered this query. * @param qe the QueryExecution object that carries detail information like logical plan, * physical plan, etc. * @param durationNs the execution time for this query in nanoseconds. + * + * @note This can be invoked by multiple different threads. */ @DeveloperApi def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit /** * A callback function that will be called when a query execution failed. - * Note that this can be invoked by multiple different threads. * * @param funcName the name of the action that triggered this query. * @param qe the QueryExecution object that carries detail information like logical plan, * physical plan, etc. * @param exception the exception that failed this query. + * + * @note This can be invoked by multiple different threads. */ @DeveloperApi def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 0daa29b666f62..b272c8e7d79c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -157,7 +157,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val allColumns = fields.map(_.name).mkString(",") val schema = StructType(fields) - // Create a RDD for the schema + // Create an RDD for the schema val rdd = sparkContext.parallelize((1 to 10000), 10).map { i => Row( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 4808d0fcbc6cc..444261da8de6a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -421,11 +421,11 @@ class StreamingContext private[streaming] ( * by "moving" them from another location within the same file system. File names * starting with . are ignored. * - * '''Note:''' We ensure that the byte array for each record in the - * resulting RDDs of the DStream has the provided record length. - * * @param directory HDFS directory to monitor for new file * @param recordLength length of each record in bytes + * + * @note We ensure that the byte array for each record in the + * resulting RDDs of the DStream has the provided record length. */ def binaryRecordsStream( directory: String, @@ -447,12 +447,12 @@ class StreamingContext private[streaming] ( * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of - * those RDDs, so `queueStream` doesn't support checkpointing. - * * @param queue Queue of RDDs. Modifications to this data structure must be synchronized. * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @tparam T Type of objects in the RDD + * + * @note Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. */ def queueStream[T: ClassTag]( queue: Queue[RDD[T]], @@ -465,14 +465,14 @@ class StreamingContext private[streaming] ( * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of - * those RDDs, so `queueStream` doesn't support checkpointing. - * * @param queue Queue of RDDs. Modifications to this data structure must be synchronized. * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @param defaultRDD Default RDD is returned by the DStream when the queue is empty. * Set as null if no RDD should be returned when empty * @tparam T Type of objects in the RDD + * + * @note Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. */ def queueStream[T: ClassTag]( queue: Queue[RDD[T]], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index da9ff858853cf..aa4003c62e1e7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -74,7 +74,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( */ def repartition(numPartitions: Int): JavaPairDStream[K, V] = dstream.repartition(numPartitions) - /** Method that generates a RDD for the given Duration */ + /** Method that generates an RDD for the given Duration */ def compute(validTime: Time): JavaPairRDD[K, V] = { dstream.compute(validTime) match { case Some(rdd) => new JavaPairRDD(rdd) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 4c4376a089f59..b43b9405def97 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -218,11 +218,11 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * for new files and reads them as flat binary files with fixed record lengths, * yielding byte arrays * - * '''Note:''' We ensure that the byte array for each record in the - * resulting RDDs of the DStream has the provided record length. - * * @param directory HDFS directory to monitor for new files * @param recordLength The length at which to split the records + * + * @note We ensure that the byte array for each record in the + * resulting RDDs of the DStream has the provided record length. */ def binaryRecordsStream(directory: String, recordLength: Int): JavaDStream[Array[Byte]] = { ssc.binaryRecordsStream(directory, recordLength) @@ -352,13 +352,13 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: + * @param queue Queue of RDDs + * @tparam T Type of objects in the RDD + * + * @note * 1. Changes to the queue after the stream is created will not be recognized. * 2. Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of * those RDDs, so `queueStream` doesn't support checkpointing. - * - * @param queue Queue of RDDs - * @tparam T Type of objects in the RDD */ def queueStream[T](queue: java.util.Queue[JavaRDD[T]]): JavaDStream[T] = { implicit val cm: ClassTag[T] = @@ -372,14 +372,14 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: - * 1. Changes to the queue after the stream is created will not be recognized. - * 2. Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of - * those RDDs, so `queueStream` doesn't support checkpointing. - * * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @tparam T Type of objects in the RDD + * + * @note + * 1. Changes to the queue after the stream is created will not be recognized. + * 2. Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. */ def queueStream[T]( queue: java.util.Queue[JavaRDD[T]], @@ -396,7 +396,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: + * @note * 1. Changes to the queue after the stream is created will not be recognized. * 2. Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of * those RDDs, so `queueStream` doesn't support checkpointing. @@ -454,9 +454,10 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { /** * Create a new DStream in which each RDD is generated by applying a function on RDDs of * the DStreams. The order of the JavaRDDs in the transform function parameter will be the - * same as the order of corresponding DStreams in the list. Note that for adding a - * JavaPairDStream in the list of JavaDStreams, convert it to a JavaDStream using - * [[org.apache.spark.streaming.api.java.JavaPairDStream]].toJavaDStream(). + * same as the order of corresponding DStreams in the list. + * + * @note For adding a JavaPairDStream in the list of JavaDStreams, convert it to a + * JavaDStream using [[org.apache.spark.streaming.api.java.JavaPairDStream]].toJavaDStream(). * In the transform function, convert the JavaRDD corresponding to that JavaDStream to * a JavaPairRDD using org.apache.spark.api.java.JavaPairRDD.fromJavaRDD(). */ @@ -476,9 +477,10 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { /** * Create a new DStream in which each RDD is generated by applying a function on RDDs of * the DStreams. The order of the JavaRDDs in the transform function parameter will be the - * same as the order of corresponding DStreams in the list. Note that for adding a - * JavaPairDStream in the list of JavaDStreams, convert it to a JavaDStream using - * [[org.apache.spark.streaming.api.java.JavaPairDStream]].toJavaDStream(). + * same as the order of corresponding DStreams in the list. + * + * @note For adding a JavaPairDStream in the list of JavaDStreams, convert it to + * a JavaDStream using [[org.apache.spark.streaming.api.java.JavaPairDStream]].toJavaDStream(). * In the transform function, convert the JavaRDD corresponding to that JavaDStream to * a JavaPairRDD using org.apache.spark.api.java.JavaPairRDD.fromJavaRDD(). */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index fa15a0bf65ab9..938a7fac1af41 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -68,13 +68,13 @@ abstract class DStream[T: ClassTag] ( // Methods that should be implemented by subclasses of DStream // ======================================================================= - /** Time interval after which the DStream generates a RDD */ + /** Time interval after which the DStream generates an RDD */ def slideDuration: Duration /** List of parent DStreams on which this DStream depends on */ def dependencies: List[DStream[_]] - /** Method that generates a RDD for the given time */ + /** Method that generates an RDD for the given time */ def compute(validTime: Time): Option[RDD[T]] // ======================================================================= diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala index ed08191f41cc8..9512db7d7d757 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala @@ -128,7 +128,7 @@ class InternalMapWithStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: Clas super.initialize(time) } - /** Method that generates a RDD for the given time */ + /** Method that generates an RDD for the given time */ override def compute(validTime: Time): Option[RDD[MapWithStateRDDRecord[K, S, E]]] = { // Get the previous state or create a new empty state RDD val prevStateRDD = getOrCompute(validTime - slideDuration) match { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index ce5a6e00fb2fe..a37fac87300b7 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -186,7 +186,7 @@ class WriteAheadLogBackedBlockRDDSuite assert(rdd.collect() === data.flatten) // Verify that the block fetching is skipped when isBlockValid is set to false. - // This is done by using a RDD whose data is only in memory but is set to skip block fetching + // This is done by using an RDD whose data is only in memory but is set to skip block fetching // Using that RDD will throw exception, as it skips block fetching even if the blocks are in // in BlockManager. if (testIsBlockValid) { From 30a6fbbb0fb47f5b74ceba3384f28a61bf4e4740 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 19 Nov 2016 11:28:25 +0000 Subject: [PATCH 149/534] [SPARK-18353][CORE] spark.rpc.askTimeout defalut value is not 120s ## What changes were proposed in this pull request? Avoid hard-coding spark.rpc.askTimeout to non-default in Client; fix doc about spark.rpc.askTimeout default ## How was this patch tested? Existing tests Author: Sean Owen Closes #15833 from srowen/SPARK-18353. (cherry picked from commit 8b1e1088eb274fb15260cd5d6d9508d42837a4d6) Signed-off-by: Sean Owen --- core/src/main/scala/org/apache/spark/deploy/Client.scala | 4 +++- docs/configuration.md | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index ee276e1b71138..a4de3d7eaf458 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -221,7 +221,9 @@ object Client { val conf = new SparkConf() val driverArgs = new ClientArguments(args) - conf.set("spark.rpc.askTimeout", "10") + if (!conf.contains("spark.rpc.askTimeout")) { + conf.set("spark.rpc.askTimeout", "10s") + } Logger.getRootLogger.setLevel(driverArgs.logLevel) val rpcEnv = diff --git a/docs/configuration.md b/docs/configuration.md index e0c661349caab..c2329b411fc69 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1175,7 +1175,7 @@ Apart from these, the following properties are also available, and may be useful spark.rpc.askTimeout - 120s + spark.network.timeout Duration for an RPC ask operation to wait before timing out. @@ -1531,7 +1531,7 @@ Apart from these, the following properties are also available, and may be useful spark.core.connection.ack.wait.timeout - 60s + spark.network.timeout How long for the connection to wait for ack to occur before timing out and giving up. To avoid unwilling timeout caused by long pause like GC, From 15ad3a319b91a8b495da9a0e6f5386417991d30d Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 19 Nov 2016 13:48:56 +0000 Subject: [PATCH 150/534] [SPARK-18448][CORE] Fix @since 2.1.0 on new SparkSession.close() method ## What changes were proposed in this pull request? Fix since 2.1.0 on new SparkSession.close() method. I goofed in https://github.com/apache/spark/pull/15932 because it was back-ported to 2.1 instead of just master as originally planned. Author: Sean Owen Closes #15938 from srowen/SPARK-18448.2. (cherry picked from commit ded5fefb6f5c0a97bf3d7fa1c0494dc434b6ee40) Signed-off-by: Sean Owen --- sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index e09e3caa3c981..71b1880dc0715 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -652,7 +652,7 @@ class SparkSession private( /** * Synonym for `stop()`. * - * @since 2.2.0 + * @since 2.1.0 */ override def close(): Unit = stop() From 15eb86c29c02178f4413df63c39b8df3cda30ca8 Mon Sep 17 00:00:00 2001 From: sethah Date: Sun, 20 Nov 2016 01:42:37 +0000 Subject: [PATCH 151/534] [SPARK-18456][ML][FOLLOWUP] Use matrix abstraction for coefficients in LogisticRegression training ## What changes were proposed in this pull request? This is a follow up to some of the discussion [here](https://github.com/apache/spark/pull/15593). During LogisticRegression training, we store the coefficients combined with intercepts as a flat vector, but a more natural abstraction is a matrix. Here, we refactor the code to use matrix where possible, which makes the code more readable and greatly simplifies the indexing. Note: We do not use a Breeze matrix for the cost function as was mentioned in the linked PR. This is because LBFGS/OWLQN require an implicit `MutableInnerProductModule[DenseMatrix[Double], Double]` which is not natively defined in Breeze. We would need to extend Breeze in Spark to define it ourselves. Also, we do not modify the `regParamL1Fun` because OWLQN in Breeze requires a `MutableEnumeratedCoordinateField[(Int, Int), DenseVector[Double]]` (since we still use a dense vector for coefficients). Here again we would have to extend Breeze inside Spark. ## How was this patch tested? This is internal code refactoring - the current unit tests passing show us that the change did not break anything. No added functionality in this patch. Author: sethah Closes #15893 from sethah/logreg_refactor. (cherry picked from commit 856e0042007c789dda4539fb19a5d4580999fbf4) Signed-off-by: DB Tsai --- .../classification/LogisticRegression.scala | 115 ++++++++---------- 1 file changed, 53 insertions(+), 62 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 71a7fe53c15f8..f58efd36a1c66 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -463,16 +463,11 @@ class LogisticRegression @Since("1.2.0") ( } /* - The coefficients are laid out in column major order during training. e.g. for - `numClasses = 3` and `numFeatures = 2` and `fitIntercept = true` the layout is: - - Array(beta_11, beta_21, beta_31, beta_12, beta_22, beta_32, intercept_1, intercept_2, - intercept_3) - - where beta_jk corresponds to the coefficient for class `j` and feature `k`. + The coefficients are laid out in column major order during training. Here we initialize + a column major matrix of initial coefficients. */ - val initialCoefficientsWithIntercept = - Vectors.zeros(numCoefficientSets * numFeaturesPlusIntercept) + val initialCoefWithInterceptMatrix = + Matrices.zeros(numCoefficientSets, numFeaturesPlusIntercept) val initialModelIsValid = optInitialModel match { case Some(_initialModel) => @@ -491,18 +486,15 @@ class LogisticRegression @Since("1.2.0") ( } if (initialModelIsValid) { - val initialCoefWithInterceptArray = initialCoefficientsWithIntercept.toArray val providedCoef = optInitialModel.get.coefficientMatrix - providedCoef.foreachActive { (row, col, value) => - // convert matrix to column major for training - val flatIndex = col * numCoefficientSets + row + providedCoef.foreachActive { (classIndex, featureIndex, value) => // We need to scale the coefficients since they will be trained in the scaled space - initialCoefWithInterceptArray(flatIndex) = value * featuresStd(col) + initialCoefWithInterceptMatrix.update(classIndex, featureIndex, + value * featuresStd(featureIndex)) } if ($(fitIntercept)) { - optInitialModel.get.interceptVector.foreachActive { (index, value) => - val coefIndex = numCoefficientSets * numFeatures + index - initialCoefWithInterceptArray(coefIndex) = value + optInitialModel.get.interceptVector.foreachActive { (classIndex, value) => + initialCoefWithInterceptMatrix.update(classIndex, numFeatures, value) } } } else if ($(fitIntercept) && isMultinomial) { @@ -532,8 +524,7 @@ class LogisticRegression @Since("1.2.0") ( val rawIntercepts = histogram.map(c => math.log(c + 1)) // add 1 for smoothing val rawMean = rawIntercepts.sum / rawIntercepts.length rawIntercepts.indices.foreach { i => - initialCoefficientsWithIntercept.toArray(numClasses * numFeatures + i) = - rawIntercepts(i) - rawMean + initialCoefWithInterceptMatrix.update(i, numFeatures, rawIntercepts(i) - rawMean) } } else if ($(fitIntercept)) { /* @@ -549,12 +540,12 @@ class LogisticRegression @Since("1.2.0") ( b = \log{P(1) / P(0)} = \log{count_1 / count_0} }}} */ - initialCoefficientsWithIntercept.toArray(numFeatures) = math.log( - histogram(1) / histogram(0)) + initialCoefWithInterceptMatrix.update(0, numFeatures, + math.log(histogram(1) / histogram(0))) } val states = optimizer.iterations(new CachedDiffFunction(costFun), - initialCoefficientsWithIntercept.asBreeze.toDenseVector) + new BDV[Double](initialCoefWithInterceptMatrix.toArray)) /* Note that in Logistic Regression, the objective history (loss + regularization) @@ -586,15 +577,24 @@ class LogisticRegression @Since("1.2.0") ( Note that the intercept in scaled space and original space is the same; as a result, no scaling is needed. */ - val rawCoefficients = state.x.toArray.clone() - val coefficientArray = Array.tabulate(numCoefficientSets * numFeatures) { i => - val colMajorIndex = (i % numFeatures) * numCoefficientSets + i / numFeatures - val featureIndex = i % numFeatures - if (featuresStd(featureIndex) != 0.0) { - rawCoefficients(colMajorIndex) / featuresStd(featureIndex) - } else { - 0.0 + val allCoefficients = state.x.toArray.clone() + val allCoefMatrix = new DenseMatrix(numCoefficientSets, numFeaturesPlusIntercept, + allCoefficients) + val denseCoefficientMatrix = new DenseMatrix(numCoefficientSets, numFeatures, + new Array[Double](numCoefficientSets * numFeatures), isTransposed = true) + val interceptVec = if ($(fitIntercept) || !isMultinomial) { + Vectors.zeros(numCoefficientSets) + } else { + Vectors.sparse(numCoefficientSets, Seq()) + } + // separate intercepts and coefficients from the combined matrix + allCoefMatrix.foreachActive { (classIndex, featureIndex, value) => + val isIntercept = $(fitIntercept) && (featureIndex == numFeatures) + if (!isIntercept && featuresStd(featureIndex) != 0.0) { + denseCoefficientMatrix.update(classIndex, featureIndex, + value / featuresStd(featureIndex)) } + if (isIntercept) interceptVec.toArray(classIndex) = value } if ($(regParam) == 0.0 && isMultinomial) { @@ -607,17 +607,16 @@ class LogisticRegression @Since("1.2.0") ( Friedman, et al. "Regularization Paths for Generalized Linear Models via Coordinate Descent," https://core.ac.uk/download/files/153/6287975.pdf */ - val coefficientMean = coefficientArray.sum / coefficientArray.length - coefficientArray.indices.foreach { i => coefficientArray(i) -= coefficientMean} + val denseValues = denseCoefficientMatrix.values + val coefficientMean = denseValues.sum / denseValues.length + denseCoefficientMatrix.update(_ - coefficientMean) } - val denseCoefficientMatrix = - new DenseMatrix(numCoefficientSets, numFeatures, coefficientArray, isTransposed = true) // TODO: use `denseCoefficientMatrix.compressed` after SPARK-17471 val compressedCoefficientMatrix = if (isMultinomial) { denseCoefficientMatrix } else { - val compressedVector = Vectors.dense(coefficientArray).compressed + val compressedVector = Vectors.dense(denseCoefficientMatrix.values).compressed compressedVector match { case dv: DenseVector => denseCoefficientMatrix case sv: SparseVector => @@ -626,25 +625,13 @@ class LogisticRegression @Since("1.2.0") ( } } - val interceptsArray: Array[Double] = if ($(fitIntercept)) { - Array.tabulate(numCoefficientSets) { i => - val coefIndex = numFeatures * numCoefficientSets + i - rawCoefficients(coefIndex) - } - } else { - Array.empty[Double] - } - val interceptVector = if (interceptsArray.nonEmpty && isMultinomial) { - // The intercepts are never regularized, so we always center the mean. - val interceptMean = interceptsArray.sum / numClasses - interceptsArray.indices.foreach { i => interceptsArray(i) -= interceptMean } - Vectors.dense(interceptsArray) - } else if (interceptsArray.length == 1) { - Vectors.dense(interceptsArray) - } else { - Vectors.sparse(numCoefficientSets, Seq()) + // center the intercepts when using multinomial algorithm + if ($(fitIntercept) && isMultinomial) { + val interceptArray = interceptVec.toArray + val interceptMean = interceptArray.sum / interceptArray.length + (0 until interceptVec.size).foreach { i => interceptArray(i) -= interceptMean } } - (compressedCoefficientMatrix, interceptVector.compressed, arrayBuilder.result()) + (compressedCoefficientMatrix, interceptVec.compressed, arrayBuilder.result()) } } @@ -1424,6 +1411,7 @@ private class LogisticAggregator( private val numFeatures = bcFeaturesStd.value.length private val numFeaturesPlusIntercept = if (fitIntercept) numFeatures + 1 else numFeatures private val coefficientSize = bcCoefficients.value.size + private val numCoefficientSets = if (multinomial) numClasses else 1 if (multinomial) { require(numClasses == coefficientSize / numFeaturesPlusIntercept, s"The number of " + s"coefficients should be ${numClasses * numFeaturesPlusIntercept} but was $coefficientSize") @@ -1633,12 +1621,12 @@ private class LogisticAggregator( lossSum / weightSum } - def gradient: Vector = { + def gradient: Matrix = { require(weightSum > 0.0, s"The effective number of instances should be " + s"greater than 0.0, but $weightSum.") val result = Vectors.dense(gradientSumArray.clone()) scal(1.0 / weightSum, result) - result + new DenseMatrix(numCoefficientSets, numFeaturesPlusIntercept, result.toArray) } } @@ -1664,6 +1652,7 @@ private class LogisticCostFun( val featuresStd = bcFeaturesStd.value val numFeatures = featuresStd.length val numCoefficientSets = if (multinomial) numClasses else 1 + val numFeaturesPlusIntercept = if (fitIntercept) numFeatures + 1 else numFeatures val logisticAggregator = { val seqOp = (c: LogisticAggregator, instance: Instance) => c.add(instance) @@ -1675,24 +1664,25 @@ private class LogisticCostFun( )(seqOp, combOp, aggregationDepth) } - val totalGradientArray = logisticAggregator.gradient.toArray + val totalGradientMatrix = logisticAggregator.gradient + val coefMatrix = new DenseMatrix(numCoefficientSets, numFeaturesPlusIntercept, coeffs.toArray) // regVal is the sum of coefficients squares excluding intercept for L2 regularization. val regVal = if (regParamL2 == 0.0) { 0.0 } else { var sum = 0.0 - coeffs.foreachActive { case (index, value) => + coefMatrix.foreachActive { case (classIndex, featureIndex, value) => // We do not apply regularization to the intercepts - val isIntercept = fitIntercept && index >= numCoefficientSets * numFeatures + val isIntercept = fitIntercept && (featureIndex == numFeatures) if (!isIntercept) { // The following code will compute the loss of the regularization; also // the gradient of the regularization, and add back to totalGradientArray. sum += { if (standardization) { - totalGradientArray(index) += regParamL2 * value + val gradValue = totalGradientMatrix(classIndex, featureIndex) + totalGradientMatrix.update(classIndex, featureIndex, gradValue + regParamL2 * value) value * value } else { - val featureIndex = index / numCoefficientSets if (featuresStd(featureIndex) != 0.0) { // If `standardization` is false, we still standardize the data // to improve the rate of convergence; as a result, we have to @@ -1700,7 +1690,8 @@ private class LogisticCostFun( // differently to get effectively the same objective function when // the training dataset is not standardized. val temp = value / (featuresStd(featureIndex) * featuresStd(featureIndex)) - totalGradientArray(index) += regParamL2 * temp + val gradValue = totalGradientMatrix(classIndex, featureIndex) + totalGradientMatrix.update(classIndex, featureIndex, gradValue + regParamL2 * temp) value * temp } else { 0.0 @@ -1713,6 +1704,6 @@ private class LogisticCostFun( } bcCoeffs.destroy(blocking = false) - (logisticAggregator.loss + regVal, new BDV(totalGradientArray)) + (logisticAggregator.loss + regVal, new BDV(totalGradientMatrix.toArray)) } } From b0b2f10817f38d9cebd2e436a07d4dd3e41e9328 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 19 Nov 2016 21:50:20 -0800 Subject: [PATCH 152/534] [SPARK-18458][CORE] Fix signed integer overflow problem at an expression in RadixSort.java ## What changes were proposed in this pull request? This PR avoids that a result of an expression is negative due to signed integer overflow (e.g. 0x10?????? * 8 < 0). This PR casts each operand to `long` before executing a calculation. Since the result is interpreted as long, the result of the expression is positive. ## How was this patch tested? Manually executed query82 of TPC-DS with 100TB Author: Kazuaki Ishizaki Closes #15907 from kiszk/SPARK-18458. (cherry picked from commit d93b6552473468df297a08c0bef9ea0bf0f5c13a) Signed-off-by: Reynold Xin --- .../collection/unsafe/sort/RadixSort.java | 48 ++++++++++--------- .../unsafe/sort/UnsafeInMemorySorter.java | 2 +- .../unsafe/sort/RadixSortSuite.scala | 28 +++++------ 3 files changed, 40 insertions(+), 38 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java index 404361734a55b..3dd318471008b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java @@ -17,6 +17,8 @@ package org.apache.spark.util.collection.unsafe.sort; +import com.google.common.primitives.Ints; + import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; @@ -40,14 +42,14 @@ public class RadixSort { * of always copying the data back to position zero for efficiency. */ public static int sort( - LongArray array, int numRecords, int startByteIndex, int endByteIndex, + LongArray array, long numRecords, int startByteIndex, int endByteIndex, boolean desc, boolean signed) { assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0"; assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7"; assert endByteIndex > startByteIndex; assert numRecords * 2 <= array.size(); - int inIndex = 0; - int outIndex = numRecords; + long inIndex = 0; + long outIndex = numRecords; if (numRecords > 0) { long[][] counts = getCounts(array, numRecords, startByteIndex, endByteIndex); for (int i = startByteIndex; i <= endByteIndex; i++) { @@ -55,13 +57,13 @@ public static int sort( sortAtByte( array, numRecords, counts[i], i, inIndex, outIndex, desc, signed && i == endByteIndex); - int tmp = inIndex; + long tmp = inIndex; inIndex = outIndex; outIndex = tmp; } } } - return inIndex; + return Ints.checkedCast(inIndex); } /** @@ -78,14 +80,14 @@ public static int sort( * @param signed whether this is a signed (two's complement) sort (only applies to last byte). */ private static void sortAtByte( - LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex, + LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex, boolean desc, boolean signed) { assert counts.length == 256; long[] offsets = transformCountsToOffsets( - counts, numRecords, array.getBaseOffset() + outIndex * 8, 8, desc, signed); + counts, numRecords, array.getBaseOffset() + outIndex * 8L, 8, desc, signed); Object baseObject = array.getBaseObject(); - long baseOffset = array.getBaseOffset() + inIndex * 8; - long maxOffset = baseOffset + numRecords * 8; + long baseOffset = array.getBaseOffset() + inIndex * 8L; + long maxOffset = baseOffset + numRecords * 8L; for (long offset = baseOffset; offset < maxOffset; offset += 8) { long value = Platform.getLong(baseObject, offset); int bucket = (int)((value >>> (byteIdx * 8)) & 0xff); @@ -106,13 +108,13 @@ private static void sortAtByte( * significant byte. If the byte does not need sorting the array will be null. */ private static long[][] getCounts( - LongArray array, int numRecords, int startByteIndex, int endByteIndex) { + LongArray array, long numRecords, int startByteIndex, int endByteIndex) { long[][] counts = new long[8][]; // Optimization: do a fast pre-pass to determine which byte indices we can skip for sorting. // If all the byte values at a particular index are the same we don't need to count it. long bitwiseMax = 0; long bitwiseMin = -1L; - long maxOffset = array.getBaseOffset() + numRecords * 8; + long maxOffset = array.getBaseOffset() + numRecords * 8L; Object baseObject = array.getBaseObject(); for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) { long value = Platform.getLong(baseObject, offset); @@ -146,18 +148,18 @@ private static long[][] getCounts( * @return the input counts array. */ private static long[] transformCountsToOffsets( - long[] counts, int numRecords, long outputOffset, int bytesPerRecord, + long[] counts, long numRecords, long outputOffset, long bytesPerRecord, boolean desc, boolean signed) { assert counts.length == 256; int start = signed ? 128 : 0; // output the negative records first (values 129-255). if (desc) { - int pos = numRecords; + long pos = numRecords; for (int i = start; i < start + 256; i++) { pos -= counts[i & 0xff]; counts[i & 0xff] = outputOffset + pos * bytesPerRecord; } } else { - int pos = 0; + long pos = 0; for (int i = start; i < start + 256; i++) { long tmp = counts[i & 0xff]; counts[i & 0xff] = outputOffset + pos * bytesPerRecord; @@ -176,8 +178,8 @@ private static long[] transformCountsToOffsets( */ public static int sortKeyPrefixArray( LongArray array, - int startIndex, - int numRecords, + long startIndex, + long numRecords, int startByteIndex, int endByteIndex, boolean desc, @@ -186,8 +188,8 @@ public static int sortKeyPrefixArray( assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7"; assert endByteIndex > startByteIndex; assert numRecords * 4 <= array.size(); - int inIndex = startIndex; - int outIndex = startIndex + numRecords * 2; + long inIndex = startIndex; + long outIndex = startIndex + numRecords * 2L; if (numRecords > 0) { long[][] counts = getKeyPrefixArrayCounts( array, startIndex, numRecords, startByteIndex, endByteIndex); @@ -196,13 +198,13 @@ public static int sortKeyPrefixArray( sortKeyPrefixArrayAtByte( array, numRecords, counts[i], i, inIndex, outIndex, desc, signed && i == endByteIndex); - int tmp = inIndex; + long tmp = inIndex; inIndex = outIndex; outIndex = tmp; } } } - return inIndex; + return Ints.checkedCast(inIndex); } /** @@ -210,7 +212,7 @@ public static int sortKeyPrefixArray( * getCounts with some added parameters but that seems to hurt in benchmarks. */ private static long[][] getKeyPrefixArrayCounts( - LongArray array, int startIndex, int numRecords, int startByteIndex, int endByteIndex) { + LongArray array, long startIndex, long numRecords, int startByteIndex, int endByteIndex) { long[][] counts = new long[8][]; long bitwiseMax = 0; long bitwiseMin = -1L; @@ -238,11 +240,11 @@ private static long[][] getKeyPrefixArrayCounts( * Specialization of sortAtByte() for key-prefix arrays. */ private static void sortKeyPrefixArrayAtByte( - LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex, + LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex, boolean desc, boolean signed) { assert counts.length == 256; long[] offsets = transformCountsToOffsets( - counts, numRecords, array.getBaseOffset() + outIndex * 8, 16, desc, signed); + counts, numRecords, array.getBaseOffset() + outIndex * 8L, 16, desc, signed); Object baseObject = array.getBaseObject(); long baseOffset = array.getBaseOffset() + inIndex * 8L; long maxOffset = baseOffset + numRecords * 16L; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 2a71e68adafad..252a35ec6bdf5 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -322,7 +322,7 @@ public UnsafeSorterIterator getSortedIterator() { if (sortComparator != null) { if (this.radixSortSupport != null) { offset = RadixSort.sortKeyPrefixArray( - array, nullBoundaryPos, (pos - nullBoundaryPos) / 2, 0, 7, + array, nullBoundaryPos, (pos - nullBoundaryPos) / 2L, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned()); } else { MemoryBlock unused = new MemoryBlock( diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala index 366ffda7788d3..d5956ea32096a 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala @@ -22,6 +22,8 @@ import java.util.{Arrays, Comparator} import scala.util.Random +import com.google.common.primitives.Ints + import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.unsafe.array.LongArray @@ -30,7 +32,7 @@ import org.apache.spark.util.collection.Sorter import org.apache.spark.util.random.XORShiftRandom class RadixSortSuite extends SparkFunSuite with Logging { - private val N = 10000 // scale this down for more readable results + private val N = 10000L // scale this down for more readable results /** * Describes a type of sort to test, e.g. two's complement descending. Each sort type has @@ -73,22 +75,22 @@ class RadixSortSuite extends SparkFunSuite with Logging { }, 2, 4, false, false, true)) - private def generateTestData(size: Int, rand: => Long): (Array[JLong], LongArray) = { - val ref = Array.tabulate[Long](size) { i => rand } - val extended = ref ++ Array.fill[Long](size)(0) + private def generateTestData(size: Long, rand: => Long): (Array[JLong], LongArray) = { + val ref = Array.tabulate[Long](Ints.checkedCast(size)) { i => rand } + val extended = ref ++ Array.fill[Long](Ints.checkedCast(size))(0) (ref.map(i => new JLong(i)), new LongArray(MemoryBlock.fromLongArray(extended))) } - private def generateKeyPrefixTestData(size: Int, rand: => Long): (LongArray, LongArray) = { - val ref = Array.tabulate[Long](size * 2) { i => rand } - val extended = ref ++ Array.fill[Long](size * 2)(0) + private def generateKeyPrefixTestData(size: Long, rand: => Long): (LongArray, LongArray) = { + val ref = Array.tabulate[Long](Ints.checkedCast(size * 2)) { i => rand } + val extended = ref ++ Array.fill[Long](Ints.checkedCast(size * 2))(0) (new LongArray(MemoryBlock.fromLongArray(ref)), new LongArray(MemoryBlock.fromLongArray(extended))) } - private def collectToArray(array: LongArray, offset: Int, length: Int): Array[Long] = { + private def collectToArray(array: LongArray, offset: Int, length: Long): Array[Long] = { var i = 0 - val out = new Array[Long](length) + val out = new Array[Long](Ints.checkedCast(length)) while (i < length) { out(i) = array.get(offset + i) i += 1 @@ -107,15 +109,13 @@ class RadixSortSuite extends SparkFunSuite with Logging { } } - private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) { + private def referenceKeyPrefixSort(buf: LongArray, lo: Long, hi: Long, refCmp: PrefixComparator) { val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt))) new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort( - buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] { + buf, Ints.checkedCast(lo), Ints.checkedCast(hi), new Comparator[RecordPointerAndKeyPrefix] { override def compare( r1: RecordPointerAndKeyPrefix, - r2: RecordPointerAndKeyPrefix): Int = { - refCmp.compare(r1.keyPrefix, r2.keyPrefix) - } + r2: RecordPointerAndKeyPrefix): Int = refCmp.compare(r1.keyPrefix, r2.keyPrefix) }) } From 94a9eed11a11510a91dc4c8adb793dc3cbdef8f5 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 19 Nov 2016 21:57:09 -0800 Subject: [PATCH 153/534] [SPARK-18508][SQL] Fix documentation error for DateDiff ## What changes were proposed in this pull request? The previous documentation and example for DateDiff was wrong. ## How was this patch tested? Doc only change. Author: Reynold Xin Closes #15937 from rxin/datediff-doc. (cherry picked from commit bce9a03677f931d52491e7768aba9e4a19a7e696) Signed-off-by: Reynold Xin --- .../sql/catalyst/expressions/datetimeExpressions.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 9cec6be841de0..1db1d1995d942 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1101,11 +1101,14 @@ case class TruncDate(date: Expression, format: Expression) * Returns the number of days from startDate to endDate. */ @ExpressionDescription( - usage = "_FUNC_(date1, date2) - Returns the number of days between `date1` and `date2`.", + usage = "_FUNC_(endDate, startDate) - Returns the number of days from `startDate` to `endDate`.", extended = """ Examples: - > SELECT _FUNC_('2009-07-30', '2009-07-31'); + > SELECT _FUNC_('2009-07-31', '2009-07-30'); 1 + + > SELECT _FUNC_('2009-07-30', '2009-07-31'); + -1 """) case class DateDiff(endDate: Expression, startDate: Expression) extends BinaryExpression with ImplicitCastInputTypes { From 063da0c8d4e82a47cf7841578dcf968080c3d89d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 19 Nov 2016 21:57:49 -0800 Subject: [PATCH 154/534] [SQL] Fix documentation for Concat and ConcatWs (cherry picked from commit a64f25d8b403b17ff68c9575f6f35b22e5b62427) Signed-off-by: Reynold Xin --- .../sql/catalyst/expressions/stringExpressions.scala | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index e74ef9a08750e..908aa44f81c97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -40,15 +40,13 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} * An expression that concatenates multiple input strings into a single string. * If any input is null, concat returns null. */ -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(str1, str2, ..., strN) - Returns the concatenation of `str1`, `str2`, ..., `strN`.", + usage = "_FUNC_(str1, str2, ..., strN) - Returns the concatenation of str1, str2, ..., strN.", extended = """ Examples: - > SELECT _FUNC_('Spark','SQL'); + > SELECT _FUNC_('Spark', 'SQL'); SparkSQL """) -// scalastyle:on line.size.limit case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) @@ -89,8 +87,8 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas usage = "_FUNC_(sep, [str | array(str)]+) - Returns the concatenation of the strings separated by `sep`.", extended = """ Examples: - > SELECT _FUNC_(' ', Spark', 'SQL'); - Spark SQL + > SELECT _FUNC_(' ', 'Spark', 'SQL'); + Spark SQL """) // scalastyle:on line.size.limit case class ConcatWs(children: Seq[Expression]) From bc3e7b3b8a0dfc00d22bf5ee168f308a6ef5d78b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 20 Nov 2016 09:52:03 +0000 Subject: [PATCH 155/534] [SPARK-3359][BUILD][DOCS] Print examples and disable group and tparam tags in javadoc ## What changes were proposed in this pull request? This PR proposes/fixes two things. - Remove many errors to generate javadoc with Java8 from unrecognisable tags, `tparam` and `group`. ``` [error] .../spark/mllib/target/java/org/apache/spark/ml/classification/Classifier.java:18: error: unknown tag: group [error] /** group setParam */ [error] ^ [error] .../spark/mllib/target/java/org/apache/spark/ml/classification/Classifier.java:8: error: unknown tag: tparam [error] * tparam FeaturesType Type of input features. E.g., Vector [error] ^ ... ``` It does not fully resolve the problem but remove many errors. It seems both `group` and `tparam` are unrecognisable in javadoc. It seems we can't print them pretty in javadoc in a way of `example` here because they appear differently (both examples can be found in http://spark.apache.org/docs/2.0.2/api/scala/index.html#org.apache.spark.ml.classification.Classifier). - Print `example` in javadoc. Currently, there are few `example` tag in several places. ``` ./graphx/src/main/scala/org/apache/spark/graphx/Graph.scala: * example This operation might be used to evaluate a graph ./graphx/src/main/scala/org/apache/spark/graphx/Graph.scala: * example We might use this operation to change the vertex values ./graphx/src/main/scala/org/apache/spark/graphx/Graph.scala: * example This function might be used to initialize edge ./graphx/src/main/scala/org/apache/spark/graphx/Graph.scala: * example This function might be used to initialize edge ./graphx/src/main/scala/org/apache/spark/graphx/Graph.scala: * example This function might be used to initialize edge ./graphx/src/main/scala/org/apache/spark/graphx/Graph.scala: * example We can use this function to compute the in-degree of each ./graphx/src/main/scala/org/apache/spark/graphx/Graph.scala: * example This function is used to update the vertices with new values based on external data. ./graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala: * example Loads a file in the following format: ./graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala: * example This function is used to update the vertices with new ./graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala: * example This function can be used to filter the graph based on some property, without ./graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala: * example We can use the Pregel abstraction to implement PageRank: ./graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala: * example Construct a `VertexRDD` from a plain RDD: ./repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala: * example new SparkCommandLine(Nil).settings ./repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala: * example addImports("org.apache.spark.SparkContext") ./sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala: * example {{{ ``` **Before** 2016-11-20 2 43 23 **After** 2016-11-20 1 27 17 ## How was this patch tested? Maunally tested by `jekyll build` with Java 7 and 8 ``` java version "1.7.0_80" Java(TM) SE Runtime Environment (build 1.7.0_80-b15) Java HotSpot(TM) 64-Bit Server VM (build 24.80-b11, mixed mode) ``` ``` java version "1.8.0_45" Java(TM) SE Runtime Environment (build 1.8.0_45-b14) Java HotSpot(TM) 64-Bit Server VM (build 25.45-b02, mixed mode) ``` Note: this does not make sbt unidoc suceed with Java 8 yet but it reduces the number of errors with Java 8. Author: hyukjinkwon Closes #15939 from HyukjinKwon/SPARK-3359-javadoc. (cherry picked from commit c528812ce770fd8a6626e7f9d2f8ca9d1e84642b) Signed-off-by: Sean Owen --- pom.xml | 13 +++++++++++++ project/SparkBuild.scala | 5 ++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 024b2850d0a3d..7c0b0b59dc62b 100644 --- a/pom.xml +++ b/pom.xml @@ -2477,11 +2477,24 @@ -Xdoclint:all -Xdoclint:-missing + + example + a + Example: + note a Note: + + group + X + + + tparam + X + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 92b45657210e1..429a163d22a6d 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -742,7 +742,10 @@ object Unidoc { "-windowtitle", "Spark " + version.value.replaceAll("-SNAPSHOT", "") + " JavaDoc", "-public", "-noqualifier", "java.lang", - "-tag", """note:a:Note\:""" + "-tag", """example:a:Example\:""", + "-tag", """note:a:Note\:""", + "-tag", "group:X", + "-tag", "tparam:X" ), // Use GitHub repository for Scaladoc source links From cffaf5035816fa6ffc4dadd47bede1eff6371fee Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sun, 20 Nov 2016 12:46:29 -0800 Subject: [PATCH 156/534] [SPARK-17732][SQL] Revert ALTER TABLE DROP PARTITION should support comparators This reverts commit 1126c3194ee1c79015cf1d3808bc963aa93dcadf. Author: Herman van Hovell Closes #15948 from hvanhovell/SPARK-17732. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 6 +- .../sql/catalyst/parser/AstBuilder.scala | 30 +---- .../spark/sql/execution/SparkSqlParser.scala | 2 +- .../spark/sql/execution/command/ddl.scala | 51 ++------- .../datasources/DataSourceStrategy.scala | 8 +- .../execution/command/DDLCommandSuite.scala | 9 +- .../sql/hive/execution/HiveDDLSuite.scala | 103 ------------------ 7 files changed, 24 insertions(+), 185 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index fcca11c69f0a3..b599a884957a8 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -239,7 +239,11 @@ partitionSpecLocation ; partitionSpec - : PARTITION '(' expression (',' expression)* ')' + : PARTITION '(' partitionVal (',' partitionVal)* ')' + ; + +partitionVal + : identifier (EQ constant)? ; describeFuncName diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 97056bba9d763..2006844923cf7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -194,15 +194,10 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { */ override def visitPartitionSpec( ctx: PartitionSpecContext): Map[String, Option[String]] = withOrigin(ctx) { - val parts = ctx.expression.asScala.map { pVal => - expression(pVal) match { - case UnresolvedAttribute(name :: Nil) => - name -> None - case cmp @ EqualTo(UnresolvedAttribute(name :: Nil), constant: Literal) => - name -> Option(constant.toString) - case _ => - throw new ParseException("Invalid partition filter specification", ctx) - } + val parts = ctx.partitionVal.asScala.map { pVal => + val name = pVal.identifier.getText + val value = Option(pVal.constant).map(visitStringConstant) + name -> value } // Before calling `toMap`, we check duplicated keys to avoid silently ignore partition values // in partition spec like PARTITION(a='1', b='2', a='3'). The real semantical check for @@ -211,23 +206,6 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { parts.toMap } - /** - * Create a partition filter specification. - */ - def visitPartitionFilterSpec(ctx: PartitionSpecContext): Expression = withOrigin(ctx) { - val parts = ctx.expression.asScala.map { pVal => - expression(pVal) match { - case EqualNullSafe(_, _) => - throw new ParseException("'<=>' operator is not allowed in partition specification.", ctx) - case cmp @ BinaryComparison(UnresolvedAttribute(name :: Nil), constant: Literal) => - cmp.withNewChildren(Seq(AttributeReference(name, StringType)(), constant)) - case _ => - throw new ParseException("Invalid partition filter specification", ctx) - } - } - parts.reduceLeft(And) - } - /** * Create a partition specification map without optional values. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 112d812cb6c76..b8be3d17ba444 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -813,7 +813,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { } AlterTableDropPartitionCommand( visitTableIdentifier(ctx.tableIdentifier), - ctx.partitionSpec.asScala.map(visitPartitionFilterSpec), + ctx.partitionSpec.asScala.map(visitNonOptionalPartitionSpec), ctx.EXISTS != null, ctx.PURGE != null) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 588aa05c37b49..570a9967871e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -31,8 +31,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryComparison} -import org.apache.spark.sql.catalyst.expressions.{EqualTo, Expression, PredicateHelper} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration @@ -419,55 +418,27 @@ case class AlterTableRenamePartitionCommand( */ case class AlterTableDropPartitionCommand( tableName: TableIdentifier, - specs: Seq[Expression], + specs: Seq[TablePartitionSpec], ifExists: Boolean, purge: Boolean) - extends RunnableCommand with PredicateHelper { - - private def isRangeComparison(expr: Expression): Boolean = { - expr.find(e => e.isInstanceOf[BinaryComparison] && !e.isInstanceOf[EqualTo]).isDefined - } + extends RunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog val table = catalog.getTableMetadata(tableName) - val resolver = sparkSession.sessionState.conf.resolver DDLUtils.verifyAlterTableType(catalog, table, isView = false) DDLUtils.verifyPartitionProviderIsHive(sparkSession, table, "ALTER TABLE DROP PARTITION") - specs.foreach { expr => - expr.references.foreach { attr => - if (!table.partitionColumnNames.exists(resolver(_, attr.name))) { - throw new AnalysisException(s"${attr.name} is not a valid partition column " + - s"in table ${table.identifier.quotedString}.") - } - } + val normalizedSpecs = specs.map { spec => + PartitioningUtils.normalizePartitionSpec( + spec, + table.partitionColumnNames, + table.identifier.quotedString, + sparkSession.sessionState.conf.resolver) } - if (specs.exists(isRangeComparison)) { - val partitionSet = specs.flatMap { spec => - val partitions = catalog.listPartitionsByFilter(table.identifier, Seq(spec)).map(_.spec) - if (partitions.isEmpty && !ifExists) { - throw new AnalysisException(s"There is no partition for ${spec.sql}") - } - partitions - }.distinct - catalog.dropPartitions( - table.identifier, partitionSet, ignoreIfNotExists = ifExists, purge = purge) - } else { - val normalizedSpecs = specs.map { expr => - val spec = splitConjunctivePredicates(expr).map { - case BinaryComparison(AttributeReference(name, _, _, _), right) => name -> right.toString - }.toMap - PartitioningUtils.normalizePartitionSpec( - spec, - table.partitionColumnNames, - table.identifier.quotedString, - resolver) - } - catalog.dropPartitions( - table.identifier, normalizedSpecs, ignoreIfNotExists = ifExists, purge = purge) - } + catalog.dropPartitions( + table.identifier, normalizedSpecs, ignoreIfNotExists = ifExists, purge = purge) Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index e81512d1abf84..4f19a2d00b0e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -215,14 +215,8 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { if (overwrite.enabled) { val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions if (deletedPartitions.nonEmpty) { - import org.apache.spark.sql.catalyst.expressions._ - val expressions = deletedPartitions.map { specs => - specs.map { case (key, value) => - EqualTo(AttributeReference(key, StringType)(), Literal.create(value, StringType)) - }.reduceLeft(And) - }.toSeq AlterTableDropPartitionCommand( - l.catalogTable.get.identifier, expressions, + l.catalogTable.get.identifier, deletedPartitions.toSeq, ifExists = true, purge = true).run(t.sparkSession) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 057528bef5084..d31e7aeb3a78a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -21,7 +21,6 @@ import scala.reflect.{classTag, ClassTag} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, Literal} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.Project @@ -613,12 +612,8 @@ class DDLCommandSuite extends PlanTest { val expected1_table = AlterTableDropPartitionCommand( tableIdent, Seq( - And( - EqualTo(AttributeReference("dt", StringType)(), Literal.create("2008-08-08", StringType)), - EqualTo(AttributeReference("country", StringType)(), Literal.create("us", StringType))), - And( - EqualTo(AttributeReference("dt", StringType)(), Literal.create("2009-09-09", StringType)), - EqualTo(AttributeReference("country", StringType)(), Literal.create("uk", StringType)))), + Map("dt" -> "2008-08-08", "country" -> "us"), + Map("dt" -> "2009-09-09", "country" -> "uk")), ifExists = true, purge = false) val expected2_table = expected1_table.copy(ifExists = false) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 15e3927b755af..951e0704148b3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.hive.HiveExternalCatalog import org.apache.spark.sql.hive.test.TestHiveSingleton @@ -226,108 +225,6 @@ class HiveDDLSuite } } - test("SPARK-17732: Drop partitions by filter") { - withTable("sales") { - sql("CREATE TABLE sales(id INT) PARTITIONED BY (country STRING, quarter STRING)") - - for (country <- Seq("US", "CA", "KR")) { - for (quarter <- 1 to 4) { - sql(s"ALTER TABLE sales ADD PARTITION (country = '$country', quarter = '$quarter')") - } - } - - sql("ALTER TABLE sales DROP PARTITION (country < 'KR', quarter > '2')") - checkAnswer(sql("SHOW PARTITIONS sales"), - Row("country=CA/quarter=1") :: - Row("country=CA/quarter=2") :: - Row("country=KR/quarter=1") :: - Row("country=KR/quarter=2") :: - Row("country=KR/quarter=3") :: - Row("country=KR/quarter=4") :: - Row("country=US/quarter=1") :: - Row("country=US/quarter=2") :: - Row("country=US/quarter=3") :: - Row("country=US/quarter=4") :: Nil) - - sql("ALTER TABLE sales DROP PARTITION (country < 'KR'), PARTITION (quarter <= '1')") - checkAnswer(sql("SHOW PARTITIONS sales"), - Row("country=KR/quarter=2") :: - Row("country=KR/quarter=3") :: - Row("country=KR/quarter=4") :: - Row("country=US/quarter=2") :: - Row("country=US/quarter=3") :: - Row("country=US/quarter=4") :: Nil) - - sql("ALTER TABLE sales DROP PARTITION (country='KR', quarter='4')") - sql("ALTER TABLE sales DROP PARTITION (country='US', quarter='3')") - checkAnswer(sql("SHOW PARTITIONS sales"), - Row("country=KR/quarter=2") :: - Row("country=KR/quarter=3") :: - Row("country=US/quarter=2") :: - Row("country=US/quarter=4") :: Nil) - - sql("ALTER TABLE sales DROP PARTITION (quarter <= 2), PARTITION (quarter >= '4')") - checkAnswer(sql("SHOW PARTITIONS sales"), - Row("country=KR/quarter=3") :: Nil) - - // According to the declarative partition spec definitions, this drops the union of target - // partitions without exceptions. Hive raises exceptions because it handles them sequentially. - sql("ALTER TABLE sales DROP PARTITION (quarter <= 4), PARTITION (quarter <= '3')") - checkAnswer(sql("SHOW PARTITIONS sales"), Nil) - } - } - - test("SPARK-17732: Error handling for drop partitions by filter") { - withTable("sales") { - sql("CREATE TABLE sales(id INT) PARTITIONED BY (country STRING, quarter STRING)") - - val m = intercept[AnalysisException] { - sql("ALTER TABLE sales DROP PARTITION (unknown = 'KR')") - }.getMessage - assert(m.contains("unknown is not a valid partition column in table")) - - val m2 = intercept[AnalysisException] { - sql("ALTER TABLE sales DROP PARTITION (unknown < 'KR')") - }.getMessage - assert(m2.contains("unknown is not a valid partition column in table")) - - val m3 = intercept[AnalysisException] { - sql("ALTER TABLE sales DROP PARTITION (unknown <=> 'KR')") - }.getMessage - assert(m3.contains("'<=>' operator is not allowed in partition specification")) - - val m4 = intercept[ParseException] { - sql("ALTER TABLE sales DROP PARTITION (unknown <=> upper('KR'))") - }.getMessage - assert(m4.contains("'<=>' operator is not allowed in partition specification")) - - val m5 = intercept[ParseException] { - sql("ALTER TABLE sales DROP PARTITION (country < 'KR', quarter)") - }.getMessage - assert(m5.contains("Invalid partition filter specification")) - - sql(s"ALTER TABLE sales ADD PARTITION (country = 'KR', quarter = '3')") - val m6 = intercept[AnalysisException] { - sql("ALTER TABLE sales DROP PARTITION (quarter <= '4'), PARTITION (quarter <= '2')") - }.getMessage - // The query is not executed because `PARTITION (quarter <= '2')` is invalid. - checkAnswer(sql("SHOW PARTITIONS sales"), - Row("country=KR/quarter=3") :: Nil) - assert(m6.contains("There is no partition for (`quarter` <= '2')")) - } - } - - test("SPARK-17732: Partition filter is not allowed in ADD PARTITION") { - withTable("sales") { - sql("CREATE TABLE sales(id INT) PARTITIONED BY (country STRING, quarter STRING)") - - val m = intercept[ParseException] { - sql("ALTER TABLE sales ADD PARTITION (country = 'US', quarter < '1')") - }.getMessage() - assert(m.contains("Invalid partition filter specification")) - } - } - test("drop views") { withTable("tab1") { val tabName = "tab1" From f8662db72815b9c89f2448511d117e6d224e0b11 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 20 Nov 2016 20:00:59 -0800 Subject: [PATCH 157/534] [HOTFIX][SQL] Fix DDLSuite failure. (cherry picked from commit b625a36ebc59cbacc223fc03005bc0f6d296b6e7) Signed-off-by: Reynold Xin --- .../org/apache/spark/sql/execution/command/DDLSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index a01073987423e..02d9d15684904 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1426,8 +1426,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { sql("DESCRIBE FUNCTION 'concat'"), Row("Class: org.apache.spark.sql.catalyst.expressions.Concat") :: Row("Function: concat") :: - Row("Usage: concat(str1, str2, ..., strN) " + - "- Returns the concatenation of `str1`, `str2`, ..., `strN`.") :: Nil + Row("Usage: concat(str1, str2, ..., strN) - " + + "Returns the concatenation of str1, str2, ..., strN.") :: Nil ) // extended mode checkAnswer( From fb4e6359d1fdb9e4f05fcfa03839024e8b91b47a Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 21 Nov 2016 12:05:01 +0800 Subject: [PATCH 158/534] [SPARK-18467][SQL] Extracts method for preparing arguments from StaticInvoke, Invoke and NewInstance and modify to short circuit if arguments have null when `needNullCheck == true`. ## What changes were proposed in this pull request? This pr extracts method for preparing arguments from `StaticInvoke`, `Invoke` and `NewInstance` and modify to short circuit if arguments have `null` when `propageteNull == true`. The steps are as follows: 1. Introduce `InvokeLike` to extract common logic from `StaticInvoke`, `Invoke` and `NewInstance` to prepare arguments. `StaticInvoke` and `Invoke` had a risk to exceed 64kb JVM limit to prepare arguments but after this patch they can handle them because they share the preparing code of NewInstance, which handles the limit well. 2. Remove unneeded null checking and fix nullability of `NewInstance`. Avoid some of nullabilty checking which are not needed because the expression is not nullable. 3. Modify to short circuit if arguments have `null` when `needNullCheck == true`. If `needNullCheck == true`, preparing arguments can be skipped if we found one of them is `null`, so modified to short circuit in the case. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #15901 from ueshin/issues/SPARK-18467. (cherry picked from commit 658547974915ebcaae83e13e4c3bdf68d5426fda) Signed-off-by: Wenchen Fan --- .../expressions/objects/objects.scala | 163 +++++++++++------- 1 file changed, 101 insertions(+), 62 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 0e3d99127ed56..0b36091ece1bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -32,6 +32,78 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types._ +/** + * Common base class for [[StaticInvoke]], [[Invoke]], and [[NewInstance]]. + */ +trait InvokeLike extends Expression with NonSQLExpression { + + def arguments: Seq[Expression] + + def propagateNull: Boolean + + protected lazy val needNullCheck: Boolean = propagateNull && arguments.exists(_.nullable) + + /** + * Prepares codes for arguments. + * + * - generate codes for argument. + * - use ctx.splitExpressions() to not exceed 64kb JVM limit while preparing arguments. + * - avoid some of nullabilty checking which are not needed because the expression is not + * nullable. + * - when needNullCheck == true, short circuit if we found one of arguments is null because + * preparing rest of arguments can be skipped in the case. + * + * @param ctx a [[CodegenContext]] + * @return (code to prepare arguments, argument string, result of argument null check) + */ + def prepareArguments(ctx: CodegenContext): (String, String, String) = { + + val resultIsNull = if (needNullCheck) { + val resultIsNull = ctx.freshName("resultIsNull") + ctx.addMutableState("boolean", resultIsNull, "") + resultIsNull + } else { + "false" + } + val argValues = arguments.map { e => + val argValue = ctx.freshName("argValue") + ctx.addMutableState(ctx.javaType(e.dataType), argValue, "") + argValue + } + + val argCodes = if (needNullCheck) { + val reset = s"$resultIsNull = false;" + val argCodes = arguments.zipWithIndex.map { case (e, i) => + val expr = e.genCode(ctx) + val updateResultIsNull = if (e.nullable) { + s"$resultIsNull = ${expr.isNull};" + } else { + "" + } + s""" + if (!$resultIsNull) { + ${expr.code} + $updateResultIsNull + ${argValues(i)} = ${expr.value}; + } + """ + } + reset +: argCodes + } else { + arguments.zipWithIndex.map { case (e, i) => + val expr = e.genCode(ctx) + s""" + ${expr.code} + ${argValues(i)} = ${expr.value}; + """ + } + } + val argCode = ctx.splitExpressions(ctx.INPUT_ROW, argCodes) + + (argCode, argValues.mkString(", "), resultIsNull) + } +} + /** * Invokes a static function, returning the result. By default, any of the arguments being null * will result in returning null instead of calling the function. @@ -50,7 +122,7 @@ case class StaticInvoke( dataType: DataType, functionName: String, arguments: Seq[Expression] = Nil, - propagateNull: Boolean = true) extends Expression with NonSQLExpression { + propagateNull: Boolean = true) extends InvokeLike { val objectName = staticObject.getName.stripSuffix("$") @@ -62,16 +134,10 @@ case class StaticInvoke( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) - val argGen = arguments.map(_.genCode(ctx)) - val argString = argGen.map(_.value).mkString(", ") - val callFunc = s"$objectName.$functionName($argString)" + val (argCode, argString, resultIsNull) = prepareArguments(ctx) - val setIsNull = if (propagateNull && arguments.nonEmpty) { - s"boolean ${ev.isNull} = ${argGen.map(_.isNull).mkString(" || ")};" - } else { - s"boolean ${ev.isNull} = false;" - } + val callFunc = s"$objectName.$functionName($argString)" // If the function can return null, we do an extra check to make sure our null bit is still set // correctly. @@ -82,9 +148,9 @@ case class StaticInvoke( } val code = s""" - ${argGen.map(_.code).mkString("\n")} - $setIsNull - final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : $callFunc; + $argCode + boolean ${ev.isNull} = $resultIsNull; + final $javaType ${ev.value} = $resultIsNull ? ${ctx.defaultValue(dataType)} : $callFunc; $postNullCheck """ ev.copy(code = code) @@ -103,13 +169,15 @@ case class StaticInvoke( * @param functionName The name of the method to call. * @param dataType The expected return type of the function. * @param arguments An optional list of expressions, whos evaluation will be passed to the function. + * @param propagateNull When true, and any of the arguments is null, null will be returned instead + * of calling the function. */ case class Invoke( targetObject: Expression, functionName: String, dataType: DataType, arguments: Seq[Expression] = Nil, - propagateNull: Boolean = true) extends Expression with NonSQLExpression { + propagateNull: Boolean = true) extends InvokeLike { override def nullable: Boolean = true override def children: Seq[Expression] = targetObject +: arguments @@ -131,8 +199,8 @@ case class Invoke( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) val obj = targetObject.genCode(ctx) - val argGen = arguments.map(_.genCode(ctx)) - val argString = argGen.map(_.value).mkString(", ") + + val (argCode, argString, resultIsNull) = prepareArguments(ctx) val returnPrimitive = method.isDefined && method.get.getReturnType.isPrimitive val needTryCatch = method.isDefined && method.get.getExceptionTypes.nonEmpty @@ -164,12 +232,6 @@ case class Invoke( """ } - val setIsNull = if (propagateNull && arguments.nonEmpty) { - s"boolean ${ev.isNull} = ${obj.isNull} || ${argGen.map(_.isNull).mkString(" || ")};" - } else { - s"boolean ${ev.isNull} = ${obj.isNull};" - } - // If the function can return null, we do an extra check to make sure our null bit is still set // correctly. val postNullCheck = if (ctx.defaultValue(dataType) == "null") { @@ -177,15 +239,19 @@ case class Invoke( } else { "" } + val code = s""" ${obj.code} - ${argGen.map(_.code).mkString("\n")} - $setIsNull + boolean ${ev.isNull} = true; $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - $evaluate + if (!${obj.isNull}) { + $argCode + ${ev.isNull} = $resultIsNull; + if (!${ev.isNull}) { + $evaluate + } + $postNullCheck } - $postNullCheck """ ev.copy(code = code) } @@ -223,10 +289,10 @@ case class NewInstance( arguments: Seq[Expression], propagateNull: Boolean, dataType: DataType, - outerPointer: Option[() => AnyRef]) extends Expression with NonSQLExpression { + outerPointer: Option[() => AnyRef]) extends InvokeLike { private val className = cls.getName - override def nullable: Boolean = propagateNull + override def nullable: Boolean = needNullCheck override def children: Seq[Expression] = arguments @@ -245,52 +311,25 @@ case class NewInstance( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) - val argIsNulls = ctx.freshName("argIsNulls") - ctx.addMutableState("boolean[]", argIsNulls, - s"$argIsNulls = new boolean[${arguments.size}];") - val argValues = arguments.zipWithIndex.map { case (e, i) => - val argValue = ctx.freshName("argValue") - ctx.addMutableState(ctx.javaType(e.dataType), argValue, "") - argValue - } - val argCodes = arguments.zipWithIndex.map { case (e, i) => - val expr = e.genCode(ctx) - expr.code + s""" - $argIsNulls[$i] = ${expr.isNull}; - ${argValues(i)} = ${expr.value}; - """ - } - val argCode = ctx.splitExpressions(ctx.INPUT_ROW, argCodes) + val (argCode, argString, resultIsNull) = prepareArguments(ctx) val outer = outerPointer.map(func => Literal.fromObject(func()).genCode(ctx)) - var isNull = ev.isNull - val setIsNull = if (propagateNull && arguments.nonEmpty) { - s""" - boolean $isNull = false; - for (int idx = 0; idx < ${arguments.length}; idx++) { - if ($argIsNulls[idx]) { $isNull = true; break; } - } - """ - } else { - isNull = "false" - "" - } + ev.isNull = resultIsNull val constructorCall = outer.map { gen => - s"""${gen.value}.new ${cls.getSimpleName}(${argValues.mkString(", ")})""" + s"${gen.value}.new ${cls.getSimpleName}($argString)" }.getOrElse { - s"new $className(${argValues.mkString(", ")})" + s"new $className($argString)" } val code = s""" $argCode ${outer.map(_.code).getOrElse("")} - $setIsNull - final $javaType ${ev.value} = $isNull ? ${ctx.defaultValue(javaType)} : $constructorCall; - """ - ev.copy(code = code, isNull = isNull) + final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $constructorCall; + """ + ev.copy(code = code) } override def toString: String = s"newInstance($cls)" From 31002e4a77ca56492f41bf35e7c8f263d767d3aa Mon Sep 17 00:00:00 2001 From: sethah Date: Mon, 21 Nov 2016 05:36:49 -0800 Subject: [PATCH 159/534] [SPARK-18282][ML][PYSPARK] Add python clustering summaries for GMM and BKM ## What changes were proposed in this pull request? Add model summary APIs for `GaussianMixtureModel` and `BisectingKMeansModel` in pyspark. ## How was this patch tested? Unit tests. Author: sethah Closes #15777 from sethah/pyspark_cluster_summaries. (cherry picked from commit e811fbf9ed131bccbc46f3c5701c4ff317222fd9) Signed-off-by: Yanbo Liang --- .../classification/LogisticRegression.scala | 11 +- .../spark/ml/clustering/BisectingKMeans.scala | 9 +- .../spark/ml/clustering/GaussianMixture.scala | 9 +- .../apache/spark/ml/clustering/KMeans.scala | 9 +- .../GeneralizedLinearRegression.scala | 11 +- .../ml/regression/LinearRegression.scala | 14 +- .../LogisticRegressionSuite.scala | 2 + .../ml/clustering/BisectingKMeansSuite.scala | 3 + .../ml/clustering/GaussianMixtureSuite.scala | 3 + .../spark/ml/clustering/KMeansSuite.scala | 3 + .../GeneralizedLinearRegressionSuite.scala | 2 + .../ml/regression/LinearRegressionSuite.scala | 2 + python/pyspark/ml/classification.py | 15 +- python/pyspark/ml/clustering.py | 162 +++++++++++++++++- python/pyspark/ml/regression.py | 16 +- python/pyspark/ml/tests.py | 32 ++++ 16 files changed, 256 insertions(+), 47 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index f58efd36a1c66..d07b4adebb08f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -648,7 +648,7 @@ class LogisticRegression @Since("1.2.0") ( $(labelCol), $(featuresCol), objectiveHistory) - model.setSummary(logRegSummary) + model.setSummary(Some(logRegSummary)) } else { model } @@ -790,9 +790,9 @@ class LogisticRegressionModel private[spark] ( } } - private[classification] def setSummary( - summary: LogisticRegressionTrainingSummary): this.type = { - this.trainingSummary = Some(summary) + private[classification] + def setSummary(summary: Option[LogisticRegressionTrainingSummary]): this.type = { + this.trainingSummary = summary this } @@ -887,8 +887,7 @@ class LogisticRegressionModel private[spark] ( override def copy(extra: ParamMap): LogisticRegressionModel = { val newModel = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector, numClasses, isMultinomial), extra) - if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) - newModel.setParent(parent) + newModel.setSummary(trainingSummary).setParent(parent) } override protected def raw2prediction(rawPrediction: Vector): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index f8a606d60b2aa..e6ca3aedffd9d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -95,8 +95,7 @@ class BisectingKMeansModel private[ml] ( @Since("2.0.0") override def copy(extra: ParamMap): BisectingKMeansModel = { val copied = copyValues(new BisectingKMeansModel(uid, parentModel), extra) - if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get) - copied.setParent(this.parent) + copied.setSummary(trainingSummary).setParent(this.parent) } @Since("2.0.0") @@ -132,8 +131,8 @@ class BisectingKMeansModel private[ml] ( private var trainingSummary: Option[BisectingKMeansSummary] = None - private[clustering] def setSummary(summary: BisectingKMeansSummary): this.type = { - this.trainingSummary = Some(summary) + private[clustering] def setSummary(summary: Option[BisectingKMeansSummary]): this.type = { + this.trainingSummary = summary this } @@ -265,7 +264,7 @@ class BisectingKMeans @Since("2.0.0") ( val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this)) val summary = new BisectingKMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) - model.setSummary(summary) + model.setSummary(Some(summary)) instr.logSuccess(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index c6035cc4c9647..92d0b7d085f12 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -90,8 +90,7 @@ class GaussianMixtureModel private[ml] ( @Since("2.0.0") override def copy(extra: ParamMap): GaussianMixtureModel = { val copied = copyValues(new GaussianMixtureModel(uid, weights, gaussians), extra) - if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get) - copied.setParent(this.parent) + copied.setSummary(trainingSummary).setParent(this.parent) } @Since("2.0.0") @@ -150,8 +149,8 @@ class GaussianMixtureModel private[ml] ( private var trainingSummary: Option[GaussianMixtureSummary] = None - private[clustering] def setSummary(summary: GaussianMixtureSummary): this.type = { - this.trainingSummary = Some(summary) + private[clustering] def setSummary(summary: Option[GaussianMixtureSummary]): this.type = { + this.trainingSummary = summary this } @@ -340,7 +339,7 @@ class GaussianMixture @Since("2.0.0") ( .setParent(this) val summary = new GaussianMixtureSummary(model.transform(dataset), $(predictionCol), $(probabilityCol), $(featuresCol), $(k)) - model.setSummary(summary) + model.setSummary(Some(summary)) instr.logNumFeatures(model.gaussians.head.mean.size) instr.logSuccess(model) model diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 26505b4cc1501..152bd13b7a17a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -110,8 +110,7 @@ class KMeansModel private[ml] ( @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { val copied = copyValues(new KMeansModel(uid, parentModel), extra) - if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get) - copied.setParent(this.parent) + copied.setSummary(trainingSummary).setParent(this.parent) } /** @group setParam */ @@ -165,8 +164,8 @@ class KMeansModel private[ml] ( private var trainingSummary: Option[KMeansSummary] = None - private[clustering] def setSummary(summary: KMeansSummary): this.type = { - this.trainingSummary = Some(summary) + private[clustering] def setSummary(summary: Option[KMeansSummary]): this.type = { + this.trainingSummary = summary this } @@ -325,7 +324,7 @@ class KMeans @Since("1.5.0") ( val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) val summary = new KMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) - model.setSummary(summary) + model.setSummary(Some(summary)) instr.logSuccess(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 736fd3b9e0f64..3f9de1fe74c9c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -270,7 +270,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val .setParent(this)) val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model, wlsModel.diagInvAtWA.toArray, 1, getSolver) - return model.setSummary(trainingSummary) + return model.setSummary(Some(trainingSummary)) } // Fit Generalized Linear Model by iteratively reweighted least squares (IRLS). @@ -284,7 +284,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val .setParent(this)) val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model, irlsModel.diagInvAtWA.toArray, irlsModel.numIterations, getSolver) - model.setSummary(trainingSummary) + model.setSummary(Some(trainingSummary)) } @Since("2.0.0") @@ -761,8 +761,8 @@ class GeneralizedLinearRegressionModel private[ml] ( def hasSummary: Boolean = trainingSummary.nonEmpty private[regression] - def setSummary(summary: GeneralizedLinearRegressionTrainingSummary): this.type = { - this.trainingSummary = Some(summary) + def setSummary(summary: Option[GeneralizedLinearRegressionTrainingSummary]): this.type = { + this.trainingSummary = summary this } @@ -778,8 +778,7 @@ class GeneralizedLinearRegressionModel private[ml] ( override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = { val copied = copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra) - if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get) - copied.setParent(parent) + copied.setSummary(trainingSummary).setParent(parent) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index da7ce6b46f2ab..8ea5e1e6c453a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -225,7 +225,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String model.diagInvAtWA.toArray, model.objectiveHistory) - return lrModel.setSummary(trainingSummary) + return lrModel.setSummary(Some(trainingSummary)) } val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE @@ -278,7 +278,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String model, Array(0D), Array(0D)) - return model.setSummary(trainingSummary) + return model.setSummary(Some(trainingSummary)) } else { require($(regParam) == 0.0, "The standard deviation of the label is zero. " + "Model cannot be regularized.") @@ -400,7 +400,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String model, Array(0D), objectiveHistory) - model.setSummary(trainingSummary) + model.setSummary(Some(trainingSummary)) } @Since("1.4.0") @@ -446,8 +446,9 @@ class LinearRegressionModel private[ml] ( throw new SparkException("No training summary available for this LinearRegressionModel") } - private[regression] def setSummary(summary: LinearRegressionTrainingSummary): this.type = { - this.trainingSummary = Some(summary) + private[regression] + def setSummary(summary: Option[LinearRegressionTrainingSummary]): this.type = { + this.trainingSummary = summary this } @@ -490,8 +491,7 @@ class LinearRegressionModel private[ml] ( @Since("1.4.0") override def copy(extra: ParamMap): LinearRegressionModel = { val newModel = copyValues(new LinearRegressionModel(uid, coefficients, intercept), extra) - if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) - newModel.setParent(parent) + newModel.setSummary(trainingSummary).setParent(parent) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 2877285eb4d59..e360542eae2ab 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -147,6 +147,8 @@ class LogisticRegressionSuite assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) + model.setSummary(None) + assert(!model.hasSummary) } test("empty probabilityCol") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 49797d938d751..fc491cd6161fd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -109,6 +109,9 @@ class BisectingKMeansSuite assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + + model.setSummary(None) + assert(!model.hasSummary) } test("read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 7165b63ed3b96..07299123f8a47 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -111,6 +111,9 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + + model.setSummary(None) + assert(!model.hasSummary) } test("read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 73972557d2631..c1b7242e11a8f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -123,6 +123,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + + model.setSummary(None) + assert(!model.hasSummary) } test("KMeansModel transform with non-default feature and prediction cols") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 6a4ac1735b2cb..9b0fa67630d2e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -197,6 +197,8 @@ class GeneralizedLinearRegressionSuite assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) + model.setSummary(None) + assert(!model.hasSummary) assert(model.getFeaturesCol === "features") assert(model.getPredictionCol === "prediction") diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index df97d0b2ae7ad..0be82742a33be 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -146,6 +146,8 @@ class LinearRegressionSuite assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) + model.setSummary(None) + assert(!model.hasSummary) model.transform(datasetWithDenseFeature) .select("label", "prediction") diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 56c8c62259e79..83e1e89347660 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -309,13 +309,16 @@ def interceptVector(self): @since("2.0.0") def summary(self): """ - Gets summary (e.g. residuals, mse, r-squared ) of model on - training set. An exception is thrown if - `trainingSummary is None`. + Gets summary (e.g. accuracy/precision/recall, objective history, total iterations) of model + trained on the training set. An exception is thrown if `trainingSummary is None`. """ - java_blrt_summary = self._call_java("summary") - # Note: Once multiclass is added, update this to return correct summary - return BinaryLogisticRegressionTrainingSummary(java_blrt_summary) + if self.hasSummary: + java_blrt_summary = self._call_java("summary") + # Note: Once multiclass is added, update this to return correct summary + return BinaryLogisticRegressionTrainingSummary(java_blrt_summary) + else: + raise RuntimeError("No training summary available for this %s" % + self.__class__.__name__) @property @since("2.0.0") diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 7632f05c3b68c..e58ec1e7ac296 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -17,16 +17,74 @@ from pyspark import since, keyword_only from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper from pyspark.ml.param.shared import * from pyspark.ml.common import inherit_doc -__all__ = ['BisectingKMeans', 'BisectingKMeansModel', +__all__ = ['BisectingKMeans', 'BisectingKMeansModel', 'BisectingKMeansSummary', 'KMeans', 'KMeansModel', - 'GaussianMixture', 'GaussianMixtureModel', + 'GaussianMixture', 'GaussianMixtureModel', 'GaussianMixtureSummary', 'LDA', 'LDAModel', 'LocalLDAModel', 'DistributedLDAModel'] +class ClusteringSummary(JavaWrapper): + """ + .. note:: Experimental + + Clustering results for a given model. + + .. versionadded:: 2.1.0 + """ + + @property + @since("2.1.0") + def predictionCol(self): + """ + Name for column of predicted clusters in `predictions`. + """ + return self._call_java("predictionCol") + + @property + @since("2.1.0") + def predictions(self): + """ + DataFrame produced by the model's `transform` method. + """ + return self._call_java("predictions") + + @property + @since("2.1.0") + def featuresCol(self): + """ + Name for column of features in `predictions`. + """ + return self._call_java("featuresCol") + + @property + @since("2.1.0") + def k(self): + """ + The number of clusters the model was trained with. + """ + return self._call_java("k") + + @property + @since("2.1.0") + def cluster(self): + """ + DataFrame of predicted cluster centers for each training data point. + """ + return self._call_java("cluster") + + @property + @since("2.1.0") + def clusterSizes(self): + """ + Size of (number of data points in) each cluster. + """ + return self._call_java("clusterSizes") + + class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable): """ .. note:: Experimental @@ -56,6 +114,28 @@ def gaussiansDF(self): """ return self._call_java("gaussiansDF") + @property + @since("2.1.0") + def hasSummary(self): + """ + Indicates whether a training summary exists for this model + instance. + """ + return self._call_java("hasSummary") + + @property + @since("2.1.0") + def summary(self): + """ + Gets summary (e.g. cluster assignments, cluster sizes) of the model trained on the + training set. An exception is thrown if no summary exists. + """ + if self.hasSummary: + return GaussianMixtureSummary(self._call_java("summary")) + else: + raise RuntimeError("No training summary available for this %s" % + self.__class__.__name__) + @inherit_doc class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed, @@ -92,6 +172,13 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte >>> gm = GaussianMixture(k=3, tol=0.0001, ... maxIter=10, seed=10) >>> model = gm.fit(df) + >>> model.hasSummary + True + >>> summary = model.summary + >>> summary.k + 3 + >>> summary.clusterSizes + [2, 2, 2] >>> weights = model.weights >>> len(weights) 3 @@ -118,6 +205,8 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte >>> model_path = temp_path + "/gmm_model" >>> model.save(model_path) >>> model2 = GaussianMixtureModel.load(model_path) + >>> model2.hasSummary + False >>> model2.weights == model.weights True >>> model2.gaussiansDF.show() @@ -181,6 +270,32 @@ def getK(self): return self.getOrDefault(self.k) +class GaussianMixtureSummary(ClusteringSummary): + """ + .. note:: Experimental + + Gaussian mixture clustering results for a given model. + + .. versionadded:: 2.1.0 + """ + + @property + @since("2.1.0") + def probabilityCol(self): + """ + Name for column of predicted probability of each cluster in `predictions`. + """ + return self._call_java("probabilityCol") + + @property + @since("2.1.0") + def probability(self): + """ + DataFrame of probabilities of each cluster for each training data point. + """ + return self._call_java("probability") + + class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by KMeans. @@ -346,6 +461,27 @@ def computeCost(self, dataset): """ return self._call_java("computeCost", dataset) + @property + @since("2.1.0") + def hasSummary(self): + """ + Indicates whether a training summary exists for this model instance. + """ + return self._call_java("hasSummary") + + @property + @since("2.1.0") + def summary(self): + """ + Gets summary (e.g. cluster assignments, cluster sizes) of the model trained on the + training set. An exception is thrown if no summary exists. + """ + if self.hasSummary: + return BisectingKMeansSummary(self._call_java("summary")) + else: + raise RuntimeError("No training summary available for this %s" % + self.__class__.__name__) + @inherit_doc class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasSeed, @@ -373,6 +509,13 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte 2 >>> model.computeCost(df) 2.000... + >>> model.hasSummary + True + >>> summary = model.summary + >>> summary.k + 2 + >>> summary.clusterSizes + [2, 2] >>> transformed = model.transform(df).select("features", "prediction") >>> rows = transformed.collect() >>> rows[0].prediction == rows[1].prediction @@ -387,6 +530,8 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte >>> model_path = temp_path + "/bkm_model" >>> model.save(model_path) >>> model2 = BisectingKMeansModel.load(model_path) + >>> model2.hasSummary + False >>> model.clusterCenters()[0] == model2.clusterCenters()[0] array([ True, True], dtype=bool) >>> model.clusterCenters()[1] == model2.clusterCenters()[1] @@ -460,6 +605,17 @@ def _create_model(self, java_model): return BisectingKMeansModel(java_model) +class BisectingKMeansSummary(ClusteringSummary): + """ + .. note:: Experimental + + Bisecting KMeans clustering results for a given model. + + .. versionadded:: 2.1.0 + """ + pass + + @inherit_doc class LDAModel(JavaModel): """ diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 0bc319ca4d601..385391ba53fd4 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -160,8 +160,12 @@ def summary(self): training set. An exception is thrown if `trainingSummary is None`. """ - java_lrt_summary = self._call_java("summary") - return LinearRegressionTrainingSummary(java_lrt_summary) + if self.hasSummary: + java_lrt_summary = self._call_java("summary") + return LinearRegressionTrainingSummary(java_lrt_summary) + else: + raise RuntimeError("No training summary available for this %s" % + self.__class__.__name__) @property @since("2.0.0") @@ -1459,8 +1463,12 @@ def summary(self): training set. An exception is thrown if `trainingSummary is None`. """ - java_glrt_summary = self._call_java("summary") - return GeneralizedLinearRegressionTrainingSummary(java_glrt_summary) + if self.hasSummary: + java_glrt_summary = self._call_java("summary") + return GeneralizedLinearRegressionTrainingSummary(java_glrt_summary) + else: + raise RuntimeError("No training summary available for this %s" % + self.__class__.__name__) @property @since("2.0.0") diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 9d46cc3b4ae64..c0f0d4073564e 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1097,6 +1097,38 @@ def test_logistic_regression_summary(self): sameSummary = model.evaluate(df) self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) + def test_gaussian_mixture_summary(self): + data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),), + (Vectors.sparse(1, [], []),)] + df = self.spark.createDataFrame(data, ["features"]) + gmm = GaussianMixture(k=2) + model = gmm.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.probabilityCol, "probability") + self.assertTrue(isinstance(s.probability, DataFrame)) + self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") + self.assertTrue(isinstance(s.cluster, DataFrame)) + self.assertEqual(len(s.clusterSizes), 2) + self.assertEqual(s.k, 2) + + def test_bisecting_kmeans_summary(self): + data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),), + (Vectors.sparse(1, [], []),)] + df = self.spark.createDataFrame(data, ["features"]) + bkm = BisectingKMeans(k=2) + model = bkm.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") + self.assertTrue(isinstance(s.cluster, DataFrame)) + self.assertEqual(len(s.clusterSizes), 2) + self.assertEqual(s.k, 2) + class OneVsRestTests(SparkSessionTestCase): From 251a9927646f367ca2cf75a87e80ce1c061a8f27 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 21 Nov 2016 05:50:35 -0800 Subject: [PATCH 160/534] [SPARK-18398][SQL] Fix nullabilities of MapObjects and ExternalMapToCatalyst. ## What changes were proposed in this pull request? The nullabilities of `MapObject` can be made more strict by relying on `inputObject.nullable` and `lambdaFunction.nullable`. Also `ExternalMapToCatalyst.dataType` can be made more strict by relying on `valueConverter.nullable`. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #15840 from ueshin/issues/SPARK-18398. (cherry picked from commit 9f262ae163b6dca6526665b3ad12b3b2ea8fb873) Signed-off-by: Herman van Hovell --- .../spark/sql/catalyst/expressions/objects/objects.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 0b36091ece1bf..5c27179ec3b46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -461,14 +461,15 @@ case class MapObjects private( lambdaFunction: Expression, inputData: Expression) extends Expression with NonSQLExpression { - override def nullable: Boolean = true + override def nullable: Boolean = inputData.nullable override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") - override def dataType: DataType = ArrayType(lambdaFunction.dataType) + override def dataType: DataType = + ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val elementJavaType = ctx.javaType(loopVarDataType) @@ -642,7 +643,8 @@ case class ExternalMapToCatalyst private( override def foldable: Boolean = false - override def dataType: MapType = MapType(keyConverter.dataType, valueConverter.dataType) + override def dataType: MapType = MapType( + keyConverter.dataType, valueConverter.dataType, valueContainsNull = valueConverter.nullable) override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") From b0a73c9be3b691f95d2f6ace3d6304db7f69705f Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 21 Nov 2016 16:14:59 -0500 Subject: [PATCH 161/534] [SPARK-18517][SQL] DROP TABLE IF EXISTS should not warn for non-existing tables ## What changes were proposed in this pull request? Currently, `DROP TABLE IF EXISTS` shows warning for non-existing tables. However, it had better be quiet for this case by definition of the command. **BEFORE** ```scala scala> sql("DROP TABLE IF EXISTS nonexist") 16/11/20 20:48:26 WARN DropTableCommand: org.apache.spark.sql.catalyst.analysis.NoSuchTableException: Table or view 'nonexist' not found in database 'default'; ``` **AFTER** ```scala scala> sql("DROP TABLE IF EXISTS nonexist") res0: org.apache.spark.sql.DataFrame = [] ``` ## How was this patch tested? Manual because this is related to the warning messages instead of exceptions. Author: Dongjoon Hyun Closes #15953 from dongjoon-hyun/SPARK-18517. (cherry picked from commit ddd02f50bb7458410d65427321efc75da5e65224) Signed-off-by: Andrew Or --- .../scala/org/apache/spark/sql/execution/command/ddl.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 570a9967871e9..0f126d0200eff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, Resolver} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} @@ -202,6 +202,7 @@ case class DropTableCommand( sparkSession.sharedState.cacheManager.uncacheQuery( sparkSession.table(tableName.quotedString)) } catch { + case _: NoSuchTableException if ifExists => case NonFatal(e) => log.warn(e.toString, e) } catalog.refreshTable(tableName) From 406f33987ac078fb20d2f5e81b7e1f646ea53fed Mon Sep 17 00:00:00 2001 From: Gabriel Huang Date: Mon, 21 Nov 2016 16:08:34 -0500 Subject: [PATCH 162/534] [SPARK-18361][PYSPARK] Expose RDD localCheckpoint in PySpark ## What changes were proposed in this pull request? Expose RDD's localCheckpoint() and associated functions in PySpark. ## How was this patch tested? I added a UnitTest in python/pyspark/tests.py which passes. I certify that this is my original work, and I license it to the project under the project's open source license. Gabriel HUANG Developer at Cardabel (http://cardabel.com/) Author: Gabriel Huang Closes #15811 from gabrielhuang/pyspark-localcheckpoint. --- python/pyspark/rdd.py | 33 ++++++++++++++++++++++++++++++++- python/pyspark/tests.py | 17 +++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 641787ee20e0c..f21a364df9100 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -263,13 +263,44 @@ def checkpoint(self): def isCheckpointed(self): """ - Return whether this RDD has been checkpointed or not + Return whether this RDD is checkpointed and materialized, either reliably or locally. """ return self._jrdd.rdd().isCheckpointed() + def localCheckpoint(self): + """ + Mark this RDD for local checkpointing using Spark's existing caching layer. + + This method is for users who wish to truncate RDD lineages while skipping the expensive + step of replicating the materialized data in a reliable distributed file system. This is + useful for RDDs with long lineages that need to be truncated periodically (e.g. GraphX). + + Local checkpointing sacrifices fault-tolerance for performance. In particular, checkpointed + data is written to ephemeral local storage in the executors instead of to a reliable, + fault-tolerant storage. The effect is that if an executor fails during the computation, + the checkpointed data may no longer be accessible, causing an irrecoverable job failure. + + This is NOT safe to use with dynamic allocation, which removes executors along + with their cached blocks. If you must use both features, you are advised to set + L{spark.dynamicAllocation.cachedExecutorIdleTimeout} to a high value. + + The checkpoint directory set through L{SparkContext.setCheckpointDir()} is not used. + """ + self._jrdd.rdd().localCheckpoint() + + def isLocallyCheckpointed(self): + """ + Return whether this RDD is marked for local checkpointing. + + Exposed for testing. + """ + return self._jrdd.rdd().isLocallyCheckpointed() + def getCheckpointFile(self): """ Gets the name of the file to which this RDD was checkpointed + + Not defined if RDD is checkpointed locally. """ checkpointFile = self._jrdd.rdd().getCheckpointFile() if checkpointFile.isDefined(): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 3e0bd16d85ca4..ab4bef8329cd0 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -390,6 +390,23 @@ def test_checkpoint_and_restore(self): self.assertEqual([1, 2, 3, 4], recovered.collect()) +class LocalCheckpointTests(ReusedPySparkTestCase): + + def test_basic_localcheckpointing(self): + parCollection = self.sc.parallelize([1, 2, 3, 4]) + flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) + + self.assertFalse(flatMappedRDD.isCheckpointed()) + self.assertFalse(flatMappedRDD.isLocallyCheckpointed()) + + flatMappedRDD.localCheckpoint() + result = flatMappedRDD.collect() + time.sleep(1) # 1 second + self.assertTrue(flatMappedRDD.isCheckpointed()) + self.assertTrue(flatMappedRDD.isLocallyCheckpointed()) + self.assertEqual(flatMappedRDD.collect(), result) + + class AddFileTests(PySparkTestCase): def test_add_py_file(self): From 2afc18be23150d283361d374caf8cbfd3da63c9c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 21 Nov 2016 13:23:32 -0800 Subject: [PATCH 163/534] [SPARK-17765][SQL] Support for writing out user-defined type in ORC datasource ## What changes were proposed in this pull request? This PR adds the support for `UserDefinedType` when writing out instead of throwing `ClassCastException` in ORC data source. In more details, `OrcStruct` is being created based on string from`DataType.catalogString`. For user-defined type, it seems it returns `sqlType.simpleString` for `catalogString` by default[1]. However, during type-dispatching to match the output with the schema, it tries to cast to, for example, `StructType`[2]. So, running the codes below (`MyDenseVector` was borrowed[3]) : ``` scala val data = Seq((1, new UDT.MyDenseVector(Array(0.25, 2.25, 4.25)))) val udtDF = data.toDF("id", "vectors") udtDF.write.orc("/tmp/test.orc") ``` ends up throwing an exception as below: ``` java.lang.ClassCastException: org.apache.spark.sql.UDT$MyDenseVectorUDT cannot be cast to org.apache.spark.sql.types.ArrayType at org.apache.spark.sql.hive.HiveInspectors$class.wrapperFor(HiveInspectors.scala:381) at org.apache.spark.sql.hive.orc.OrcSerializer.wrapperFor(OrcFileFormat.scala:164) ... ``` So, this PR uses `UserDefinedType.sqlType` during finding the correct converter when writing out in ORC data source. [1]https://github.com/apache/spark/blob/dfdcab00c7b6200c22883baa3ebc5818be09556f/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala#L95 [2]https://github.com/apache/spark/blob/d2dc8c4a162834818190ffd82894522c524ca3e5/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala#L326 [3]https://github.com/apache/spark/blob/2bfed1a0c5be7d0718fd574a4dad90f4f6b44be7/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala#L38-L70 ## How was this patch tested? Unit tests in `OrcQuerySuite`. Author: hyukjinkwon Closes #15361 from HyukjinKwon/SPARK-17765. (cherry picked from commit a2d464770cd183daa7d727bf377bde9c21e29e6a) Signed-off-by: Reynold Xin --- .../org/apache/spark/sql/hive/HiveInspectors.scala | 3 +++ .../org/apache/spark/sql/hive/orc/OrcQuerySuite.scala | 10 ++++++++++ 2 files changed, 13 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index e303065127c3b..52aa1088acd4a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -246,6 +246,9 @@ private[hive] trait HiveInspectors { * Wraps with Hive types based on object inspector. */ protected def wrapperFor(oi: ObjectInspector, dataType: DataType): Any => Any = oi match { + case _ if dataType.isInstanceOf[UserDefinedType[_]] => + val sqlType = dataType.asInstanceOf[UserDefinedType[_]].sqlType + wrapperFor(oi, sqlType) case x: ConstantObjectInspector => (o: Any) => x.getWritableConstantValue diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index a628977af2f4e..b8761e9de2886 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -93,6 +93,16 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } + test("Read/write UserDefinedType") { + withTempPath { path => + val data = Seq((1, new UDT.MyDenseVector(Array(0.25, 2.25, 4.25)))) + val udtDF = data.toDF("id", "vectors") + udtDF.write.orc(path.getAbsolutePath) + val readBack = spark.read.schema(udtDF.schema).orc(path.getAbsolutePath) + checkAnswer(udtDF, readBack) + } + } + test("Creating case class RDD table") { val data = (1 to 100).map(i => (i, s"val_$i")) sparkContext.parallelize(data).toDF().createOrReplaceTempView("t") From 6dbe44891458b497c1ad4df8d8358e326fb3f795 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 21 Nov 2016 17:24:02 -0800 Subject: [PATCH 164/534] [SPARK-18493] Add missing python APIs: withWatermark and checkpoint to dataframe ## What changes were proposed in this pull request? This PR adds two of the newly added methods of `Dataset`s to Python: `withWatermark` and `checkpoint` ## How was this patch tested? Doc tests Author: Burak Yavuz Closes #15921 from brkyvz/py-watermark. (cherry picked from commit 97a8239a625df455d2c439f3628a529d6d9413ca) Signed-off-by: Shixiong Zhu --- python/pyspark/sql/dataframe.py | 57 ++++++++++++++++++- .../scala/org/apache/spark/sql/Dataset.scala | 10 +++- 2 files changed, 62 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 38998900837cf..6fe622643291e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -322,6 +322,54 @@ def show(self, n=20, truncate=True): def __repr__(self): return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) + @since(2.1) + def checkpoint(self, eager=True): + """Returns a checkpointed version of this Dataset. Checkpointing can be used to truncate the + logical plan of this DataFrame, which is especially useful in iterative algorithms where the + plan may grow exponentially. It will be saved to files inside the checkpoint + directory set with L{SparkContext.setCheckpointDir()}. + + :param eager: Whether to checkpoint this DataFrame immediately + + .. note:: Experimental + """ + jdf = self._jdf.checkpoint(eager) + return DataFrame(jdf, self.sql_ctx) + + @since(2.1) + def withWatermark(self, eventTime, delayThreshold): + """Defines an event time watermark for this :class:`DataFrame`. A watermark tracks a point + in time before which we assume no more late data is going to arrive. + + Spark will use this watermark for several purposes: + - To know when a given time window aggregation can be finalized and thus can be emitted + when using output modes that do not allow updates. + + - To minimize the amount of state that we need to keep for on-going aggregations. + + The current watermark is computed by looking at the `MAX(eventTime)` seen across + all of the partitions in the query minus a user specified `delayThreshold`. Due to the cost + of coordinating this value across partitions, the actual watermark used is only guaranteed + to be at least `delayThreshold` behind the actual event time. In some cases we may still + process records that arrive more than `delayThreshold` late. + + :param eventTime: the name of the column that contains the event time of the row. + :param delayThreshold: the minimum delay to wait to data to arrive late, relative to the + latest record that has been processed in the form of an interval + (e.g. "1 minute" or "5 hours"). + + .. note:: Experimental + + >>> sdf.select('name', sdf.time.cast('timestamp')).withWatermark('time', '10 minutes') + DataFrame[name: string, time: timestamp] + """ + if not eventTime or type(eventTime) is not str: + raise TypeError("eventTime should be provided as a string") + if not delayThreshold or type(delayThreshold) is not str: + raise TypeError("delayThreshold should be provided as a string interval") + jdf = self._jdf.withWatermark(eventTime, delayThreshold) + return DataFrame(jdf, self.sql_ctx) + @since(1.3) def count(self): """Returns the number of rows in this :class:`DataFrame`. @@ -1626,6 +1674,7 @@ def _test(): from pyspark.context import SparkContext from pyspark.sql import Row, SQLContext, SparkSession import pyspark.sql.dataframe + from pyspark.sql.functions import from_unixtime globs = pyspark.sql.dataframe.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc @@ -1638,9 +1687,11 @@ def _test(): globs['df3'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF() globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80), - Row(name='Bob', age=5, height=None), - Row(name='Tom', age=None, height=None), - Row(name=None, age=None, height=None)]).toDF() + Row(name='Bob', age=5, height=None), + Row(name='Tom', age=None, height=None), + Row(name=None, age=None, height=None)]).toDF() + globs['sdf'] = sc.parallelize([Row(name='Tom', time=1479441846), + Row(name='Bob', time=1479442946)]).toDF() (failure_count, test_count) = doctest.testmod( pyspark.sql.dataframe, globs=globs, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 3c75a6a45ec86..7ba6ffce278cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -485,7 +485,10 @@ class Dataset[T] private[sql]( def isStreaming: Boolean = logicalPlan.isStreaming /** - * Returns a checkpointed version of this Dataset. + * Eagerly checkpoint a Dataset and return the new Dataset. Checkpointing can be used to truncate + * the logical plan of this Dataset, which is especially useful in iterative algorithms where the + * plan may grow exponentially. It will be saved to files inside the checkpoint + * directory set with `SparkContext#setCheckpointDir`. * * @group basic * @since 2.1.0 @@ -495,7 +498,10 @@ class Dataset[T] private[sql]( def checkpoint(): Dataset[T] = checkpoint(eager = true) /** - * Returns a checkpointed version of this Dataset. + * Returns a checkpointed version of this Dataset. Checkpointing can be used to truncate the + * logical plan of this Dataset, which is especially useful in iterative algorithms where the + * plan may grow exponentially. It will be saved to files inside the checkpoint + * directory set with `SparkContext#setCheckpointDir`. * * @group basic * @since 2.1.0 From aaa2a173a81868a92d61bcc9420961aaa7eaeb57 Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Mon, 21 Nov 2016 21:14:13 -0800 Subject: [PATCH 165/534] [SPARK-18425][STRUCTURED STREAMING][TESTS] Test `CompactibleFileStreamLog` directly ## What changes were proposed in this pull request? Right now we are testing the most of `CompactibleFileStreamLog` in `FileStreamSinkLogSuite` (because `FileStreamSinkLog` once was the only subclass of `CompactibleFileStreamLog`, but now it's not the case any more). Let's refactor the tests so that `CompactibleFileStreamLog` is directly tested, making future changes (like https://github.com/apache/spark/pull/15828, https://github.com/apache/spark/pull/15827) to `CompactibleFileStreamLog` much easier to test and much easier to review. ## How was this patch tested? the PR itself is about tests Author: Liwei Lin Closes #15870 from lw-lin/test-compact-1113. (cherry picked from commit ebeb0830a3a4837c7354a0eee667b9f5fad389c5) Signed-off-by: Shixiong Zhu --- .../CompactibleFileStreamLogSuite.scala | 216 +++++++++++++++++- .../streaming/FileStreamSinkLogSuite.scala | 68 ------ 2 files changed, 214 insertions(+), 70 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala index 2cd2157b293cb..e511fda57912c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala @@ -17,12 +17,79 @@ package org.apache.spark.sql.execution.streaming -import org.apache.spark.SparkFunSuite +import java.io._ +import java.nio.charset.StandardCharsets._ -class CompactibleFileStreamLogSuite extends SparkFunSuite { +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.sql.execution.streaming.FakeFileSystem._ +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.test.SharedSQLContext + +class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext { + + /** To avoid caching of FS objects */ + override protected val sparkConf = + new SparkConf().set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") import CompactibleFileStreamLog._ + /** -- testing of `object CompactibleFileStreamLog` begins -- */ + + test("getBatchIdFromFileName") { + assert(1234L === getBatchIdFromFileName("1234")) + assert(1234L === getBatchIdFromFileName("1234.compact")) + intercept[NumberFormatException] { + getBatchIdFromFileName("1234a") + } + } + + test("isCompactionBatch") { + assert(false === isCompactionBatch(0, compactInterval = 3)) + assert(false === isCompactionBatch(1, compactInterval = 3)) + assert(true === isCompactionBatch(2, compactInterval = 3)) + assert(false === isCompactionBatch(3, compactInterval = 3)) + assert(false === isCompactionBatch(4, compactInterval = 3)) + assert(true === isCompactionBatch(5, compactInterval = 3)) + } + + test("nextCompactionBatchId") { + assert(2 === nextCompactionBatchId(0, compactInterval = 3)) + assert(2 === nextCompactionBatchId(1, compactInterval = 3)) + assert(5 === nextCompactionBatchId(2, compactInterval = 3)) + assert(5 === nextCompactionBatchId(3, compactInterval = 3)) + assert(5 === nextCompactionBatchId(4, compactInterval = 3)) + assert(8 === nextCompactionBatchId(5, compactInterval = 3)) + } + + test("getValidBatchesBeforeCompactionBatch") { + intercept[AssertionError] { + getValidBatchesBeforeCompactionBatch(0, compactInterval = 3) + } + intercept[AssertionError] { + getValidBatchesBeforeCompactionBatch(1, compactInterval = 3) + } + assert(Seq(0, 1) === getValidBatchesBeforeCompactionBatch(2, compactInterval = 3)) + intercept[AssertionError] { + getValidBatchesBeforeCompactionBatch(3, compactInterval = 3) + } + intercept[AssertionError] { + getValidBatchesBeforeCompactionBatch(4, compactInterval = 3) + } + assert(Seq(2, 3, 4) === getValidBatchesBeforeCompactionBatch(5, compactInterval = 3)) + } + + test("getAllValidBatches") { + assert(Seq(0) === getAllValidBatches(0, compactInterval = 3)) + assert(Seq(0, 1) === getAllValidBatches(1, compactInterval = 3)) + assert(Seq(2) === getAllValidBatches(2, compactInterval = 3)) + assert(Seq(2, 3) === getAllValidBatches(3, compactInterval = 3)) + assert(Seq(2, 3, 4) === getAllValidBatches(4, compactInterval = 3)) + assert(Seq(5) === getAllValidBatches(5, compactInterval = 3)) + assert(Seq(5, 6) === getAllValidBatches(6, compactInterval = 3)) + assert(Seq(5, 6, 7) === getAllValidBatches(7, compactInterval = 3)) + assert(Seq(8) === getAllValidBatches(8, compactInterval = 3)) + } + test("deriveCompactInterval") { // latestCompactBatchId(4) + 1 <= default(5) // then use latestestCompactBatchId + 1 === 5 @@ -30,4 +97,149 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite { // First divisor of 10 greater than 4 === 5 assert(5 === deriveCompactInterval(4, 9)) } + + /** -- testing of `object CompactibleFileStreamLog` ends -- */ + + test("batchIdToPath") { + withFakeCompactibleFileStreamLog( + fileCleanupDelayMs = Long.MaxValue, + defaultCompactInterval = 3, + compactibleLog => { + assert("0" === compactibleLog.batchIdToPath(0).getName) + assert("1" === compactibleLog.batchIdToPath(1).getName) + assert("2.compact" === compactibleLog.batchIdToPath(2).getName) + assert("3" === compactibleLog.batchIdToPath(3).getName) + assert("4" === compactibleLog.batchIdToPath(4).getName) + assert("5.compact" === compactibleLog.batchIdToPath(5).getName) + }) + } + + test("serialize") { + withFakeCompactibleFileStreamLog( + fileCleanupDelayMs = Long.MaxValue, + defaultCompactInterval = 3, + compactibleLog => { + val logs = Array("entry_1", "entry_2", "entry_3") + val expected = s"""${FakeCompactibleFileStreamLog.VERSION} + |"entry_1" + |"entry_2" + |"entry_3"""".stripMargin + val baos = new ByteArrayOutputStream() + compactibleLog.serialize(logs, baos) + assert(expected === baos.toString(UTF_8.name())) + + baos.reset() + compactibleLog.serialize(Array(), baos) + assert(FakeCompactibleFileStreamLog.VERSION === baos.toString(UTF_8.name())) + }) + } + + test("deserialize") { + withFakeCompactibleFileStreamLog( + fileCleanupDelayMs = Long.MaxValue, + defaultCompactInterval = 3, + compactibleLog => { + val logs = s"""${FakeCompactibleFileStreamLog.VERSION} + |"entry_1" + |"entry_2" + |"entry_3"""".stripMargin + val expected = Array("entry_1", "entry_2", "entry_3") + assert(expected === + compactibleLog.deserialize(new ByteArrayInputStream(logs.getBytes(UTF_8)))) + + assert(Nil === + compactibleLog.deserialize( + new ByteArrayInputStream(FakeCompactibleFileStreamLog.VERSION.getBytes(UTF_8)))) + }) + } + + testWithUninterruptibleThread("compact") { + withFakeCompactibleFileStreamLog( + fileCleanupDelayMs = Long.MaxValue, + defaultCompactInterval = 3, + compactibleLog => { + for (batchId <- 0 to 10) { + compactibleLog.add(batchId, Array("some_path_" + batchId)) + val expectedFiles = (0 to batchId).map { id => "some_path_" + id } + assert(compactibleLog.allFiles() === expectedFiles) + if (isCompactionBatch(batchId, 3)) { + // Since batchId is a compaction batch, the batch log file should contain all logs + assert(compactibleLog.get(batchId).getOrElse(Nil) === expectedFiles) + } + } + }) + } + + testWithUninterruptibleThread("delete expired file") { + // Set `fileCleanupDelayMs` to 0 so that we can detect the deleting behaviour deterministically + withFakeCompactibleFileStreamLog( + fileCleanupDelayMs = 0, + defaultCompactInterval = 3, + compactibleLog => { + val fs = compactibleLog.metadataPath.getFileSystem(spark.sessionState.newHadoopConf()) + + def listBatchFiles(): Set[String] = { + fs.listStatus(compactibleLog.metadataPath).map(_.getPath.getName).filter { fileName => + try { + getBatchIdFromFileName(fileName) + true + } catch { + case _: NumberFormatException => false + } + }.toSet + } + + compactibleLog.add(0, Array("some_path_0")) + assert(Set("0") === listBatchFiles()) + compactibleLog.add(1, Array("some_path_1")) + assert(Set("0", "1") === listBatchFiles()) + compactibleLog.add(2, Array("some_path_2")) + assert(Set("2.compact") === listBatchFiles()) + compactibleLog.add(3, Array("some_path_3")) + assert(Set("2.compact", "3") === listBatchFiles()) + compactibleLog.add(4, Array("some_path_4")) + assert(Set("2.compact", "3", "4") === listBatchFiles()) + compactibleLog.add(5, Array("some_path_5")) + assert(Set("5.compact") === listBatchFiles()) + }) + } + + private def withFakeCompactibleFileStreamLog( + fileCleanupDelayMs: Long, + defaultCompactInterval: Int, + f: FakeCompactibleFileStreamLog => Unit + ): Unit = { + withTempDir { file => + val compactibleLog = new FakeCompactibleFileStreamLog( + fileCleanupDelayMs, + defaultCompactInterval, + spark, + file.getCanonicalPath) + f(compactibleLog) + } + } +} + +object FakeCompactibleFileStreamLog { + val VERSION = "test_version" +} + +class FakeCompactibleFileStreamLog( + _fileCleanupDelayMs: Long, + _defaultCompactInterval: Int, + sparkSession: SparkSession, + path: String) + extends CompactibleFileStreamLog[String]( + FakeCompactibleFileStreamLog.VERSION, + sparkSession, + path + ) { + + override protected def fileCleanupDelayMs: Long = _fileCleanupDelayMs + + override protected def isDeletingExpiredLog: Boolean = true + + override protected def defaultCompactInterval: Int = _defaultCompactInterval + + override def compactLogs(logs: Seq[String]): Seq[String] = logs } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala index e1bc674a28071..e046fee0c04d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala @@ -29,61 +29,6 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { import CompactibleFileStreamLog._ import FileStreamSinkLog._ - test("getBatchIdFromFileName") { - assert(1234L === getBatchIdFromFileName("1234")) - assert(1234L === getBatchIdFromFileName("1234.compact")) - intercept[NumberFormatException] { - getBatchIdFromFileName("1234a") - } - } - - test("isCompactionBatch") { - assert(false === isCompactionBatch(0, compactInterval = 3)) - assert(false === isCompactionBatch(1, compactInterval = 3)) - assert(true === isCompactionBatch(2, compactInterval = 3)) - assert(false === isCompactionBatch(3, compactInterval = 3)) - assert(false === isCompactionBatch(4, compactInterval = 3)) - assert(true === isCompactionBatch(5, compactInterval = 3)) - } - - test("nextCompactionBatchId") { - assert(2 === nextCompactionBatchId(0, compactInterval = 3)) - assert(2 === nextCompactionBatchId(1, compactInterval = 3)) - assert(5 === nextCompactionBatchId(2, compactInterval = 3)) - assert(5 === nextCompactionBatchId(3, compactInterval = 3)) - assert(5 === nextCompactionBatchId(4, compactInterval = 3)) - assert(8 === nextCompactionBatchId(5, compactInterval = 3)) - } - - test("getValidBatchesBeforeCompactionBatch") { - intercept[AssertionError] { - getValidBatchesBeforeCompactionBatch(0, compactInterval = 3) - } - intercept[AssertionError] { - getValidBatchesBeforeCompactionBatch(1, compactInterval = 3) - } - assert(Seq(0, 1) === getValidBatchesBeforeCompactionBatch(2, compactInterval = 3)) - intercept[AssertionError] { - getValidBatchesBeforeCompactionBatch(3, compactInterval = 3) - } - intercept[AssertionError] { - getValidBatchesBeforeCompactionBatch(4, compactInterval = 3) - } - assert(Seq(2, 3, 4) === getValidBatchesBeforeCompactionBatch(5, compactInterval = 3)) - } - - test("getAllValidBatches") { - assert(Seq(0) === getAllValidBatches(0, compactInterval = 3)) - assert(Seq(0, 1) === getAllValidBatches(1, compactInterval = 3)) - assert(Seq(2) === getAllValidBatches(2, compactInterval = 3)) - assert(Seq(2, 3) === getAllValidBatches(3, compactInterval = 3)) - assert(Seq(2, 3, 4) === getAllValidBatches(4, compactInterval = 3)) - assert(Seq(5) === getAllValidBatches(5, compactInterval = 3)) - assert(Seq(5, 6) === getAllValidBatches(6, compactInterval = 3)) - assert(Seq(5, 6, 7) === getAllValidBatches(7, compactInterval = 3)) - assert(Seq(8) === getAllValidBatches(8, compactInterval = 3)) - } - test("compactLogs") { withFileStreamSinkLog { sinkLog => val logs = Seq( @@ -184,19 +129,6 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { } } - test("batchIdToPath") { - withSQLConf(SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key -> "3") { - withFileStreamSinkLog { sinkLog => - assert("0" === sinkLog.batchIdToPath(0).getName) - assert("1" === sinkLog.batchIdToPath(1).getName) - assert("2.compact" === sinkLog.batchIdToPath(2).getName) - assert("3" === sinkLog.batchIdToPath(3).getName) - assert("4" === sinkLog.batchIdToPath(4).getName) - assert("5.compact" === sinkLog.batchIdToPath(5).getName) - } - } - } - testWithUninterruptibleThread("compact") { withSQLConf(SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key -> "3") { withFileStreamSinkLog { sinkLog => From c7021407597480bddf226ffa6d1d3f682408dfeb Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 22 Nov 2016 00:05:30 -0800 Subject: [PATCH 166/534] [SPARK-18444][SPARKR] SparkR running in yarn-cluster mode should not download Spark package. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? When running SparkR job in yarn-cluster mode, it will download Spark package from apache website which is not necessary. ``` ./bin/spark-submit --master yarn-cluster ./examples/src/main/r/dataframe.R ``` The following is output: ``` Attaching package: ‘SparkR’ The following objects are masked from ‘package:stats’: cov, filter, lag, na.omit, predict, sd, var, window The following objects are masked from ‘package:base’: as.data.frame, colnames, colnames<-, drop, endsWith, intersect, rank, rbind, sample, startsWith, subset, summary, transform, union Spark not found in SPARK_HOME: Spark not found in the cache directory. Installation will start. MirrorUrl not provided. Looking for preferred site from apache website... ...... ``` There's no ```SPARK_HOME``` in yarn-cluster mode since the R process is in a remote host of the yarn cluster rather than in the client host. The JVM comes up first and the R process then connects to it. So in such cases we should never have to download Spark as Spark is already running. ## How was this patch tested? Offline test. Author: Yanbo Liang Closes #15888 from yanboliang/spark-18444. (cherry picked from commit acb97157796231fef74aba985825b05b607b9279) Signed-off-by: Yanbo Liang --- R/pkg/R/sparkR.R | 20 +++++++---- R/pkg/R/utils.R | 4 +++ R/pkg/inst/tests/testthat/test_sparkR.R | 46 +++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 6 deletions(-) create mode 100644 R/pkg/inst/tests/testthat/test_sparkR.R diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 6b4a2f2fdc85c..a7152b4313993 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -373,8 +373,13 @@ sparkR.session <- function( overrideEnvs(sparkConfigMap, paramMap) } + deployMode <- "" + if (exists("spark.submit.deployMode", envir = sparkConfigMap)) { + deployMode <- sparkConfigMap[["spark.submit.deployMode"]] + } + if (!exists(".sparkRjsc", envir = .sparkREnv)) { - retHome <- sparkCheckInstall(sparkHome, master) + retHome <- sparkCheckInstall(sparkHome, master, deployMode) if (!is.null(retHome)) sparkHome <- retHome sparkExecutorEnvMap <- new.env() sparkR.sparkContext(master, appName, sparkHome, sparkConfigMap, sparkExecutorEnvMap, @@ -550,24 +555,27 @@ processSparkPackages <- function(packages) { # # @param sparkHome directory to find Spark package. # @param master the Spark master URL, used to check local or remote mode. +# @param deployMode whether to deploy your driver on the worker nodes (cluster) +# or locally as an external client (client). # @return NULL if no need to update sparkHome, and new sparkHome otherwise. -sparkCheckInstall <- function(sparkHome, master) { +sparkCheckInstall <- function(sparkHome, master, deployMode) { if (!isSparkRShell()) { if (!is.na(file.info(sparkHome)$isdir)) { msg <- paste0("Spark package found in SPARK_HOME: ", sparkHome) message(msg) NULL } else { - if (!nzchar(master) || isMasterLocal(master)) { - msg <- paste0("Spark not found in SPARK_HOME: ", - sparkHome) + if (isMasterLocal(master)) { + msg <- paste0("Spark not found in SPARK_HOME: ", sparkHome) message(msg) packageLocalDir <- install.spark() packageLocalDir - } else { + } else if (isClientMode(master) || deployMode == "client") { msg <- paste0("Spark not found in SPARK_HOME: ", sparkHome, "\n", installInstruction("remote")) stop(msg) + } else { + NULL } } } else { diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 20004549cc037..098c0e3e31e95 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -777,6 +777,10 @@ isMasterLocal <- function(master) { grepl("^local(\\[([0-9]+|\\*)\\])?$", master, perl = TRUE) } +isClientMode <- function(master) { + grepl("([a-z]+)-client$", master, perl = TRUE) +} + isSparkRShell <- function() { grepl(".*shell\\.R$", Sys.getenv("R_PROFILE_USER"), perl = TRUE) } diff --git a/R/pkg/inst/tests/testthat/test_sparkR.R b/R/pkg/inst/tests/testthat/test_sparkR.R new file mode 100644 index 0000000000000..f73fc6baeccef --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_sparkR.R @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("functions in sparkR.R") + +test_that("sparkCheckInstall", { + # "local, yarn-client, mesos-client" mode, SPARK_HOME was set correctly, + # and the SparkR job was submitted by "spark-submit" + sparkHome <- paste0(tempdir(), "/", "sparkHome") + dir.create(sparkHome) + master <- "" + deployMode <- "" + expect_true(is.null(sparkCheckInstall(sparkHome, master, deployMode))) + unlink(sparkHome, recursive = TRUE) + + # "yarn-cluster, mesos-cluster" mode, SPARK_HOME was not set, + # and the SparkR job was submitted by "spark-submit" + sparkHome <- "" + master <- "" + deployMode <- "" + expect_true(is.null(sparkCheckInstall(sparkHome, master, deployMode))) + + # "yarn-client, mesos-client" mode, SPARK_HOME was not set + sparkHome <- "" + master <- "yarn-client" + deployMode <- "" + expect_error(sparkCheckInstall(sparkHome, master, deployMode)) + sparkHome <- "" + master <- "" + deployMode <- "client" + expect_error(sparkCheckInstall(sparkHome, master, deployMode)) +}) From 63aa01ffe06e49af032b57ba2eb28dfb8f14f779 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 22 Nov 2016 11:26:10 +0000 Subject: [PATCH 167/534] [SPARK-18514][DOCS] Fix the markdown for `Note:`/`NOTE:`/`Note that` across R API documentation ## What changes were proposed in this pull request? It seems in R, there are - `Note:` - `NOTE:` - `Note that` This PR proposes to fix those to `Note:` to be consistent. **Before** ![2016-11-21 11 30 07](https://cloud.githubusercontent.com/assets/6477701/20468848/2f27b0fa-afde-11e6-89e3-993701269dbe.png) **After** ![2016-11-21 11 29 44](https://cloud.githubusercontent.com/assets/6477701/20468851/39469664-afde-11e6-9929-ad80be7fc405.png) ## How was this patch tested? The notes were found via ```bash grep -r "NOTE: " . grep -r "Note that " . ``` And then fixed one by one comparing with API documentation. After that, manually tested via `sh create-docs.sh` under `./R`. Author: hyukjinkwon Closes #15952 from HyukjinKwon/SPARK-18514. (cherry picked from commit 4922f9cdcac8b7c10320ac1fb701997fffa45d46) Signed-off-by: Sean Owen --- R/pkg/R/DataFrame.R | 6 ++++-- R/pkg/R/functions.R | 7 ++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 4e3d97bb3ad07..9a51d530f120a 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2541,7 +2541,8 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { #' #' Return a new SparkDataFrame containing the union of rows in this SparkDataFrame #' and another SparkDataFrame. This is equivalent to \code{UNION ALL} in SQL. -#' Note that this does not remove duplicate rows across the two SparkDataFrames. +#' +#' Note: This does not remove duplicate rows across the two SparkDataFrames. #' #' @param x A SparkDataFrame #' @param y A SparkDataFrame @@ -2584,7 +2585,8 @@ setMethod("unionAll", #' Union two or more SparkDataFrames #' #' Union two or more SparkDataFrames. This is equivalent to \code{UNION ALL} in SQL. -#' Note that this does not remove duplicate rows across the two SparkDataFrames. +#' +#' Note: This does not remove duplicate rows across the two SparkDataFrames. #' #' @param x a SparkDataFrame. #' @param ... additional SparkDataFrame(s). diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index f8a9d3ce5d918..bf5c96373c632 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -2296,7 +2296,7 @@ setMethod("n", signature(x = "Column"), #' A pattern could be for instance \preformatted{dd.MM.yyyy} and could return a string like '18.03.1993'. All #' pattern letters of \code{java.text.SimpleDateFormat} can be used. #' -#' NOTE: Use when ever possible specialized functions like \code{year}. These benefit from a +#' Note: Use when ever possible specialized functions like \code{year}. These benefit from a #' specialized implementation. #' #' @param y Column to compute on. @@ -2341,7 +2341,7 @@ setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), #' Locate the position of the first occurrence of substr column in the given string. #' Returns null if either of the arguments are null. #' -#' NOTE: The position is not zero based, but 1 based index. Returns 0 if substr +#' Note: The position is not zero based, but 1 based index. Returns 0 if substr #' could not be found in str. #' #' @param y column to check @@ -2779,7 +2779,8 @@ setMethod("window", signature(x = "Column"), #' locate #' #' Locate the position of the first occurrence of substr. -#' NOTE: The position is not zero based, but 1 based index. Returns 0 if substr +#' +#' Note: The position is not zero based, but 1 based index. Returns 0 if substr #' could not be found in str. #' #' @param substr a character string to be matched. From 36cd10d19d95418cec4b789545afc798088be315 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 22 Nov 2016 11:40:18 +0000 Subject: [PATCH 168/534] [SPARK-18447][DOCS] Fix the markdown for `Note:`/`NOTE:`/`Note that` across Python API documentation ## What changes were proposed in this pull request? It seems in Python, there are - `Note:` - `NOTE:` - `Note that` - `.. note::` This PR proposes to fix those to `.. note::` to be consistent. **Before** 2016-11-21 1 18 49 2016-11-21 12 42 43 **After** 2016-11-21 1 18 42 2016-11-21 12 42 51 ## How was this patch tested? The notes were found via ```bash grep -r "Note: " . grep -r "NOTE: " . grep -r "Note that " . ``` And then fixed one by one comparing with API documentation. After that, manually tested via `make html` under `./python/docs`. Author: hyukjinkwon Closes #15947 from HyukjinKwon/SPARK-18447. (cherry picked from commit 933a6548d423cf17448207a99299cf36fc1a95f6) Signed-off-by: Sean Owen --- python/pyspark/conf.py | 4 +- python/pyspark/context.py | 8 ++-- python/pyspark/ml/classification.py | 45 +++++++++--------- python/pyspark/ml/clustering.py | 8 ++-- python/pyspark/ml/feature.py | 13 +++--- python/pyspark/ml/linalg/__init__.py | 11 +++-- python/pyspark/ml/regression.py | 32 ++++++------- python/pyspark/mllib/clustering.py | 6 +-- python/pyspark/mllib/feature.py | 24 +++++----- python/pyspark/mllib/linalg/__init__.py | 11 +++-- python/pyspark/mllib/linalg/distributed.py | 15 +++--- python/pyspark/mllib/regression.py | 2 +- python/pyspark/mllib/stat/_statistics.py | 3 +- python/pyspark/mllib/tree.py | 12 ++--- python/pyspark/rdd.py | 54 +++++++++++----------- python/pyspark/sql/dataframe.py | 28 ++++++----- python/pyspark/sql/functions.py | 11 +++-- python/pyspark/sql/streaming.py | 10 ++-- python/pyspark/streaming/context.py | 2 +- python/pyspark/streaming/kinesis.py | 4 +- 20 files changed, 157 insertions(+), 146 deletions(-) diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index 64b6f238e9c32..491b3a81972bc 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -90,8 +90,8 @@ class SparkConf(object): All setter methods in this class support chaining. For example, you can write C{conf.setMaster("local").setAppName("My app")}. - Note that once a SparkConf object is passed to Spark, it is cloned - and can no longer be modified by the user. + .. note:: Once a SparkConf object is passed to Spark, it is cloned + and can no longer be modified by the user. """ def __init__(self, loadDefaults=True, _jvm=None, _jconf=None): diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 2c2cf6a373bb7..2fd3aee01d76c 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -520,8 +520,8 @@ def wholeTextFiles(self, path, minPartitions=None, use_unicode=True): ... (a-hdfs-path/part-nnnnn, its content) - NOTE: Small files are preferred, as each file will be loaded - fully in memory. + .. note:: Small files are preferred, as each file will be loaded + fully in memory. >>> dirPath = os.path.join(tempdir, "files") >>> os.mkdir(dirPath) @@ -547,8 +547,8 @@ def binaryFiles(self, path, minPartitions=None): in a key-value pair, where the key is the path of each file, the value is the content of each file. - Note: Small files are preferred, large file is also allowable, but - may cause bad performance. + .. note:: Small files are preferred, large file is also allowable, but + may cause bad performance. """ minPartitions = minPartitions or self.defaultMinPartitions return RDD(self._jsc.binaryFiles(path, minPartitions), self, diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 83e1e89347660..8054a34db30f2 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -440,9 +440,9 @@ def roc(self): .. seealso:: `Wikipedia reference \ `_ - Note: This ignores instance weights (setting all to 1.0) from - `LogisticRegression.weightCol`. This will change in later Spark - versions. + .. note:: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. """ return self._call_java("roc") @@ -453,9 +453,9 @@ def areaUnderROC(self): Computes the area under the receiver operating characteristic (ROC) curve. - Note: This ignores instance weights (setting all to 1.0) from - `LogisticRegression.weightCol`. This will change in later Spark - versions. + .. note:: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. """ return self._call_java("areaUnderROC") @@ -467,9 +467,9 @@ def pr(self): containing two fields recall, precision with (0.0, 1.0) prepended to it. - Note: This ignores instance weights (setting all to 1.0) from - `LogisticRegression.weightCol`. This will change in later Spark - versions. + .. note:: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. """ return self._call_java("pr") @@ -480,9 +480,9 @@ def fMeasureByThreshold(self): Returns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0. - Note: This ignores instance weights (setting all to 1.0) from - `LogisticRegression.weightCol`. This will change in later Spark - versions. + .. note:: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. """ return self._call_java("fMeasureByThreshold") @@ -494,9 +494,9 @@ def precisionByThreshold(self): Every possible probability obtained in transforming the dataset are used as thresholds used in calculating the precision. - Note: This ignores instance weights (setting all to 1.0) from - `LogisticRegression.weightCol`. This will change in later Spark - versions. + .. note:: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. """ return self._call_java("precisionByThreshold") @@ -508,9 +508,9 @@ def recallByThreshold(self): Every possible probability obtained in transforming the dataset are used as thresholds used in calculating the recall. - Note: This ignores instance weights (setting all to 1.0) from - `LogisticRegression.weightCol`. This will change in later Spark - versions. + .. note:: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. """ return self._call_java("recallByThreshold") @@ -695,9 +695,9 @@ def featureImportances(self): where gain is scaled by the number of instances passing through node - Normalize importances for tree to sum to 1. - Note: Feature importance for single decision trees can have high variance due to - correlated predictor variables. Consider using a :py:class:`RandomForestClassifier` - to determine feature importance instead. + .. note:: Feature importance for single decision trees can have high variance due to + correlated predictor variables. Consider using a :py:class:`RandomForestClassifier` + to determine feature importance instead. """ return self._call_java("featureImportances") @@ -839,7 +839,6 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol `Gradient-Boosted Trees (GBTs) `_ learning algorithm for classification. It supports binary labels, as well as both continuous and categorical features. - Note: Multiclass labels are not currently supported. The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999. @@ -851,6 +850,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol - We expect to implement TreeBoost in the future: `SPARK-4240 `_ + .. note:: Multiclass labels are not currently supported. + >>> from numpy import allclose >>> from pyspark.ml.linalg import Vectors >>> from pyspark.ml.feature import StringIndexer diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index e58ec1e7ac296..b29b5ac70e6fe 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -155,7 +155,7 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte While this process is generally guaranteed to converge, it is not guaranteed to find a global optimum. - Note: For high-dimensional data (with many features), this algorithm may perform poorly. + .. note:: For high-dimensional data (with many features), this algorithm may perform poorly. This is due to high-dimensional data (a) making it difficult to cluster at all (based on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions. @@ -749,9 +749,9 @@ def getCheckpointFiles(self): If using checkpointing and :py:attr:`LDA.keepLastCheckpoint` is set to true, then there may be saved checkpoint files. This method is provided so that users can manage those files. - Note that removing the checkpoints can cause failures if a partition is lost and is needed - by certain :py:class:`DistributedLDAModel` methods. Reference counting will clean up the - checkpoints when this model and derivative data go out of scope. + .. note:: Removing the checkpoints can cause failures if a partition is lost and is needed + by certain :py:class:`DistributedLDAModel` methods. Reference counting will clean up + the checkpoints when this model and derivative data go out of scope. :return List of checkpoint files from training """ diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 635cf1304588e..40b63d4d31d4b 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -742,8 +742,8 @@ class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Jav For the case E_max == E_min, Rescaled(e_i) = 0.5 * (max + min) - Note that since zero values will probably be transformed to non-zero values, output of the - transformer will be DenseVector even for sparse input. + .. note:: Since zero values will probably be transformed to non-zero values, output of the + transformer will be DenseVector even for sparse input. >>> from pyspark.ml.linalg import Vectors >>> df = spark.createDataFrame([(Vectors.dense([0.0]),), (Vectors.dense([2.0]),)], ["a"]) @@ -1014,9 +1014,9 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, :py:attr:`dropLast`) because it makes the vector entries sum up to one, and hence linearly dependent. So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. - Note that this is different from scikit-learn's OneHotEncoder, - which keeps all categories. - The output vectors are sparse. + + .. note:: This is different from scikit-learn's OneHotEncoder, + which keeps all categories. The output vectors are sparse. .. seealso:: @@ -1698,7 +1698,8 @@ def getLabels(self): class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ A feature transformer that filters out stop words from input. - Note: null values from input array are preserved unless adding null to stopWords explicitly. + + .. note:: null values from input array are preserved unless adding null to stopWords explicitly. >>> df = spark.createDataFrame([(["a", "b", "c"],)], ["text"]) >>> remover = StopWordsRemover(inputCol="text", outputCol="words", stopWords=["b"]) diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py index a5df727fdb418..1705c156ce4c8 100644 --- a/python/pyspark/ml/linalg/__init__.py +++ b/python/pyspark/ml/linalg/__init__.py @@ -746,11 +746,12 @@ def __hash__(self): class Vectors(object): """ - Factory methods for working with vectors. Note that dense vectors - are simply represented as NumPy array objects, so there is no need - to covert them for use in MLlib. For sparse vectors, the factory - methods in this class create an MLlib-compatible type, or users - can pass in SciPy's C{scipy.sparse} column vectors. + Factory methods for working with vectors. + + .. note:: Dense vectors are simply represented as NumPy array objects, + so there is no need to covert them for use in MLlib. For sparse vectors, + the factory methods in this class create an MLlib-compatible type, or users + can pass in SciPy's C{scipy.sparse} column vectors. """ @staticmethod diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 385391ba53fd4..b42e807069802 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -245,9 +245,9 @@ def explainedVariance(self): .. seealso:: `Wikipedia explain variation \ `_ - Note: This ignores instance weights (setting all to 1.0) from - `LinearRegression.weightCol`. This will change in later Spark - versions. + .. note:: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. """ return self._call_java("explainedVariance") @@ -259,9 +259,9 @@ def meanAbsoluteError(self): corresponding to the expected value of the absolute error loss or l1-norm loss. - Note: This ignores instance weights (setting all to 1.0) from - `LinearRegression.weightCol`. This will change in later Spark - versions. + .. note:: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. """ return self._call_java("meanAbsoluteError") @@ -273,9 +273,9 @@ def meanSquaredError(self): corresponding to the expected value of the squared error loss or quadratic loss. - Note: This ignores instance weights (setting all to 1.0) from - `LinearRegression.weightCol`. This will change in later Spark - versions. + .. note:: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. """ return self._call_java("meanSquaredError") @@ -286,9 +286,9 @@ def rootMeanSquaredError(self): Returns the root mean squared error, which is defined as the square root of the mean squared error. - Note: This ignores instance weights (setting all to 1.0) from - `LinearRegression.weightCol`. This will change in later Spark - versions. + .. note:: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. """ return self._call_java("rootMeanSquaredError") @@ -301,9 +301,9 @@ def r2(self): .. seealso:: `Wikipedia coefficient of determination \ ` - Note: This ignores instance weights (setting all to 1.0) from - `LinearRegression.weightCol`. This will change in later Spark - versions. + .. note:: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. """ return self._call_java("r2") @@ -822,7 +822,7 @@ def featureImportances(self): where gain is scaled by the number of instances passing through node - Normalize importances for tree to sum to 1. - Note: Feature importance for single decision trees can have high variance due to + .. note:: Feature importance for single decision trees can have high variance due to correlated predictor variables. Consider using a :py:class:`RandomForestRegressor` to determine feature importance instead. """ diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 2036168e456fd..91123ace3387e 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -699,9 +699,9 @@ class StreamingKMeansModel(KMeansModel): * n_t+1: New number of weights. * a: Decay Factor, which gives the forgetfulness. - Note that if a is set to 1, it is the weighted mean of the previous - and new data. If it set to zero, the old centroids are completely - forgotten. + .. note:: If a is set to 1, it is the weighted mean of the previous + and new data. If it set to zero, the old centroids are completely + forgotten. :param clusterCenters: Initial cluster centers. diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 7eaa2282cb8bb..bde0f67be775c 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -114,9 +114,9 @@ def transform(self, vector): """ Applies transformation on a vector or an RDD[Vector]. - Note: In Python, transform cannot currently be used within - an RDD transformation or action. - Call transform directly on the RDD instead. + .. note:: In Python, transform cannot currently be used within + an RDD transformation or action. + Call transform directly on the RDD instead. :param vector: Vector or RDD of Vector to be transformed. """ @@ -139,9 +139,9 @@ def transform(self, vector): """ Applies standardization transformation on a vector. - Note: In Python, transform cannot currently be used within - an RDD transformation or action. - Call transform directly on the RDD instead. + .. note:: In Python, transform cannot currently be used within + an RDD transformation or action. + Call transform directly on the RDD instead. :param vector: Vector or RDD of Vector to be standardized. :return: Standardized vector. If the variance of a column is @@ -407,7 +407,7 @@ class HashingTF(object): Maps a sequence of terms to their term frequencies using the hashing trick. - Note: the terms must be hashable (can not be dict/set/list...). + .. note:: The terms must be hashable (can not be dict/set/list...). :param numFeatures: number of features (default: 2^20) @@ -469,9 +469,9 @@ def transform(self, x): the terms which occur in fewer than `minDocFreq` documents will have an entry of 0. - Note: In Python, transform cannot currently be used within - an RDD transformation or action. - Call transform directly on the RDD instead. + .. note:: In Python, transform cannot currently be used within + an RDD transformation or action. + Call transform directly on the RDD instead. :param x: an RDD of term frequency vectors or a term frequency vector @@ -551,7 +551,7 @@ def transform(self, word): """ Transforms a word to its vector representation - Note: local use only + .. note:: Local use only :param word: a word :return: vector representation of word(s) @@ -570,7 +570,7 @@ def findSynonyms(self, word, num): :param num: number of synonyms to find :return: array of (word, cosineSimilarity) - Note: local use only + .. note:: Local use only """ if not isinstance(word, basestring): word = _convert_to_vector(word) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index d37e715c8d8ec..031f22c02098e 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -835,11 +835,12 @@ def __hash__(self): class Vectors(object): """ - Factory methods for working with vectors. Note that dense vectors - are simply represented as NumPy array objects, so there is no need - to covert them for use in MLlib. For sparse vectors, the factory - methods in this class create an MLlib-compatible type, or users - can pass in SciPy's C{scipy.sparse} column vectors. + Factory methods for working with vectors. + + .. note:: Dense vectors are simply represented as NumPy array objects, + so there is no need to covert them for use in MLlib. For sparse vectors, + the factory methods in this class create an MLlib-compatible type, or users + can pass in SciPy's C{scipy.sparse} column vectors. """ @staticmethod diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index 538cada7d163d..600655c912ca6 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -171,8 +171,9 @@ def computeColumnSummaryStatistics(self): def computeCovariance(self): """ Computes the covariance matrix, treating each row as an - observation. Note that this cannot be computed on matrices - with more than 65535 columns. + observation. + + .. note:: This cannot be computed on matrices with more than 65535 columns. >>> rows = sc.parallelize([[1, 2], [2, 1]]) >>> mat = RowMatrix(rows) @@ -185,8 +186,9 @@ def computeCovariance(self): @since('2.0.0') def computeGramianMatrix(self): """ - Computes the Gramian matrix `A^T A`. Note that this cannot be - computed on matrices with more than 65535 columns. + Computes the Gramian matrix `A^T A`. + + .. note:: This cannot be computed on matrices with more than 65535 columns. >>> rows = sc.parallelize([[1, 2, 3], [4, 5, 6]]) >>> mat = RowMatrix(rows) @@ -458,8 +460,9 @@ def columnSimilarities(self): @since('2.0.0') def computeGramianMatrix(self): """ - Computes the Gramian matrix `A^T A`. Note that this cannot be - computed on matrices with more than 65535 columns. + Computes the Gramian matrix `A^T A`. + + .. note:: This cannot be computed on matrices with more than 65535 columns. >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]), ... IndexedRow(1, [4, 5, 6])]) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 705022934e41b..1b66f5b51044b 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -44,7 +44,7 @@ class LabeledPoint(object): Vector of features for this point (NumPy array, list, pyspark.mllib.linalg.SparseVector, or scipy.sparse column matrix). - Note: 'label' and 'features' are accessible as class attributes. + .. note:: 'label' and 'features' are accessible as class attributes. .. versionadded:: 1.0.0 """ diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py index 67d5f0e44f41c..49b26446dbc32 100644 --- a/python/pyspark/mllib/stat/_statistics.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -164,7 +164,6 @@ def chiSqTest(observed, expected=None): of fit test of the observed data against the expected distribution, or againt the uniform distribution (by default), with each category having an expected frequency of `1 / len(observed)`. - (Note: `observed` cannot contain negative values) If `observed` is matrix, conduct Pearson's independence test on the input contingency matrix, which cannot contain negative entries or @@ -176,6 +175,8 @@ def chiSqTest(observed, expected=None): contingency matrix for which the chi-squared statistic is computed. All label and feature values must be categorical. + .. note:: `observed` cannot contain negative values + :param observed: it could be a vector containing the observed categorical counts/relative frequencies, or the contingency matrix (containing either counts or relative frequencies), diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index b3011d42e56af..a6089fc8b9d32 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -40,9 +40,9 @@ def predict(self, x): Predict values for a single data point or an RDD of points using the model trained. - Note: In Python, predict cannot currently be used within an RDD - transformation or action. - Call predict directly on the RDD instead. + .. note:: In Python, predict cannot currently be used within an RDD + transformation or action. + Call predict directly on the RDD instead. """ if isinstance(x, RDD): return self.call("predict", x.map(_convert_to_vector)) @@ -85,9 +85,9 @@ def predict(self, x): """ Predict the label of one or more examples. - Note: In Python, predict cannot currently be used within an RDD - transformation or action. - Call predict directly on the RDD instead. + .. note:: In Python, predict cannot currently be used within an RDD + transformation or action. + Call predict directly on the RDD instead. :param x: Data point (feature vector), or an RDD of data points (feature diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index f21a364df9100..9e05da89af082 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -417,10 +417,8 @@ def sample(self, withReplacement, fraction, seed=None): with replacement: expected number of times each element is chosen; fraction must be >= 0 :param seed: seed for the random number generator - .. note:: - - This is not guaranteed to provide exactly the fraction specified of the total count - of the given :class:`DataFrame`. + .. note:: This is not guaranteed to provide exactly the fraction specified of the total + count of the given :class:`DataFrame`. >>> rdd = sc.parallelize(range(100), 4) >>> 6 <= rdd.sample(False, 0.1, 81).count() <= 14 @@ -460,8 +458,8 @@ def takeSample(self, withReplacement, num, seed=None): """ Return a fixed-size sampled subset of this RDD. - Note that this method should only be used if the resulting array is expected - to be small, as all the data is loaded into the driver's memory. + .. note:: This method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. >>> rdd = sc.parallelize(range(0, 10)) >>> len(rdd.takeSample(True, 20, 1)) @@ -572,7 +570,7 @@ def intersection(self, other): Return the intersection of this RDD and another one. The output will not contain any duplicate elements, even if the input RDDs did. - Note that this method performs a shuffle internally. + .. note:: This method performs a shuffle internally. >>> rdd1 = sc.parallelize([1, 10, 2, 3, 4, 5]) >>> rdd2 = sc.parallelize([1, 6, 2, 3, 7, 8]) @@ -803,8 +801,9 @@ def func(it): def collect(self): """ Return a list that contains all of the elements in this RDD. - Note that this method should only be used if the resulting array is expected - to be small, as all the data is loaded into the driver's memory. + + .. note:: This method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. """ with SCCallSiteSync(self.context) as css: port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) @@ -1251,10 +1250,10 @@ def top(self, num, key=None): """ Get the top N elements from an RDD. - Note that this method should only be used if the resulting array is expected - to be small, as all the data is loaded into the driver's memory. + .. note:: This method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. - Note: It returns the list sorted in descending order. + .. note:: It returns the list sorted in descending order. >>> sc.parallelize([10, 4, 2, 12, 3]).top(1) [12] @@ -1276,8 +1275,8 @@ def takeOrdered(self, num, key=None): Get the N elements from an RDD ordered in ascending order or as specified by the optional key function. - Note that this method should only be used if the resulting array is expected - to be small, as all the data is loaded into the driver's memory. + .. note:: this method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7]).takeOrdered(6) [1, 2, 3, 4, 5, 6] @@ -1298,11 +1297,11 @@ def take(self, num): that partition to estimate the number of additional partitions needed to satisfy the limit. - Note that this method should only be used if the resulting array is expected - to be small, as all the data is loaded into the driver's memory. - Translated from the Scala implementation in RDD#take(). + .. note:: this method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. + >>> sc.parallelize([2, 3, 4, 5, 6]).cache().take(2) [2, 3] >>> sc.parallelize([2, 3, 4, 5, 6]).take(10) @@ -1366,8 +1365,9 @@ def first(self): def isEmpty(self): """ - Returns true if and only if the RDD contains no elements at all. Note that an RDD - may be empty even when it has at least 1 partition. + Returns true if and only if the RDD contains no elements at all. + + .. note:: an RDD may be empty even when it has at least 1 partition. >>> sc.parallelize([]).isEmpty() True @@ -1558,8 +1558,8 @@ def collectAsMap(self): """ Return the key-value pairs in this RDD to the master as a dictionary. - Note that this method should only be used if the resulting data is expected - to be small, as all the data is loaded into the driver's memory. + .. note:: this method should only be used if the resulting data is expected + to be small, as all the data is loaded into the driver's memory. >>> m = sc.parallelize([(1, 2), (3, 4)]).collectAsMap() >>> m[1] @@ -1796,8 +1796,7 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, set of aggregation functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined - type" C. Note that V and C can be different -- for example, one might - group an RDD of type (Int, Int) into an RDD of type (Int, List[Int]). + type" C. Users provide three functions: @@ -1809,6 +1808,9 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, In addition, users can control the partitioning of the output RDD. + .. note:: V and C can be different -- for example, one might group an RDD of type + (Int, Int) into an RDD of type (Int, List[Int]). + >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> def add(a, b): return a + str(b) >>> sorted(x.combineByKey(str, add, add).collect()) @@ -1880,9 +1882,9 @@ def groupByKey(self, numPartitions=None, partitionFunc=portable_hash): Group the values for each key in the RDD into a single sequence. Hash-partitions the resulting RDD with numPartitions partitions. - Note: If you are grouping in order to perform an aggregation (such as a - sum or average) over each key, using reduceByKey or aggregateByKey will - provide much better performance. + .. note:: If you are grouping in order to perform an aggregation (such as a + sum or average) over each key, using reduceByKey or aggregateByKey will + provide much better performance. >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> sorted(rdd.groupByKey().mapValues(len).collect()) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 6fe622643291e..b9d90384e3e2c 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -457,7 +457,7 @@ def foreachPartition(self, f): def cache(self): """Persists the :class:`DataFrame` with the default storage level (C{MEMORY_AND_DISK}). - .. note:: the default storage level has changed to C{MEMORY_AND_DISK} to match Scala in 2.0. + .. note:: The default storage level has changed to C{MEMORY_AND_DISK} to match Scala in 2.0. """ self.is_cached = True self._jdf.cache() @@ -470,7 +470,7 @@ def persist(self, storageLevel=StorageLevel.MEMORY_AND_DISK): a new storage level if the :class:`DataFrame` does not have a storage level set yet. If no storage level is specified defaults to (C{MEMORY_AND_DISK}). - .. note:: the default storage level has changed to C{MEMORY_AND_DISK} to match Scala in 2.0. + .. note:: The default storage level has changed to C{MEMORY_AND_DISK} to match Scala in 2.0. """ self.is_cached = True javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) @@ -597,10 +597,8 @@ def distinct(self): def sample(self, withReplacement, fraction, seed=None): """Returns a sampled subset of this :class:`DataFrame`. - .. note:: - - This is not guaranteed to provide exactly the fraction specified of the total count - of the given :class:`DataFrame`. + .. note:: This is not guaranteed to provide exactly the fraction specified of the total + count of the given :class:`DataFrame`. >>> df.sample(False, 0.5, 42).count() 2 @@ -866,8 +864,8 @@ def describe(self, *cols): This include count, mean, stddev, min, and max. If no columns are given, this function computes statistics for all numerical or string columns. - .. note:: This function is meant for exploratory data analysis, as we make no \ - guarantee about the backward compatibility of the schema of the resulting DataFrame. + .. note:: This function is meant for exploratory data analysis, as we make no + guarantee about the backward compatibility of the schema of the resulting DataFrame. >>> df.describe(['age']).show() +-------+------------------+ @@ -900,8 +898,8 @@ def describe(self, *cols): def head(self, n=None): """Returns the first ``n`` rows. - Note that this method should only be used if the resulting array is expected - to be small, as all the data is loaded into the driver's memory. + .. note:: This method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. :param n: int, default 1. Number of rows to return. :return: If n is greater than 1, return a list of :class:`Row`. @@ -1462,8 +1460,8 @@ def freqItems(self, cols, support=None): "http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou". :func:`DataFrame.freqItems` and :func:`DataFrameStatFunctions.freqItems` are aliases. - .. note:: This function is meant for exploratory data analysis, as we make no \ - guarantee about the backward compatibility of the schema of the resulting DataFrame. + .. note:: This function is meant for exploratory data analysis, as we make no + guarantee about the backward compatibility of the schema of the resulting DataFrame. :param cols: Names of the columns to calculate frequent items for as a list or tuple of strings. @@ -1564,11 +1562,11 @@ def toDF(self, *cols): def toPandas(self): """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. - Note that this method should only be used if the resulting Pandas's DataFrame is expected - to be small, as all the data is loaded into the driver's memory. - This is only available if Pandas is installed and available. + .. note:: This method should only be used if the resulting Pandas's DataFrame is expected + to be small, as all the data is loaded into the driver's memory. + >>> df.toPandas() # doctest: +SKIP age name 0 2 Alice diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 46a092f16d4fc..d8abafcde3846 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -359,7 +359,7 @@ def grouping_id(*cols): (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn) - .. note:: the list of columns should match with grouping columns exactly, or empty (means all + .. note:: The list of columns should match with grouping columns exactly, or empty (means all the grouping columns). >>> df.cube("name").agg(grouping_id(), sum("age")).orderBy("name").show() @@ -547,7 +547,7 @@ def shiftRightUnsigned(col, numBits): def spark_partition_id(): """A column for partition ID. - Note that this is indeterministic because it depends on data partitioning and task scheduling. + .. note:: This is indeterministic because it depends on data partitioning and task scheduling. >>> df.repartition(1).select(spark_partition_id().alias("pid")).collect() [Row(pid=0), Row(pid=0)] @@ -1852,9 +1852,10 @@ def __call__(self, *cols): @since(1.3) def udf(f, returnType=StringType()): """Creates a :class:`Column` expression representing a user defined function (UDF). - Note that the user-defined functions must be deterministic. Due to optimization, - duplicate invocations may be eliminated or the function may even be invoked more times than - it is present in the query. + + .. note:: The user-defined functions must be deterministic. Due to optimization, + duplicate invocations may be eliminated or the function may even be invoked more times than + it is present in the query. :param f: python function :param returnType: a :class:`pyspark.sql.types.DataType` object diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 0e4589be976ea..9c3a237699f96 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -90,10 +90,12 @@ def awaitTermination(self, timeout=None): @since(2.0) def processAllAvailable(self): """Blocks until all available data in the source has been processed and committed to the - sink. This method is intended for testing. Note that in the case of continually arriving - data, this method may block forever. Additionally, this method is only guaranteed to block - until data that has been synchronously appended data to a stream source prior to invocation. - (i.e. `getOffset` must immediately reflect the addition). + sink. This method is intended for testing. + + .. note:: In the case of continually arriving data, this method may block forever. + Additionally, this method is only guaranteed to block until data that has been + synchronously appended data to a stream source prior to invocation. + (i.e. `getOffset` must immediately reflect the addition). """ return self._jsq.processAllAvailable() diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index ec3ad9933cf60..17c34f8a1c54c 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -304,7 +304,7 @@ def queueStream(self, rdds, oneAtATime=True, default=None): Create an input stream from an queue of RDDs or list. In each batch, it will process either one or all of the RDDs returned by the queue. - NOTE: changes to the queue after the stream is created will not be recognized. + .. note:: Changes to the queue after the stream is created will not be recognized. @param rdds: Queue of RDDs @param oneAtATime: pick one rdd each time or pick all of them once. diff --git a/python/pyspark/streaming/kinesis.py b/python/pyspark/streaming/kinesis.py index 434ce83e1e6f9..3a8d8b819fd37 100644 --- a/python/pyspark/streaming/kinesis.py +++ b/python/pyspark/streaming/kinesis.py @@ -42,8 +42,8 @@ def createStream(ssc, kinesisAppName, streamName, endpointUrl, regionName, Create an input stream that pulls messages from a Kinesis stream. This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. - Note: The given AWS credentials will get saved in DStream checkpoints if checkpointing is - enabled. Make sure that your checkpoint directory is secure. + .. note:: The given AWS credentials will get saved in DStream checkpoints if checkpointing + is enabled. Make sure that your checkpoint directory is secure. :param ssc: StreamingContext object :param kinesisAppName: Kinesis application name used by the Kinesis Client Library (KCL) to From 0e60e4b88014fcdd54acc650bfd3a1683f06f09e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 22 Nov 2016 09:16:20 -0800 Subject: [PATCH 169/534] [SPARK-18519][SQL] map type can not be used in EqualTo ## What changes were proposed in this pull request? Technically map type is not orderable, but can be used in equality comparison. However, due to the limitation of the current implementation, map type can't be used in equality comparison so that it can't be join key or grouping key. This PR makes this limitation explicit, to avoid wrong result. ## How was this patch tested? updated tests. Author: Wenchen Fan Closes #15956 from cloud-fan/map-type. (cherry picked from commit bb152cdfbb8d02130c71d2326ae81939725c2cf0) Signed-off-by: Herman van Hovell --- .../sql/catalyst/analysis/CheckAnalysis.scala | 15 ------- .../sql/catalyst/expressions/predicates.scala | 30 +++++++++++++ .../analysis/AnalysisErrorSuite.scala | 44 +++++++------------ .../ExpressionTypeCheckingSuite.scala | 2 + 4 files changed, 48 insertions(+), 43 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 98e50d0d3c674..80e577e5c4c79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -183,21 +183,6 @@ trait CheckAnalysis extends PredicateHelper { s"join condition '${condition.sql}' " + s"of type ${condition.dataType.simpleString} is not a boolean.") - case j @ Join(_, _, _, Some(condition)) => - def checkValidJoinConditionExprs(expr: Expression): Unit = expr match { - case p: Predicate => - p.asInstanceOf[Expression].children.foreach(checkValidJoinConditionExprs) - case e if e.dataType.isInstanceOf[BinaryType] => - failAnalysis(s"binary type expression ${e.sql} cannot be used " + - "in join conditions") - case e if e.dataType.isInstanceOf[MapType] => - failAnalysis(s"map type expression ${e.sql} cannot be used " + - "in join conditions") - case _ => // OK - } - - checkValidJoinConditionExprs(condition) - case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { case aggExpr: AggregateExpression => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 7946c201f4ffc..2ad452b6a90ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -412,6 +412,21 @@ case class EqualTo(left: Expression, right: Expression) override def inputType: AbstractDataType = AnyDataType + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + // TODO: although map type is not orderable, technically map type should be able to be used + // in equality comparison, remove this type check once we support it. + if (left.dataType.existsRecursively(_.isInstanceOf[MapType])) { + TypeCheckResult.TypeCheckFailure("Cannot use map type in EqualTo, but the actual " + + s"input type is ${left.dataType.catalogString}.") + } else { + TypeCheckResult.TypeCheckSuccess + } + case failure => failure + } + } + override def symbol: String = "=" protected override def nullSafeEval(input1: Any, input2: Any): Any = { @@ -440,6 +455,21 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp override def inputType: AbstractDataType = AnyDataType + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + // TODO: although map type is not orderable, technically map type should be able to be used + // in equality comparison, remove this type check once we support it. + if (left.dataType.existsRecursively(_.isInstanceOf[MapType])) { + TypeCheckResult.TypeCheckFailure("Cannot use map type in EqualNullSafe, but the actual " + + s"input type is ${left.dataType.catalogString}.") + } else { + TypeCheckResult.TypeCheckSuccess + } + case failure => failure + } + } + override def symbol: String = "<=>" override def nullable: Boolean = false diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 21afe9fec5944..8c1faea2394c6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -465,34 +465,22 @@ class AnalysisErrorSuite extends AnalysisTest { "another aggregate function." :: Nil) } - test("Join can't work on binary and map types") { - val plan = - Join( - LocalRelation( - AttributeReference("a", BinaryType)(exprId = ExprId(2)), - AttributeReference("b", IntegerType)(exprId = ExprId(1))), - LocalRelation( - AttributeReference("c", BinaryType)(exprId = ExprId(4)), - AttributeReference("d", IntegerType)(exprId = ExprId(3))), - Cross, - Some(EqualTo(AttributeReference("a", BinaryType)(exprId = ExprId(2)), - AttributeReference("c", BinaryType)(exprId = ExprId(4))))) - - assertAnalysisError(plan, "binary type expression `a` cannot be used in join conditions" :: Nil) - - val plan2 = - Join( - LocalRelation( - AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)), - AttributeReference("b", IntegerType)(exprId = ExprId(1))), - LocalRelation( - AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4)), - AttributeReference("d", IntegerType)(exprId = ExprId(3))), - Cross, - Some(EqualTo(AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)), - AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4))))) - - assertAnalysisError(plan2, "map type expression `a` cannot be used in join conditions" :: Nil) + test("Join can work on binary types but can't work on map types") { + val left = LocalRelation('a.binary, 'b.map(StringType, StringType)) + val right = LocalRelation('c.binary, 'd.map(StringType, StringType)) + + val plan1 = left.join( + right, + joinType = Cross, + condition = Some('a === 'c)) + + assertAnalysisSuccess(plan1) + + val plan2 = left.join( + right, + joinType = Cross, + condition = Some('b === 'd)) + assertAnalysisError(plan2, "Cannot use map type in EqualTo" :: Nil) } test("PredicateSubQuery is used outside of a filter") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 542e654bbce12..744057b7c5f4c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -111,6 +111,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) + assertError(EqualTo('mapField, 'mapField), "Cannot use map type in EqualTo") + assertError(EqualNullSafe('mapField, 'mapField), "Cannot use map type in EqualNullSafe") assertError(LessThan('mapField, 'mapField), s"requires ${TypeCollection.Ordered.simpleString} type") assertError(LessThanOrEqual('mapField, 'mapField), From 0e624e990b3b426dba0a6149ad6340f85d214a58 Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Tue, 22 Nov 2016 12:06:21 -0800 Subject: [PATCH 170/534] [SPARK-18504][SQL] Scalar subquery with extra group by columns returning incorrect result ## What changes were proposed in this pull request? This PR blocks an incorrect result scenario in scalar subquery where there are GROUP BY column(s) that are not part of the correlated predicate(s). Example: // Incorrect result Seq(1).toDF("c1").createOrReplaceTempView("t1") Seq((1,1),(1,2)).toDF("c1","c2").createOrReplaceTempView("t2") sql("select (select sum(-1) from t2 where t1.c1=t2.c1 group by t2.c2) from t1").show // How can selecting a scalar subquery from a 1-row table return 2 rows? ## How was this patch tested? sql/test, catalyst/test new test case covering the reported problem is added to SubquerySuite.scala Author: Nattavut Sutyanyong Closes #15936 from nsyca/scalarSubqueryIncorrect-1. (cherry picked from commit 45ea46b7b397f023b4da878eb11e21b08d931115) Signed-off-by: Herman van Hovell --- .../sql/catalyst/analysis/Analyzer.scala | 3 -- .../sql/catalyst/analysis/CheckAnalysis.scala | 30 +++++++++++++++---- .../org/apache/spark/sql/SubquerySuite.scala | 12 ++++++++ 3 files changed, 36 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b7e167557c559..2918e9d158829 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1182,9 +1182,6 @@ class Analyzer( */ private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = { plan transformExpressions { - case s @ ScalarSubquery(sub, conditions, exprId) - if sub.resolved && conditions.isEmpty && sub.output.size != 1 => - failAnalysis(s"Scalar subquery must return only one column, but got ${sub.output.size}") case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved => resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, exprId) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 80e577e5c4c79..26d26385904f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -117,19 +117,37 @@ trait CheckAnalysis extends PredicateHelper { failAnalysis(s"Window specification $s is not valid because $m") case None => w } + case s @ ScalarSubquery(query, conditions, _) + // If no correlation, the output must be exactly one column + if (conditions.isEmpty && query.output.size != 1) => + failAnalysis( + s"Scalar subquery must return only one column, but got ${query.output.size}") case s @ ScalarSubquery(query, conditions, _) if conditions.nonEmpty => - // Make sure correlated scalar subqueries contain one row for every outer row by - // enforcing that they are aggregates which contain exactly one aggregate expressions. - // The analyzer has already checked that subquery contained only one output column, and - // added all the grouping expressions to the aggregate. - def checkAggregate(a: Aggregate): Unit = { - val aggregates = a.expressions.flatMap(_.collect { + def checkAggregate(agg: Aggregate): Unit = { + // Make sure correlated scalar subqueries contain one row for every outer row by + // enforcing that they are aggregates which contain exactly one aggregate expressions. + // The analyzer has already checked that subquery contained only one output column, + // and added all the grouping expressions to the aggregate. + val aggregates = agg.expressions.flatMap(_.collect { case a: AggregateExpression => a }) if (aggregates.isEmpty) { failAnalysis("The output of a correlated scalar subquery must be aggregated") } + + // SPARK-18504: block cases where GROUP BY columns + // are not part of the correlated columns + val groupByCols = ExpressionSet.apply(agg.groupingExpressions.flatMap(_.references)) + val predicateCols = ExpressionSet.apply(conditions.flatMap(_.references)) + val invalidCols = groupByCols.diff(predicateCols) + // GROUP BY columns must be a subset of columns in the predicates + if (invalidCols.nonEmpty) { + failAnalysis( + "a GROUP BY clause in a scalar correlated subquery " + + "cannot contain non-correlated columns: " + + invalidCols.mkString(",")) + } } // Skip projects and subquery aliases added by the Analyzer and the SQLBuilder. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index c84a6f161893c..f1dd1c620e660 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -483,6 +483,18 @@ class SubquerySuite extends QueryTest with SharedSQLContext { Row(1, null) :: Row(2, 6.0) :: Row(3, 2.0) :: Row(null, null) :: Row(6, null) :: Nil) } + test("SPARK-18504 extra GROUP BY column in correlated scalar subquery is not permitted") { + withTempView("t") { + Seq((1, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t") + + val errMsg = intercept[AnalysisException] { + sql("select (select sum(-1) from t t2 where t1.c2 = t2.c1 group by t2.c2) sum from t t1") + } + assert(errMsg.getMessage.contains( + "a GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns:")) + } + } + test("non-aggregated correlated scalar subquery") { val msg1 = intercept[AnalysisException] { sql("select a, (select b from l l2 where l2.a = l1.a) sum_b from l l1") From fa360134d06e5bfb423f0bd769edb47dbda1d9af Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 22 Nov 2016 15:25:22 -0500 Subject: [PATCH 171/534] [SPARK-18507][SQL] HiveExternalCatalog.listPartitions should only call getTable once ## What changes were proposed in this pull request? HiveExternalCatalog.listPartitions should only call `getTable` once, instead of calling it for every partitions. ## How was this patch tested? N/A Author: Wenchen Fan Closes #15978 from cloud-fan/perf. (cherry picked from commit 702cd403fc8e5ce8281fe8828197ead46bdb8832) Signed-off-by: Andrew Or --- .../scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 5dbb4024bbee0..ff0923f04893d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -907,8 +907,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat db: String, table: String, partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = withClient { + val actualPartColNames = getTable(db, table).partitionColumnNames client.getPartitions(db, table, partialSpec.map(lowerCasePartitionSpec)).map { part => - part.copy(spec = restorePartitionSpec(part.spec, getTable(db, table).partitionColumnNames)) + part.copy(spec = restorePartitionSpec(part.spec, actualPartColNames)) } } From fb2ea54a69b521463b93b270b63081da726ee036 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 22 Nov 2016 13:03:50 -0800 Subject: [PATCH 172/534] [SPARK-18465] Add 'IF EXISTS' clause to 'UNCACHE' to not throw exceptions when table doesn't exist ## What changes were proposed in this pull request? While this behavior is debatable, consider the following use case: ```sql UNCACHE TABLE foo; CACHE TABLE foo AS SELECT * FROM bar ``` The command above fails the first time you run it. But I want to run the command above over and over again, and I don't want to change my code just for the first run of it. The issue is that subsequent `CACHE TABLE` commands do not overwrite the existing table. Now we can do: ```sql UNCACHE TABLE IF EXISTS foo; CACHE TABLE foo AS SELECT * FROM bar ``` ## How was this patch tested? Unit tests Author: Burak Yavuz Closes #15896 from brkyvz/uncache. (cherry picked from commit bdc8153e8689262708c7fade5c065bd7fc8a84fc) Signed-off-by: Herman van Hovell --- .../org/apache/spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../apache/spark/sql/execution/SparkSqlParser.scala | 2 +- .../apache/spark/sql/execution/command/cache.scala | 12 ++++++++++-- .../org/apache/spark/sql/hive/CachedTableSuite.scala | 5 ++++- 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index b599a884957a8..0aa2a97407c53 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -142,7 +142,7 @@ statement | REFRESH TABLE tableIdentifier #refreshTable | REFRESH .*? #refreshResource | CACHE LAZY? TABLE tableIdentifier (AS? query)? #cacheTable - | UNCACHE TABLE tableIdentifier #uncacheTable + | UNCACHE TABLE (IF EXISTS)? tableIdentifier #uncacheTable | CLEAR CACHE #clearCache | LOAD DATA LOCAL? INPATH path=STRING OVERWRITE? INTO TABLE tableIdentifier partitionSpec? #loadData diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index b8be3d17ba444..47610453ac23a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -233,7 +233,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * Create an [[UncacheTableCommand]] logical plan. */ override def visitUncacheTable(ctx: UncacheTableContext): LogicalPlan = withOrigin(ctx) { - UncacheTableCommand(visitTableIdentifier(ctx.tableIdentifier)) + UncacheTableCommand(visitTableIdentifier(ctx.tableIdentifier), ctx.EXISTS != null) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index c31f4dc9aba4b..336f14dd97aea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -49,10 +50,17 @@ case class CacheTableCommand( } -case class UncacheTableCommand(tableIdent: TableIdentifier) extends RunnableCommand { +case class UncacheTableCommand( + tableIdent: TableIdentifier, + ifExists: Boolean) extends RunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { - sparkSession.catalog.uncacheTable(tableIdent.quotedString) + val tableId = tableIdent.quotedString + try { + sparkSession.catalog.uncacheTable(tableId) + } catch { + case _: NoSuchTableException if ifExists => // don't throw + } Seq.empty[Row] } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index fc35304c80ecc..3871b3d785882 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -101,13 +101,16 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto sql("DROP TABLE IF EXISTS nonexistantTable") } - test("correct error on uncache of nonexistant tables") { + test("uncache of nonexistant tables") { + // make sure table doesn't exist + intercept[NoSuchTableException](spark.table("nonexistantTable")) intercept[NoSuchTableException] { spark.catalog.uncacheTable("nonexistantTable") } intercept[NoSuchTableException] { sql("UNCACHE TABLE nonexistantTable") } + sql("UNCACHE TABLE IF EXISTS nonexistantTable") } test("no error on uncache of non-cached table") { From bd338f60d7f30f0cb735dffb39b3a6ec60766301 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 22 Nov 2016 14:15:57 -0800 Subject: [PATCH 173/534] [SPARK-18373][SPARK-18529][SS][KAFKA] Make failOnDataLoss=false work with Spark jobs ## What changes were proposed in this pull request? This PR adds `CachedKafkaConsumer.getAndIgnoreLostData` to handle corner cases of `failOnDataLoss=false`. It also resolves [SPARK-18529](https://issues.apache.org/jira/browse/SPARK-18529) after refactoring codes: Timeout will throw a TimeoutException. ## How was this patch tested? Because I cannot find any way to manually control the Kafka server to clean up logs, it's impossible to write unit tests for each corner case. Therefore, I just created `test("stress test for failOnDataLoss=false")` which should cover most of corner cases. I also modified some existing tests to test for both `failOnDataLoss=false` and `failOnDataLoss=true` to make sure it doesn't break existing logic. Author: Shixiong Zhu Closes #15820 from zsxwing/failOnDataLoss. (cherry picked from commit 2fd101b2f0028e005fbb0bdd29e59af37aa637da) Signed-off-by: Tathagata Das --- .../sql/kafka010/CachedKafkaConsumer.scala | 236 ++++++++++++-- .../spark/sql/kafka010/KafkaSource.scala | 23 +- .../spark/sql/kafka010/KafkaSourceRDD.scala | 42 ++- .../spark/sql/kafka010/KafkaSourceSuite.scala | 297 +++++++++++++++--- .../spark/sql/kafka010/KafkaTestUtils.scala | 20 +- 5 files changed, 523 insertions(+), 95 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala index 3b5a96534f9b6..3f438e99185b5 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala @@ -18,12 +18,16 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} +import java.util.concurrent.TimeoutException -import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord, KafkaConsumer} +import scala.collection.JavaConverters._ + +import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord, KafkaConsumer, OffsetOutOfRangeException} import org.apache.kafka.common.TopicPartition import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging +import org.apache.spark.sql.kafka010.KafkaSource._ /** @@ -34,10 +38,18 @@ import org.apache.spark.internal.Logging private[kafka010] case class CachedKafkaConsumer private( topicPartition: TopicPartition, kafkaParams: ju.Map[String, Object]) extends Logging { + import CachedKafkaConsumer._ private val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - private val consumer = { + private var consumer = createConsumer + + /** Iterator to the already fetch data */ + private var fetchedData = ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]] + private var nextOffsetInFetchedData = UNKNOWN_OFFSET + + /** Create a KafkaConsumer to fetch records for `topicPartition` */ + private def createConsumer: KafkaConsumer[Array[Byte], Array[Byte]] = { val c = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams) val tps = new ju.ArrayList[TopicPartition]() tps.add(topicPartition) @@ -45,42 +57,193 @@ private[kafka010] case class CachedKafkaConsumer private( c } - /** Iterator to the already fetch data */ - private var fetchedData = ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]] - private var nextOffsetInFetchedData = -2L - /** - * Get the record for the given offset, waiting up to timeout ms if IO is necessary. - * Sequential forward access will use buffers, but random access will be horribly inefficient. + * Get the record for the given offset if available. Otherwise it will either throw error + * (if failOnDataLoss = true), or return the next available offset within [offset, untilOffset), + * or null. + * + * @param offset the offset to fetch. + * @param untilOffset the max offset to fetch. Exclusive. + * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. + * @param failOnDataLoss When `failOnDataLoss` is `true`, this method will either return record at + * offset if available, or throw exception.when `failOnDataLoss` is `false`, + * this method will either return record at offset if available, or return + * the next earliest available record less than untilOffset, or null. It + * will not throw any exception. */ - def get(offset: Long, pollTimeoutMs: Long): ConsumerRecord[Array[Byte], Array[Byte]] = { + def get( + offset: Long, + untilOffset: Long, + pollTimeoutMs: Long, + failOnDataLoss: Boolean): ConsumerRecord[Array[Byte], Array[Byte]] = { + require(offset < untilOffset, + s"offset must always be less than untilOffset [offset: $offset, untilOffset: $untilOffset]") logDebug(s"Get $groupId $topicPartition nextOffset $nextOffsetInFetchedData requested $offset") - if (offset != nextOffsetInFetchedData) { - logInfo(s"Initial fetch for $topicPartition $offset") - seek(offset) - poll(pollTimeoutMs) + // The following loop is basically for `failOnDataLoss = false`. When `failOnDataLoss` is + // `false`, first, we will try to fetch the record at `offset`. If no such record exists, then + // we will move to the next available offset within `[offset, untilOffset)` and retry. + // If `failOnDataLoss` is `true`, the loop body will be executed only once. + var toFetchOffset = offset + while (toFetchOffset != UNKNOWN_OFFSET) { + try { + return fetchData(toFetchOffset, pollTimeoutMs) + } catch { + case e: OffsetOutOfRangeException => + // When there is some error thrown, it's better to use a new consumer to drop all cached + // states in the old consumer. We don't need to worry about the performance because this + // is not a common path. + resetConsumer() + reportDataLoss(failOnDataLoss, s"Cannot fetch offset $toFetchOffset", e) + toFetchOffset = getEarliestAvailableOffsetBetween(toFetchOffset, untilOffset) + } } + resetFetchedData() + null + } - if (!fetchedData.hasNext()) { poll(pollTimeoutMs) } - assert(fetchedData.hasNext(), - s"Failed to get records for $groupId $topicPartition $offset " + - s"after polling for $pollTimeoutMs") - var record = fetchedData.next() + /** + * Return the next earliest available offset in [offset, untilOffset). If all offsets in + * [offset, untilOffset) are invalid (e.g., the topic is deleted and recreated), it will return + * `UNKNOWN_OFFSET`. + */ + private def getEarliestAvailableOffsetBetween(offset: Long, untilOffset: Long): Long = { + val (earliestOffset, latestOffset) = getAvailableOffsetRange() + logWarning(s"Some data may be lost. Recovering from the earliest offset: $earliestOffset") + if (offset >= latestOffset || earliestOffset >= untilOffset) { + // [offset, untilOffset) and [earliestOffset, latestOffset) have no overlap, + // either + // -------------------------------------------------------- + // ^ ^ ^ ^ + // | | | | + // earliestOffset latestOffset offset untilOffset + // + // or + // -------------------------------------------------------- + // ^ ^ ^ ^ + // | | | | + // offset untilOffset earliestOffset latestOffset + val warningMessage = + s""" + |The current available offset range is [$earliestOffset, $latestOffset). + | Offset ${offset} is out of range, and records in [$offset, $untilOffset) will be + | skipped ${additionalMessage(failOnDataLoss = false)} + """.stripMargin + logWarning(warningMessage) + UNKNOWN_OFFSET + } else if (offset >= earliestOffset) { + // ----------------------------------------------------------------------------- + // ^ ^ ^ ^ + // | | | | + // earliestOffset offset min(untilOffset,latestOffset) max(untilOffset, latestOffset) + // + // This will happen when a topic is deleted and recreated, and new data are pushed very fast, + // then we will see `offset` disappears first then appears again. Although the parameters + // are same, the state in Kafka cluster is changed, so the outer loop won't be endless. + logWarning(s"Found a disappeared offset $offset. " + + s"Some data may be lost ${additionalMessage(failOnDataLoss = false)}") + offset + } else { + // ------------------------------------------------------------------------------ + // ^ ^ ^ ^ + // | | | | + // offset earliestOffset min(untilOffset,latestOffset) max(untilOffset, latestOffset) + val warningMessage = + s""" + |The current available offset range is [$earliestOffset, $latestOffset). + | Offset ${offset} is out of range, and records in [$offset, $earliestOffset) will be + | skipped ${additionalMessage(failOnDataLoss = false)} + """.stripMargin + logWarning(warningMessage) + earliestOffset + } + } - if (record.offset != offset) { - logInfo(s"Buffer miss for $groupId $topicPartition $offset") + /** + * Get the record at `offset`. + * + * @throws OffsetOutOfRangeException if `offset` is out of range + * @throws TimeoutException if cannot fetch the record in `pollTimeoutMs` milliseconds. + */ + private def fetchData( + offset: Long, + pollTimeoutMs: Long): ConsumerRecord[Array[Byte], Array[Byte]] = { + if (offset != nextOffsetInFetchedData || !fetchedData.hasNext()) { + // This is the first fetch, or the last pre-fetched data has been drained. + // Seek to the offset because we may call seekToBeginning or seekToEnd before this. seek(offset) poll(pollTimeoutMs) - assert(fetchedData.hasNext(), - s"Failed to get records for $groupId $topicPartition $offset " + - s"after polling for $pollTimeoutMs") - record = fetchedData.next() + } + + if (!fetchedData.hasNext()) { + // We cannot fetch anything after `poll`. Two possible cases: + // - `offset` is out of range so that Kafka returns nothing. Just throw + // `OffsetOutOfRangeException` to let the caller handle it. + // - Cannot fetch any data before timeout. TimeoutException will be thrown. + val (earliestOffset, latestOffset) = getAvailableOffsetRange() + if (offset < earliestOffset || offset >= latestOffset) { + throw new OffsetOutOfRangeException( + Map(topicPartition -> java.lang.Long.valueOf(offset)).asJava) + } else { + throw new TimeoutException( + s"Cannot fetch record for offset $offset in $pollTimeoutMs milliseconds") + } + } else { + val record = fetchedData.next() + nextOffsetInFetchedData = record.offset + 1 + // `seek` is always called before "poll". So "record.offset" must be same as "offset". assert(record.offset == offset, - s"Got wrong record for $groupId $topicPartition even after seeking to offset $offset") + s"The fetched data has a different offset: expected $offset but was ${record.offset}") + record } + } + + /** Create a new consumer and reset cached states */ + private def resetConsumer(): Unit = { + consumer.close() + consumer = createConsumer + resetFetchedData() + } - nextOffsetInFetchedData = offset + 1 - record + /** Reset the internal pre-fetched data. */ + private def resetFetchedData(): Unit = { + nextOffsetInFetchedData = UNKNOWN_OFFSET + fetchedData = ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]] + } + + /** + * Return an addition message including useful message and instruction. + */ + private def additionalMessage(failOnDataLoss: Boolean): String = { + if (failOnDataLoss) { + s"(GroupId: $groupId, TopicPartition: $topicPartition). " + + s"$INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE" + } else { + s"(GroupId: $groupId, TopicPartition: $topicPartition). " + + s"$INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE" + } + } + + /** + * Throw an exception or log a warning as per `failOnDataLoss`. + */ + private def reportDataLoss( + failOnDataLoss: Boolean, + message: String, + cause: Throwable = null): Unit = { + val finalMessage = s"$message ${additionalMessage(failOnDataLoss)}" + if (failOnDataLoss) { + if (cause != null) { + throw new IllegalStateException(finalMessage) + } else { + throw new IllegalStateException(finalMessage, cause) + } + } else { + if (cause != null) { + logWarning(finalMessage) + } else { + logWarning(finalMessage, cause) + } + } } private def close(): Unit = consumer.close() @@ -96,10 +259,24 @@ private[kafka010] case class CachedKafkaConsumer private( logDebug(s"Polled $groupId ${p.partitions()} ${r.size}") fetchedData = r.iterator } + + /** + * Return the available offset range of the current partition. It's a pair of the earliest offset + * and the latest offset. + */ + private def getAvailableOffsetRange(): (Long, Long) = { + consumer.seekToBeginning(Set(topicPartition).asJava) + val earliestOffset = consumer.position(topicPartition) + consumer.seekToEnd(Set(topicPartition).asJava) + val latestOffset = consumer.position(topicPartition) + (earliestOffset, latestOffset) + } } private[kafka010] object CachedKafkaConsumer extends Logging { + private val UNKNOWN_OFFSET = -2L + private case class CacheKey(groupId: String, topicPartition: TopicPartition) private lazy val cache = { @@ -140,7 +317,10 @@ private[kafka010] object CachedKafkaConsumer extends Logging { // If this is reattempt at running the task, then invalidate cache and start with // a new consumer if (TaskContext.get != null && TaskContext.get.attemptNumber > 1) { - cache.remove(key) + val removedConsumer = cache.remove(key) + if (removedConsumer != null) { + removedConsumer.close() + } new CachedKafkaConsumer(topicPartition, kafkaParams) } else { if (!cache.containsKey(key)) { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 341081a338c0e..1d0d402b82a35 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -281,7 +281,7 @@ private[kafka010] case class KafkaSource( // Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays. val rdd = new KafkaSourceRDD( - sc, executorKafkaParams, offsetRanges, pollTimeoutMs).map { cr => + sc, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss).map { cr => Row(cr.key, cr.value, cr.topic, cr.partition, cr.offset, cr.timestamp, cr.timestampType.id) } @@ -463,10 +463,9 @@ private[kafka010] case class KafkaSource( */ private def reportDataLoss(message: String): Unit = { if (failOnDataLoss) { - throw new IllegalStateException(message + - ". Set the source option 'failOnDataLoss' to 'false' if you want to ignore these checks.") + throw new IllegalStateException(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE") } else { - logWarning(message) + logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE") } } } @@ -475,6 +474,22 @@ private[kafka010] case class KafkaSource( /** Companion object for the [[KafkaSource]]. */ private[kafka010] object KafkaSource { + val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE = + """ + |Some data may have been lost because they are not available in Kafka any more; either the + | data was aged out by Kafka or the topic may have been deleted before all the data in the + | topic was processed. If you want your streaming query to fail on such cases, set the source + | option "failOnDataLoss" to "true". + """.stripMargin + + val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE = + """ + |Some data may have been lost because they are not available in Kafka any more; either the + | data was aged out by Kafka or the topic may have been deleted before all the data in the + | topic was processed. If you don't want your streaming query to fail on such cases, set the + | source option "failOnDataLoss" to "false". + """.stripMargin + def kafkaSchema: StructType = StructType(Seq( StructField("key", BinaryType), StructField("value", BinaryType), diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala index 802dd040aed93..244cd2c225bdd 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala @@ -28,6 +28,7 @@ import org.apache.spark.{Partition, SparkContext, TaskContext} import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.NextIterator /** Offset range that one partition of the KafkaSourceRDD has to read */ @@ -61,7 +62,8 @@ private[kafka010] class KafkaSourceRDD( sc: SparkContext, executorKafkaParams: ju.Map[String, Object], offsetRanges: Seq[KafkaSourceRDDOffsetRange], - pollTimeoutMs: Long) + pollTimeoutMs: Long, + failOnDataLoss: Boolean) extends RDD[ConsumerRecord[Array[Byte], Array[Byte]]](sc, Nil) { override def persist(newLevel: StorageLevel): this.type = { @@ -130,23 +132,31 @@ private[kafka010] class KafkaSourceRDD( logInfo(s"Beginning offset ${range.fromOffset} is the same as ending offset " + s"skipping ${range.topic} ${range.partition}") Iterator.empty - } else { - - val consumer = CachedKafkaConsumer.getOrCreate( - range.topic, range.partition, executorKafkaParams) - var requestOffset = range.fromOffset - - logDebug(s"Creating iterator for $range") - - new Iterator[ConsumerRecord[Array[Byte], Array[Byte]]]() { - override def hasNext(): Boolean = requestOffset < range.untilOffset - override def next(): ConsumerRecord[Array[Byte], Array[Byte]] = { - assert(hasNext(), "Can't call next() once untilOffset has been reached") - val r = consumer.get(requestOffset, pollTimeoutMs) - requestOffset += 1 - r + new NextIterator[ConsumerRecord[Array[Byte], Array[Byte]]]() { + val consumer = CachedKafkaConsumer.getOrCreate( + range.topic, range.partition, executorKafkaParams) + var requestOffset = range.fromOffset + + override def getNext(): ConsumerRecord[Array[Byte], Array[Byte]] = { + if (requestOffset >= range.untilOffset) { + // Processed all offsets in this partition. + finished = true + null + } else { + val r = consumer.get(requestOffset, range.untilOffset, pollTimeoutMs, failOnDataLoss) + if (r == null) { + // Losing some data. Skip the rest offsets in this partition. + finished = true + null + } else { + requestOffset = r.offset + 1 + r + } + } } + + override protected def close(): Unit = {} } } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 89e713f92df46..cd52fd93d10a4 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -17,8 +17,12 @@ package org.apache.spark.sql.kafka010 +import java.util.Properties +import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger +import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.util.Random import org.apache.kafka.clients.producer.RecordMetadata @@ -27,8 +31,9 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ +import org.apache.spark.sql.ForeachWriter import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.streaming.{ ProcessingTime, StreamTest } +import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.test.SharedSQLContext abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { @@ -202,7 +207,7 @@ class KafkaSourceSuite extends KafkaSourceTest { test("cannot stop Kafka stream") { val topic = newTopic() - testUtils.createTopic(newTopic(), partitions = 5) + testUtils.createTopic(topic, partitions = 5) testUtils.sendMessages(topic, (101 to 105).map { _.toString }.toArray) val reader = spark @@ -223,52 +228,85 @@ class KafkaSourceSuite extends KafkaSourceTest { ) } - test("assign from latest offsets") { - val topic = newTopic() - testFromLatestOffsets(topic, false, "assign" -> assignString(topic, 0 to 4)) - } + for (failOnDataLoss <- Seq(true, false)) { + test(s"assign from latest offsets (failOnDataLoss: $failOnDataLoss)") { + val topic = newTopic() + testFromLatestOffsets( + topic, + addPartitions = false, + failOnDataLoss = failOnDataLoss, + "assign" -> assignString(topic, 0 to 4)) + } - test("assign from earliest offsets") { - val topic = newTopic() - testFromEarliestOffsets(topic, false, "assign" -> assignString(topic, 0 to 4)) - } + test(s"assign from earliest offsets (failOnDataLoss: $failOnDataLoss)") { + val topic = newTopic() + testFromEarliestOffsets( + topic, + addPartitions = false, + failOnDataLoss = failOnDataLoss, + "assign" -> assignString(topic, 0 to 4)) + } - test("assign from specific offsets") { - val topic = newTopic() - testFromSpecificOffsets(topic, "assign" -> assignString(topic, 0 to 4)) - } + test(s"assign from specific offsets (failOnDataLoss: $failOnDataLoss)") { + val topic = newTopic() + testFromSpecificOffsets( + topic, + failOnDataLoss = failOnDataLoss, + "assign" -> assignString(topic, 0 to 4), + "failOnDataLoss" -> failOnDataLoss.toString) + } - test("subscribing topic by name from latest offsets") { - val topic = newTopic() - testFromLatestOffsets(topic, true, "subscribe" -> topic) - } + test(s"subscribing topic by name from latest offsets (failOnDataLoss: $failOnDataLoss)") { + val topic = newTopic() + testFromLatestOffsets( + topic, + addPartitions = true, + failOnDataLoss = failOnDataLoss, + "subscribe" -> topic) + } - test("subscribing topic by name from earliest offsets") { - val topic = newTopic() - testFromEarliestOffsets(topic, true, "subscribe" -> topic) - } + test(s"subscribing topic by name from earliest offsets (failOnDataLoss: $failOnDataLoss)") { + val topic = newTopic() + testFromEarliestOffsets( + topic, + addPartitions = true, + failOnDataLoss = failOnDataLoss, + "subscribe" -> topic) + } - test("subscribing topic by name from specific offsets") { - val topic = newTopic() - testFromSpecificOffsets(topic, "subscribe" -> topic) - } + test(s"subscribing topic by name from specific offsets (failOnDataLoss: $failOnDataLoss)") { + val topic = newTopic() + testFromSpecificOffsets(topic, failOnDataLoss = failOnDataLoss, "subscribe" -> topic) + } - test("subscribing topic by pattern from latest offsets") { - val topicPrefix = newTopic() - val topic = topicPrefix + "-suffix" - testFromLatestOffsets(topic, true, "subscribePattern" -> s"$topicPrefix-.*") - } + test(s"subscribing topic by pattern from latest offsets (failOnDataLoss: $failOnDataLoss)") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-suffix" + testFromLatestOffsets( + topic, + addPartitions = true, + failOnDataLoss = failOnDataLoss, + "subscribePattern" -> s"$topicPrefix-.*") + } - test("subscribing topic by pattern from earliest offsets") { - val topicPrefix = newTopic() - val topic = topicPrefix + "-suffix" - testFromEarliestOffsets(topic, true, "subscribePattern" -> s"$topicPrefix-.*") - } + test(s"subscribing topic by pattern from earliest offsets (failOnDataLoss: $failOnDataLoss)") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-suffix" + testFromEarliestOffsets( + topic, + addPartitions = true, + failOnDataLoss = failOnDataLoss, + "subscribePattern" -> s"$topicPrefix-.*") + } - test("subscribing topic by pattern from specific offsets") { - val topicPrefix = newTopic() - val topic = topicPrefix + "-suffix" - testFromSpecificOffsets(topic, "subscribePattern" -> s"$topicPrefix-.*") + test(s"subscribing topic by pattern from specific offsets (failOnDataLoss: $failOnDataLoss)") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-suffix" + testFromSpecificOffsets( + topic, + failOnDataLoss = failOnDataLoss, + "subscribePattern" -> s"$topicPrefix-.*") + } } test("subscribing topic by pattern with topic deletions") { @@ -413,13 +451,59 @@ class KafkaSourceSuite extends KafkaSourceTest { ) } + test("delete a topic when a Spark job is running") { + KafkaSourceSuite.collectedData.clear() + + val topic = newTopic() + testUtils.createTopic(topic, partitions = 1) + testUtils.sendMessages(topic, (1 to 10).map(_.toString).toArray) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribe", topic) + // If a topic is deleted and we try to poll data starting from offset 0, + // the Kafka consumer will just block until timeout and return an empty result. + // So set the timeout to 1 second to make this test fast. + .option("kafkaConsumer.pollTimeoutMs", "1000") + .option("startingOffsets", "earliest") + .option("failOnDataLoss", "false") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + KafkaSourceSuite.globalTestUtils = testUtils + // The following ForeachWriter will delete the topic before fetching data from Kafka + // in executors. + val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] { + override def open(partitionId: Long, version: Long): Boolean = { + KafkaSourceSuite.globalTestUtils.deleteTopic(topic) + true + } + + override def process(value: Int): Unit = { + KafkaSourceSuite.collectedData.add(value) + } + + override def close(errorOrNull: Throwable): Unit = {} + }).start() + query.processAllAvailable() + query.stop() + // `failOnDataLoss` is `false`, we should not fail the query + assert(query.exception.isEmpty) + } + private def newTopic(): String = s"topic-${topicId.getAndIncrement()}" private def assignString(topic: String, partitions: Iterable[Int]): String = { JsonUtils.partitions(partitions.map(p => new TopicPartition(topic, p))) } - private def testFromSpecificOffsets(topic: String, options: (String, String)*): Unit = { + private def testFromSpecificOffsets( + topic: String, + failOnDataLoss: Boolean, + options: (String, String)*): Unit = { val partitionOffsets = Map( new TopicPartition(topic, 0) -> -2L, new TopicPartition(topic, 1) -> -1L, @@ -448,6 +532,7 @@ class KafkaSourceSuite extends KafkaSourceTest { .option("startingOffsets", startingOffsets) .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") + .option("failOnDataLoss", failOnDataLoss.toString) options.foreach { case (k, v) => reader.option(k, v) } val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") @@ -469,6 +554,7 @@ class KafkaSourceSuite extends KafkaSourceTest { private def testFromLatestOffsets( topic: String, addPartitions: Boolean, + failOnDataLoss: Boolean, options: (String, String)*): Unit = { testUtils.createTopic(topic, partitions = 5) testUtils.sendMessages(topic, Array("-1")) @@ -480,6 +566,7 @@ class KafkaSourceSuite extends KafkaSourceTest { .option("startingOffsets", s"latest") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") + .option("failOnDataLoss", failOnDataLoss.toString) options.foreach { case (k, v) => reader.option(k, v) } val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") @@ -513,6 +600,7 @@ class KafkaSourceSuite extends KafkaSourceTest { private def testFromEarliestOffsets( topic: String, addPartitions: Boolean, + failOnDataLoss: Boolean, options: (String, String)*): Unit = { testUtils.createTopic(topic, partitions = 5) testUtils.sendMessages(topic, (1 to 3).map { _.toString }.toArray) @@ -524,6 +612,7 @@ class KafkaSourceSuite extends KafkaSourceTest { .option("startingOffsets", s"earliest") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") + .option("failOnDataLoss", failOnDataLoss.toString) options.foreach { case (k, v) => reader.option(k, v) } val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") @@ -552,6 +641,11 @@ class KafkaSourceSuite extends KafkaSourceTest { } } +object KafkaSourceSuite { + @volatile var globalTestUtils: KafkaTestUtils = _ + val collectedData = new ConcurrentLinkedQueue[Any]() +} + class KafkaSourceStressSuite extends KafkaSourceTest { @@ -615,7 +709,7 @@ class KafkaSourceStressSuite extends KafkaSourceTest { } }) case 2 => // Add new partitions - AddKafkaData(topics.toSet, d: _*)(message = "Add partitiosn", + AddKafkaData(topics.toSet, d: _*)(message = "Add partition", topicAction = (topic, partition) => { testUtils.addPartitions(topic, partition.get + nextInt(1, 6)) }) @@ -626,3 +720,122 @@ class KafkaSourceStressSuite extends KafkaSourceTest { iterations = 50) } } + +class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with SharedSQLContext { + + import testImplicits._ + + private var testUtils: KafkaTestUtils = _ + + private val topicId = new AtomicInteger(0) + + private def newTopic(): String = s"failOnDataLoss-${topicId.getAndIncrement()}" + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils { + override def brokerConfiguration: Properties = { + val props = super.brokerConfiguration + // Try to make Kafka clean up messages as fast as possible. However, there is a hard-code + // 30 seconds delay (kafka.log.LogManager.InitialTaskDelayMs) so this test should run at + // least 30 seconds. + props.put("log.cleaner.backoff.ms", "100") + props.put("log.segment.bytes", "40") + props.put("log.retention.bytes", "40") + props.put("log.retention.check.interval.ms", "100") + props.put("delete.retention.ms", "10") + props.put("log.flush.scheduler.interval.ms", "10") + props + } + } + testUtils.setup() + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + super.afterAll() + } + } + + test("stress test for failOnDataLoss=false") { + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", "failOnDataLoss.*") + .option("startingOffsets", "earliest") + .option("failOnDataLoss", "false") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] { + + override def open(partitionId: Long, version: Long): Boolean = { + true + } + + override def process(value: Int): Unit = { + // Slow down the processing speed so that messages may be aged out. + Thread.sleep(Random.nextInt(500)) + } + + override def close(errorOrNull: Throwable): Unit = { + } + }).start() + + val testTime = 1.minutes + val startTime = System.currentTimeMillis() + // Track the current existing topics + val topics = mutable.ArrayBuffer[String]() + // Track topics that have been deleted + val deletedTopics = mutable.Set[String]() + while (System.currentTimeMillis() - testTime.toMillis < startTime) { + Random.nextInt(10) match { + case 0 => // Create a new topic + val topic = newTopic() + topics += topic + // As pushing messages into Kafka updates Zookeeper asynchronously, there is a small + // chance that a topic will be recreated after deletion due to the asynchronous update. + // Hence, always overwrite to handle this race condition. + testUtils.createTopic(topic, partitions = 1, overwrite = true) + logInfo(s"Create topic $topic") + case 1 if topics.nonEmpty => // Delete an existing topic + val topic = topics.remove(Random.nextInt(topics.size)) + testUtils.deleteTopic(topic) + logInfo(s"Delete topic $topic") + deletedTopics += topic + case 2 if deletedTopics.nonEmpty => // Recreate a topic that was deleted. + val topic = deletedTopics.toSeq(Random.nextInt(deletedTopics.size)) + deletedTopics -= topic + topics += topic + // As pushing messages into Kafka updates Zookeeper asynchronously, there is a small + // chance that a topic will be recreated after deletion due to the asynchronous update. + // Hence, always overwrite to handle this race condition. + testUtils.createTopic(topic, partitions = 1, overwrite = true) + logInfo(s"Create topic $topic") + case 3 => + Thread.sleep(1000) + case _ => // Push random messages + for (topic <- topics) { + val size = Random.nextInt(10) + for (_ <- 0 until size) { + testUtils.sendMessages(topic, Array(Random.nextInt(10).toString)) + } + } + } + // `failOnDataLoss` is `false`, we should not fail the query + if (query.exception.nonEmpty) { + throw query.exception.get + } + } + + query.stop() + // `failOnDataLoss` is `false`, we should not fail the query + if (query.exception.nonEmpty) { + throw query.exception.get + } + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index 9b24ccdd560e8..f43917e151c57 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -155,8 +155,16 @@ class KafkaTestUtils extends Logging { } /** Create a Kafka topic and wait until it is propagated to the whole cluster */ - def createTopic(topic: String, partitions: Int): Unit = { - AdminUtils.createTopic(zkUtils, topic, partitions, 1) + def createTopic(topic: String, partitions: Int, overwrite: Boolean = false): Unit = { + var created = false + while (!created) { + try { + AdminUtils.createTopic(zkUtils, topic, partitions, 1) + created = true + } catch { + case e: kafka.common.TopicExistsException if overwrite => deleteTopic(topic) + } + } // wait until metadata is propagated (0 until partitions).foreach { p => waitUntilMetadataIsPropagated(topic, p) @@ -244,7 +252,7 @@ class KafkaTestUtils extends Logging { offsets } - private def brokerConfiguration: Properties = { + protected def brokerConfiguration: Properties = { val props = new Properties() props.put("broker.id", "0") props.put("host.name", "localhost") @@ -302,9 +310,11 @@ class KafkaTestUtils extends Logging { } checkpoints.forall(checkpointsPerLogDir => !checkpointsPerLogDir.contains(tp)) }) - deletePath && topicPath && replicaManager && logManager && cleaner + // ensure the topic is gone + val deleted = !zkUtils.getAllTopics().contains(topic) + deletePath && topicPath && replicaManager && logManager && cleaner && deleted } - eventually(timeout(10.seconds)) { + eventually(timeout(60.seconds)) { assert(isDeleted, s"$topic not deleted after timeout") } } From 64b9de9c079672eff49dc38e55749d9a26c743a6 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 22 Nov 2016 15:10:49 -0800 Subject: [PATCH 174/534] [SPARK-16803][SQL] SaveAsTable does not work when target table is a Hive serde table ### What changes were proposed in this pull request? In Spark 2.0, `SaveAsTable` does not work when the target table is a Hive serde table, but Spark 1.6 works. **Spark 1.6** ``` Scala scala> sql("create table sample.sample stored as SEQUENCEFILE as select 1 as key, 'abc' as value") res2: org.apache.spark.sql.DataFrame = [] scala> val df = sql("select key, value as value from sample.sample") df: org.apache.spark.sql.DataFrame = [key: int, value: string] scala> df.write.mode("append").saveAsTable("sample.sample") scala> sql("select * from sample.sample").show() +---+-----+ |key|value| +---+-----+ | 1| abc| | 1| abc| +---+-----+ ``` **Spark 2.0** ``` Scala scala> df.write.mode("append").saveAsTable("sample.sample") org.apache.spark.sql.AnalysisException: Saving data in MetastoreRelation sample, sample is not supported.; ``` So far, we do not plan to support it in Spark 2.1 due to the risk. Spark 1.6 works because it internally uses insertInto. But, if we change it back it will break the semantic of saveAsTable (this method uses by-name resolution instead of using by-position resolution used by insertInto). More extra changes are needed to support `hive` as a `format` in DataFrameWriter. Instead, users should use insertInto API. This PR corrects the error messages. Users can understand how to bypass it before we support it in a separate PR. ### How was this patch tested? Test cases are added Author: gatorsmile Closes #15926 from gatorsmile/saveAsTableFix5. (cherry picked from commit 9c42d4a76ca8046fcca2e20067f2aa461977e65a) Signed-off-by: gatorsmile --- .../command/createDataSourceTables.scala | 4 ++++ .../sql/hive/MetastoreDataSourcesSuite.scala | 20 +++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 7e16e43f2bb0e..add732c1afc16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -175,6 +175,10 @@ case class CreateDataSourceTableAsSelectCommand( existingSchema = Some(l.schema) case s: SimpleCatalogRelation if DDLUtils.isDatasourceTable(s.metadata) => existingSchema = Some(s.metadata.schema) + case c: CatalogRelation if c.catalogTable.provider == Some(DDLUtils.HIVE_PROVIDER) => + throw new AnalysisException("Saving data in the Hive serde table " + + s"${c.catalogTable.identifier} is not supported yet. Please use the " + + "insertInto() API as an alternative..") case o => throw new AnalysisException(s"Saving data in ${o.toString} is not supported.") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 4ab1a54edc46d..c7cc75fbc8a07 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -413,6 +413,26 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } } + test("saveAsTable(CTAS) using append and insertInto when the target table is Hive serde") { + val tableName = "tab1" + withTable(tableName) { + sql(s"CREATE TABLE $tableName STORED AS SEQUENCEFILE AS SELECT 1 AS key, 'abc' AS value") + + val df = sql(s"SELECT key, value FROM $tableName") + val e = intercept[AnalysisException] { + df.write.mode(SaveMode.Append).saveAsTable(tableName) + }.getMessage + assert(e.contains("Saving data in the Hive serde table `default`.`tab1` is not supported " + + "yet. Please use the insertInto() API as an alternative.")) + + df.write.insertInto(tableName) + checkAnswer( + sql(s"SELECT * FROM $tableName"), + Row(1, "abc") :: Row(1, "abc") :: Nil + ) + } + } + test("SPARK-5839 HiveMetastoreCatalog does not recognize table aliases of data source tables.") { withTable("savedJsonTable") { // Save the df as a managed table (by not specifying the path). From 4b96ffb13a5171ef422aed955fd6b50354ae4253 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Tue, 22 Nov 2016 15:57:07 -0800 Subject: [PATCH 175/534] [SPARK-18533] Raise correct error upon specification of schema for datasource tables created using CTAS ## What changes were proposed in this pull request? Fixes the inconsistency of error raised between data source and hive serde tables when schema is specified in CTAS scenario. In the process the grammar for create table (datasource) is simplified. **before:** ``` SQL spark-sql> create table t2 (c1 int, c2 int) using parquet as select * from t1; Error in query: mismatched input 'as' expecting {, '.', 'OPTIONS', 'CLUSTERED', 'PARTITIONED'}(line 1, pos 64) == SQL == create table t2 (c1 int, c2 int) using parquet as select * from t1 ----------------------------------------------------------------^^^ ``` **After:** ```SQL spark-sql> create table t2 (c1 int, c2 int) using parquet as select * from t1 > ; Error in query: Operation not allowed: Schema may not be specified in a Create Table As Select (CTAS) statement(line 1, pos 0) == SQL == create table t2 (c1 int, c2 int) using parquet as select * from t1 ^^^ ``` ## How was this patch tested? Added a new test in CreateTableAsSelectSuite Author: Dilip Biswal Closes #15968 from dilipbiswal/ctas. (cherry picked from commit 39a1d30636857715247c82d551b200e1c331ad69) Signed-off-by: gatorsmile --- .../spark/sql/catalyst/parser/SqlBase.g4 | 6 +---- .../spark/sql/execution/SparkSqlParser.scala | 24 +++++++++++++++++-- .../sources/CreateTableAsSelectSuite.scala | 9 +++++++ 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 0aa2a97407c53..df85c70c6cdea 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -71,11 +71,7 @@ statement | createTableHeader ('(' colTypeList ')')? tableProvider (OPTIONS tablePropertyList)? (PARTITIONED BY partitionColumnNames=identifierList)? - bucketSpec? #createTableUsing - | createTableHeader tableProvider - (OPTIONS tablePropertyList)? - (PARTITIONED BY partitionColumnNames=identifierList)? - bucketSpec? AS? query #createTableUsing + bucketSpec? (AS? query)? #createTableUsing | createTableHeader ('(' columns=colTypeList ')')? (COMMENT STRING)? (PARTITIONED BY '(' partitionColumns=colTypeList ')')? diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 47610453ac23a..5f89a229d6242 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -322,7 +322,20 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { } /** - * Create a [[CreateTable]] logical plan. + * Create a data source table, returning a [[CreateTable]] logical plan. + * + * Expected format: + * {{{ + * CREATE [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name + * USING table_provider + * [OPTIONS table_property_list] + * [PARTITIONED BY (col_name, col_name, ...)] + * [CLUSTERED BY (col_name, col_name, ...) + * [SORTED BY (col_name [ASC|DESC], ...)] + * INTO num_buckets BUCKETS + * ] + * [AS select_statement]; + * }}} */ override def visitCreateTableUsing(ctx: CreateTableUsingContext): LogicalPlan = withOrigin(ctx) { val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) @@ -371,6 +384,12 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { operationNotAllowed("CREATE TEMPORARY TABLE ... USING ... AS query", ctx) } + // Don't allow explicit specification of schema for CTAS + if (schema.nonEmpty) { + operationNotAllowed( + "Schema may not be specified in a Create Table As Select (CTAS) statement", + ctx) + } CreateTable(tableDesc, mode, Some(query)) } else { if (temp) { @@ -1052,7 +1071,8 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { "CTAS statement." operationNotAllowed(errorMessage, ctx) } - // Just use whatever is projected in the select statement as our schema + + // Don't allow explicit specification of schema for CTAS. if (schema.nonEmpty) { operationNotAllowed( "Schema may not be specified in a Create Table As Select (CTAS) statement", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 5cc9467395adc..61939fe5ef5b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -249,4 +249,13 @@ class CreateTableAsSelectSuite } } } + + test("specifying the column list for CTAS") { + withTable("t") { + val e = intercept[ParseException] { + sql("CREATE TABLE t (a int, b int) USING parquet AS SELECT 1, 2") + }.getMessage + assert(e.contains("Schema may not be specified in a Create Table As Select (CTAS)")) + } + } } From 3be2d1e0b52bf15ac28a9f96b03ae048e680b035 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 22 Nov 2016 16:49:15 -0800 Subject: [PATCH 176/534] [SPARK-18530][SS][KAFKA] Change Kafka timestamp column type to TimestampType ## What changes were proposed in this pull request? Changed Kafka timestamp column type to TimestampType. ## How was this patch tested? `test("Kafka column types")`. Author: Shixiong Zhu Closes #15969 from zsxwing/SPARK-18530. (cherry picked from commit d0212eb0f22473ee5482fe98dafc24e16ffcfc63) Signed-off-by: Shixiong Zhu --- .../spark/sql/kafka010/KafkaSource.scala | 16 +++- .../spark/sql/kafka010/KafkaSourceSuite.scala | 81 ++++++++++++++++++- 2 files changed, 93 insertions(+), 4 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 1d0d402b82a35..d9ab4bb4f873d 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -32,9 +32,12 @@ import org.apache.spark.SparkContext import org.apache.spark.internal.Logging import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.kafka010.KafkaSource._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.UninterruptibleThread /** @@ -282,7 +285,14 @@ private[kafka010] case class KafkaSource( // Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays. val rdd = new KafkaSourceRDD( sc, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss).map { cr => - Row(cr.key, cr.value, cr.topic, cr.partition, cr.offset, cr.timestamp, cr.timestampType.id) + InternalRow( + cr.key, + cr.value, + UTF8String.fromString(cr.topic), + cr.partition, + cr.offset, + DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(cr.timestamp)), + cr.timestampType.id) } logInfo("GetBatch generating RDD of offset range: " + @@ -293,7 +303,7 @@ private[kafka010] case class KafkaSource( currentPartitionOffsets = Some(untilPartitionOffsets) } - sqlContext.createDataFrame(rdd, schema) + sqlContext.internalCreateDataFrame(rdd, schema) } /** Stop this source and free any resources it has allocated. */ @@ -496,7 +506,7 @@ private[kafka010] object KafkaSource { StructField("topic", StringType), StructField("partition", IntegerType), StructField("offset", LongType), - StructField("timestamp", LongType), + StructField("timestamp", TimestampType), StructField("timestampType", IntegerType) )) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index cd52fd93d10a4..f9f62581a3066 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.kafka010 +import java.nio.charset.StandardCharsets.UTF_8 import java.util.Properties import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger -import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Random @@ -33,6 +33,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.sql.ForeachWriter import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.test.SharedSQLContext @@ -551,6 +552,84 @@ class KafkaSourceSuite extends KafkaSourceTest { ) } + test("Kafka column types") { + val now = System.currentTimeMillis() + val topic = newTopic() + testUtils.createTopic(newTopic(), partitions = 1) + testUtils.sendMessages(topic, Array(1).map(_.toString)) + + val kafka = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("startingOffsets", s"earliest") + .option("subscribe", topic) + .load() + + val query = kafka + .writeStream + .format("memory") + .outputMode("append") + .queryName("kafkaColumnTypes") + .start() + query.processAllAvailable() + val rows = spark.table("kafkaColumnTypes").collect() + assert(rows.length === 1, s"Unexpected results: ${rows.toList}") + val row = rows(0) + assert(row.getAs[Array[Byte]]("key") === null, s"Unexpected results: $row") + assert(row.getAs[Array[Byte]]("value") === "1".getBytes(UTF_8), s"Unexpected results: $row") + assert(row.getAs[String]("topic") === topic, s"Unexpected results: $row") + assert(row.getAs[Int]("partition") === 0, s"Unexpected results: $row") + assert(row.getAs[Long]("offset") === 0L, s"Unexpected results: $row") + // We cannot check the exact timestamp as it's the time that messages were inserted by the + // producer. So here we just use a low bound to make sure the internal conversion works. + assert(row.getAs[java.sql.Timestamp]("timestamp").getTime >= now, s"Unexpected results: $row") + assert(row.getAs[Int]("timestampType") === 0, s"Unexpected results: $row") + query.stop() + } + + test("KafkaSource with watermark") { + val now = System.currentTimeMillis() + val topic = newTopic() + testUtils.createTopic(newTopic(), partitions = 1) + testUtils.sendMessages(topic, Array(1).map(_.toString)) + + val kafka = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("startingOffsets", s"earliest") + .option("subscribe", topic) + .load() + + val windowedAggregation = kafka + .withWatermark("timestamp", "10 seconds") + .groupBy(window($"timestamp", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start") as 'window, $"count") + + val query = windowedAggregation + .writeStream + .format("memory") + .outputMode("complete") + .queryName("kafkaWatermark") + .start() + query.processAllAvailable() + val rows = spark.table("kafkaWatermark").collect() + assert(rows.length === 1, s"Unexpected results: ${rows.toList}") + val row = rows(0) + // We cannot check the exact window start time as it depands on the time that messages were + // inserted by the producer. So here we just use a low bound to make sure the internal + // conversion works. + assert( + row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000, + s"Unexpected results: $row") + assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row") + query.stop() + } + private def testFromLatestOffsets( topic: String, addPartitions: Boolean, From fc5fee83e363bc6df22459a9b1ba2ba11bfdfa20 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 22 Nov 2016 19:17:48 -0800 Subject: [PATCH 177/534] [SPARK-18501][ML][SPARKR] Fix spark.glm errors when fitting on collinear data ## What changes were proposed in this pull request? * Fix SparkR ```spark.glm``` errors when fitting on collinear data, since ```standard error of coefficients, t value and p value``` are not available in this condition. * Scala/Python GLM summary should throw exception if users get ```standard error of coefficients, t value and p value``` but the underlying WLS was solved by local "l-bfgs". ## How was this patch tested? Add unit tests. Author: Yanbo Liang Closes #15930 from yanboliang/spark-18501. (cherry picked from commit 982b82e32e0fc7d30c5d557944a79eb3e6d2da59) Signed-off-by: Yanbo Liang --- R/pkg/R/mllib.R | 21 ++++++-- R/pkg/inst/tests/testthat/test_mllib.R | 9 ++++ .../GeneralizedLinearRegressionWrapper.scala | 54 +++++++++++-------- .../GeneralizedLinearRegression.scala | 46 +++++++++++++--- .../GeneralizedLinearRegressionSuite.scala | 21 ++++++++ 5 files changed, 115 insertions(+), 36 deletions(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 265e64e7466fa..02bc6456de4d0 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -278,8 +278,10 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDat #' @param object a fitted generalized linear model. #' @return \code{summary} returns a summary object of the fitted model, a list of components -#' including at least the coefficients, null/residual deviance, null/residual degrees -#' of freedom, AIC and number of iterations IRLS takes. +#' including at least the coefficients matrix (which includes coefficients, standard error +#' of coefficients, t value and p value), null/residual deviance, null/residual degrees of +#' freedom, AIC and number of iterations IRLS takes. If there are collinear columns +#' in you data, the coefficients matrix only provides coefficients. #' #' @rdname spark.glm #' @export @@ -303,9 +305,18 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), } else { dataFrame(callJMethod(jobj, "rDevianceResiduals")) } - coefficients <- matrix(coefficients, ncol = 4) - colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)") - rownames(coefficients) <- unlist(features) + # If the underlying WeightedLeastSquares using "normal" solver, we can provide + # coefficients, standard error of coefficients, t value and p value. Otherwise, + # it will be fitted by local "l-bfgs", we can only provide coefficients. + if (length(features) == length(coefficients)) { + coefficients <- matrix(coefficients, ncol = 1) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) + } else { + coefficients <- matrix(coefficients, ncol = 4) + colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)") + rownames(coefficients) <- unlist(features) + } ans <- list(deviance.resid = deviance.resid, coefficients = coefficients, dispersion = dispersion, null.deviance = null.deviance, deviance = deviance, df.null = df.null, df.residual = df.residual, diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 70a033de5308e..b05be476a3fa8 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -169,6 +169,15 @@ test_that("spark.glm summary", { df <- suppressWarnings(createDataFrame(data)) regStats <- summary(spark.glm(df, b ~ a1 + a2, regParam = 1.0)) expect_equal(regStats$aic, 14.00976, tolerance = 1e-4) # 14.00976 is from summary() result + + # Test spark.glm works on collinear data + A <- matrix(c(1, 2, 3, 4, 2, 4, 6, 8), 4, 2) + b <- c(1, 2, 3, 4) + data <- as.data.frame(cbind(A, b)) + df <- createDataFrame(data) + stats <- summary(spark.glm(df, b ~ . - 1)) + coefs <- unlist(stats$coefficients) + expect_true(all(abs(c(0.5, 0.25) - coefs) < 1e-4)) }) test_that("spark.glm save/load", { diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index add4d49110d16..8bcc9fe5d1b85 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -144,30 +144,38 @@ private[r] object GeneralizedLinearRegressionWrapper features } - val rCoefficientStandardErrors = if (glm.getFitIntercept) { - Array(summary.coefficientStandardErrors.last) ++ - summary.coefficientStandardErrors.dropRight(1) + val rCoefficients: Array[Double] = if (summary.isNormalSolver) { + val rCoefficientStandardErrors = if (glm.getFitIntercept) { + Array(summary.coefficientStandardErrors.last) ++ + summary.coefficientStandardErrors.dropRight(1) + } else { + summary.coefficientStandardErrors + } + + val rTValues = if (glm.getFitIntercept) { + Array(summary.tValues.last) ++ summary.tValues.dropRight(1) + } else { + summary.tValues + } + + val rPValues = if (glm.getFitIntercept) { + Array(summary.pValues.last) ++ summary.pValues.dropRight(1) + } else { + summary.pValues + } + + if (glm.getFitIntercept) { + Array(glm.intercept) ++ glm.coefficients.toArray ++ + rCoefficientStandardErrors ++ rTValues ++ rPValues + } else { + glm.coefficients.toArray ++ rCoefficientStandardErrors ++ rTValues ++ rPValues + } } else { - summary.coefficientStandardErrors - } - - val rTValues = if (glm.getFitIntercept) { - Array(summary.tValues.last) ++ summary.tValues.dropRight(1) - } else { - summary.tValues - } - - val rPValues = if (glm.getFitIntercept) { - Array(summary.pValues.last) ++ summary.pValues.dropRight(1) - } else { - summary.pValues - } - - val rCoefficients: Array[Double] = if (glm.getFitIntercept) { - Array(glm.intercept) ++ glm.coefficients.toArray ++ - rCoefficientStandardErrors ++ rTValues ++ rPValues - } else { - glm.coefficients.toArray ++ rCoefficientStandardErrors ++ rTValues ++ rPValues + if (glm.getFitIntercept) { + Array(glm.intercept) ++ glm.coefficients.toArray + } else { + glm.coefficients.toArray + } } val rDispersion: Double = summary.dispersion diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 3f9de1fe74c9c..f33dd0fd294ba 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -1063,45 +1063,75 @@ class GeneralizedLinearRegressionTrainingSummary private[regression] ( import GeneralizedLinearRegression._ + /** + * Whether the underlying [[WeightedLeastSquares]] using the "normal" solver. + */ + private[ml] val isNormalSolver: Boolean = { + diagInvAtWA.length != 1 || diagInvAtWA(0) != 0 + } + /** * Standard error of estimated coefficients and intercept. + * This value is only available when the underlying [[WeightedLeastSquares]] + * using the "normal" solver. * * If [[GeneralizedLinearRegression.fitIntercept]] is set to true, * then the last element returned corresponds to the intercept. */ @Since("2.0.0") lazy val coefficientStandardErrors: Array[Double] = { - diagInvAtWA.map(_ * dispersion).map(math.sqrt) + if (isNormalSolver) { + diagInvAtWA.map(_ * dispersion).map(math.sqrt) + } else { + throw new UnsupportedOperationException( + "No Std. Error of coefficients available for this GeneralizedLinearRegressionModel") + } } /** * T-statistic of estimated coefficients and intercept. + * This value is only available when the underlying [[WeightedLeastSquares]] + * using the "normal" solver. * * If [[GeneralizedLinearRegression.fitIntercept]] is set to true, * then the last element returned corresponds to the intercept. */ @Since("2.0.0") lazy val tValues: Array[Double] = { - val estimate = if (model.getFitIntercept) { - Array.concat(model.coefficients.toArray, Array(model.intercept)) + if (isNormalSolver) { + val estimate = if (model.getFitIntercept) { + Array.concat(model.coefficients.toArray, Array(model.intercept)) + } else { + model.coefficients.toArray + } + estimate.zip(coefficientStandardErrors).map { x => x._1 / x._2 } } else { - model.coefficients.toArray + throw new UnsupportedOperationException( + "No t-statistic available for this GeneralizedLinearRegressionModel") } - estimate.zip(coefficientStandardErrors).map { x => x._1 / x._2 } } /** * Two-sided p-value of estimated coefficients and intercept. + * This value is only available when the underlying [[WeightedLeastSquares]] + * using the "normal" solver. * * If [[GeneralizedLinearRegression.fitIntercept]] is set to true, * then the last element returned corresponds to the intercept. */ @Since("2.0.0") lazy val pValues: Array[Double] = { - if (model.getFamily == Binomial.name || model.getFamily == Poisson.name) { - tValues.map { x => 2.0 * (1.0 - dist.Gaussian(0.0, 1.0).cdf(math.abs(x))) } + if (isNormalSolver) { + if (model.getFamily == Binomial.name || model.getFamily == Poisson.name) { + tValues.map { x => 2.0 * (1.0 - dist.Gaussian(0.0, 1.0).cdf(math.abs(x))) } + } else { + tValues.map { x => + 2.0 * (1.0 - dist.StudentsT(degreesOfFreedom.toDouble).cdf(math.abs(x))) + } + } } else { - tValues.map { x => 2.0 * (1.0 - dist.StudentsT(degreesOfFreedom.toDouble).cdf(math.abs(x))) } + throw new UnsupportedOperationException( + "No p-value available for this GeneralizedLinearRegressionModel") } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 9b0fa67630d2e..4fab2160339c6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -1048,6 +1048,27 @@ class GeneralizedLinearRegressionSuite assert(summary.solver === "irls") } + test("glm handle collinear features") { + val collinearInstances = Seq( + Instance(1.0, 1.0, Vectors.dense(1.0, 2.0)), + Instance(2.0, 1.0, Vectors.dense(2.0, 4.0)), + Instance(3.0, 1.0, Vectors.dense(3.0, 6.0)), + Instance(4.0, 1.0, Vectors.dense(4.0, 8.0)) + ).toDF() + val trainer = new GeneralizedLinearRegression() + val model = trainer.fit(collinearInstances) + // to make it clear that underlying WLS did not solve analytically + intercept[UnsupportedOperationException] { + model.summary.coefficientStandardErrors + } + intercept[UnsupportedOperationException] { + model.summary.pValues + } + intercept[UnsupportedOperationException] { + model.summary.tValues + } + } + test("read/write") { def checkModelData( model: GeneralizedLinearRegressionModel, From fabb5aeaf62e5c18d5d489e769e998e52379ba20 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 22 Nov 2016 22:25:27 -0800 Subject: [PATCH 178/534] [SPARK-18179][SQL] Throws analysis exception with a proper message for unsupported argument types in reflect/java_method function ## What changes were proposed in this pull request? This PR proposes throwing an `AnalysisException` with a proper message rather than `NoSuchElementException` with the message ` key not found: TimestampType` when unsupported types are given to `reflect` and `java_method` functions. ```scala spark.range(1).selectExpr("reflect('java.lang.String', 'valueOf', cast('1990-01-01' as timestamp))") ``` produces **Before** ``` java.util.NoSuchElementException: key not found: TimestampType at scala.collection.MapLike$class.default(MapLike.scala:228) at scala.collection.AbstractMap.default(Map.scala:59) at scala.collection.MapLike$class.apply(MapLike.scala:141) at scala.collection.AbstractMap.apply(Map.scala:59) at org.apache.spark.sql.catalyst.expressions.CallMethodViaReflection$$anonfun$findMethod$1$$anonfun$apply$1.apply(CallMethodViaReflection.scala:159) ... ``` **After** ``` cannot resolve 'reflect('java.lang.String', 'valueOf', CAST('1990-01-01' AS TIMESTAMP))' due to data type mismatch: arguments from the third require boolean, byte, short, integer, long, float, double or string expressions; line 1 pos 0; 'Project [unresolvedalias(reflect(java.lang.String, valueOf, cast(1990-01-01 as timestamp)), Some())] +- Range (0, 1, step=1, splits=Some(2)) ... ``` Added message is, ``` arguments from the third require boolean, byte, short, integer, long, float, double or string expressions ``` ## How was this patch tested? Tests added in `CallMethodViaReflection`. Author: hyukjinkwon Closes #15694 from HyukjinKwon/SPARK-18179. (cherry picked from commit 2559fb4b40c9f42f7b3ed2b77de14461f68b6fa5) Signed-off-by: Reynold Xin --- .../catalyst/expressions/CallMethodViaReflection.scala | 4 ++++ .../expressions/CallMethodViaReflectionSuite.scala | 9 +++++++++ 2 files changed, 13 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala index 40f1b148f9287..4859e0c537610 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala @@ -65,6 +65,10 @@ case class CallMethodViaReflection(children: Seq[Expression]) TypeCheckFailure("first two arguments should be string literals") } else if (!classExists) { TypeCheckFailure(s"class $className not found") + } else if (children.slice(2, children.length) + .exists(e => !CallMethodViaReflection.typeMapping.contains(e.dataType))) { + TypeCheckFailure("arguments from the third require boolean, byte, short, " + + "integer, long, float, double or string expressions") } else if (method == null) { TypeCheckFailure(s"cannot find a static method that matches the argument types in $className") } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala index 43367c7e14c34..88d4d460751b6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.Timestamp + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.types.{IntegerType, StringType} @@ -85,6 +87,13 @@ class CallMethodViaReflectionSuite extends SparkFunSuite with ExpressionEvalHelp assert(createExpr(staticClassName, "method1").checkInputDataTypes().isSuccess) } + test("unsupported type checking") { + val ret = createExpr(staticClassName, "method1", new Timestamp(1)).checkInputDataTypes() + assert(ret.isFailure) + val errorMsg = ret.asInstanceOf[TypeCheckFailure].message + assert(errorMsg.contains("arguments from the third require boolean, byte, short")) + } + test("invoking methods using acceptable types") { checkEvaluation(createExpr(staticClassName, "method1"), "m1") checkEvaluation(createExpr(staticClassName, "method2", 2), "m2") From 5f198d200d47703f6ab770e592c0a1d9f8d7b0dc Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 23 Nov 2016 11:25:47 +0000 Subject: [PATCH 179/534] [SPARK-18073][DOCS][WIP] Migrate wiki to spark.apache.org web site ## What changes were proposed in this pull request? Updates links to the wiki to links to the new location of content on spark.apache.org. ## How was this patch tested? Doc builds Author: Sean Owen Closes #15967 from srowen/SPARK-18073.1. (cherry picked from commit 7e0cd1d9b168286386f15e9b55988733476ae2bb) Signed-off-by: Sean Owen --- .github/PULL_REQUEST_TEMPLATE | 2 +- CONTRIBUTING.md | 4 ++-- R/README.md | 2 +- R/pkg/DESCRIPTION | 2 +- README.md | 11 ++++++----- dev/checkstyle.xml | 2 +- docs/_layouts/global.html | 4 ++-- docs/building-spark.md | 4 ++-- docs/contributing-to-spark.md | 2 +- docs/index.md | 4 ++-- docs/sparkr.md | 2 +- docs/streaming-programming-guide.md | 2 +- .../spark/sql/execution/datasources/DataSource.scala | 5 ++--- 13 files changed, 23 insertions(+), 23 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE b/.github/PULL_REQUEST_TEMPLATE index 0e41cf1826453..5af45d6fa7988 100644 --- a/.github/PULL_REQUEST_TEMPLATE +++ b/.github/PULL_REQUEST_TEMPLATE @@ -7,4 +7,4 @@ (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) -Please review https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark before opening a pull request. +Please review http://spark.apache.org/contributing.html before opening a pull request. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1a8206abe3838..8fdd5aa9e7dfb 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,12 +1,12 @@ ## Contributing to Spark *Before opening a pull request*, review the -[Contributing to Spark wiki](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark). +[Contributing to Spark guide](http://spark.apache.org/contributing.html). It lists steps that are required before creating a PR. In particular, consider: - Is the change important and ready enough to ask the community to spend time reviewing? - Have you searched for existing, related JIRAs and pull requests? -- Is this a new feature that can stand alone as a [third party project](https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects) ? +- Is this a new feature that can stand alone as a [third party project](http://spark.apache.org/third-party-projects.html) ? - Is the change being proposed clearly explained and motivated? When you contribute code, you affirm that the contribution is your original work and that you diff --git a/R/README.md b/R/README.md index 47f9a86dfde11..4c40c5963db70 100644 --- a/R/README.md +++ b/R/README.md @@ -51,7 +51,7 @@ sparkR.session() #### Making changes to SparkR -The [instructions](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) for making contributions to Spark also apply to SparkR. +The [instructions](http://spark.apache.org/contributing.html) for making contributions to Spark also apply to SparkR. If you only make R file changes (i.e. no Scala changes) then you can just re-install the R package using `R/install-dev.sh` and test your changes. Once you have made your changes, please include unit tests for them and run existing unit tests using the `R/run-tests.sh` script as described below. diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index fe41a9e7dabbd..981ae1246476b 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -11,7 +11,7 @@ Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), email = "felixcheung@apache.org"), person(family = "The Apache Software Foundation", role = c("aut", "cph"))) URL: http://www.apache.org/ http://spark.apache.org/ -BugReports: https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark#ContributingtoSpark-ContributingBugReports +BugReports: http://spark.apache.org/contributing.html Depends: R (>= 3.0), methods diff --git a/README.md b/README.md index dd7d0e22495b3..853f7f5ded3cb 100644 --- a/README.md +++ b/README.md @@ -29,8 +29,9 @@ To build Spark and its example programs, run: You can build Spark using more than one thread by using the -T option with Maven, see ["Parallel builds in Maven 3"](https://cwiki.apache.org/confluence/display/MAVEN/Parallel+builds+in+Maven+3). More detailed documentation is available from the project site, at ["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html). -For developing Spark using an IDE, see [Eclipse](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-Eclipse) -and [IntelliJ](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-IntelliJ). + +For general development tips, including info on developing Spark using an IDE, see +[http://spark.apache.org/developer-tools.html](the Useful Developer Tools page). ## Interactive Scala Shell @@ -80,7 +81,7 @@ can be run using: ./dev/run-tests Please see the guidance on how to -[run tests for a module, or individual tests](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools). +[run tests for a module, or individual tests](http://spark.apache.org/developer-tools.html#individual-tests). ## A Note About Hadoop Versions @@ -100,5 +101,5 @@ in the online documentation for an overview on how to configure Spark. ## Contributing -Please review the [Contribution to Spark](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) -wiki for information on how to get started contributing to the project. +Please review the [Contribution to Spark guide](http://spark.apache.org/contributing.html) +for information on how to get started contributing to the project. diff --git a/dev/checkstyle.xml b/dev/checkstyle.xml index 92c5251c85037..fd73ca73ee7ef 100644 --- a/dev/checkstyle.xml +++ b/dev/checkstyle.xml @@ -28,7 +28,7 @@ with Spark-specific changes from: - https://cwiki.apache.org/confluence/display/SPARK/Spark+Code+Style+Guide + http://spark.apache.org/contributing.html#code-style-guide Checkstyle is very configurable. Be sure to read the documentation at http://checkstyle.sf.net (or in your downloaded distribution). diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index ad5b5c9adfac8..c00d0db63cd10 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -113,8 +113,8 @@
  • Hardware Provisioning
  • Building Spark
  • -
  • Contributing to Spark
  • -
  • Third Party Projects
  • +
  • Contributing to Spark
  • +
  • Third Party Projects
  • diff --git a/docs/building-spark.md b/docs/building-spark.md index 88da0cc9c3bbf..65c2895b29b10 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -197,7 +197,7 @@ can be set to control the SBT build. For example: To avoid the overhead of launching sbt each time you need to re-compile, you can launch sbt in interactive mode by running `build/sbt`, and then run all build commands at the command prompt. For more recommendations on reducing build time, refer to the -[wiki page](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-ReducingBuildTimes). +[Useful Developer Tools page](http://spark.apache.org/developer-tools.html). ## Encrypted Filesystems @@ -215,7 +215,7 @@ to the `sharedSettings` val. See also [this PR](https://github.com/apache/spark/ ## IntelliJ IDEA or Eclipse For help in setting up IntelliJ IDEA or Eclipse for Spark development, and troubleshooting, refer to the -[wiki page for IDE setup](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-IDESetup). +[Useful Developer Tools page](http://spark.apache.org/developer-tools.html). # Running Tests diff --git a/docs/contributing-to-spark.md b/docs/contributing-to-spark.md index ef1b3ad6da57a..9252545e4a129 100644 --- a/docs/contributing-to-spark.md +++ b/docs/contributing-to-spark.md @@ -5,4 +5,4 @@ title: Contributing to Spark The Spark team welcomes all forms of contributions, including bug reports, documentation or patches. For the newest information on how to contribute to the project, please read the -[wiki page on contributing to Spark](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark). +[Contributing to Spark guide](http://spark.apache.org/contributing.html). diff --git a/docs/index.md b/docs/index.md index 39de11de854a7..c5d34cb5c4e73 100644 --- a/docs/index.md +++ b/docs/index.md @@ -125,8 +125,8 @@ options for deployment: * Integration with other storage systems: * [OpenStack Swift](storage-openstack-swift.html) * [Building Spark](building-spark.html): build Spark using the Maven system -* [Contributing to Spark](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) -* [Third Party Projects](https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects): related third party Spark projects +* [Contributing to Spark](http://spark.apache.org/contributing.html) +* [Third Party Projects](http://spark.apache.org/third-party-projects.html): related third party Spark projects **External Resources:** diff --git a/docs/sparkr.md b/docs/sparkr.md index f30bd4026fed3..d26949226b117 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -126,7 +126,7 @@ head(df) SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. This section describes the general methods for loading and saving data using Data Sources. You can check the Spark SQL programming guide for more [specific options](sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. The general method for creating SparkDataFrames from data sources is `read.df`. This method takes in the path for the file to load and the type of data source, and the currently active SparkSession will be used automatically. -SparkR supports reading JSON, CSV and Parquet files natively, and through packages available from sources like [Third Party Projects](https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects), you can find data source connectors for popular file formats like Avro. These packages can either be added by +SparkR supports reading JSON, CSV and Parquet files natively, and through packages available from sources like [Third Party Projects](http://spark.apache.org/third-party-projects.html), you can find data source connectors for popular file formats like Avro. These packages can either be added by specifying `--packages` with `spark-submit` or `sparkR` commands, or if initializing SparkSession with `sparkPackages` parameter when in an interactive R shell or from RStudio.
    diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 18fc1cd934826..1fcd198685a51 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -2382,7 +2382,7 @@ additional effort may be necessary to achieve exactly-once semantics. There are - [Kafka Integration Guide](streaming-kafka-integration.html) - [Kinesis Integration Guide](streaming-kinesis-integration.html) - [Custom Receiver Guide](streaming-custom-receivers.html) -* Third-party DStream data sources can be found in [Third Party Projects](https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects) +* Third-party DStream data sources can be found in [Third Party Projects](http://spark.apache.org/third-party-projects.html) * API documentation - Scala docs * [StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) and diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index cfee7be1e3f07..84fde0bbf9268 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -505,12 +505,11 @@ object DataSource { provider1 == "com.databricks.spark.avro") { throw new AnalysisException( s"Failed to find data source: ${provider1.toLowerCase}. Please find an Avro " + - "package at " + - "https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects") + "package at http://spark.apache.org/third-party-projects.html") } else { throw new ClassNotFoundException( s"Failed to find data source: $provider1. Please find packages at " + - "https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects", + "http://spark.apache.org/third-party-projects.html", error) } } From ebeb051405b84cb4abafbb6929ddcfadf59672db Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 23 Nov 2016 04:15:19 -0800 Subject: [PATCH 180/534] [SPARK-18053][SQL] compare unsafe and safe complex-type values correctly ## What changes were proposed in this pull request? In Spark SQL, some expression may output safe format values, e.g. `CreateArray`, `CreateStruct`, `Cast`, etc. When we compare 2 values, we should be able to compare safe and unsafe formats. The `GreaterThan`, `LessThan`, etc. in Spark SQL already handles it, but the `EqualTo` doesn't. This PR fixes it. ## How was this patch tested? new unit test and regression test Author: Wenchen Fan Closes #15929 from cloud-fan/type-aware. (cherry picked from commit 84284e8c82542d80dad94e458a0c0210bf803db3) Signed-off-by: Herman van Hovell --- .../sql/catalyst/expressions/UnsafeRow.java | 6 +--- .../expressions/codegen/CodeGenerator.scala | 20 ++++++++++-- .../sql/catalyst/expressions/predicates.scala | 32 +++---------------- .../catalyst/expressions/PredicateSuite.scala | 29 +++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 7 ++++ 5 files changed, 59 insertions(+), 35 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index c3f0abac244cf..d205547698c5b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -578,12 +578,8 @@ public boolean equals(Object other) { return (sizeInBytes == o.sizeInBytes) && ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset, sizeInBytes); - } else if (!(other instanceof InternalRow)) { - return false; - } else { - throw new IllegalArgumentException( - "Cannot compare UnsafeRow to " + other.getClass().getName()); } + return false; } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 9c3c6d3b2a7f2..09007b7c89fe3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -481,8 +481,13 @@ class CodegenContext { case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2" case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2" case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2" + case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)" + case array: ArrayType => genComp(array, c1, c2) + " == 0" + case struct: StructType => genComp(struct, c1, c2) + " == 0" case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2) - case other => s"$c1.equals($c2)" + case _ => + throw new IllegalArgumentException( + "cannot generate equality code for un-comparable type: " + dataType.simpleString) } /** @@ -512,6 +517,11 @@ class CodegenContext { val funcCode: String = s""" public int $compareFunc(ArrayData a, ArrayData b) { + // when comparing unsafe arrays, try equals first as it compares the binary directly + // which is very fast. + if (a instanceof UnsafeArrayData && b instanceof UnsafeArrayData && a.equals(b)) { + return 0; + } int lengthA = a.numElements(); int lengthB = b.numElements(); int $minLength = (lengthA > lengthB) ? lengthB : lengthA; @@ -551,6 +561,11 @@ class CodegenContext { val funcCode: String = s""" public int $compareFunc(InternalRow a, InternalRow b) { + // when comparing unsafe rows, try equals first as it compares the binary directly + // which is very fast. + if (a instanceof UnsafeRow && b instanceof UnsafeRow && a.equals(b)) { + return 0; + } InternalRow i = null; $comparisons return 0; @@ -561,7 +576,8 @@ class CodegenContext { case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)" case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2) case _ => - throw new IllegalArgumentException("cannot generate compare code for un-comparable type") + throw new IllegalArgumentException( + "cannot generate compare code for un-comparable type: " + dataType.simpleString) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 2ad452b6a90ca..3fcbb05372d87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -388,6 +388,8 @@ abstract class BinaryComparison extends BinaryOperator with Predicate { defineCodeGen(ctx, ev, (c1, c2) => s"${ctx.genComp(left.dataType, c1, c2)} $symbol 0") } } + + protected lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) } @@ -429,17 +431,7 @@ case class EqualTo(left: Expression, right: Expression) override def symbol: String = "=" - protected override def nullSafeEval(input1: Any, input2: Any): Any = { - if (left.dataType == FloatType) { - Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0 - } else if (left.dataType == DoubleType) { - Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0 - } else if (left.dataType != BinaryType) { - input1 == input2 - } else { - java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]]) - } - } + protected override def nullSafeEval(left: Any, right: Any): Any = ordering.equiv(left, right) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => ctx.genEqual(left.dataType, c1, c2)) @@ -482,15 +474,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } else if (input1 == null || input2 == null) { false } else { - if (left.dataType == FloatType) { - Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0 - } else if (left.dataType == DoubleType) { - Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0 - } else if (left.dataType != BinaryType) { - input1 == input2 - } else { - java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]]) - } + ordering.equiv(input1, input2) } } @@ -513,8 +497,6 @@ case class LessThan(left: Expression, right: Expression) override def symbol: String = "<" - private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) - protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2) } @@ -527,8 +509,6 @@ case class LessThanOrEqual(left: Expression, right: Expression) override def symbol: String = "<=" - private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) - protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2) } @@ -541,8 +521,6 @@ case class GreaterThan(left: Expression, right: Expression) override def symbol: String = ">" - private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) - protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2) } @@ -555,7 +533,5 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) override def symbol: String = ">=" - private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) - protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gteq(input1, input2) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 2a445b8cdb091..f9f6799e6e72f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -21,6 +21,8 @@ import scala.collection.immutable.HashSet import org.apache.spark.SparkFunSuite import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ @@ -293,4 +295,31 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(EqualNullSafe(nullInt, normalInt), false) checkEvaluation(EqualNullSafe(nullInt, nullInt), true) } + + test("EqualTo on complex type") { + val array = new GenericArrayData(Array(1, 2, 3)) + val struct = create_row("a", 1L, array) + + val arrayType = ArrayType(IntegerType) + val structType = new StructType() + .add("1", StringType) + .add("2", LongType) + .add("3", ArrayType(IntegerType)) + + val projection = UnsafeProjection.create( + new StructType().add("array", arrayType).add("struct", structType)) + + val unsafeRow = projection(InternalRow(array, struct)) + + val unsafeArray = unsafeRow.getArray(0) + val unsafeStruct = unsafeRow.getStruct(1, 3) + + checkEvaluation(EqualTo( + Literal.create(array, arrayType), + Literal.create(unsafeArray, arrayType)), true) + + checkEvaluation(EqualTo( + Literal.create(struct, structType), + Literal.create(unsafeStruct, structType)), true) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 6b517bc70f7d2..806381008aba6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2476,4 +2476,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-18053: ARRAY equality is broken") { + withTable("array_tbl") { + spark.range(10).select(array($"id").as("arr")).write.saveAsTable("array_tbl") + assert(sql("SELECT * FROM array_tbl where arr = ARRAY(1L)").count == 1) + } + } } From 539c193af7e3e08e9b48df15e94eafcc3532105c Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 23 Nov 2016 20:14:08 +0800 Subject: [PATCH 181/534] [SPARK-18545][SQL] Verify number of hive client RPCs in PartitionedTablePerfStatsSuite ## What changes were proposed in this pull request? This would help catch accidental O(n) calls to the hive client as in https://issues.apache.org/jira/browse/SPARK-18507 ## How was this patch tested? Checked that the test fails before https://issues.apache.org/jira/browse/SPARK-18507 was patched. cc cloud-fan Author: Eric Liang Closes #15985 from ericl/spark-18545. (cherry picked from commit 85235ed6c600270e3fa434738bd50dce3564440a) Signed-off-by: Wenchen Fan --- .../spark/metrics/source/StaticSources.scala | 7 +++ .../sql/hive/client/HiveClientImpl.scala | 1 + .../hive/PartitionedTablePerfStatsSuite.scala | 58 ++++++++++++++++++- 3 files changed, 64 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala b/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala index 3f7cfd9d2c11f..b433cd0a89ac9 100644 --- a/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala +++ b/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala @@ -85,6 +85,11 @@ object HiveCatalogMetrics extends Source { */ val METRIC_FILE_CACHE_HITS = metricRegistry.counter(MetricRegistry.name("fileCacheHits")) + /** + * Tracks the total number of Hive client calls (e.g. to lookup a table). + */ + val METRIC_HIVE_CLIENT_CALLS = metricRegistry.counter(MetricRegistry.name("hiveClientCalls")) + /** * Resets the values of all metrics to zero. This is useful in tests. */ @@ -92,10 +97,12 @@ object HiveCatalogMetrics extends Source { METRIC_PARTITIONS_FETCHED.dec(METRIC_PARTITIONS_FETCHED.getCount()) METRIC_FILES_DISCOVERED.dec(METRIC_FILES_DISCOVERED.getCount()) METRIC_FILE_CACHE_HITS.dec(METRIC_FILE_CACHE_HITS.getCount()) + METRIC_HIVE_CLIENT_CALLS.dec(METRIC_HIVE_CLIENT_CALLS.getCount()) } // clients can use these to avoid classloader issues with the codahale classes def incrementFetchedPartitions(n: Int): Unit = METRIC_PARTITIONS_FETCHED.inc(n) def incrementFilesDiscovered(n: Int): Unit = METRIC_FILES_DISCOVERED.inc(n) def incrementFileCacheHits(n: Int): Unit = METRIC_FILE_CACHE_HITS.inc(n) + def incrementHiveClientCalls(n: Int): Unit = METRIC_HIVE_CLIENT_CALLS.inc(n) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index daae8523c6366..68dcfd86731bd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -281,6 +281,7 @@ private[hive] class HiveClientImpl( shim.setCurrentSessionState(state) val ret = try f finally { Thread.currentThread().setContextClassLoader(original) + HiveCatalogMetrics.incrementHiveClientCalls(1) } ret } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala index b41bc862e9bc5..9838b9a4eba3d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala @@ -57,7 +57,11 @@ class PartitionedTablePerfStatsSuite } private def setupPartitionedHiveTable(tableName: String, dir: File): Unit = { - spark.range(5).selectExpr("id as fieldOne", "id as partCol1", "id as partCol2").write + setupPartitionedHiveTable(tableName, dir, 5) + } + + private def setupPartitionedHiveTable(tableName: String, dir: File, scale: Int): Unit = { + spark.range(scale).selectExpr("id as fieldOne", "id as partCol1", "id as partCol2").write .partitionBy("partCol1", "partCol2") .mode("overwrite") .parquet(dir.getAbsolutePath) @@ -71,7 +75,11 @@ class PartitionedTablePerfStatsSuite } private def setupPartitionedDatasourceTable(tableName: String, dir: File): Unit = { - spark.range(5).selectExpr("id as fieldOne", "id as partCol1", "id as partCol2").write + setupPartitionedDatasourceTable(tableName, dir, 5) + } + + private def setupPartitionedDatasourceTable(tableName: String, dir: File, scale: Int): Unit = { + spark.range(scale).selectExpr("id as fieldOne", "id as partCol1", "id as partCol2").write .partitionBy("partCol1", "partCol2") .mode("overwrite") .parquet(dir.getAbsolutePath) @@ -242,6 +250,52 @@ class PartitionedTablePerfStatsSuite } } + test("hive table: num hive client calls does not scale with partition count") { + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "true") { + withTable("test") { + withTempDir { dir => + setupPartitionedHiveTable("test", dir, scale = 100) + + HiveCatalogMetrics.reset() + assert(spark.sql("select * from test where partCol1 = 1").count() == 1) + assert(HiveCatalogMetrics.METRIC_HIVE_CLIENT_CALLS.getCount() > 0) + assert(HiveCatalogMetrics.METRIC_HIVE_CLIENT_CALLS.getCount() < 10) + + HiveCatalogMetrics.reset() + assert(spark.sql("select * from test").count() == 100) + assert(HiveCatalogMetrics.METRIC_HIVE_CLIENT_CALLS.getCount() < 10) + + HiveCatalogMetrics.reset() + assert(spark.sql("show partitions test").count() == 100) + assert(HiveCatalogMetrics.METRIC_HIVE_CLIENT_CALLS.getCount() < 10) + } + } + } + } + + test("datasource table: num hive client calls does not scale with partition count") { + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "true") { + withTable("test") { + withTempDir { dir => + setupPartitionedDatasourceTable("test", dir, scale = 100) + + HiveCatalogMetrics.reset() + assert(spark.sql("select * from test where partCol1 = 1").count() == 1) + assert(HiveCatalogMetrics.METRIC_HIVE_CLIENT_CALLS.getCount() > 0) + assert(HiveCatalogMetrics.METRIC_HIVE_CLIENT_CALLS.getCount() < 10) + + HiveCatalogMetrics.reset() + assert(spark.sql("select * from test").count() == 100) + assert(HiveCatalogMetrics.METRIC_HIVE_CLIENT_CALLS.getCount() < 10) + + HiveCatalogMetrics.reset() + assert(spark.sql("show partitions test").count() == 100) + assert(HiveCatalogMetrics.METRIC_HIVE_CLIENT_CALLS.getCount() < 10) + } + } + } + } + test("hive table: files read and cached when filesource partition management is off") { withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "false") { withTable("test") { From e11d7c6874debfbbe44be4a2b0983d6b6763fff8 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 23 Nov 2016 04:22:26 -0800 Subject: [PATCH 182/534] [SPARK-18557] Downgrade confusing memory leak warning message ## What changes were proposed in this pull request? TaskMemoryManager has a memory leak detector that gets called at task completion callback and checks whether any memory has not been released. If they are not released by the time the callback is invoked, TaskMemoryManager releases them. The current error message says something like the following: ``` WARN [Executor task launch worker-0] org.apache.spark.memory.TaskMemoryManager - leak 16.3 MB memory from org.apache.spark.unsafe.map.BytesToBytesMap33fb6a15 In practice, there are multiple reasons why these can be triggered in the normal code path (e.g. limit, or task failures), and the fact that these messages are log means the "leak" is fixed by TaskMemoryManager. ``` To not confuse users, this patch downgrade the message from warning to debug level, and avoids using the word "leak" since it is not actually a leak. ## How was this patch tested? N/A - this is a simple logging improvement. Author: Reynold Xin Closes #15989 from rxin/SPARK-18557. (cherry picked from commit 9785ed40d7fe4e1fcd440e55706519c6e5f8d6b1) Signed-off-by: Herman van Hovell --- .../main/java/org/apache/spark/memory/TaskMemoryManager.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 1a700aa37554e..c40974b54cb47 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -378,14 +378,14 @@ public long cleanUpAllAllocatedMemory() { for (MemoryConsumer c: consumers) { if (c != null && c.getUsed() > 0) { // In case of failed task, it's normal to see leaked memory - logger.warn("leak " + Utils.bytesToString(c.getUsed()) + " memory from " + c); + logger.debug("unreleased " + Utils.bytesToString(c.getUsed()) + " memory from " + c); } } consumers.clear(); for (MemoryBlock page : pageTable) { if (page != null) { - logger.warn("leak a page: " + page + " in task " + taskAttemptId); + logger.debug("unreleased page: " + page + " in task " + taskAttemptId); memoryManager.tungstenMemoryAllocator().free(page); } } From 599dac1594ed52934dd483e12d2e39d514793dd9 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 23 Nov 2016 20:48:41 +0800 Subject: [PATCH 183/534] [SPARK-18522][SQL] Explicit contract for column stats serialization ## What changes were proposed in this pull request? The current implementation of column stats uses the base64 encoding of the internal UnsafeRow format to persist statistics (in table properties in Hive metastore). This is an internal format that is not stable across different versions of Spark and should NOT be used for persistence. In addition, it would be better if statistics stored in the catalog is human readable. This pull request introduces the following changes: 1. Created a single ColumnStat class to for all data types. All data types track the same set of statistics. 2. Updated the implementation for stats collection to get rid of the dependency on internal data structures (e.g. InternalRow, or storing DateType as an int32). For example, previously dates were stored as a single integer, but are now stored as java.sql.Date. When we implement the next steps of CBO, we can add code to convert those back into internal types again. 3. Documented clearly what JVM data types are being used to store what data. 4. Defined a simple Map[String, String] interface for serializing and deserializing column stats into/from the catalog. 5. Rearranged the method/function structure so it is more clear what the supported data types are, and also moved how stats are generated into ColumnStat class so they are easy to find. ## How was this patch tested? Removed most of the original test cases created for column statistics, and added three very simple ones to cover all the cases. The three test cases validate: 1. Roundtrip serialization works. 2. Behavior when analyzing non-existent column or unsupported data type column. 3. Result for stats collection for all valid data types. Also moved parser related tests into a parser test suite and added an explicit serialization test for the Hive external catalog. Author: Reynold Xin Closes #15959 from rxin/SPARK-18522. (cherry picked from commit 70ad07a9d20586ae182c4e60ed97bdddbcbceff3) Signed-off-by: Wenchen Fan --- .../catalyst/plans/logical/Statistics.scala | 212 ++++++++--- .../command/AnalyzeColumnCommand.scala | 105 +----- .../spark/sql/StatisticsCollectionSuite.scala | 218 ++++++++++++ .../spark/sql/StatisticsColumnSuite.scala | 334 ------------------ .../apache/spark/sql/StatisticsSuite.scala | 92 ----- .../org/apache/spark/sql/StatisticsTest.scala | 130 ------- .../sql/execution/SparkSqlParserSuite.scala | 26 +- .../spark/sql/hive/HiveExternalCatalog.scala | 93 +++-- .../spark/sql/hive/StatisticsSuite.scala | 299 ++++++---------- 9 files changed, 591 insertions(+), 918 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index f3e2147b8f974..79865609cb647 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.commons.codec.binary.Base64 +import scala.util.control.NonFatal -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.types._ + /** * Estimates of various statistics. The default estimation logic simply lazily multiplies the * corresponding statistic produced by the children. To override this behavior, override @@ -58,60 +61,175 @@ case class Statistics( } } + /** - * Statistics for a column. + * Statistics collected for a column. + * + * 1. Supported data types are defined in `ColumnStat.supportsType`. + * 2. The JVM data type stored in min/max is the external data type (used in Row) for the + * corresponding Catalyst data type. For example, for DateType we store java.sql.Date, and for + * TimestampType we store java.sql.Timestamp. + * 3. For integral types, they are all upcasted to longs, i.e. shorts are stored as longs. + * 4. There is no guarantee that the statistics collected are accurate. Approximation algorithms + * (sketches) might have been used, and the data collected can also be stale. + * + * @param distinctCount number of distinct values + * @param min minimum value + * @param max maximum value + * @param nullCount number of nulls + * @param avgLen average length of the values. For fixed-length types, this should be a constant. + * @param maxLen maximum length of the values. For fixed-length types, this should be a constant. */ -case class ColumnStat(statRow: InternalRow) { +case class ColumnStat( + distinctCount: BigInt, + min: Option[Any], + max: Option[Any], + nullCount: BigInt, + avgLen: Long, + maxLen: Long) { - def forNumeric[T <: AtomicType](dataType: T): NumericColumnStat[T] = { - NumericColumnStat(statRow, dataType) - } - def forString: StringColumnStat = StringColumnStat(statRow) - def forBinary: BinaryColumnStat = BinaryColumnStat(statRow) - def forBoolean: BooleanColumnStat = BooleanColumnStat(statRow) + // We currently don't store min/max for binary/string type. This can change in the future and + // then we need to remove this require. + require(min.isEmpty || (!min.get.isInstanceOf[Array[Byte]] && !min.get.isInstanceOf[String])) + require(max.isEmpty || (!max.get.isInstanceOf[Array[Byte]] && !max.get.isInstanceOf[String])) - override def toString: String = { - // use Base64 for encoding - Base64.encodeBase64String(statRow.asInstanceOf[UnsafeRow].getBytes) + /** + * Returns a map from string to string that can be used to serialize the column stats. + * The key is the name of the field (e.g. "distinctCount" or "min"), and the value is the string + * representation for the value. The deserialization side is defined in [[ColumnStat.fromMap]]. + * + * As part of the protocol, the returned map always contains a key called "version". + * In the case min/max values are null (None), they won't appear in the map. + */ + def toMap: Map[String, String] = { + val map = new scala.collection.mutable.HashMap[String, String] + map.put(ColumnStat.KEY_VERSION, "1") + map.put(ColumnStat.KEY_DISTINCT_COUNT, distinctCount.toString) + map.put(ColumnStat.KEY_NULL_COUNT, nullCount.toString) + map.put(ColumnStat.KEY_AVG_LEN, avgLen.toString) + map.put(ColumnStat.KEY_MAX_LEN, maxLen.toString) + min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, v.toString) } + max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, v.toString) } + map.toMap } } -object ColumnStat { - def apply(numFields: Int, str: String): ColumnStat = { - // use Base64 for decoding - val bytes = Base64.decodeBase64(str) - val unsafeRow = new UnsafeRow(numFields) - unsafeRow.pointTo(bytes, bytes.length) - ColumnStat(unsafeRow) + +object ColumnStat extends Logging { + + // List of string keys used to serialize ColumnStat + val KEY_VERSION = "version" + private val KEY_DISTINCT_COUNT = "distinctCount" + private val KEY_MIN_VALUE = "min" + private val KEY_MAX_VALUE = "max" + private val KEY_NULL_COUNT = "nullCount" + private val KEY_AVG_LEN = "avgLen" + private val KEY_MAX_LEN = "maxLen" + + /** Returns true iff the we support gathering column statistics on column of the given type. */ + def supportsType(dataType: DataType): Boolean = dataType match { + case _: IntegralType => true + case _: DecimalType => true + case DoubleType | FloatType => true + case BooleanType => true + case DateType => true + case TimestampType => true + case BinaryType | StringType => true + case _ => false } -} -case class NumericColumnStat[T <: AtomicType](statRow: InternalRow, dataType: T) { - // The indices here must be consistent with `ColumnStatStruct.numericColumnStat`. - val numNulls: Long = statRow.getLong(0) - val max: T#InternalType = statRow.get(1, dataType).asInstanceOf[T#InternalType] - val min: T#InternalType = statRow.get(2, dataType).asInstanceOf[T#InternalType] - val ndv: Long = statRow.getLong(3) -} + /** + * Creates a [[ColumnStat]] object from the given map. This is used to deserialize column stats + * from some external storage. The serialization side is defined in [[ColumnStat.toMap]]. + */ + def fromMap(table: String, field: StructField, map: Map[String, String]) + : Option[ColumnStat] = { + val str2val: (String => Any) = field.dataType match { + case _: IntegralType => _.toLong + case _: DecimalType => new java.math.BigDecimal(_) + case DoubleType | FloatType => _.toDouble + case BooleanType => _.toBoolean + case DateType => java.sql.Date.valueOf + case TimestampType => java.sql.Timestamp.valueOf + // This version of Spark does not use min/max for binary/string types so we ignore it. + case BinaryType | StringType => _ => null + case _ => + throw new AnalysisException("Column statistics deserialization is not supported for " + + s"column ${field.name} of data type: ${field.dataType}.") + } -case class StringColumnStat(statRow: InternalRow) { - // The indices here must be consistent with `ColumnStatStruct.stringColumnStat`. - val numNulls: Long = statRow.getLong(0) - val avgColLen: Double = statRow.getDouble(1) - val maxColLen: Long = statRow.getInt(2) - val ndv: Long = statRow.getLong(3) -} + try { + Some(ColumnStat( + distinctCount = BigInt(map(KEY_DISTINCT_COUNT).toLong), + // Note that flatMap(Option.apply) turns Option(null) into None. + min = map.get(KEY_MIN_VALUE).map(str2val).flatMap(Option.apply), + max = map.get(KEY_MAX_VALUE).map(str2val).flatMap(Option.apply), + nullCount = BigInt(map(KEY_NULL_COUNT).toLong), + avgLen = map.getOrElse(KEY_AVG_LEN, field.dataType.defaultSize.toString).toLong, + maxLen = map.getOrElse(KEY_MAX_LEN, field.dataType.defaultSize.toString).toLong + )) + } catch { + case NonFatal(e) => + logWarning(s"Failed to parse column statistics for column ${field.name} in table $table", e) + None + } + } -case class BinaryColumnStat(statRow: InternalRow) { - // The indices here must be consistent with `ColumnStatStruct.binaryColumnStat`. - val numNulls: Long = statRow.getLong(0) - val avgColLen: Double = statRow.getDouble(1) - val maxColLen: Long = statRow.getInt(2) -} + /** + * Constructs an expression to compute column statistics for a given column. + * + * The expression should create a single struct column with the following schema: + * distinctCount: Long, min: T, max: T, nullCount: Long, avgLen: Long, maxLen: Long + * + * Together with [[rowToColumnStat]], this function is used to create [[ColumnStat]] and + * as a result should stay in sync with it. + */ + def statExprs(col: Attribute, relativeSD: Double): CreateNamedStruct = { + def struct(exprs: Expression*): CreateNamedStruct = CreateStruct(exprs.map { expr => + expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() } + }) + val one = Literal(1, LongType) + + // the approximate ndv (num distinct value) should never be larger than the number of rows + val numNonNulls = if (col.nullable) Count(col) else Count(one) + val ndv = Least(Seq(HyperLogLogPlusPlus(col, relativeSD), numNonNulls)) + val numNulls = Subtract(Count(one), numNonNulls) + + def fixedLenTypeStruct(castType: DataType) = { + // For fixed width types, avg size should be the same as max size. + val avgSize = Literal(col.dataType.defaultSize, LongType) + struct(ndv, Cast(Min(col), castType), Cast(Max(col), castType), numNulls, avgSize, avgSize) + } + + col.dataType match { + case _: IntegralType => fixedLenTypeStruct(LongType) + case _: DecimalType => fixedLenTypeStruct(col.dataType) + case DoubleType | FloatType => fixedLenTypeStruct(DoubleType) + case BooleanType => fixedLenTypeStruct(col.dataType) + case DateType => fixedLenTypeStruct(col.dataType) + case TimestampType => fixedLenTypeStruct(col.dataType) + case BinaryType | StringType => + // For string and binary type, we don't store min/max. + val nullLit = Literal(null, col.dataType) + struct( + ndv, nullLit, nullLit, numNulls, + Ceil(Average(Length(col))), Cast(Max(Length(col)), LongType)) + case _ => + throw new AnalysisException("Analyzing column statistics is not supported for column " + + s"${col.name} of data type: ${col.dataType}.") + } + } + + /** Convert a struct for column stats (defined in statExprs) into [[ColumnStat]]. */ + def rowToColumnStat(row: Row): ColumnStat = { + ColumnStat( + distinctCount = BigInt(row.getLong(0)), + min = Option(row.get(1)), // for string/binary min/max, get should return null + max = Option(row.get(2)), + nullCount = BigInt(row.getLong(3)), + avgLen = row.getLong(4), + maxLen = row.getLong(5) + ) + } -case class BooleanColumnStat(statRow: InternalRow) { - // The indices here must be consistent with `ColumnStatStruct.booleanColumnStat`. - val numNulls: Long = statRow.getLong(0) - val numTrues: Long = statRow.getLong(1) - val numFalses: Long = statRow.getLong(2) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 7fc57d09e9243..9dffe3614a87c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -24,9 +24,8 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, ColumnStat, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.types._ /** @@ -62,7 +61,7 @@ case class AnalyzeColumnCommand( // Compute stats for each column val (rowCount, newColStats) = - AnalyzeColumnCommand.computeColStats(sparkSession, relation, columnNames) + AnalyzeColumnCommand.computeColumnStats(sparkSession, tableIdent.table, relation, columnNames) // We also update table-level stats in order to keep them consistent with column-level stats. val statistics = Statistics( @@ -88,8 +87,9 @@ object AnalyzeColumnCommand extends Logging { * * This is visible for testing. */ - def computeColStats( + def computeColumnStats( sparkSession: SparkSession, + tableName: String, relation: LogicalPlan, columnNames: Seq[String]): (Long, Map[String, ColumnStat]) = { @@ -97,102 +97,33 @@ object AnalyzeColumnCommand extends Logging { val resolver = sparkSession.sessionState.conf.resolver val attributesToAnalyze = AttributeSet(columnNames.map { col => val exprOption = relation.output.find(attr => resolver(attr.name, col)) - exprOption.getOrElse(throw new AnalysisException(s"Invalid column name: $col.")) + exprOption.getOrElse(throw new AnalysisException(s"Column $col does not exist.")) }).toSeq + // Make sure the column types are supported for stats gathering. + attributesToAnalyze.foreach { attr => + if (!ColumnStat.supportsType(attr.dataType)) { + throw new AnalysisException( + s"Column ${attr.name} in table $tableName is of type ${attr.dataType}, " + + "and Spark does not support statistics collection on this column type.") + } + } + // Collect statistics per column. // The first element in the result will be the overall row count, the following elements // will be structs containing all column stats. // The layout of each struct follows the layout of the ColumnStats. val ndvMaxErr = sparkSession.sessionState.conf.ndvMaxError val expressions = Count(Literal(1)).toAggregateExpression() +: - attributesToAnalyze.map(AnalyzeColumnCommand.createColumnStatStruct(_, ndvMaxErr)) + attributesToAnalyze.map(ColumnStat.statExprs(_, ndvMaxErr)) + val namedExpressions = expressions.map(e => Alias(e, e.toString)()) - val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)) - .queryExecution.toRdd.collect().head + val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)).head() - // unwrap the result - // TODO: Get rid of numFields by using the public Dataset API. val rowCount = statsRow.getLong(0) val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) => - val numFields = AnalyzeColumnCommand.numStatFields(expr.dataType) - (expr.name, ColumnStat(statsRow.getStruct(i + 1, numFields))) + (expr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1))) }.toMap (rowCount, columnStats) } - - private val zero = Literal(0, LongType) - private val one = Literal(1, LongType) - - private def numNulls(e: Expression): Expression = { - if (e.nullable) Sum(If(IsNull(e), one, zero)) else zero - } - private def max(e: Expression): Expression = Max(e) - private def min(e: Expression): Expression = Min(e) - private def ndv(e: Expression, relativeSD: Double): Expression = { - // the approximate ndv should never be larger than the number of rows - Least(Seq(HyperLogLogPlusPlus(e, relativeSD), Count(one))) - } - private def avgLength(e: Expression): Expression = Average(Length(e)) - private def maxLength(e: Expression): Expression = Max(Length(e)) - private def numTrues(e: Expression): Expression = Sum(If(e, one, zero)) - private def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero)) - - /** - * Creates a struct that groups the sequence of expressions together. This is used to create - * one top level struct per column. - */ - private def createStruct(exprs: Seq[Expression]): CreateNamedStruct = { - CreateStruct(exprs.map { expr: Expression => - expr.transformUp { - case af: AggregateFunction => af.toAggregateExpression() - } - }) - } - - private def numericColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = { - Seq(numNulls(e), max(e), min(e), ndv(e, relativeSD)) - } - - private def stringColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = { - Seq(numNulls(e), avgLength(e), maxLength(e), ndv(e, relativeSD)) - } - - private def binaryColumnStat(e: Expression): Seq[Expression] = { - Seq(numNulls(e), avgLength(e), maxLength(e)) - } - - private def booleanColumnStat(e: Expression): Seq[Expression] = { - Seq(numNulls(e), numTrues(e), numFalses(e)) - } - - // TODO(rxin): Get rid of this function. - def numStatFields(dataType: DataType): Int = { - dataType match { - case BinaryType | BooleanType => 3 - case _ => 4 - } - } - - /** - * Creates a struct expression that contains the statistics to collect for a column. - * - * @param attr column to collect statistics - * @param relativeSD relative error for approximate number of distinct values. - */ - def createColumnStatStruct(attr: Attribute, relativeSD: Double): CreateNamedStruct = { - attr.dataType match { - case _: NumericType | TimestampType | DateType => - createStruct(numericColumnStat(attr, relativeSD)) - case StringType => - createStruct(stringColumnStat(attr, relativeSD)) - case BinaryType => - createStruct(binaryColumnStat(attr)) - case BooleanType => - createStruct(booleanColumnStat(attr)) - case otherType => - throw new AnalysisException("Analyzing columns is not supported for column " + - s"${attr.name} of data type: ${attr.dataType}.") - } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala new file mode 100644 index 0000000000000..1fcccd061079e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.{lang => jl} +import java.sql.{Date, Timestamp} + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.sql.test.SQLTestData.ArrayData +import org.apache.spark.sql.types._ + + +/** + * End-to-end suite testing statistics collection and use on both entire table and columns. + */ +class StatisticsCollectionSuite extends StatisticsCollectionTestBase with SharedSQLContext { + import testImplicits._ + + private def checkTableStats(tableName: String, expectedRowCount: Option[Int]) + : Option[Statistics] = { + val df = spark.table(tableName) + val stats = df.queryExecution.analyzed.collect { case rel: LogicalRelation => + assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount) + rel.catalogTable.get.stats + } + assert(stats.size == 1) + stats.head + } + + test("estimates the size of a limit 0 on outer join") { + withTempView("test") { + Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v") + .createOrReplaceTempView("test") + val df1 = spark.table("test") + val df2 = spark.table("test").limit(0) + val df = df1.join(df2, Seq("k"), "left") + + val sizes = df.queryExecution.analyzed.collect { case g: Join => + g.statistics.sizeInBytes + } + + assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}") + assert(sizes.head === BigInt(96), + s"expected exact size 96 for table 'test', got: ${sizes.head}") + } + } + + test("analyze column command - unsupported types and invalid columns") { + val tableName = "column_stats_test1" + withTable(tableName) { + Seq(ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3)))).toDF().write.saveAsTable(tableName) + + // Test unsupported data types + val err1 = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS data") + } + assert(err1.message.contains("does not support statistics collection")) + + // Test invalid columns + val err2 = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS some_random_column") + } + assert(err2.message.contains("does not exist")) + } + } + + test("test table-level statistics for data source table") { + val tableName = "tbl" + withTable(tableName) { + sql(s"CREATE TABLE $tableName(i INT, j STRING) USING parquet") + Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto(tableName) + + // noscan won't count the number of rows + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") + checkTableStats(tableName, expectedRowCount = None) + + // without noscan, we count the number of rows + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS") + checkTableStats(tableName, expectedRowCount = Some(2)) + } + } + + test("SPARK-15392: DataFrame created from RDD should not be broadcasted") { + val rdd = sparkContext.range(1, 100).map(i => Row(i, i)) + val df = spark.createDataFrame(rdd, new StructType().add("a", LongType).add("b", LongType)) + assert(df.queryExecution.analyzed.statistics.sizeInBytes > + spark.sessionState.conf.autoBroadcastJoinThreshold) + assert(df.selectExpr("a").queryExecution.analyzed.statistics.sizeInBytes > + spark.sessionState.conf.autoBroadcastJoinThreshold) + } + + test("estimates the size of limit") { + withTempView("test") { + Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v") + .createOrReplaceTempView("test") + Seq((0, 1), (1, 24), (2, 48)).foreach { case (limit, expected) => + val df = sql(s"""SELECT * FROM test limit $limit""") + + val sizesGlobalLimit = df.queryExecution.analyzed.collect { case g: GlobalLimit => + g.statistics.sizeInBytes + } + assert(sizesGlobalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}") + assert(sizesGlobalLimit.head === BigInt(expected), + s"expected exact size $expected for table 'test', got: ${sizesGlobalLimit.head}") + + val sizesLocalLimit = df.queryExecution.analyzed.collect { case l: LocalLimit => + l.statistics.sizeInBytes + } + assert(sizesLocalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}") + assert(sizesLocalLimit.head === BigInt(expected), + s"expected exact size $expected for table 'test', got: ${sizesLocalLimit.head}") + } + } + } + +} + + +/** + * The base for test cases that we want to include in both the hive module (for verifying behavior + * when using the Hive external catalog) as well as in the sql/core module. + */ +abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils { + import testImplicits._ + + private val dec1 = new java.math.BigDecimal("1.000000000000000000") + private val dec2 = new java.math.BigDecimal("8.000000000000000000") + private val d1 = Date.valueOf("2016-05-08") + private val d2 = Date.valueOf("2016-05-09") + private val t1 = Timestamp.valueOf("2016-05-08 00:00:01") + private val t2 = Timestamp.valueOf("2016-05-09 00:00:02") + + /** + * Define a very simple 3 row table used for testing column serialization. + * Note: last column is seq[int] which doesn't support stats collection. + */ + protected val data = Seq[ + (jl.Boolean, jl.Byte, jl.Short, jl.Integer, jl.Long, + jl.Double, jl.Float, java.math.BigDecimal, + String, Array[Byte], Date, Timestamp, + Seq[Int])]( + (false, 1.toByte, 1.toShort, 1, 1L, 1.0, 1.0f, dec1, "s1", "b1".getBytes, d1, t1, null), + (true, 2.toByte, 3.toShort, 4, 5L, 6.0, 7.0f, dec2, "ss9", "bb0".getBytes, d2, t2, null), + (null, null, null, null, null, null, null, null, null, null, null, null, null) + ) + + /** A mapping from column to the stats collected. */ + protected val stats = mutable.LinkedHashMap( + "cbool" -> ColumnStat(2, Some(false), Some(true), 1, 1, 1), + "cbyte" -> ColumnStat(2, Some(1L), Some(2L), 1, 1, 1), + "cshort" -> ColumnStat(2, Some(1L), Some(3L), 1, 2, 2), + "cint" -> ColumnStat(2, Some(1L), Some(4L), 1, 4, 4), + "clong" -> ColumnStat(2, Some(1L), Some(5L), 1, 8, 8), + "cdouble" -> ColumnStat(2, Some(1.0), Some(6.0), 1, 8, 8), + "cfloat" -> ColumnStat(2, Some(1.0), Some(7.0), 1, 4, 4), + "cdecimal" -> ColumnStat(2, Some(dec1), Some(dec2), 1, 16, 16), + "cstring" -> ColumnStat(2, None, None, 1, 3, 3), + "cbinary" -> ColumnStat(2, None, None, 1, 3, 3), + "cdate" -> ColumnStat(2, Some(d1), Some(d2), 1, 4, 4), + "ctimestamp" -> ColumnStat(2, Some(t1), Some(t2), 1, 8, 8) + ) + + test("column stats round trip serialization") { + // Make sure we serialize and then deserialize and we will get the result data + val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) + stats.zip(df.schema).foreach { case ((k, v), field) => + withClue(s"column $k with type ${field.dataType}") { + val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap) + assert(roundtrip == Some(v)) + } + } + } + + test("analyze column command - result verification") { + val tableName = "column_stats_test2" + // (data.head.productArity - 1) because the last column does not support stats collection. + assert(stats.size == data.head.productArity - 1) + val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) + + withTable(tableName) { + df.write.saveAsTable(tableName) + + // Collect statistics + sql(s"analyze table $tableName compute STATISTICS FOR COLUMNS " + stats.keys.mkString(", ")) + + // Validate statistics + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) + assert(table.stats.isDefined) + assert(table.stats.get.colStats.size == stats.size) + + stats.foreach { case (k, v) => + withClue(s"column $k") { + assert(table.stats.get.colStats(k) == v) + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala deleted file mode 100644 index e866ac2cb3b34..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala +++ /dev/null @@ -1,334 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.sql.{Date, Timestamp} - -import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} -import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.catalyst.plans.logical.ColumnStat -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.command.AnalyzeColumnCommand -import org.apache.spark.sql.test.SQLTestData.ArrayData -import org.apache.spark.sql.types._ - -class StatisticsColumnSuite extends StatisticsTest { - import testImplicits._ - - test("parse analyze column commands") { - val tableName = "tbl" - - // we need to specify column names - intercept[ParseException] { - sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS") - } - - val analyzeSql = s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS key, value" - val parsed = spark.sessionState.sqlParser.parsePlan(analyzeSql) - val expected = AnalyzeColumnCommand(TableIdentifier(tableName), Seq("key", "value")) - comparePlans(parsed, expected) - } - - test("analyzing columns of non-atomic types is not supported") { - val tableName = "tbl" - withTable(tableName) { - Seq(ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3)))).toDF().write.saveAsTable(tableName) - val err = intercept[AnalysisException] { - sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS data") - } - assert(err.message.contains("Analyzing columns is not supported")) - } - } - - test("check correctness of columns") { - val table = "tbl" - val colName1 = "abc" - val colName2 = "x.yz" - withTable(table) { - sql(s"CREATE TABLE $table ($colName1 int, `$colName2` string) USING PARQUET") - - val invalidColError = intercept[AnalysisException] { - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS key") - } - assert(invalidColError.message == "Invalid column name: key.") - - withSQLConf("spark.sql.caseSensitive" -> "true") { - val invalidErr = intercept[AnalysisException] { - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS ${colName1.toUpperCase}") - } - assert(invalidErr.message == s"Invalid column name: ${colName1.toUpperCase}.") - } - - withSQLConf("spark.sql.caseSensitive" -> "false") { - val columnsToAnalyze = Seq(colName2.toUpperCase, colName1, colName2) - val tableIdent = TableIdentifier(table, Some("default")) - val relation = spark.sessionState.catalog.lookupRelation(tableIdent) - val (_, columnStats) = - AnalyzeColumnCommand.computeColStats(spark, relation, columnsToAnalyze) - assert(columnStats.contains(colName1)) - assert(columnStats.contains(colName2)) - // check deduplication - assert(columnStats.size == 2) - assert(!columnStats.contains(colName2.toUpperCase)) - } - } - } - - private def getNonNullValues[T](values: Seq[Option[T]]): Seq[T] = { - values.filter(_.isDefined).map(_.get) - } - - test("column-level statistics for integral type columns") { - val values = (0 to 5).map { i => - if (i % 2 == 0) None else Some(i) - } - val data = values.map { i => - (i.map(_.toByte), i.map(_.toShort), i.map(_.toInt), i.map(_.toLong)) - } - - val df = data.toDF("c1", "c2", "c3", "c4") - val nonNullValues = getNonNullValues[Int](values) - val expectedColStatsSeq = df.schema.map { f => - val colStat = ColumnStat(InternalRow( - values.count(_.isEmpty).toLong, - nonNullValues.max, - nonNullValues.min, - nonNullValues.distinct.length.toLong)) - (f, colStat) - } - checkColStats(df, expectedColStatsSeq) - } - - test("column-level statistics for fractional type columns") { - val values: Seq[Option[Decimal]] = (0 to 5).map { i => - if (i == 0) None else Some(Decimal(i + i * 0.01)) - } - val data = values.map { i => - (i.map(_.toFloat), i.map(_.toDouble), i) - } - - val df = data.toDF("c1", "c2", "c3") - val nonNullValues = getNonNullValues[Decimal](values) - val numNulls = values.count(_.isEmpty).toLong - val ndv = nonNullValues.distinct.length.toLong - val expectedColStatsSeq = df.schema.map { f => - val colStat = f.dataType match { - case floatType: FloatType => - ColumnStat(InternalRow(numNulls, nonNullValues.max.toFloat, nonNullValues.min.toFloat, - ndv)) - case doubleType: DoubleType => - ColumnStat(InternalRow(numNulls, nonNullValues.max.toDouble, nonNullValues.min.toDouble, - ndv)) - case decimalType: DecimalType => - ColumnStat(InternalRow(numNulls, nonNullValues.max, nonNullValues.min, ndv)) - } - (f, colStat) - } - checkColStats(df, expectedColStatsSeq) - } - - test("column-level statistics for string column") { - val values = Seq(None, Some("a"), Some("bbbb"), Some("cccc"), Some("")) - val df = values.toDF("c1") - val nonNullValues = getNonNullValues[String](values) - val expectedColStatsSeq = df.schema.map { f => - val colStat = ColumnStat(InternalRow( - values.count(_.isEmpty).toLong, - nonNullValues.map(_.length).sum / nonNullValues.length.toDouble, - nonNullValues.map(_.length).max.toInt, - nonNullValues.distinct.length.toLong)) - (f, colStat) - } - checkColStats(df, expectedColStatsSeq) - } - - test("column-level statistics for binary column") { - val values = Seq(None, Some("a"), Some("bbbb"), Some("cccc"), Some("")).map(_.map(_.getBytes)) - val df = values.toDF("c1") - val nonNullValues = getNonNullValues[Array[Byte]](values) - val expectedColStatsSeq = df.schema.map { f => - val colStat = ColumnStat(InternalRow( - values.count(_.isEmpty).toLong, - nonNullValues.map(_.length).sum / nonNullValues.length.toDouble, - nonNullValues.map(_.length).max.toInt)) - (f, colStat) - } - checkColStats(df, expectedColStatsSeq) - } - - test("column-level statistics for boolean column") { - val values = Seq(None, Some(true), Some(false), Some(true)) - val df = values.toDF("c1") - val nonNullValues = getNonNullValues[Boolean](values) - val expectedColStatsSeq = df.schema.map { f => - val colStat = ColumnStat(InternalRow( - values.count(_.isEmpty).toLong, - nonNullValues.count(_.equals(true)).toLong, - nonNullValues.count(_.equals(false)).toLong)) - (f, colStat) - } - checkColStats(df, expectedColStatsSeq) - } - - test("column-level statistics for date column") { - val values = Seq(None, Some("1970-01-01"), Some("1970-02-02")).map(_.map(Date.valueOf)) - val df = values.toDF("c1") - val nonNullValues = getNonNullValues[Date](values) - val expectedColStatsSeq = df.schema.map { f => - val colStat = ColumnStat(InternalRow( - values.count(_.isEmpty).toLong, - // Internally, DateType is represented as the number of days from 1970-01-01. - nonNullValues.map(DateTimeUtils.fromJavaDate).max, - nonNullValues.map(DateTimeUtils.fromJavaDate).min, - nonNullValues.distinct.length.toLong)) - (f, colStat) - } - checkColStats(df, expectedColStatsSeq) - } - - test("column-level statistics for timestamp column") { - val values = Seq(None, Some("1970-01-01 00:00:00"), Some("1970-01-01 00:00:05")).map { i => - i.map(Timestamp.valueOf) - } - val df = values.toDF("c1") - val nonNullValues = getNonNullValues[Timestamp](values) - val expectedColStatsSeq = df.schema.map { f => - val colStat = ColumnStat(InternalRow( - values.count(_.isEmpty).toLong, - // Internally, TimestampType is represented as the number of days from 1970-01-01 - nonNullValues.map(DateTimeUtils.fromJavaTimestamp).max, - nonNullValues.map(DateTimeUtils.fromJavaTimestamp).min, - nonNullValues.distinct.length.toLong)) - (f, colStat) - } - checkColStats(df, expectedColStatsSeq) - } - - test("column-level statistics for null columns") { - val values = Seq(None, None) - val data = values.map { i => - (i.map(_.toString), i.map(_.toString.toInt)) - } - val df = data.toDF("c1", "c2") - val expectedColStatsSeq = df.schema.map { f => - (f, ColumnStat(InternalRow(values.count(_.isEmpty).toLong, null, null, 0L))) - } - checkColStats(df, expectedColStatsSeq) - } - - test("column-level statistics for columns with different types") { - val intSeq = Seq(1, 2) - val doubleSeq = Seq(1.01d, 2.02d) - val stringSeq = Seq("a", "bb") - val binarySeq = Seq("a", "bb").map(_.getBytes) - val booleanSeq = Seq(true, false) - val dateSeq = Seq("1970-01-01", "1970-02-02").map(Date.valueOf) - val timestampSeq = Seq("1970-01-01 00:00:00", "1970-01-01 00:00:05").map(Timestamp.valueOf) - val longSeq = Seq(5L, 4L) - - val data = intSeq.indices.map { i => - (intSeq(i), doubleSeq(i), stringSeq(i), binarySeq(i), booleanSeq(i), dateSeq(i), - timestampSeq(i), longSeq(i)) - } - val df = data.toDF("c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8") - val expectedColStatsSeq = df.schema.map { f => - val colStat = f.dataType match { - case IntegerType => - ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong)) - case DoubleType => - ColumnStat(InternalRow(0L, doubleSeq.max, doubleSeq.min, - doubleSeq.distinct.length.toLong)) - case StringType => - ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble, - stringSeq.map(_.length).max.toInt, stringSeq.distinct.length.toLong)) - case BinaryType => - ColumnStat(InternalRow(0L, binarySeq.map(_.length).sum / binarySeq.length.toDouble, - binarySeq.map(_.length).max.toInt)) - case BooleanType => - ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong, - booleanSeq.count(_.equals(false)).toLong)) - case DateType => - ColumnStat(InternalRow(0L, dateSeq.map(DateTimeUtils.fromJavaDate).max, - dateSeq.map(DateTimeUtils.fromJavaDate).min, dateSeq.distinct.length.toLong)) - case TimestampType => - ColumnStat(InternalRow(0L, timestampSeq.map(DateTimeUtils.fromJavaTimestamp).max, - timestampSeq.map(DateTimeUtils.fromJavaTimestamp).min, - timestampSeq.distinct.length.toLong)) - case LongType => - ColumnStat(InternalRow(0L, longSeq.max, longSeq.min, longSeq.distinct.length.toLong)) - } - (f, colStat) - } - checkColStats(df, expectedColStatsSeq) - } - - test("update table-level stats while collecting column-level stats") { - val table = "tbl" - withTable(table) { - sql(s"CREATE TABLE $table (c1 int) USING PARQUET") - sql(s"INSERT INTO $table SELECT 1") - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") - checkTableStats(tableName = table, expectedRowCount = Some(1)) - - // update table-level stats between analyze table and analyze column commands - sql(s"INSERT INTO $table SELECT 1") - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") - val fetchedStats = checkTableStats(tableName = table, expectedRowCount = Some(2)) - - val colStat = fetchedStats.get.colStats("c1") - StatisticsTest.checkColStat( - dataType = IntegerType, - colStat = colStat, - expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)), - rsd = spark.sessionState.conf.ndvMaxError) - } - } - - test("analyze column stats independently") { - val table = "tbl" - withTable(table) { - sql(s"CREATE TABLE $table (c1 int, c2 long) USING PARQUET") - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") - val fetchedStats1 = checkTableStats(tableName = table, expectedRowCount = Some(0)) - assert(fetchedStats1.get.colStats.size == 1) - val expected1 = ColumnStat(InternalRow(0L, null, null, 0L)) - val rsd = spark.sessionState.conf.ndvMaxError - StatisticsTest.checkColStat( - dataType = IntegerType, - colStat = fetchedStats1.get.colStats("c1"), - expectedColStat = expected1, - rsd = rsd) - - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c2") - val fetchedStats2 = checkTableStats(tableName = table, expectedRowCount = Some(0)) - // column c1 is kept in the stats - assert(fetchedStats2.get.colStats.size == 2) - StatisticsTest.checkColStat( - dataType = IntegerType, - colStat = fetchedStats2.get.colStats("c1"), - expectedColStat = expected1, - rsd = rsd) - val expected2 = ColumnStat(InternalRow(0L, null, null, 0L)) - StatisticsTest.checkColStat( - dataType = LongType, - colStat = fetchedStats2.get.colStats("c2"), - expectedColStat = expected2, - rsd = rsd) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala deleted file mode 100644 index 8cf42e9248c2a..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.sql.catalyst.plans.logical.{GlobalLimit, Join, LocalLimit} -import org.apache.spark.sql.types._ - -class StatisticsSuite extends StatisticsTest { - import testImplicits._ - - test("SPARK-15392: DataFrame created from RDD should not be broadcasted") { - val rdd = sparkContext.range(1, 100).map(i => Row(i, i)) - val df = spark.createDataFrame(rdd, new StructType().add("a", LongType).add("b", LongType)) - assert(df.queryExecution.analyzed.statistics.sizeInBytes > - spark.sessionState.conf.autoBroadcastJoinThreshold) - assert(df.selectExpr("a").queryExecution.analyzed.statistics.sizeInBytes > - spark.sessionState.conf.autoBroadcastJoinThreshold) - } - - test("estimates the size of limit") { - withTempView("test") { - Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v") - .createOrReplaceTempView("test") - Seq((0, 1), (1, 24), (2, 48)).foreach { case (limit, expected) => - val df = sql(s"""SELECT * FROM test limit $limit""") - - val sizesGlobalLimit = df.queryExecution.analyzed.collect { case g: GlobalLimit => - g.statistics.sizeInBytes - } - assert(sizesGlobalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}") - assert(sizesGlobalLimit.head === BigInt(expected), - s"expected exact size $expected for table 'test', got: ${sizesGlobalLimit.head}") - - val sizesLocalLimit = df.queryExecution.analyzed.collect { case l: LocalLimit => - l.statistics.sizeInBytes - } - assert(sizesLocalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}") - assert(sizesLocalLimit.head === BigInt(expected), - s"expected exact size $expected for table 'test', got: ${sizesLocalLimit.head}") - } - } - } - - test("estimates the size of a limit 0 on outer join") { - withTempView("test") { - Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v") - .createOrReplaceTempView("test") - val df1 = spark.table("test") - val df2 = spark.table("test").limit(0) - val df = df1.join(df2, Seq("k"), "left") - - val sizes = df.queryExecution.analyzed.collect { case g: Join => - g.statistics.sizeInBytes - } - - assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}") - assert(sizes.head === BigInt(96), - s"expected exact size 96 for table 'test', got: ${sizes.head}") - } - } - - test("test table-level statistics for data source table created in InMemoryCatalog") { - val tableName = "tbl" - withTable(tableName) { - sql(s"CREATE TABLE $tableName(i INT, j STRING) USING parquet") - Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto(tableName) - - // noscan won't count the number of rows - sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") - checkTableStats(tableName, expectedRowCount = None) - - // without noscan, we count the number of rows - sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS") - checkTableStats(tableName, expectedRowCount = Some(2)) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala deleted file mode 100644 index 915ee0d31bca2..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} -import org.apache.spark.sql.execution.command.AnalyzeColumnCommand -import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types._ - - -trait StatisticsTest extends QueryTest with SharedSQLContext { - - def checkColStats( - df: DataFrame, - expectedColStatsSeq: Seq[(StructField, ColumnStat)]): Unit = { - val table = "tbl" - withTable(table) { - df.write.format("json").saveAsTable(table) - val columns = expectedColStatsSeq.map(_._1) - val tableIdent = TableIdentifier(table, Some("default")) - val relation = spark.sessionState.catalog.lookupRelation(tableIdent) - val (_, columnStats) = - AnalyzeColumnCommand.computeColStats(spark, relation, columns.map(_.name)) - expectedColStatsSeq.foreach { case (field, expectedColStat) => - assert(columnStats.contains(field.name)) - val colStat = columnStats(field.name) - StatisticsTest.checkColStat( - dataType = field.dataType, - colStat = colStat, - expectedColStat = expectedColStat, - rsd = spark.sessionState.conf.ndvMaxError) - - // check if we get the same colStat after encoding and decoding - val encodedCS = colStat.toString - val numFields = AnalyzeColumnCommand.numStatFields(field.dataType) - val decodedCS = ColumnStat(numFields, encodedCS) - StatisticsTest.checkColStat( - dataType = field.dataType, - colStat = decodedCS, - expectedColStat = expectedColStat, - rsd = spark.sessionState.conf.ndvMaxError) - } - } - } - - def checkTableStats(tableName: String, expectedRowCount: Option[Int]): Option[Statistics] = { - val df = spark.table(tableName) - val stats = df.queryExecution.analyzed.collect { case rel: LogicalRelation => - assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount) - rel.catalogTable.get.stats - } - assert(stats.size == 1) - stats.head - } -} - -object StatisticsTest { - def checkColStat( - dataType: DataType, - colStat: ColumnStat, - expectedColStat: ColumnStat, - rsd: Double): Unit = { - dataType match { - case StringType => - val cs = colStat.forString - val expectedCS = expectedColStat.forString - assert(cs.numNulls == expectedCS.numNulls) - assert(cs.avgColLen == expectedCS.avgColLen) - assert(cs.maxColLen == expectedCS.maxColLen) - checkNdv(ndv = cs.ndv, expectedNdv = expectedCS.ndv, rsd = rsd) - case BinaryType => - val cs = colStat.forBinary - val expectedCS = expectedColStat.forBinary - assert(cs.numNulls == expectedCS.numNulls) - assert(cs.avgColLen == expectedCS.avgColLen) - assert(cs.maxColLen == expectedCS.maxColLen) - case BooleanType => - val cs = colStat.forBoolean - val expectedCS = expectedColStat.forBoolean - assert(cs.numNulls == expectedCS.numNulls) - assert(cs.numTrues == expectedCS.numTrues) - assert(cs.numFalses == expectedCS.numFalses) - case atomicType: AtomicType => - checkNumericColStats( - dataType = atomicType, colStat = colStat, expectedColStat = expectedColStat, rsd = rsd) - } - } - - private def checkNumericColStats( - dataType: AtomicType, - colStat: ColumnStat, - expectedColStat: ColumnStat, - rsd: Double): Unit = { - val cs = colStat.forNumeric(dataType) - val expectedCS = expectedColStat.forNumeric(dataType) - assert(cs.numNulls == expectedCS.numNulls) - assert(cs.max == expectedCS.max) - assert(cs.min == expectedCS.min) - checkNdv(ndv = cs.ndv, expectedNdv = expectedCS.ndv, rsd = rsd) - } - - private def checkNdv(ndv: Long, expectedNdv: Long, rsd: Double): Unit = { - // ndv is an approximate value, so we make sure we have the value, and it should be - // within 3*SD's of the given rsd. - if (expectedNdv == 0) { - assert(ndv == 0) - } else if (expectedNdv > 0) { - assert(ndv > 0) - val error = math.abs((ndv / expectedNdv.toDouble) - 1.0d) - assert(error <= rsd * 3.0d, "Error should be within 3 std. errors.") - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 797fe9ffa8be1..b070138be05d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -23,9 +23,8 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.command.{AnalyzeTableCommand, DescribeFunctionCommand, - DescribeTableCommand, ShowFunctionsCommand} -import org.apache.spark.sql.execution.datasources.{CreateTable, CreateTempViewUsing} +import org.apache.spark.sql.execution.command._ +import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} @@ -221,12 +220,22 @@ class SparkSqlParserSuite extends PlanTest { intercept("explain describe tables x", "Unsupported SQL statement") } - test("SPARK-18106 analyze table") { + test("analyze table statistics") { assertEqual("analyze table t compute statistics", AnalyzeTableCommand(TableIdentifier("t"), noscan = false)) assertEqual("analyze table t compute statistics noscan", AnalyzeTableCommand(TableIdentifier("t"), noscan = true)) - assertEqual("analyze table t partition (a) compute statistics noscan", + assertEqual("analyze table t partition (a) compute statistics nOscAn", + AnalyzeTableCommand(TableIdentifier("t"), noscan = true)) + + // Partitions specified - we currently parse them but don't do anything with it + assertEqual("ANALYZE TABLE t PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS", + AnalyzeTableCommand(TableIdentifier("t"), noscan = false)) + assertEqual("ANALYZE TABLE t PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS noscan", + AnalyzeTableCommand(TableIdentifier("t"), noscan = true)) + assertEqual("ANALYZE TABLE t PARTITION(ds, hr) COMPUTE STATISTICS", + AnalyzeTableCommand(TableIdentifier("t"), noscan = false)) + assertEqual("ANALYZE TABLE t PARTITION(ds, hr) COMPUTE STATISTICS noscan", AnalyzeTableCommand(TableIdentifier("t"), noscan = true)) intercept("analyze table t compute statistics xxxx", @@ -234,4 +243,11 @@ class SparkSqlParserSuite extends PlanTest { intercept("analyze table t partition (a) compute statistics xxxx", "Expected `NOSCAN` instead of `xxxx`") } + + test("analyze table column statistics") { + intercept("ANALYZE TABLE t COMPUTE STATISTICS FOR COLUMNS", "") + + assertEqual("ANALYZE TABLE t COMPUTE STATISTICS FOR COLUMNS key, value", + AnalyzeColumnCommand(TableIdentifier("t"), Seq("key", "value"))) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index ff0923f04893d..fd9dc32063872 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.execution.command.{AnalyzeColumnCommand, DDLUtils} +import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.sql.internal.StaticSQLConf._ @@ -514,7 +514,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() } stats.colStats.foreach { case (colName, colStat) => - statsProperties += (STATISTICS_COL_STATS_PREFIX + colName) -> colStat.toString + colStat.toMap.foreach { case (k, v) => + statsProperties += (columnStatKeyPropName(colName, k) -> v) + } } tableDefinition.copy(properties = tableDefinition.properties ++ statsProperties) } else { @@ -605,48 +607,65 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * It reads table schema, provider, partition column names and bucket specification from table * properties, and filter out these special entries from table properties. */ - private def restoreTableMetadata(table: CatalogTable): CatalogTable = { + private def restoreTableMetadata(inputTable: CatalogTable): CatalogTable = { if (conf.get(DEBUG_MODE)) { - return table + return inputTable } - val tableWithSchema = if (table.tableType == VIEW) { - table - } else { - getProviderFromTableProperties(table) match { + var table = inputTable + + if (table.tableType != VIEW) { + table.properties.get(DATASOURCE_PROVIDER) match { // No provider in table properties, which means this table is created by Spark prior to 2.1, // or is created at Hive side. case None => - table.copy(provider = Some(DDLUtils.HIVE_PROVIDER), tracksPartitionsInCatalog = true) + table = table.copy( + provider = Some(DDLUtils.HIVE_PROVIDER), tracksPartitionsInCatalog = true) // This is a Hive serde table created by Spark 2.1 or higher versions. - case Some(DDLUtils.HIVE_PROVIDER) => restoreHiveSerdeTable(table) + case Some(DDLUtils.HIVE_PROVIDER) => + table = restoreHiveSerdeTable(table) // This is a regular data source table. - case Some(provider) => restoreDataSourceTable(table, provider) + case Some(provider) => + table = restoreDataSourceTable(table, provider) } } // construct Spark's statistics from information in Hive metastore - val statsProps = tableWithSchema.properties.filterKeys(_.startsWith(STATISTICS_PREFIX)) - val tableWithStats = if (statsProps.nonEmpty) { - val colStatsProps = statsProps.filterKeys(_.startsWith(STATISTICS_COL_STATS_PREFIX)) - .map { case (k, v) => (k.drop(STATISTICS_COL_STATS_PREFIX.length), v) } - val colStats: Map[String, ColumnStat] = tableWithSchema.schema.collect { - case f if colStatsProps.contains(f.name) => - val numFields = AnalyzeColumnCommand.numStatFields(f.dataType) - (f.name, ColumnStat(numFields, colStatsProps(f.name))) - }.toMap - tableWithSchema.copy( + val statsProps = table.properties.filterKeys(_.startsWith(STATISTICS_PREFIX)) + + if (statsProps.nonEmpty) { + val colStats = new scala.collection.mutable.HashMap[String, ColumnStat] + + // For each column, recover its column stats. Note that this is currently a O(n^2) operation, + // but given the number of columns it usually not enormous, this is probably OK as a start. + // If we want to map this a linear operation, we'd need a stronger contract between the + // naming convention used for serialization. + table.schema.foreach { field => + if (statsProps.contains(columnStatKeyPropName(field.name, ColumnStat.KEY_VERSION))) { + // If "version" field is defined, then the column stat is defined. + val keyPrefix = columnStatKeyPropName(field.name, "") + val colStatMap = statsProps.filterKeys(_.startsWith(keyPrefix)).map { case (k, v) => + (k.drop(keyPrefix.length), v) + } + + ColumnStat.fromMap(table.identifier.table, field, colStatMap).foreach { + colStat => colStats += field.name -> colStat + } + } + } + + table = table.copy( stats = Some(Statistics( - sizeInBytes = BigInt(tableWithSchema.properties(STATISTICS_TOTAL_SIZE)), - rowCount = tableWithSchema.properties.get(STATISTICS_NUM_ROWS).map(BigInt(_)), - colStats = colStats))) - } else { - tableWithSchema + sizeInBytes = BigInt(table.properties(STATISTICS_TOTAL_SIZE)), + rowCount = table.properties.get(STATISTICS_NUM_ROWS).map(BigInt(_)), + colStats = colStats.toMap))) } - tableWithStats.copy(properties = getOriginalTableProperties(table)) + // Get the original table properties as defined by the user. + table.copy( + properties = table.properties.filterNot { case (key, _) => key.startsWith(SPARK_SQL_PREFIX) }) } private def restoreHiveSerdeTable(table: CatalogTable): CatalogTable = { @@ -1020,17 +1039,17 @@ object HiveExternalCatalog { val TABLE_PARTITION_PROVIDER_CATALOG = "catalog" val TABLE_PARTITION_PROVIDER_FILESYSTEM = "filesystem" - - def getProviderFromTableProperties(metadata: CatalogTable): Option[String] = { - metadata.properties.get(DATASOURCE_PROVIDER) - } - - def getOriginalTableProperties(metadata: CatalogTable): Map[String, String] = { - metadata.properties.filterNot { case (key, _) => key.startsWith(SPARK_SQL_PREFIX) } + /** + * Returns the fully qualified name used in table properties for a particular column stat. + * For example, for column "mycol", and "min" stat, this should return + * "spark.sql.statistics.colStats.mycol.min". + */ + private def columnStatKeyPropName(columnName: String, statKey: String): String = { + STATISTICS_COL_STATS_PREFIX + columnName + "." + statKey } // A persisted data source table always store its schema in the catalog. - def getSchemaFromTableProperties(metadata: CatalogTable): StructType = { + private def getSchemaFromTableProperties(metadata: CatalogTable): StructType = { val errorMessage = "Could not read schema from the hive metastore because it is corrupted." val props = metadata.properties val schema = props.get(DATASOURCE_SCHEMA) @@ -1078,11 +1097,11 @@ object HiveExternalCatalog { ) } - def getPartitionColumnsFromTableProperties(metadata: CatalogTable): Seq[String] = { + private def getPartitionColumnsFromTableProperties(metadata: CatalogTable): Seq[String] = { getColumnNamesByType(metadata.properties, "part", "partitioning columns") } - def getBucketSpecFromTableProperties(metadata: CatalogTable): Option[BucketSpec] = { + private def getBucketSpecFromTableProperties(metadata: CatalogTable): Option[BucketSpec] = { metadata.properties.get(DATASOURCE_SCHEMA_NUMBUCKETS).map { numBuckets => BucketSpec( numBuckets.toInt, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 4f5ebc3d838b9..5ae202fdc98da 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -22,56 +22,16 @@ import java.io.{File, PrintWriter} import scala.reflect.ClassTag import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} -import org.apache.spark.sql.execution.command.{AnalyzeTableCommand, DDLUtils} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ -class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { - - test("parse analyze commands") { - def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { - val parsed = spark.sessionState.sqlParser.parsePlan(analyzeCommand) - val operators = parsed.collect { - case a: AnalyzeTableCommand => a - case o => o - } - - assert(operators.size === 1) - if (operators(0).getClass() != c) { - fail( - s"""$analyzeCommand expected command: $c, but got ${operators(0)} - |parsed command: - |$parsed - """.stripMargin) - } - } - - assertAnalyzeCommand( - "ANALYZE TABLE Table1 COMPUTE STATISTICS", - classOf[AnalyzeTableCommand]) - assertAnalyzeCommand( - "ANALYZE TABLE Table1 PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS", - classOf[AnalyzeTableCommand]) - assertAnalyzeCommand( - "ANALYZE TABLE Table1 PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS noscan", - classOf[AnalyzeTableCommand]) - assertAnalyzeCommand( - "ANALYZE TABLE Table1 PARTITION(ds, hr) COMPUTE STATISTICS", - classOf[AnalyzeTableCommand]) - assertAnalyzeCommand( - "ANALYZE TABLE Table1 PARTITION(ds, hr) COMPUTE STATISTICS noscan", - classOf[AnalyzeTableCommand]) - - assertAnalyzeCommand( - "ANALYZE TABLE Table1 COMPUTE STATISTICS nOscAn", - classOf[AnalyzeTableCommand]) - } +class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleton { test("MetastoreRelations fallback to HDFS for size estimation") { val enableFallBackToHdfsForStats = spark.sessionState.conf.fallBackToHdfsForStatsEnabled @@ -310,6 +270,110 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils } } + test("verify serialized column stats after analyzing columns") { + import testImplicits._ + + val tableName = "column_stats_test2" + // (data.head.productArity - 1) because the last column does not support stats collection. + assert(stats.size == data.head.productArity - 1) + val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) + + withTable(tableName) { + df.write.saveAsTable(tableName) + + // Collect statistics + sql(s"analyze table $tableName compute STATISTICS FOR COLUMNS " + stats.keys.mkString(", ")) + + // Validate statistics + val hiveClient = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val table = hiveClient.getTable("default", tableName) + + val props = table.properties.filterKeys(_.startsWith("spark.sql.statistics.colStats")) + assert(props == Map( + "spark.sql.statistics.colStats.cbinary.avgLen" -> "3", + "spark.sql.statistics.colStats.cbinary.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbinary.maxLen" -> "3", + "spark.sql.statistics.colStats.cbinary.nullCount" -> "1", + "spark.sql.statistics.colStats.cbinary.version" -> "1", + "spark.sql.statistics.colStats.cbool.avgLen" -> "1", + "spark.sql.statistics.colStats.cbool.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbool.max" -> "true", + "spark.sql.statistics.colStats.cbool.maxLen" -> "1", + "spark.sql.statistics.colStats.cbool.min" -> "false", + "spark.sql.statistics.colStats.cbool.nullCount" -> "1", + "spark.sql.statistics.colStats.cbool.version" -> "1", + "spark.sql.statistics.colStats.cbyte.avgLen" -> "1", + "spark.sql.statistics.colStats.cbyte.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbyte.max" -> "2", + "spark.sql.statistics.colStats.cbyte.maxLen" -> "1", + "spark.sql.statistics.colStats.cbyte.min" -> "1", + "spark.sql.statistics.colStats.cbyte.nullCount" -> "1", + "spark.sql.statistics.colStats.cbyte.version" -> "1", + "spark.sql.statistics.colStats.cdate.avgLen" -> "4", + "spark.sql.statistics.colStats.cdate.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdate.max" -> "2016-05-09", + "spark.sql.statistics.colStats.cdate.maxLen" -> "4", + "spark.sql.statistics.colStats.cdate.min" -> "2016-05-08", + "spark.sql.statistics.colStats.cdate.nullCount" -> "1", + "spark.sql.statistics.colStats.cdate.version" -> "1", + "spark.sql.statistics.colStats.cdecimal.avgLen" -> "16", + "spark.sql.statistics.colStats.cdecimal.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdecimal.max" -> "8.000000000000000000", + "spark.sql.statistics.colStats.cdecimal.maxLen" -> "16", + "spark.sql.statistics.colStats.cdecimal.min" -> "1.000000000000000000", + "spark.sql.statistics.colStats.cdecimal.nullCount" -> "1", + "spark.sql.statistics.colStats.cdecimal.version" -> "1", + "spark.sql.statistics.colStats.cdouble.avgLen" -> "8", + "spark.sql.statistics.colStats.cdouble.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdouble.max" -> "6.0", + "spark.sql.statistics.colStats.cdouble.maxLen" -> "8", + "spark.sql.statistics.colStats.cdouble.min" -> "1.0", + "spark.sql.statistics.colStats.cdouble.nullCount" -> "1", + "spark.sql.statistics.colStats.cdouble.version" -> "1", + "spark.sql.statistics.colStats.cfloat.avgLen" -> "4", + "spark.sql.statistics.colStats.cfloat.distinctCount" -> "2", + "spark.sql.statistics.colStats.cfloat.max" -> "7.0", + "spark.sql.statistics.colStats.cfloat.maxLen" -> "4", + "spark.sql.statistics.colStats.cfloat.min" -> "1.0", + "spark.sql.statistics.colStats.cfloat.nullCount" -> "1", + "spark.sql.statistics.colStats.cfloat.version" -> "1", + "spark.sql.statistics.colStats.cint.avgLen" -> "4", + "spark.sql.statistics.colStats.cint.distinctCount" -> "2", + "spark.sql.statistics.colStats.cint.max" -> "4", + "spark.sql.statistics.colStats.cint.maxLen" -> "4", + "spark.sql.statistics.colStats.cint.min" -> "1", + "spark.sql.statistics.colStats.cint.nullCount" -> "1", + "spark.sql.statistics.colStats.cint.version" -> "1", + "spark.sql.statistics.colStats.clong.avgLen" -> "8", + "spark.sql.statistics.colStats.clong.distinctCount" -> "2", + "spark.sql.statistics.colStats.clong.max" -> "5", + "spark.sql.statistics.colStats.clong.maxLen" -> "8", + "spark.sql.statistics.colStats.clong.min" -> "1", + "spark.sql.statistics.colStats.clong.nullCount" -> "1", + "spark.sql.statistics.colStats.clong.version" -> "1", + "spark.sql.statistics.colStats.cshort.avgLen" -> "2", + "spark.sql.statistics.colStats.cshort.distinctCount" -> "2", + "spark.sql.statistics.colStats.cshort.max" -> "3", + "spark.sql.statistics.colStats.cshort.maxLen" -> "2", + "spark.sql.statistics.colStats.cshort.min" -> "1", + "spark.sql.statistics.colStats.cshort.nullCount" -> "1", + "spark.sql.statistics.colStats.cshort.version" -> "1", + "spark.sql.statistics.colStats.cstring.avgLen" -> "3", + "spark.sql.statistics.colStats.cstring.distinctCount" -> "2", + "spark.sql.statistics.colStats.cstring.maxLen" -> "3", + "spark.sql.statistics.colStats.cstring.nullCount" -> "1", + "spark.sql.statistics.colStats.cstring.version" -> "1", + "spark.sql.statistics.colStats.ctimestamp.avgLen" -> "8", + "spark.sql.statistics.colStats.ctimestamp.distinctCount" -> "2", + "spark.sql.statistics.colStats.ctimestamp.max" -> "2016-05-09 00:00:02.0", + "spark.sql.statistics.colStats.ctimestamp.maxLen" -> "8", + "spark.sql.statistics.colStats.ctimestamp.min" -> "2016-05-08 00:00:01.0", + "spark.sql.statistics.colStats.ctimestamp.nullCount" -> "1", + "spark.sql.statistics.colStats.ctimestamp.version" -> "1" + )) + } + } + private def testUpdatingTableStats(tableDescription: String, createTableCmd: String): Unit = { test("test table-level statistics for " + tableDescription) { val parquetTable = "parquetTable" @@ -319,7 +383,8 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils TableIdentifier(parquetTable)) assert(DDLUtils.isDatasourceTable(catalogTable)) - sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src") + // Add a filter to avoid creating too many partitions + sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src WHERE key < 10") checkTableStats( parquetTable, isDataSourceTable = true, hasSizeInBytes = false, expectedRowCounts = None) @@ -328,7 +393,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils val fetchedStats1 = checkTableStats( parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None) - sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src") + sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src WHERE key < 10") sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS noscan") val fetchedStats2 = checkTableStats( parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None) @@ -340,7 +405,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils parquetTable, isDataSourceTable = true, hasSizeInBytes = true, - expectedRowCounts = Some(1000)) + expectedRowCounts = Some(20)) assert(fetchedStats3.get.sizeInBytes == fetchedStats2.get.sizeInBytes) } } @@ -369,6 +434,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils } } + /** Used to test refreshing cached metadata once table stats are updated. */ private def getStatsBeforeAfterUpdate(isAnalyzeColumns: Boolean): (Statistics, Statistics) = { val tableName = "tbl" var statsBeforeUpdate: Statistics = null @@ -411,145 +477,6 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils assert(statsAfterUpdate.rowCount == Some(2)) } - test("test refreshing column stats of cached data source table by `ANALYZE TABLE` statement") { - val (statsBeforeUpdate, statsAfterUpdate) = getStatsBeforeAfterUpdate(isAnalyzeColumns = true) - - assert(statsBeforeUpdate.sizeInBytes > 0) - assert(statsBeforeUpdate.rowCount == Some(1)) - StatisticsTest.checkColStat( - dataType = IntegerType, - colStat = statsBeforeUpdate.colStats("key"), - expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)), - rsd = spark.sessionState.conf.ndvMaxError) - - assert(statsAfterUpdate.sizeInBytes > statsBeforeUpdate.sizeInBytes) - assert(statsAfterUpdate.rowCount == Some(2)) - StatisticsTest.checkColStat( - dataType = IntegerType, - colStat = statsAfterUpdate.colStats("key"), - expectedColStat = ColumnStat(InternalRow(0L, 2, 1, 2L)), - rsd = spark.sessionState.conf.ndvMaxError) - } - - private lazy val (testDataFrame, expectedColStatsSeq) = { - import testImplicits._ - - val intSeq = Seq(1, 2) - val stringSeq = Seq("a", "bb") - val binarySeq = Seq("a", "bb").map(_.getBytes) - val booleanSeq = Seq(true, false) - val data = intSeq.indices.map { i => - (intSeq(i), stringSeq(i), binarySeq(i), booleanSeq(i)) - } - val df: DataFrame = data.toDF("c1", "c2", "c3", "c4") - val expectedColStatsSeq: Seq[(StructField, ColumnStat)] = df.schema.map { f => - val colStat = f.dataType match { - case IntegerType => - ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong)) - case StringType => - ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble, - stringSeq.map(_.length).max.toInt, stringSeq.distinct.length.toLong)) - case BinaryType => - ColumnStat(InternalRow(0L, binarySeq.map(_.length).sum / binarySeq.length.toDouble, - binarySeq.map(_.length).max.toInt)) - case BooleanType => - ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong, - booleanSeq.count(_.equals(false)).toLong)) - } - (f, colStat) - } - (df, expectedColStatsSeq) - } - - private def checkColStats( - tableName: String, - isDataSourceTable: Boolean, - expectedColStatsSeq: Seq[(StructField, ColumnStat)]): Unit = { - val readback = spark.table(tableName) - val stats = readback.queryExecution.analyzed.collect { - case rel: MetastoreRelation => - assert(!isDataSourceTable, "Expected a Hive serde table, but got a data source table") - rel.catalogTable.stats.get - case rel: LogicalRelation => - assert(isDataSourceTable, "Expected a data source table, but got a Hive serde table") - rel.catalogTable.get.stats.get - } - assert(stats.length == 1) - val columnStats = stats.head.colStats - assert(columnStats.size == expectedColStatsSeq.length) - expectedColStatsSeq.foreach { case (field, expectedColStat) => - StatisticsTest.checkColStat( - dataType = field.dataType, - colStat = columnStats(field.name), - expectedColStat = expectedColStat, - rsd = spark.sessionState.conf.ndvMaxError) - } - } - - test("generate and load column-level stats for data source table") { - val dsTable = "dsTable" - withTable(dsTable) { - testDataFrame.write.format("parquet").saveAsTable(dsTable) - sql(s"ANALYZE TABLE $dsTable COMPUTE STATISTICS FOR COLUMNS c1, c2, c3, c4") - checkColStats(dsTable, isDataSourceTable = true, expectedColStatsSeq) - } - } - - test("generate and load column-level stats for hive serde table") { - val hTable = "hTable" - val tmp = "tmp" - withTable(hTable, tmp) { - testDataFrame.write.format("parquet").saveAsTable(tmp) - sql(s"CREATE TABLE $hTable (c1 int, c2 string, c3 binary, c4 boolean) STORED AS TEXTFILE") - sql(s"INSERT INTO $hTable SELECT * FROM $tmp") - sql(s"ANALYZE TABLE $hTable COMPUTE STATISTICS FOR COLUMNS c1, c2, c3, c4") - checkColStats(hTable, isDataSourceTable = false, expectedColStatsSeq) - } - } - - // When caseSensitive is on, for columns with only case difference, they are different columns - // and we should generate column stats for all of them. - private def checkCaseSensitiveColStats(columnName: String): Unit = { - val tableName = "tbl" - withTable(tableName) { - val column1 = columnName.toLowerCase - val column2 = columnName.toUpperCase - withSQLConf("spark.sql.caseSensitive" -> "true") { - sql(s"CREATE TABLE $tableName (`$column1` int, `$column2` double) USING PARQUET") - sql(s"INSERT INTO $tableName SELECT 1, 3.0") - sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS `$column1`, `$column2`") - val readback = spark.table(tableName) - val relations = readback.queryExecution.analyzed.collect { case rel: LogicalRelation => - val columnStats = rel.catalogTable.get.stats.get.colStats - assert(columnStats.size == 2) - StatisticsTest.checkColStat( - dataType = IntegerType, - colStat = columnStats(column1), - expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)), - rsd = spark.sessionState.conf.ndvMaxError) - StatisticsTest.checkColStat( - dataType = DoubleType, - colStat = columnStats(column2), - expectedColStat = ColumnStat(InternalRow(0L, 3.0d, 3.0d, 1L)), - rsd = spark.sessionState.conf.ndvMaxError) - rel - } - assert(relations.size == 1) - } - } - } - - test("check column statistics for case sensitive column names") { - checkCaseSensitiveColStats(columnName = "c1") - } - - test("check column statistics for case sensitive non-ascii column names") { - // scalastyle:off - // non ascii characters are not allowed in the source code, so we disable the scalastyle. - checkCaseSensitiveColStats(columnName = "列c") - // scalastyle:on - } - test("estimates the size of a test MetastoreRelation") { val df = sql("""SELECT * FROM src""") val sizes = df.queryExecution.analyzed.collect { case mr: MetastoreRelation => From 835f03f344f2dea2134409d09e06b34feaae09f9 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 23 Nov 2016 12:54:18 -0500 Subject: [PATCH 184/534] [SPARK-18050][SQL] do not create default database if it already exists ## What changes were proposed in this pull request? When we try to create the default database, we ask hive to do nothing if it already exists. However, Hive will log an error message instead of doing nothing, and the error message is quite annoying and confusing. In this PR, we only create default database if it doesn't exist. ## How was this patch tested? N/A Author: Wenchen Fan Closes #15993 from cloud-fan/default-db. (cherry picked from commit f129ebcd302168b628f47705f4a7d6b7e7b057b0) Signed-off-by: Andrew Or --- .../scala/org/apache/spark/sql/internal/SharedState.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index 6232c18b1cea8..8de95fe64e663 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -92,8 +92,12 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { { val defaultDbDefinition = CatalogDatabase( SessionCatalog.DEFAULT_DATABASE, "default database", warehousePath, Map()) - // Initialize default database if it doesn't already exist - externalCatalog.createDatabase(defaultDbDefinition, ignoreIfExists = true) + // Initialize default database if it doesn't exist + if (!externalCatalog.databaseExists(SessionCatalog.DEFAULT_DATABASE)) { + // There may be another Spark application creating default database at the same time, here we + // set `ignoreIfExists = true` to avoid `DatabaseAlreadyExists` exception. + externalCatalog.createDatabase(defaultDbDefinition, ignoreIfExists = true) + } } /** From 15d2cf26427084c0398f8d9303c218f360c52bb7 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 23 Nov 2016 11:48:59 -0800 Subject: [PATCH 185/534] [SPARK-18510] Fix data corruption from inferred partition column dataTypes ## What changes were proposed in this pull request? ### The Issue If I specify my schema when doing ```scala spark.read .schema(someSchemaWherePartitionColumnsAreStrings) ``` but if the partition inference can infer it as IntegerType or I assume LongType or DoubleType (basically fixed size types), then once UnsafeRows are generated, your data will be corrupted. ### Proposed solution The partition handling code path is kind of a mess. In my fix I'm probably adding to the mess, but at least trying to standardize the code path. The real issue is that a user that uses the `spark.read` code path can never clearly specify what the partition columns are. If you try to specify the fields in `schema`, we practically ignore what the user provides, and fall back to our inferred data types. What happens in the end is data corruption. My solution tries to fix this by always trying to infer partition columns the first time you specify the table. Once we find what the partition columns are, we try to find them in the user specified schema and use the dataType provided there, or fall back to the smallest common data type. We will ALWAYS append partition columns to the user's schema, even if they didn't ask for it. We will only use the data type they provided if they specified it. While this is confusing, this has been the behavior since Spark 1.6, and I didn't want to change this behavior in the QA period of Spark 2.1. We may revisit this decision later. A side effect of this PR is that we won't need https://github.com/apache/spark/pull/15942 if this PR goes in. ## How was this patch tested? Regression tests Author: Burak Yavuz Closes #15951 from brkyvz/partition-corruption. (cherry picked from commit 0d1bf2b6c8ac4d4141d7cef0552c22e586843c57) Signed-off-by: Tathagata Das --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 2 +- .../execution/datasources/DataSource.scala | 159 ++++++++++++------ .../sql/execution/command/DDLSuite.scala | 2 +- .../sql/streaming/FileStreamSourceSuite.scala | 2 +- .../test/DataStreamReaderWriterSuite.scala | 45 ++++- .../sql/test/DataFrameReaderWriterSuite.scala | 38 ++++- 6 files changed, 190 insertions(+), 58 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index ee48baa59c7af..c669c2e2e26ef 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -2684,7 +2684,7 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume # It makes sure that we can omit path argument in read.df API and then it calls # DataFrameWriter.load() without path. expect_error(read.df(source = "json"), - paste("Error in loadDF : analysis error - Unable to infer schema for JSON at .", + paste("Error in loadDF : analysis error - Unable to infer schema for JSON.", "It must be specified manually")) expect_error(read.df("arbitrary_path"), "Error in loadDF : analysis error - Path does not exist") expect_error(read.json("arbitrary_path"), "Error in json : analysis error - Path does not exist") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 84fde0bbf9268..dbc3e712332f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -61,8 +61,12 @@ import org.apache.spark.util.Utils * qualified. This option only works when reading from a [[FileFormat]]. * @param userSpecifiedSchema An optional specification of the schema of the data. When present * we skip attempting to infer the schema. - * @param partitionColumns A list of column names that the relation is partitioned by. When this - * list is empty, the relation is unpartitioned. + * @param partitionColumns A list of column names that the relation is partitioned by. This list is + * generally empty during the read path, unless this DataSource is managed + * by Hive. In these cases, during `resolveRelation`, we will call + * `getOrInferFileFormatSchema` for file based DataSources to infer the + * partitioning. In other cases, if this list is empty, then this table + * is unpartitioned. * @param bucketSpec An optional specification for bucketing (hash-partitioning) of the data. * @param catalogTable Optional catalog table reference that can be used to push down operations * over the datasource to the catalog service. @@ -84,30 +88,106 @@ case class DataSource( private val caseInsensitiveOptions = new CaseInsensitiveMap(options) /** - * Infer the schema of the given FileFormat, returns a pair of schema and partition column names. + * Get the schema of the given FileFormat, if provided by `userSpecifiedSchema`, or try to infer + * it. In the read path, only managed tables by Hive provide the partition columns properly when + * initializing this class. All other file based data sources will try to infer the partitioning, + * and then cast the inferred types to user specified dataTypes if the partition columns exist + * inside `userSpecifiedSchema`, otherwise we can hit data corruption bugs like SPARK-18510. + * This method will try to skip file scanning whether `userSpecifiedSchema` and + * `partitionColumns` are provided. Here are some code paths that use this method: + * 1. `spark.read` (no schema): Most amount of work. Infer both schema and partitioning columns + * 2. `spark.read.schema(userSpecifiedSchema)`: Parse partitioning columns, cast them to the + * dataTypes provided in `userSpecifiedSchema` if they exist or fallback to inferred + * dataType if they don't. + * 3. `spark.readStream.schema(userSpecifiedSchema)`: For streaming use cases, users have to + * provide the schema. Here, we also perform partition inference like 2, and try to use + * dataTypes in `userSpecifiedSchema`. All subsequent triggers for this stream will re-use + * this information, therefore calls to this method should be very cheap, i.e. there won't + * be any further inference in any triggers. + * 4. `df.saveAsTable(tableThatExisted)`: In this case, we call this method to resolve the + * existing table's partitioning scheme. This is achieved by not providing + * `userSpecifiedSchema`. For this case, we add the boolean `justPartitioning` for an early + * exit, if we don't care about the schema of the original table. + * + * @param format the file format object for this DataSource + * @param justPartitioning Whether to exit early and provide just the schema partitioning. + * @return A pair of the data schema (excluding partition columns) and the schema of the partition + * columns. If `justPartitioning` is `true`, then the dataSchema will be provided as + * `null`. */ - private def inferFileFormatSchema(format: FileFormat): (StructType, Seq[String]) = { - userSpecifiedSchema.map(_ -> partitionColumns).orElse { - val allPaths = caseInsensitiveOptions.get("path") + private def getOrInferFileFormatSchema( + format: FileFormat, + justPartitioning: Boolean = false): (StructType, StructType) = { + // the operations below are expensive therefore try not to do them if we don't need to + lazy val tempFileCatalog = { + val allPaths = caseInsensitiveOptions.get("path") ++ paths + val hadoopConf = sparkSession.sessionState.newHadoopConf() val globbedPaths = allPaths.toSeq.flatMap { path => val hdfsPath = new Path(path) - val fs = hdfsPath.getFileSystem(sparkSession.sessionState.newHadoopConf()) + val fs = hdfsPath.getFileSystem(hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) SparkHadoopUtil.get.globPathIfNecessary(qualified) }.toArray - val fileCatalog = new InMemoryFileIndex(sparkSession, globbedPaths, options, None) - val partitionSchema = fileCatalog.partitionSpec().partitionColumns - val inferred = format.inferSchema( + new InMemoryFileIndex(sparkSession, globbedPaths, options, None) + } + val partitionSchema = if (partitionColumns.isEmpty && catalogTable.isEmpty) { + // Try to infer partitioning, because no DataSource in the read path provides the partitioning + // columns properly unless it is a Hive DataSource + val resolved = tempFileCatalog.partitionSchema.map { partitionField => + val equality = sparkSession.sessionState.conf.resolver + // SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred + userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse( + partitionField) + } + StructType(resolved) + } else { + // in streaming mode, we have already inferred and registered partition columns, we will + // never have to materialize the lazy val below + lazy val inferredPartitions = tempFileCatalog.partitionSchema + // maintain old behavior before SPARK-18510. If userSpecifiedSchema is empty used inferred + // partitioning + if (userSpecifiedSchema.isEmpty) { + inferredPartitions + } else { + val partitionFields = partitionColumns.map { partitionColumn => + userSpecifiedSchema.flatMap(_.find(_.name == partitionColumn)).orElse { + val inferredOpt = inferredPartitions.find(_.name == partitionColumn) + if (inferredOpt.isDefined) { + logDebug( + s"""Type of partition column: $partitionColumn not found in specified schema + |for $format. + |User Specified Schema + |===================== + |${userSpecifiedSchema.orNull} + | + |Falling back to inferred dataType if it exists. + """.stripMargin) + } + inferredPartitions.find(_.name == partitionColumn) + }.getOrElse { + throw new AnalysisException(s"Failed to resolve the schema for $format for " + + s"the partition column: $partitionColumn. It must be specified manually.") + } + } + StructType(partitionFields) + } + } + if (justPartitioning) { + return (null, partitionSchema) + } + val dataSchema = userSpecifiedSchema.map { schema => + val equality = sparkSession.sessionState.conf.resolver + StructType(schema.filterNot(f => partitionSchema.exists(p => equality(p.name, f.name)))) + }.orElse { + format.inferSchema( sparkSession, caseInsensitiveOptions, - fileCatalog.allFiles()) - - inferred.map { inferredSchema => - StructType(inferredSchema ++ partitionSchema) -> partitionSchema.map(_.name) - } + tempFileCatalog.allFiles()) }.getOrElse { - throw new AnalysisException("Unable to infer schema. It must be specified manually.") + throw new AnalysisException( + s"Unable to infer schema for $format. It must be specified manually.") } + (dataSchema, partitionSchema) } /** Returns the name and schema of the source that can be used to continually read data. */ @@ -144,8 +224,8 @@ case class DataSource( "you may be able to create a static DataFrame on that directory with " + "'spark.read.load(directory)' and infer schema from it.") } - val (schema, partCols) = inferFileFormatSchema(format) - SourceInfo(s"FileSource[$path]", schema, partCols) + val (schema, partCols) = getOrInferFileFormatSchema(format) + SourceInfo(s"FileSource[$path]", StructType(schema ++ partCols), partCols.fieldNames) case _ => throw new UnsupportedOperationException( @@ -272,7 +352,7 @@ case class DataSource( HadoopFsRelation( fileCatalog, - partitionSchema = fileCatalog.partitionSpec().partitionColumns, + partitionSchema = fileCatalog.partitionSchema, dataSchema = dataSchema, bucketSpec = None, format, @@ -281,9 +361,10 @@ case class DataSource( // This is a non-streaming file based datasource. case (format: FileFormat, _) => val allPaths = caseInsensitiveOptions.get("path") ++ paths + val hadoopConf = sparkSession.sessionState.newHadoopConf() val globbedPaths = allPaths.flatMap { path => val hdfsPath = new Path(path) - val fs = hdfsPath.getFileSystem(sparkSession.sessionState.newHadoopConf()) + val fs = hdfsPath.getFileSystem(hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) val globPath = SparkHadoopUtil.get.globPathIfNecessary(qualified) @@ -291,23 +372,14 @@ case class DataSource( throw new AnalysisException(s"Path does not exist: $qualified") } // Sufficient to check head of the globPath seq for non-glob scenario + // Don't need to check once again if files exist in streaming mode if (checkFilesExist && !fs.exists(globPath.head)) { throw new AnalysisException(s"Path does not exist: ${globPath.head}") } globPath }.toArray - // If they gave a schema, then we try and figure out the types of the partition columns - // from that schema. - val partitionSchema = userSpecifiedSchema.map { schema => - StructType( - partitionColumns.map { c => - // TODO: Case sensitivity. - schema - .find(_.name.toLowerCase() == c.toLowerCase()) - .getOrElse(throw new AnalysisException(s"Invalid partition column '$c'")) - }) - } + val (dataSchema, inferredPartitionSchema) = getOrInferFileFormatSchema(format) val fileCatalog = if (sparkSession.sqlContext.conf.manageFilesourcePartitions && catalogTable.isDefined && catalogTable.get.tracksPartitionsInCatalog) { @@ -316,27 +388,12 @@ case class DataSource( catalogTable.get, catalogTable.get.stats.map(_.sizeInBytes.toLong).getOrElse(0L)) } else { - new InMemoryFileIndex( - sparkSession, globbedPaths, options, partitionSchema) - } - - val dataSchema = userSpecifiedSchema.map { schema => - val equality = sparkSession.sessionState.conf.resolver - StructType(schema.filterNot(f => partitionColumns.exists(equality(_, f.name)))) - }.orElse { - format.inferSchema( - sparkSession, - caseInsensitiveOptions, - fileCatalog.asInstanceOf[InMemoryFileIndex].allFiles()) - }.getOrElse { - throw new AnalysisException( - s"Unable to infer schema for $format at ${allPaths.take(2).mkString(",")}. " + - "It must be specified manually") + new InMemoryFileIndex(sparkSession, globbedPaths, options, Some(inferredPartitionSchema)) } HadoopFsRelation( fileCatalog, - partitionSchema = fileCatalog.partitionSchema, + partitionSchema = inferredPartitionSchema, dataSchema = dataSchema.asNullable, bucketSpec = bucketSpec, format, @@ -384,11 +441,7 @@ case class DataSource( // up. If we fail to load the table for whatever reason, ignore the check. if (mode == SaveMode.Append) { val existingPartitionColumns = Try { - resolveRelation() - .asInstanceOf[HadoopFsRelation] - .partitionSchema - .fieldNames - .toSeq + getOrInferFileFormatSchema(format, justPartitioning = true)._2.fieldNames.toList }.getOrElse(Seq.empty[String]) // TODO: Case sensitivity. val sameColumns = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 02d9d15684904..10843e9ba5753 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -274,7 +274,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { pathToPartitionedTable, userSpecifiedSchema = Option("num int, str string"), userSpecifiedPartitionCols = partitionCols, - expectedSchema = new StructType().add("num", IntegerType).add("str", StringType), + expectedSchema = new StructType().add("str", StringType).add("num", IntegerType), expectedPartitionCols = partitionCols.map(Seq(_)).getOrElse(Seq.empty[String])) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index a099153d2e58e..bad6642ea4058 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -282,7 +282,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { createFileStreamSourceAndGetSchema( format = Some("json"), path = Some(src.getCanonicalPath), schema = None) } - assert("Unable to infer schema. It must be specified manually.;" === e.getMessage) + assert("Unable to infer schema for JSON. It must be specified manually.;" === e.getMessage) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index 5630464f40803..0eb95a02432fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, StreamingQuery, StreamTest} -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils object LastOptions { @@ -532,4 +532,47 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { assert(e.getMessage.contains("does not support recovering")) assert(e.getMessage.contains("checkpoint location")) } + + test("SPARK-18510: use user specified types for partition columns in file sources") { + import org.apache.spark.sql.functions.udf + import testImplicits._ + withTempDir { src => + val createArray = udf { (length: Long) => + for (i <- 1 to length.toInt) yield i.toString + } + spark.range(4).select(createArray('id + 1) as 'ex, 'id, 'id % 4 as 'part).coalesce(1).write + .partitionBy("part", "id") + .mode("overwrite") + .parquet(src.toString) + // Specify a random ordering of the schema, partition column in the middle, etc. + // Also let's say that the partition columns are Strings instead of Longs. + // partition columns should go to the end + val schema = new StructType() + .add("id", StringType) + .add("ex", ArrayType(StringType)) + + val sdf = spark.readStream + .schema(schema) + .format("parquet") + .load(src.toString) + + assert(sdf.schema.toList === List( + StructField("ex", ArrayType(StringType)), + StructField("part", IntegerType), // inferred partitionColumn dataType + StructField("id", StringType))) // used user provided partitionColumn dataType + + val sq = sdf.writeStream + .queryName("corruption_test") + .format("memory") + .start() + sq.processAllAvailable() + checkAnswer( + spark.table("corruption_test"), + // notice how `part` is ordered before `id` + Row(Array("1"), 0, "0") :: Row(Array("1", "2"), 1, "1") :: + Row(Array("1", "2", "3"), 2, "2") :: Row(Array("1", "2", "3", "4"), 3, "3") :: Nil + ) + sq.stop() + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index a7fda01098560..e0887e0f1c7de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -573,4 +573,40 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be } } } + + test("SPARK-18510: use user specified types for partition columns in file sources") { + import org.apache.spark.sql.functions.udf + import testImplicits._ + withTempDir { src => + val createArray = udf { (length: Long) => + for (i <- 1 to length.toInt) yield i.toString + } + spark.range(4).select(createArray('id + 1) as 'ex, 'id, 'id % 4 as 'part).coalesce(1).write + .partitionBy("part", "id") + .mode("overwrite") + .parquet(src.toString) + // Specify a random ordering of the schema, partition column in the middle, etc. + // Also let's say that the partition columns are Strings instead of Longs. + // partition columns should go to the end + val schema = new StructType() + .add("id", StringType) + .add("ex", ArrayType(StringType)) + val df = spark.read + .schema(schema) + .format("parquet") + .load(src.toString) + + assert(df.schema.toList === List( + StructField("ex", ArrayType(StringType)), + StructField("part", IntegerType), // inferred partitionColumn dataType + StructField("id", StringType))) // used user provided partitionColumn dataType + + checkAnswer( + df, + // notice how `part` is ordered before `id` + Row(Array("1"), 0, "0") :: Row(Array("1", "2"), 1, "1") :: + Row(Array("1", "2", "3"), 2, "2") :: Row(Array("1", "2", "3", "4"), 3, "3") :: Nil + ) + } + } } From 27d81d0007f4358480148fa6f3f6b079a5431a81 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 23 Nov 2016 16:15:35 -0800 Subject: [PATCH 186/534] [SPARK-18510][SQL] Follow up to address comments in #15951 ## What changes were proposed in this pull request? This PR addressed the rest comments in #15951. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #15997 from zsxwing/SPARK-18510-follow-up. (cherry picked from commit 223fa218e1f637f0d62332785a3bee225b65b990) Signed-off-by: Tathagata Das --- .../execution/datasources/DataSource.scala | 35 +++++++++++-------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index dbc3e712332f7..ccfc759c8fa7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -118,8 +118,10 @@ case class DataSource( private def getOrInferFileFormatSchema( format: FileFormat, justPartitioning: Boolean = false): (StructType, StructType) = { - // the operations below are expensive therefore try not to do them if we don't need to - lazy val tempFileCatalog = { + // the operations below are expensive therefore try not to do them if we don't need to, e.g., + // in streaming mode, we have already inferred and registered partition columns, we will + // never have to materialize the lazy val below + lazy val tempFileIndex = { val allPaths = caseInsensitiveOptions.get("path") ++ paths val hadoopConf = sparkSession.sessionState.newHadoopConf() val globbedPaths = allPaths.toSeq.flatMap { path => @@ -133,7 +135,7 @@ case class DataSource( val partitionSchema = if (partitionColumns.isEmpty && catalogTable.isEmpty) { // Try to infer partitioning, because no DataSource in the read path provides the partitioning // columns properly unless it is a Hive DataSource - val resolved = tempFileCatalog.partitionSchema.map { partitionField => + val resolved = tempFileIndex.partitionSchema.map { partitionField => val equality = sparkSession.sessionState.conf.resolver // SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse( @@ -141,17 +143,17 @@ case class DataSource( } StructType(resolved) } else { - // in streaming mode, we have already inferred and registered partition columns, we will - // never have to materialize the lazy val below - lazy val inferredPartitions = tempFileCatalog.partitionSchema // maintain old behavior before SPARK-18510. If userSpecifiedSchema is empty used inferred // partitioning if (userSpecifiedSchema.isEmpty) { + val inferredPartitions = tempFileIndex.partitionSchema inferredPartitions } else { val partitionFields = partitionColumns.map { partitionColumn => - userSpecifiedSchema.flatMap(_.find(_.name == partitionColumn)).orElse { - val inferredOpt = inferredPartitions.find(_.name == partitionColumn) + val equality = sparkSession.sessionState.conf.resolver + userSpecifiedSchema.flatMap(_.find(c => equality(c.name, partitionColumn))).orElse { + val inferredPartitions = tempFileIndex.partitionSchema + val inferredOpt = inferredPartitions.find(p => equality(p.name, partitionColumn)) if (inferredOpt.isDefined) { logDebug( s"""Type of partition column: $partitionColumn not found in specified schema @@ -163,7 +165,7 @@ case class DataSource( |Falling back to inferred dataType if it exists. """.stripMargin) } - inferredPartitions.find(_.name == partitionColumn) + inferredOpt }.getOrElse { throw new AnalysisException(s"Failed to resolve the schema for $format for " + s"the partition column: $partitionColumn. It must be specified manually.") @@ -182,7 +184,7 @@ case class DataSource( format.inferSchema( sparkSession, caseInsensitiveOptions, - tempFileCatalog.allFiles()) + tempFileIndex.allFiles()) }.getOrElse { throw new AnalysisException( s"Unable to infer schema for $format. It must be specified manually.") @@ -224,8 +226,11 @@ case class DataSource( "you may be able to create a static DataFrame on that directory with " + "'spark.read.load(directory)' and infer schema from it.") } - val (schema, partCols) = getOrInferFileFormatSchema(format) - SourceInfo(s"FileSource[$path]", StructType(schema ++ partCols), partCols.fieldNames) + val (dataSchema, partitionSchema) = getOrInferFileFormatSchema(format) + SourceInfo( + s"FileSource[$path]", + StructType(dataSchema ++ partitionSchema), + partitionSchema.fieldNames) case _ => throw new UnsupportedOperationException( @@ -379,7 +384,7 @@ case class DataSource( globPath }.toArray - val (dataSchema, inferredPartitionSchema) = getOrInferFileFormatSchema(format) + val (dataSchema, partitionSchema) = getOrInferFileFormatSchema(format) val fileCatalog = if (sparkSession.sqlContext.conf.manageFilesourcePartitions && catalogTable.isDefined && catalogTable.get.tracksPartitionsInCatalog) { @@ -388,12 +393,12 @@ case class DataSource( catalogTable.get, catalogTable.get.stats.map(_.sizeInBytes.toLong).getOrElse(0L)) } else { - new InMemoryFileIndex(sparkSession, globbedPaths, options, Some(inferredPartitionSchema)) + new InMemoryFileIndex(sparkSession, globbedPaths, options, Some(partitionSchema)) } HadoopFsRelation( fileCatalog, - partitionSchema = inferredPartitionSchema, + partitionSchema = partitionSchema, dataSchema = dataSchema.asNullable, bucketSpec = bucketSpec, format, From 04ec74f1274a164b2f72b31e2c147e042bf41bd9 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Thu, 24 Nov 2016 05:46:05 -0800 Subject: [PATCH 187/534] [SPARK-18520][ML] Add missing setXXXCol methods for BisectingKMeansModel and GaussianMixtureModel ## What changes were proposed in this pull request? add `setFeaturesCol` and `setPredictionCol` for BiKModel and GMModel add `setProbabilityCol` for GMModel ## How was this patch tested? existing tests Author: Zheng RuiFeng Closes #15957 from zhengruifeng/bikm_set. (cherry picked from commit 2dfabec38c24174e7f747c27c7144f7738483ec1) Signed-off-by: Yanbo Liang --- .../apache/spark/ml/clustering/BisectingKMeans.scala | 8 ++++++++ .../apache/spark/ml/clustering/GaussianMixture.scala | 12 ++++++++++++ 2 files changed, 20 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index e6ca3aedffd9d..cf11ba37abb58 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -98,6 +98,14 @@ class BisectingKMeansModel private[ml] ( copied.setSummary(trainingSummary).setParent(this.parent) } + /** @group setParam */ + @Since("2.1.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("2.1.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 92d0b7d085f12..19998ca44b115 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -87,6 +87,18 @@ class GaussianMixtureModel private[ml] ( @Since("2.0.0") val gaussians: Array[MultivariateGaussian]) extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable { + /** @group setParam */ + @Since("2.1.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("2.1.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + @Since("2.1.0") + def setProbabilityCol(value: String): this.type = set(probabilityCol, value) + @Since("2.0.0") override def copy(extra: ParamMap): GaussianMixtureModel = { val copied = copyValues(new GaussianMixtureModel(uid, weights, gaussians), extra) From a7f414561325a7140557562d45fecc5ccbc8d7ff Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Thu, 24 Nov 2016 12:07:55 -0800 Subject: [PATCH 188/534] [SPARK-18578][SQL] Full outer join in correlated subquery returns incorrect results ## What changes were proposed in this pull request? - Raise Analysis exception when correlated predicates exist in the descendant operators of either operand of a Full outer join in a subquery as well as in a FOJ operator itself - Raise Analysis exception when correlated predicates exists in a Window operator (a side effect inadvertently introduced by SPARK-17348) ## How was this patch tested? Run sql/test catalyst/test and new test cases, added to SubquerySuite, showing the reported incorrect results. Author: Nattavut Sutyanyong Closes #16005 from nsyca/FOJ-incorrect.1. (cherry picked from commit a367d5ff005884322fb8bb43a1cfa4d4bf54b31a) Signed-off-by: Herman van Hovell --- .../sql/catalyst/analysis/Analyzer.scala | 10 +++++ .../org/apache/spark/sql/SubquerySuite.scala | 45 +++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2918e9d158829..2d272762b384f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1017,6 +1017,10 @@ class Analyzer( // Simplify the predicates before pulling them out. val transformed = BooleanSimplification(sub) transformUp { + // WARNING: + // Only Filter can host correlated expressions at this time + // Anyone adding a new "case" below needs to add the call to + // "failOnOuterReference" to disallow correlated expressions in it. case f @ Filter(cond, child) => // Find all predicates with an outer reference. val (correlated, local) = splitConjunctivePredicates(cond).partition(containsOuter) @@ -1057,12 +1061,18 @@ class Analyzer( a } case w : Window => + failOnOuterReference(w) failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, w) w case j @ Join(left, _, RightOuter, _) => failOnOuterReference(j) failOnOuterReferenceInSubTree(left, "a RIGHT OUTER JOIN") j + // SPARK-18578: Do not allow any correlated predicate + // in a Full (Outer) Join operator and its descendants + case j @ Join(_, _, FullOuter, _) => + failOnOuterReferenceInSubTree(j, "a FULL OUTER JOIN") + j case j @ Join(_, right, jt, _) if !jt.isInstanceOf[InnerLike] => failOnOuterReference(j) failOnOuterReferenceInSubTree(right, "a LEFT (OUTER) JOIN") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index f1dd1c620e660..73a53944964fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -744,4 +744,49 @@ class SubquerySuite extends QueryTest with SharedSQLContext { } } } + // This restriction applies to + // the permutation of { LOJ, ROJ, FOJ } x { EXISTS, IN, scalar subquery } + // where correlated predicates appears in right operand of LOJ, + // or in left operand of ROJ, or in either operand of FOJ. + // The test cases below cover the representatives of the patterns + test("Correlated subqueries in outer joins") { + withTempView("t1", "t2", "t3") { + Seq(1).toDF("c1").createOrReplaceTempView("t1") + Seq(2).toDF("c1").createOrReplaceTempView("t2") + Seq(1).toDF("c1").createOrReplaceTempView("t3") + + // Left outer join (LOJ) in IN subquery context + intercept[AnalysisException] { + sql( + """ + | select t1.c1 + | from t1 + | where 1 IN (select 1 + | from t3 left outer join + | (select c1 from t2 where t1.c1 = 2) t2 + | on t2.c1 = t3.c1)""".stripMargin).collect() + } + // Right outer join (ROJ) in EXISTS subquery context + intercept[AnalysisException] { + sql( + """ + | select t1.c1 + | from t1 + | where exists (select 1 + | from (select c1 from t2 where t1.c1 = 2) t2 + | right outer join t3 + | on t2.c1 = t3.c1)""".stripMargin).collect() + } + // SPARK-18578: Full outer join (FOJ) in scalar subquery context + intercept[AnalysisException] { + sql( + """ + | select (select max(1) + | from (select c1 from t2 where t1.c1 = 2 and t1.c1=t2.c1) t2 + | full join t3 + | on t2.c1=t3.c1) + | from t1""".stripMargin).collect() + } + } + } } From 57dbc682dfafc87076dcaafd29c637cb16ace91a Mon Sep 17 00:00:00 2001 From: uncleGen Date: Fri, 25 Nov 2016 09:10:17 +0000 Subject: [PATCH 189/534] [SPARK-18575][WEB] Keep same style: adjust the position of driver log links ## What changes were proposed in this pull request? NOT BUG, just adjust the position of driver log link to keep the same style with other executors log link. ![image](https://cloud.githubusercontent.com/assets/7402327/20590092/f8bddbb8-b25b-11e6-9aaf-3b5b3073df10.png) ## How was this patch tested? no Author: uncleGen Closes #16001 from uncleGen/SPARK-18575. (cherry picked from commit f58a8aa20106ea36386db79a8a66f529a8da75c9) Signed-off-by: Sean Owen --- .../spark/scheduler/cluster/YarnClusterSchedulerBackend.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index ced597bed36d9..4f3d5ebf403e0 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -55,8 +55,8 @@ private[spark] class YarnClusterSchedulerBackend( val baseUrl = s"$httpScheme$httpAddress/node/containerlogs/$containerId/$user" logDebug(s"Base URL for logs: $baseUrl") driverLogs = Some(Map( - "stderr" -> s"$baseUrl/stderr?start=-4096", - "stdout" -> s"$baseUrl/stdout?start=-4096")) + "stdout" -> s"$baseUrl/stdout?start=-4096", + "stderr" -> s"$baseUrl/stderr?start=-4096")) } catch { case e: Exception => logInfo("Error while building AM log links, so AM" + From a49dfa93e160d63e806f35cb6b6953367916f44b Mon Sep 17 00:00:00 2001 From: "n.fraison" Date: Fri, 25 Nov 2016 09:45:51 +0000 Subject: [PATCH 190/534] [SPARK-18119][SPARK-CORE] Namenode safemode check is only performed on one namenode which can stuck the startup of SparkHistory server ## What changes were proposed in this pull request? Instead of using the setSafeMode method that check the first namenode used the one which permitts to check only for active NNs ## How was this patch tested? manual tests Please review https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark before opening a pull request. This commit is contributed by Criteo SA under the Apache v2 licence. Author: n.fraison Closes #15648 from ashangit/SPARK-18119. (cherry picked from commit f42db0c0c1434bfcccaa70d0db55e16c4396af04) Signed-off-by: Sean Owen --- .../org/apache/spark/deploy/history/FsHistoryProvider.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index ca38a47639422..8ef69b142cd15 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -663,9 +663,9 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) false } - // For testing. private[history] def isFsInSafeMode(dfs: DistributedFileSystem): Boolean = { - dfs.setSafeMode(HdfsConstants.SafeModeAction.SAFEMODE_GET) + /* true to check only for Active NNs status */ + dfs.setSafeMode(HdfsConstants.SafeModeAction.SAFEMODE_GET, true) } /** From 69856f28361022812d2af83128d8591694bcef4b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 25 Nov 2016 11:27:07 +0000 Subject: [PATCH 191/534] [SPARK-3359][BUILD][DOCS] More changes to resolve javadoc 8 errors that will help unidoc/genjavadoc compatibility ## What changes were proposed in this pull request? This PR only tries to fix things that looks pretty straightforward and were fixed in other previous PRs before. This PR roughly fixes several things as below: - Fix unrecognisable class and method links in javadoc by changing it from `[[..]]` to `` `...` `` ``` [error] .../spark/sql/core/target/java/org/apache/spark/sql/streaming/DataStreamReader.java:226: error: reference not found [error] * Loads text files and returns a {link DataFrame} whose schema starts with a string column named ``` - Fix an exception annotation and remove code backticks in `throws` annotation Currently, sbt unidoc with Java 8 complains as below: ``` [error] .../java/org/apache/spark/sql/streaming/StreamingQuery.java:72: error: unexpected text [error] * throws StreamingQueryException, if this query has terminated with an exception. ``` `throws` should specify the correct class name from `StreamingQueryException,` to `StreamingQueryException` without backticks. (see [JDK-8007644](https://bugs.openjdk.java.net/browse/JDK-8007644)). - Fix `[[http..]]` to ``. ```diff - * [[https://blogs.oracle.com/java-platform-group/entry/diagnosing_tls_ssl_and_https Oracle - * blog page]]. + * + * Oracle blog page. ``` `[[http...]]` link markdown in scaladoc is unrecognisable in javadoc. - It seems class can't have `return` annotation. So, two cases of this were removed. ``` [error] .../java/org/apache/spark/mllib/regression/IsotonicRegression.java:27: error: invalid use of return [error] * return New instance of IsotonicRegression. ``` - Fix < to `<` and > to `>` according to HTML rules. - Fix `

    ` complaint - Exclude unrecognisable in javadoc, `constructor`, `todo` and `groupname`. ## How was this patch tested? Manually tested by `jekyll build` with Java 7 and 8 ``` java version "1.7.0_80" Java(TM) SE Runtime Environment (build 1.7.0_80-b15) Java HotSpot(TM) 64-Bit Server VM (build 24.80-b11, mixed mode) ``` ``` java version "1.8.0_45" Java(TM) SE Runtime Environment (build 1.8.0_45-b14) Java HotSpot(TM) 64-Bit Server VM (build 25.45-b02, mixed mode) ``` Note: this does not yet make sbt unidoc suceed with Java 8 yet but it reduces the number of errors with Java 8. Author: hyukjinkwon Closes #15999 from HyukjinKwon/SPARK-3359-errors. (cherry picked from commit 51b1c1551d3a7147403b9e821fcc7c8f57b4824c) Signed-off-by: Sean Owen --- .../scala/org/apache/spark/SSLOptions.scala | 4 +- .../apache/spark/api/java/JavaPairRDD.scala | 6 +- .../org/apache/spark/api/java/JavaRDD.scala | 10 +-- .../spark/api/java/JavaSparkContext.scala | 14 ++-- .../apache/spark/io/CompressionCodec.scala | 2 +- .../main/scala/org/apache/spark/rdd/RDD.scala | 18 ++--- .../spark/security/CryptoStreamUtils.scala | 4 +- .../spark/serializer/KryoSerializer.scala | 3 +- .../storage/BlockReplicationPolicy.scala | 7 +- .../scala/org/apache/spark/ui/UIUtils.scala | 4 +- .../org/apache/spark/util/AccumulatorV2.scala | 2 +- .../org/apache/spark/util/RpcUtils.scala | 2 +- .../org/apache/spark/util/StatCounter.scala | 4 +- .../org/apache/spark/util/ThreadUtils.scala | 6 +- .../scala/org/apache/spark/util/Utils.scala | 10 +-- .../spark/util/io/ChunkedByteBuffer.scala | 2 +- .../scala/org/apache/spark/graphx/Graph.scala | 4 +- .../org/apache/spark/graphx/GraphLoader.scala | 2 +- .../spark/graphx/impl/EdgeRDDImpl.scala | 2 +- .../apache/spark/graphx/lib/PageRank.scala | 4 +- .../apache/spark/graphx/lib/SVDPlusPlus.scala | 3 +- .../spark/graphx/lib/TriangleCount.scala | 2 +- .../distribution/MultivariateGaussian.scala | 3 +- .../scala/org/apache/spark/ml/Predictor.scala | 2 +- .../spark/ml/attribute/AttributeGroup.scala | 2 +- .../spark/ml/attribute/attributes.scala | 4 +- .../classification/LogisticRegression.scala | 74 +++++++++---------- .../MultilayerPerceptronClassifier.scala | 1 - .../spark/ml/classification/NaiveBayes.scala | 8 +- .../RandomForestClassifier.scala | 6 +- .../spark/ml/clustering/BisectingKMeans.scala | 14 ++-- .../ml/clustering/ClusteringSummary.scala | 2 +- .../spark/ml/clustering/GaussianMixture.scala | 6 +- .../apache/spark/ml/clustering/KMeans.scala | 8 +- .../org/apache/spark/ml/clustering/LDA.scala | 42 +++++------ .../org/apache/spark/ml/feature/DCT.scala | 3 +- .../org/apache/spark/ml/feature/MinHash.scala | 5 +- .../spark/ml/feature/MinMaxScaler.scala | 4 +- .../ml/feature/PolynomialExpansion.scala | 14 ++-- .../spark/ml/feature/RandomProjection.scala | 4 +- .../spark/ml/feature/StandardScaler.scala | 4 +- .../spark/ml/feature/StopWordsRemover.scala | 5 +- .../org/apache/spark/ml/feature/package.scala | 3 +- .../IterativelyReweightedLeastSquares.scala | 7 +- .../spark/ml/param/shared/sharedParams.scala | 12 +-- .../ml/regression/AFTSurvivalRegression.scala | 27 +++---- .../ml/regression/DecisionTreeRegressor.scala | 4 +- .../spark/ml/regression/GBTRegressor.scala | 4 +- .../GeneralizedLinearRegression.scala | 12 +-- .../ml/regression/LinearRegression.scala | 38 +++++----- .../ml/regression/RandomForestRegressor.scala | 5 +- .../ml/source/libsvm/LibSVMDataSource.scala | 13 ++-- .../ml/tree/impl/GradientBoostedTrees.scala | 10 +-- .../spark/ml/tree/impl/RandomForest.scala | 2 +- .../org/apache/spark/ml/tree/treeParams.scala | 6 +- .../spark/ml/tuning/CrossValidator.scala | 4 +- .../org/apache/spark/ml/util/ReadWrite.scala | 10 +-- .../mllib/classification/NaiveBayes.scala | 28 +++---- .../mllib/clustering/BisectingKMeans.scala | 21 +++--- .../clustering/BisectingKMeansModel.scala | 4 +- .../mllib/clustering/GaussianMixture.scala | 6 +- .../clustering/GaussianMixtureModel.scala | 2 +- .../apache/spark/mllib/clustering/LDA.scala | 24 +++--- .../spark/mllib/clustering/LDAModel.scala | 2 +- .../spark/mllib/clustering/LDAOptimizer.scala | 2 +- .../clustering/PowerIterationClustering.scala | 13 ++-- .../mllib/clustering/StreamingKMeans.scala | 4 +- .../mllib/evaluation/RegressionMetrics.scala | 10 ++- .../org/apache/spark/mllib/fpm/FPGrowth.scala | 12 +-- .../apache/spark/mllib/fpm/PrefixSpan.scala | 7 +- .../linalg/distributed/BlockMatrix.scala | 20 ++--- .../linalg/distributed/CoordinateMatrix.scala | 4 +- .../linalg/distributed/IndexedRowMatrix.scala | 4 +- .../mllib/linalg/distributed/RowMatrix.scala | 2 +- .../spark/mllib/optimization/Gradient.scala | 24 +++--- .../mllib/optimization/GradientDescent.scala | 4 +- .../spark/mllib/optimization/LBFGS.scala | 7 +- .../spark/mllib/optimization/NNLS.scala | 2 +- .../spark/mllib/optimization/Updater.scala | 6 +- .../org/apache/spark/mllib/package.scala | 4 +- .../apache/spark/mllib/rdd/RDDFunctions.scala | 2 +- .../spark/mllib/recommendation/ALS.scala | 7 +- .../MatrixFactorizationModel.scala | 6 +- .../mllib/regression/IsotonicRegression.scala | 9 +-- .../stat/MultivariateOnlineSummarizer.scala | 7 +- .../apache/spark/mllib/stat/Statistics.scala | 11 +-- .../distribution/MultivariateGaussian.scala | 3 +- .../mllib/tree/GradientBoostedTrees.scala | 2 +- .../spark/mllib/tree/RandomForest.scala | 8 +- .../apache/spark/mllib/tree/model/Split.scala | 2 +- .../org/apache/spark/mllib/util/MLUtils.scala | 10 +-- .../spark/mllib/util/modelSaveLoad.scala | 2 +- pom.xml | 12 +++ project/SparkBuild.scala | 5 +- .../main/scala/org/apache/spark/sql/Row.scala | 2 +- .../aggregate/CentralMomentAgg.scala | 4 +- .../apache/spark/sql/types/BinaryType.scala | 2 +- .../apache/spark/sql/types/BooleanType.scala | 2 +- .../org/apache/spark/sql/types/ByteType.scala | 2 +- .../sql/types/CalendarIntervalType.scala | 2 +- .../org/apache/spark/sql/types/DateType.scala | 2 +- .../apache/spark/sql/types/DecimalType.scala | 4 +- .../apache/spark/sql/types/DoubleType.scala | 2 +- .../apache/spark/sql/types/FloatType.scala | 2 +- .../apache/spark/sql/types/IntegerType.scala | 2 +- .../org/apache/spark/sql/types/LongType.scala | 2 +- .../org/apache/spark/sql/types/MapType.scala | 2 +- .../org/apache/spark/sql/types/NullType.scala | 2 +- .../apache/spark/sql/types/ShortType.scala | 2 +- .../apache/spark/sql/types/StringType.scala | 2 +- .../spark/sql/types/TimestampType.scala | 2 +- .../apache/spark/sql/DataFrameReader.scala | 17 +++-- .../spark/sql/DataFrameStatFunctions.scala | 16 ++-- .../apache/spark/sql/DataFrameWriter.scala | 4 +- .../org/apache/spark/sql/SQLContext.scala | 62 ++++++++-------- .../sql/execution/stat/FrequentItems.scala | 3 +- .../sql/execution/stat/StatFunctions.scala | 4 +- .../spark/sql/expressions/Aggregator.scala | 8 +- .../sql/expressions/UserDefinedFunction.scala | 2 +- .../apache/spark/sql/expressions/Window.scala | 16 ++-- .../spark/sql/expressions/WindowSpec.scala | 16 ++-- .../sql/expressions/scalalang/typed.scala | 2 +- .../apache/spark/sql/expressions/udaf.scala | 24 +++--- .../apache/spark/sql/jdbc/JdbcDialects.scala | 6 +- .../sql/streaming/DataStreamReader.scala | 20 ++--- .../sql/streaming/DataStreamWriter.scala | 8 +- .../spark/sql/streaming/StreamingQuery.scala | 10 ++- .../sql/streaming/StreamingQueryManager.scala | 8 +- .../sql/util/QueryExecutionListener.scala | 2 +- .../hive/execution/InsertIntoHiveTable.scala | 4 +- .../spark/sql/hive/orc/OrcFileFormat.scala | 4 +- .../spark/sql/hive/orc/OrcFileOperator.scala | 2 +- 132 files changed, 558 insertions(+), 499 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index be19179b00a49..5f14102c3c366 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -150,8 +150,8 @@ private[spark] object SSLOptions extends Logging { * $ - `[ns].enabledAlgorithms` - a comma separated list of ciphers * * For a list of protocols and ciphers supported by particular Java versions, you may go to - * [[https://blogs.oracle.com/java-platform-group/entry/diagnosing_tls_ssl_and_https Oracle - * blog page]]. + * + * Oracle blog page. * * You can optionally specify the default configuration. If you do, for each setting which is * missing in SparkConf, the corresponding setting is used from the default configuration. diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index bff5a29bb60f1..d7e3a1b1be48c 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -405,7 +405,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * partitioning of the resulting key-value pair RDD by passing a Partitioner. * * @note If you are grouping in order to perform an aggregation (such as a sum or average) over - * each key, using [[JavaPairRDD.reduceByKey]] or [[JavaPairRDD.combineByKey]] + * each key, using `JavaPairRDD.reduceByKey` or `JavaPairRDD.combineByKey` * will provide much better performance. */ def groupByKey(partitioner: Partitioner): JavaPairRDD[K, JIterable[V]] = @@ -416,7 +416,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * resulting RDD with into `numPartitions` partitions. * * @note If you are grouping in order to perform an aggregation (such as a sum or average) over - * each key, using [[JavaPairRDD.reduceByKey]] or [[JavaPairRDD.combineByKey]] + * each key, using `JavaPairRDD.reduceByKey` or `JavaPairRDD.combineByKey` * will provide much better performance. */ def groupByKey(numPartitions: Int): JavaPairRDD[K, JIterable[V]] = @@ -546,7 +546,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * resulting RDD with the existing partitioner/parallelism level. * * @note If you are grouping in order to perform an aggregation (such as a sum or average) over - * each key, using [[JavaPairRDD.reduceByKey]] or [[JavaPairRDD.combineByKey]] + * each key, using `JavaPairRDD.reduceByKey` or `JavaPairRDD.combineByKey` * will provide much better performance. */ def groupByKey(): JavaPairRDD[K, JIterable[V]] = diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index ccd94f876e0b8..a20d264be5afd 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -103,10 +103,10 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] - * with replacement: expected number of times each element is chosen; fraction must be >= 0 + * with replacement: expected number of times each element is chosen; fraction must be >= 0 * * @note This is NOT guaranteed to provide exactly the fraction of the count - * of the given [[RDD]]. + * of the given `RDD`. */ def sample(withReplacement: Boolean, fraction: Double): JavaRDD[T] = sample(withReplacement, fraction, Utils.random.nextLong) @@ -117,11 +117,11 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] - * with replacement: expected number of times each element is chosen; fraction must be >= 0 + * with replacement: expected number of times each element is chosen; fraction must be >= 0 * @param seed seed for the random number generator * * @note This is NOT guaranteed to provide exactly the fraction of the count - * of the given [[RDD]]. + * of the given `RDD`. */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaRDD[T] = wrapRDD(rdd.sample(withReplacement, fraction, seed)) @@ -167,7 +167,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) * Return an RDD with the elements from `this` that are not in `other`. * * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. + * RDD will be <= us. */ def subtract(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.subtract(other)) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 38d347aeab8c6..9481156bc93a5 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -238,7 +238,9 @@ class JavaSparkContext(val sc: SparkContext) * }}} * * Do - * `JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`, + * {{{ + * JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path") + * }}} * * then `rdd` contains * {{{ @@ -270,7 +272,9 @@ class JavaSparkContext(val sc: SparkContext) * }}} * * Do - * `JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`, + * {{{ + * JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path") + * }}}, * * then `rdd` contains * {{{ @@ -749,7 +753,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get a local property set in this thread, or null if it is missing. See - * [[org.apache.spark.api.java.JavaSparkContext.setLocalProperty]]. + * `org.apache.spark.api.java.JavaSparkContext.setLocalProperty`. */ def getLocalProperty(key: String): String = sc.getLocalProperty(key) @@ -769,7 +773,7 @@ class JavaSparkContext(val sc: SparkContext) * Application programmers can use this method to group all those jobs together and give a * group description. Once set, the Spark web UI will associate such jobs with this group. * - * The application can also use [[org.apache.spark.api.java.JavaSparkContext.cancelJobGroup]] + * The application can also use `org.apache.spark.api.java.JavaSparkContext.cancelJobGroup` * to cancel all running jobs in this group. For example, * {{{ * // In the main thread: @@ -802,7 +806,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Cancel active jobs for the specified group. See - * [[org.apache.spark.api.java.JavaSparkContext.setJobGroup]] for more information. + * `org.apache.spark.api.java.JavaSparkContext.setJobGroup` for more information. */ def cancelJobGroup(groupId: String): Unit = sc.cancelJobGroup(groupId) diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 6ba79e506a648..2e991ce394c42 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -172,7 +172,7 @@ private final object SnappyCompressionCodec { } /** - * Wrapper over [[SnappyOutputStream]] which guards against write-after-close and double-close + * Wrapper over `SnappyOutputStream` which guards against write-after-close and double-close * issues. See SPARK-7660 for more details. This wrapping can be removed if we upgrade to a version * of snappy-java that contains the fix for https://github.com/xerial/snappy-java/issues/107. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index bff2b8f1d06c9..8e673447581cf 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -70,8 +70,8 @@ import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, Poi * All of the scheduling and execution in Spark is done based on these methods, allowing each RDD * to implement its own way of computing itself. Indeed, users can implement custom RDDs (e.g. for * reading data from a new storage system) by overriding these functions. Please refer to the - * [[http://people.csail.mit.edu/matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details - * on RDD internals. + * Spark paper + * for more details on RDD internals. */ abstract class RDD[T: ClassTag]( @transient private var _sc: SparkContext, @@ -469,7 +469,7 @@ abstract class RDD[T: ClassTag]( * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] - * with replacement: expected number of times each element is chosen; fraction must be >= 0 + * with replacement: expected number of times each element is chosen; fraction must be >= 0 * @param seed seed for the random number generator * * @note This is NOT guaranteed to provide exactly the fraction of the count @@ -675,8 +675,8 @@ abstract class RDD[T: ClassTag]( * may even differ each time the resulting RDD is evaluated. * * @note This operation may be very expensive. If you are grouping in order to perform an - * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] - * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. + * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey` + * or `PairRDDFunctions.reduceByKey` will provide much better performance. */ def groupBy[K](f: T => K)(implicit kt: ClassTag[K]): RDD[(K, Iterable[T])] = withScope { groupBy[K](f, defaultPartitioner(this)) @@ -688,8 +688,8 @@ abstract class RDD[T: ClassTag]( * may even differ each time the resulting RDD is evaluated. * * @note This operation may be very expensive. If you are grouping in order to perform an - * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] - * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. + * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey` + * or `PairRDDFunctions.reduceByKey` will provide much better performance. */ def groupBy[K]( f: T => K, @@ -703,8 +703,8 @@ abstract class RDD[T: ClassTag]( * may even differ each time the resulting RDD is evaluated. * * @note This operation may be very expensive. If you are grouping in order to perform an - * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] - * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. + * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey` + * or `PairRDDFunctions.reduceByKey` will provide much better performance. */ def groupBy[K](f: T => K, p: Partitioner)(implicit kt: ClassTag[K], ord: Ordering[K] = null) : RDD[(K, Iterable[T])] = withScope { diff --git a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala index 8f15f50bee814..f41fc38be2080 100644 --- a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala +++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala @@ -46,7 +46,7 @@ private[spark] object CryptoStreamUtils extends Logging { val COMMONS_CRYPTO_CONF_PREFIX = "commons.crypto." /** - * Helper method to wrap [[OutputStream]] with [[CryptoOutputStream]] for encryption. + * Helper method to wrap `OutputStream` with `CryptoOutputStream` for encryption. */ def createCryptoOutputStream( os: OutputStream, @@ -62,7 +62,7 @@ private[spark] object CryptoStreamUtils extends Logging { } /** - * Helper method to wrap [[InputStream]] with [[CryptoInputStream]] for decryption. + * Helper method to wrap `InputStream` with `CryptoInputStream` for decryption. */ def createCryptoInputStream( is: InputStream, diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 19e020c968a9a..7eb2da1c2748c 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -43,7 +43,8 @@ import org.apache.spark.util.{BoundedPriorityQueue, SerializableConfiguration, S import org.apache.spark.util.collection.CompactBuffer /** - * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. + * A Spark serializer that uses the + * Kryo serialization library. * * @note This serializer is not guaranteed to be wire-compatible across different versions of * Spark. It is intended to be used to serialize/de-serialize data within a single diff --git a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala index bf087af16a5b1..bb8a684b4c7a8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala @@ -89,17 +89,18 @@ class RandomBlockReplicationPolicy prioritizedPeers } + // scalastyle:off line.size.limit /** * Uses sampling algorithm by Robert Floyd. Finds a random sample in O(n) while - * minimizing space usage - * [[http://math.stackexchange.com/questions/178690/ - * whats-the-proof-of-correctness-for-robert-floyds-algorithm-for-selecting-a-sin]] + * minimizing space usage. Please see + * here. * * @param n total number of indices * @param m number of samples needed * @param r random number generator * @return list of m random unique indices */ + // scalastyle:on line.size.limit private def getSampleIds(n: Int, m: Int, r: Random): List[Int] = { val indices = (n - m + 1 to n).foldLeft(Set.empty[Int]) {case (set, i) => val t = r.nextInt(i) + 1 diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 57f6f2f0a9be5..dbeb970c81dfe 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -422,8 +422,8 @@ private[spark] object UIUtils extends Logging { * the whole string will rendered as a simple escaped text. * * Note: In terms of security, only anchor tags with root relative links are supported. So any - * attempts to embed links outside Spark UI, or other tags like