diff --git a/.github/PULL_REQUEST_TEMPLATE b/.github/PULL_REQUEST_TEMPLATE index 0e41cf1826453..5af45d6fa7988 100644 --- a/.github/PULL_REQUEST_TEMPLATE +++ b/.github/PULL_REQUEST_TEMPLATE @@ -7,4 +7,4 @@ (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) -Please review https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark before opening a pull request. +Please review http://spark.apache.org/contributing.html before opening a pull request. diff --git a/.gitignore b/.gitignore index 39d17e1793f77..5634a434db0c0 100644 --- a/.gitignore +++ b/.gitignore @@ -57,6 +57,8 @@ project/plugins/project/build.properties project/plugins/src_managed/ project/plugins/target/ python/lib/pyspark.zip +python/deps +python/pyspark/python reports/ scalastyle-on-compile.generated.xml scalastyle-output.xml diff --git a/.travis.yml b/.travis.yml index 8739849a20798..b9ae28a421309 100644 --- a/.travis.yml +++ b/.travis.yml @@ -25,11 +25,22 @@ sudo: required dist: trusty -# 2. Choose language and target JDKs for parallel builds. +# 2. Choose language, target JDK and env's for parallel builds. language: java jdk: - - oraclejdk7 - oraclejdk8 +env: # Used by the install section below. + # Configure the unit test build for spark core and kubernetes modules, + # while excluding some flaky unit tests using a regex pattern. + - PHASE=test \ + PROFILES="-Pmesos -Pyarn -Phadoop-2.7 -Pkubernetes" \ + MODULES="-pl core,resource-managers/kubernetes/core -am" \ + ARGS="-Dtest=none -Dsuffixes='^org\.apache\.spark\.(?!ExternalShuffleServiceSuite|SortShuffleSuite$|rdd\.LocalCheckpointSuite$|deploy\.SparkSubmitSuite$|deploy\.StandaloneDynamicAllocationSuite$).*'" + # Configure the full build. + - PHASE=install \ + PROFILES="-Pmesos -Pyarn -Phadoop-2.7 -Pkubernetes -Pkinesis-asl -Phive -Phive-thriftserver" \ + MODULES="" \ + ARGS="-T 4 -q -DskipTests" # 3. Setup cache directory for SBT and Maven. cache: @@ -41,11 +52,12 @@ cache: notifications: email: false -# 5. Run maven install before running lint-java. +# 5. Run maven build before running lints. install: - export MAVEN_SKIP_RC=1 - - build/mvn -T 4 -q -DskipTests -Pmesos -Pyarn -Phadoop-2.3 -Pkinesis-asl -Phive -Phive-thriftserver install + - build/mvn ${PHASE} ${PROFILES} ${MODULES} ${ARGS} -# 6. Run lint-java. +# 6. Run lints. script: - dev/lint-java + - dev/lint-scala diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1a8206abe3838..8fdd5aa9e7dfb 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,12 +1,12 @@ ## Contributing to Spark *Before opening a pull request*, review the -[Contributing to Spark wiki](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark). +[Contributing to Spark guide](http://spark.apache.org/contributing.html). It lists steps that are required before creating a PR. In particular, consider: - Is the change important and ready enough to ask the community to spend time reviewing? - Have you searched for existing, related JIRAs and pull requests? -- Is this a new feature that can stand alone as a [third party project](https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects) ? +- Is this a new feature that can stand alone as a [third party project](http://spark.apache.org/third-party-projects.html) ? - Is the change being proposed clearly explained and motivated? When you contribute code, you affirm that the contribution is your original work and that you diff --git a/NOTICE b/NOTICE index 69b513ea3ba3c..f4b64b5c3f470 100644 --- a/NOTICE +++ b/NOTICE @@ -421,9 +421,6 @@ Copyright (c) 2011, Terrence Parr. This product includes/uses ASM (http://asm.ow2.org/), Copyright (c) 2000-2007 INRIA, France Telecom. -This product includes/uses org.json (http://www.json.org/java/index.html), -Copyright (c) 2002 JSON.org - This product includes/uses JLine (http://jline.sourceforge.net/), Copyright (c) 2002-2006, Marc Prud'hommeaux . diff --git a/R/CRAN_RELEASE.md b/R/CRAN_RELEASE.md new file mode 100644 index 0000000000000..d6084c7a7cc90 --- /dev/null +++ b/R/CRAN_RELEASE.md @@ -0,0 +1,91 @@ +# SparkR CRAN Release + +To release SparkR as a package to CRAN, we would use the `devtools` package. Please work with the +`dev@spark.apache.org` community and R package maintainer on this. + +### Release + +First, check that the `Version:` field in the `pkg/DESCRIPTION` file is updated. Also, check for stale files not under source control. + +Note that while `run-tests.sh` runs `check-cran.sh` (which runs `R CMD check`), it is doing so with `--no-manual --no-vignettes`, which skips a few vignettes or PDF checks - therefore it will be preferred to run `R CMD check` on the source package built manually before uploading a release. Also note that for CRAN checks for pdf vignettes to success, `qpdf` tool must be there (to install it, eg. `yum -q -y install qpdf`). + +To upload a release, we would need to update the `cran-comments.md`. This should generally contain the results from running the `check-cran.sh` script along with comments on status of all `WARNING` (should not be any) or `NOTE`. As a part of `check-cran.sh` and the release process, the vignettes is build - make sure `SPARK_HOME` is set and Spark jars are accessible. + +Once everything is in place, run in R under the `SPARK_HOME/R` directory: + +```R +paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); devtools::release(); .libPaths(paths) +``` + +For more information please refer to http://r-pkgs.had.co.nz/release.html#release-check + +### Testing: build package manually + +To build package manually such as to inspect the resulting `.tar.gz` file content, we would also use the `devtools` package. + +Source package is what get released to CRAN. CRAN would then build platform-specific binary packages from the source package. + +#### Build source package + +To build source package locally without releasing to CRAN, run in R under the `SPARK_HOME/R` directory: + +```R +paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); devtools::build("pkg"); .libPaths(paths) +``` + +(http://r-pkgs.had.co.nz/vignettes.html#vignette-workflow-2) + +Similarly, the source package is also created by `check-cran.sh` with `R CMD build pkg`. + +For example, this should be the content of the source package: + +```sh +DESCRIPTION R inst tests +NAMESPACE build man vignettes + +inst/doc/ +sparkr-vignettes.html +sparkr-vignettes.Rmd +sparkr-vignettes.Rman + +build/ +vignette.rds + +man/ + *.Rd files... + +vignettes/ +sparkr-vignettes.Rmd +``` + +#### Test source package + +To install, run this: + +```sh +R CMD INSTALL SparkR_2.1.0.tar.gz +``` + +With "2.1.0" replaced with the version of SparkR. + +This command installs SparkR to the default libPaths. Once that is done, you should be able to start R and run: + +```R +library(SparkR) +vignette("sparkr-vignettes", package="SparkR") +``` + +#### Build binary package + +To build binary package locally, run in R under the `SPARK_HOME/R` directory: + +```R +paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); devtools::build("pkg", binary = TRUE); .libPaths(paths) +``` + +For example, this should be the content of the binary package: + +```sh +DESCRIPTION Meta R html tests +INDEX NAMESPACE help profile worker +``` diff --git a/R/README.md b/R/README.md index 932d5272d0b4f..4c40c5963db70 100644 --- a/R/README.md +++ b/R/README.md @@ -6,7 +6,7 @@ SparkR is an R package that provides a light-weight frontend to use Spark from R Libraries of sparkR need to be created in `$SPARK_HOME/R/lib`. This can be done by running the script `$SPARK_HOME/R/install-dev.sh`. By default the above script uses the system wide installation of R. However, this can be changed to any user installed location of R by setting the environment variable `R_HOME` the full path of the base directory where R is installed, before running install-dev.sh script. -Example: +Example: ```bash # where /home/username/R is where R is installed and /home/username/R/bin contains the files R and RScript export R_HOME=/home/username/R @@ -46,19 +46,19 @@ Sys.setenv(SPARK_HOME="/Users/username/spark") # This line loads SparkR from the installed directory .libPaths(c(file.path(Sys.getenv("SPARK_HOME"), "R", "lib"), .libPaths())) library(SparkR) -sc <- sparkR.init(master="local") +sparkR.session() ``` #### Making changes to SparkR -The [instructions](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) for making contributions to Spark also apply to SparkR. +The [instructions](http://spark.apache.org/contributing.html) for making contributions to Spark also apply to SparkR. If you only make R file changes (i.e. no Scala changes) then you can just re-install the R package using `R/install-dev.sh` and test your changes. Once you have made your changes, please include unit tests for them and run existing unit tests using the `R/run-tests.sh` script as described below. - + #### Generating documentation The SparkR documentation (Rd files and HTML files) are not a part of the source repository. To generate them you can run the script `R/create-docs.sh`. This script uses `devtools` and `knitr` to generate the docs and these packages need to be installed on the machine before using the script. Also, you may need to install these [prerequisites](https://github.com/apache/spark/tree/master/docs#prerequisites). See also, `R/DOCUMENTATION.md` - + ### Examples, Unit tests SparkR comes with several sample programs in the `examples/src/main/r` directory. diff --git a/R/check-cran.sh b/R/check-cran.sh index bb331466ae931..1288e7fc9fb4c 100755 --- a/R/check-cran.sh +++ b/R/check-cran.sh @@ -34,13 +34,30 @@ if [ ! -z "$R_HOME" ] fi R_SCRIPT_PATH="$(dirname $(which R))" fi -echo "USING R_HOME = $R_HOME" +echo "Using R_SCRIPT_PATH = ${R_SCRIPT_PATH}" -# Build the latest docs +# Install the package (this is required for code in vignettes to run when building it later) +# Build the latest docs, but not vignettes, which is built with the package next $FWDIR/create-docs.sh -# Build a zip file containing the source package -"$R_SCRIPT_PATH/"R CMD build $FWDIR/pkg +# Build source package with vignettes +SPARK_HOME="$(cd "${FWDIR}"/..; pwd)" +. "${SPARK_HOME}"/bin/load-spark-env.sh +if [ -f "${SPARK_HOME}/RELEASE" ]; then + SPARK_JARS_DIR="${SPARK_HOME}/jars" +else + SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars" +fi + +if [ -d "$SPARK_JARS_DIR" ]; then + # Build a zip file containing the source package with vignettes + SPARK_HOME="${SPARK_HOME}" "$R_SCRIPT_PATH/"R CMD build $FWDIR/pkg + + find pkg/vignettes/. -not -name '.' -not -name '*.Rmd' -not -name '*.md' -not -name '*.pdf' -not -name '*.html' -delete +else + echo "Error Spark JARs not found in $SPARK_HOME" + exit 1 +fi # Run check as-cran. VERSION=`grep Version $FWDIR/pkg/DESCRIPTION | awk '{print $NF}'` @@ -54,11 +71,32 @@ fi if [ -n "$NO_MANUAL" ] then - CRAN_CHECK_OPTIONS=$CRAN_CHECK_OPTIONS" --no-manual" + CRAN_CHECK_OPTIONS=$CRAN_CHECK_OPTIONS" --no-manual --no-vignettes" fi echo "Running CRAN check with $CRAN_CHECK_OPTIONS options" -"$R_SCRIPT_PATH/"R CMD check $CRAN_CHECK_OPTIONS SparkR_"$VERSION".tar.gz +if [ -n "$NO_TESTS" ] && [ -n "$NO_MANUAL" ] +then + "$R_SCRIPT_PATH/"R CMD check $CRAN_CHECK_OPTIONS SparkR_"$VERSION".tar.gz +else + # This will run tests and/or build vignettes, and require SPARK_HOME + SPARK_HOME="${SPARK_HOME}" "$R_SCRIPT_PATH/"R CMD check $CRAN_CHECK_OPTIONS SparkR_"$VERSION".tar.gz +fi + +# Install source package to get it to generate vignettes rds files, etc. +if [ -n "$CLEAN_INSTALL" ] +then + echo "Removing lib path and installing from source package" + LIB_DIR="$FWDIR/lib" + rm -rf $LIB_DIR + mkdir -p $LIB_DIR + "$R_SCRIPT_PATH/"R CMD INSTALL SparkR_"$VERSION".tar.gz --library=$LIB_DIR + + # Zip the SparkR package so that it can be distributed to worker nodes on YARN + pushd $LIB_DIR > /dev/null + jar cfM "$LIB_DIR/sparkr.zip" SparkR + popd > /dev/null +fi popd > /dev/null diff --git a/R/create-docs.sh b/R/create-docs.sh index 69ffc5f678c36..84e6aa928cb0f 100755 --- a/R/create-docs.sh +++ b/R/create-docs.sh @@ -20,7 +20,7 @@ # Script to create API docs and vignettes for SparkR # This requires `devtools`, `knitr` and `rmarkdown` to be installed on the machine. -# After running this script the html docs can be found in +# After running this script the html docs can be found in # $SPARK_HOME/R/pkg/html # The vignettes can be found in # $SPARK_HOME/R/pkg/vignettes/sparkr_vignettes.html @@ -52,21 +52,4 @@ Rscript -e 'libDir <- "../../lib"; library(SparkR, lib.loc=libDir); library(knit popd -# Find Spark jars. -if [ -f "${SPARK_HOME}/RELEASE" ]; then - SPARK_JARS_DIR="${SPARK_HOME}/jars" -else - SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars" -fi - -# Only create vignettes if Spark JARs exist -if [ -d "$SPARK_JARS_DIR" ]; then - # render creates SparkR vignettes - Rscript -e 'library(rmarkdown); paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); render("pkg/vignettes/sparkr-vignettes.Rmd"); .libPaths(paths)' - - find pkg/vignettes/. -not -name '.' -not -name '*.Rmd' -not -name '*.md' -not -name '*.pdf' -not -name '*.html' -delete -else - echo "Skipping R vignettes as Spark JARs not found in $SPARK_HOME" -fi - popd diff --git a/R/install-dev.sh b/R/install-dev.sh index ada6303a722b7..0f881208bcadb 100755 --- a/R/install-dev.sh +++ b/R/install-dev.sh @@ -46,7 +46,7 @@ if [ ! -z "$R_HOME" ] fi R_SCRIPT_PATH="$(dirname $(which R))" fi -echo "USING R_HOME = $R_HOME" +echo "Using R_SCRIPT_PATH = ${R_SCRIPT_PATH}" # Generate Rd files if devtools is installed "$R_SCRIPT_PATH/"Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }' diff --git a/R/pkg/.Rbuildignore b/R/pkg/.Rbuildignore index 544d203a6dce6..f12f8c275a989 100644 --- a/R/pkg/.Rbuildignore +++ b/R/pkg/.Rbuildignore @@ -1,5 +1,8 @@ ^.*\.Rproj$ ^\.Rproj\.user$ ^\.lintr$ +^cran-comments\.md$ +^NEWS\.md$ +^README\.Rmd$ ^src-native$ ^html$ diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 5a83883089e0e..0cb3a80a6e892 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,8 +1,8 @@ Package: SparkR Type: Package +Version: 2.1.0 Title: R Frontend for Apache Spark -Version: 2.0.0 -Date: 2016-08-27 +Description: The SparkR package provides an R Frontend for Apache Spark. Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), email = "shivaram@cs.berkeley.edu"), person("Xiangrui", "Meng", role = "aut", @@ -10,17 +10,18 @@ Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), person("Felix", "Cheung", role = "aut", email = "felixcheung@apache.org"), person(family = "The Apache Software Foundation", role = c("aut", "cph"))) +License: Apache License (== 2.0) URL: http://www.apache.org/ http://spark.apache.org/ -BugReports: https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark#ContributingtoSpark-ContributingBugReports +BugReports: http://spark.apache.org/contributing.html Depends: R (>= 3.0), methods Suggests: + knitr, + rmarkdown, testthat, e1071, survival -Description: The SparkR package provides an R frontend for Apache Spark. -License: Apache License (== 2.0) Collate: 'schema.R' 'generics.R' @@ -48,3 +49,4 @@ Collate: 'utils.R' 'window.R' RoxygenNote: 5.0.1 +VignetteBuilder: knitr diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 9cd6269f9a8f7..377f9429ae5c1 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -3,7 +3,7 @@ importFrom("methods", "setGeneric", "setMethod", "setOldClass") importFrom("methods", "is", "new", "signature", "show") importFrom("stats", "gaussian", "setNames") -importFrom("utils", "download.file", "object.size", "packageVersion", "untar") +importFrom("utils", "download.file", "object.size", "packageVersion", "tail", "untar") # Disable native libraries till we figure out how to package it # See SPARKR-7839 @@ -45,7 +45,8 @@ exportMethods("glm", "spark.als", "spark.kstest", "spark.logit", - "spark.randomForest") + "spark.randomForest", + "spark.gbt") # Job group lifecycle management methods export("setJobGroup", @@ -353,7 +354,9 @@ export("as.DataFrame", "read.ml", "print.summary.KSTest", "print.summary.RandomForestRegressionModel", - "print.summary.RandomForestClassificationModel") + "print.summary.RandomForestClassificationModel", + "print.summary.GBTRegressionModel", + "print.summary.GBTClassificationModel") export("structField", "structField.jobj", @@ -380,6 +383,8 @@ S3method(print, summary.GeneralizedLinearRegressionModel) S3method(print, summary.KSTest) S3method(print, summary.RandomForestRegressionModel) S3method(print, summary.RandomForestClassificationModel) +S3method(print, summary.GBTRegressionModel) +S3method(print, summary.GBTClassificationModel) S3method(structField, character) S3method(structField, jobj) S3method(structType, jobj) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 1cf9b38ea6483..9a51d530f120a 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -936,7 +936,9 @@ setMethod("unique", #' Sample #' -#' Return a sampled subset of this SparkDataFrame using a random seed. +#' Return a sampled subset of this SparkDataFrame using a random seed. +#' Note: this is not guaranteed to provide exactly the fraction specified +#' of the total count of of the given SparkDataFrame. #' #' @param x A SparkDataFrame #' @param withReplacement Sampling with replacement or not @@ -2539,7 +2541,8 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { #' #' Return a new SparkDataFrame containing the union of rows in this SparkDataFrame #' and another SparkDataFrame. This is equivalent to \code{UNION ALL} in SQL. -#' Note that this does not remove duplicate rows across the two SparkDataFrames. +#' +#' Note: This does not remove duplicate rows across the two SparkDataFrames. #' #' @param x A SparkDataFrame #' @param y A SparkDataFrame @@ -2582,7 +2585,8 @@ setMethod("unionAll", #' Union two or more SparkDataFrames #' #' Union two or more SparkDataFrames. This is equivalent to \code{UNION ALL} in SQL. -#' Note that this does not remove duplicate rows across the two SparkDataFrames. +#' +#' Note: This does not remove duplicate rows across the two SparkDataFrames. #' #' @param x a SparkDataFrame. #' @param ... additional SparkDataFrame(s). diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 38d83c6e5c52b..6f48cd66396ea 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -634,7 +634,7 @@ tableNames <- function(x, ...) { cacheTable.default <- function(tableName) { sparkSession <- getSparkSession() catalog <- callJMethod(sparkSession, "catalog") - callJMethod(catalog, "cacheTable", tableName) + invisible(callJMethod(catalog, "cacheTable", tableName)) } cacheTable <- function(x, ...) { @@ -663,7 +663,7 @@ cacheTable <- function(x, ...) { uncacheTable.default <- function(tableName) { sparkSession <- getSparkSession() catalog <- callJMethod(sparkSession, "catalog") - callJMethod(catalog, "uncacheTable", tableName) + invisible(callJMethod(catalog, "uncacheTable", tableName)) } uncacheTable <- function(x, ...) { @@ -686,7 +686,7 @@ uncacheTable <- function(x, ...) { clearCache.default <- function() { sparkSession <- getSparkSession() catalog <- callJMethod(sparkSession, "catalog") - callJMethod(catalog, "clearCache") + invisible(callJMethod(catalog, "clearCache")) } clearCache <- function() { @@ -730,6 +730,7 @@ dropTempTable <- function(x, ...) { #' If the view has been cached before, then it will also be uncached. #' #' @param viewName the name of the view to be dropped. +#' @return TRUE if the view is dropped successfully, FALSE otherwise. #' @rdname dropTempView #' @name dropTempView #' @export diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 438d77a388f0e..1138caf98ed8a 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -87,8 +87,8 @@ objectFile <- function(sc, path, minPartitions = NULL) { #' in the list are split into \code{numSlices} slices and distributed to nodes #' in the cluster. #' -#' If size of serialized slices is larger than spark.r.maxAllocationLimit or (200MB), the function -#' will write it to disk and send the file name to JVM. Also to make sure each slice is not +#' If size of serialized slices is larger than spark.r.maxAllocationLimit or (200MB), the function +#' will write it to disk and send the file name to JVM. Also to make sure each slice is not #' larger than that limit, number of slices may be increased. #' #' @param sc SparkContext to use @@ -379,5 +379,5 @@ spark.lapply <- function(list, func) { #' @note setLogLevel since 2.0.0 setLogLevel <- function(level) { sc <- getSparkContext() - callJMethod(sc, "setLogLevel", level) + invisible(callJMethod(sc, "setLogLevel", level)) } diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 4d94b4cd05d44..bf5c96373c632 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -1485,7 +1485,7 @@ setMethod("soundex", #' Return the partition ID as a column #' -#' Return the partition ID of the Spark task as a SparkDataFrame column. +#' Return the partition ID as a SparkDataFrame column. #' Note that this is nondeterministic because it depends on data partitioning and #' task scheduling. #' @@ -2296,7 +2296,7 @@ setMethod("n", signature(x = "Column"), #' A pattern could be for instance \preformatted{dd.MM.yyyy} and could return a string like '18.03.1993'. All #' pattern letters of \code{java.text.SimpleDateFormat} can be used. #' -#' NOTE: Use when ever possible specialized functions like \code{year}. These benefit from a +#' Note: Use when ever possible specialized functions like \code{year}. These benefit from a #' specialized implementation. #' #' @param y Column to compute on. @@ -2317,7 +2317,8 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' from_utc_timestamp #' -#' Assumes given timestamp is UTC and converts to given timezone. +#' Given a timestamp, which corresponds to a certain time of day in UTC, returns another timestamp +#' that corresponds to the same time of day in the given timezone. #' #' @param y Column to compute on. #' @param x time zone to use. @@ -2340,7 +2341,7 @@ setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), #' Locate the position of the first occurrence of substr column in the given string. #' Returns null if either of the arguments are null. #' -#' NOTE: The position is not zero based, but 1 based index, returns 0 if substr +#' Note: The position is not zero based, but 1 based index. Returns 0 if substr #' could not be found in str. #' #' @param y column to check @@ -2391,7 +2392,8 @@ setMethod("next_day", signature(y = "Column", x = "character"), #' to_utc_timestamp #' -#' Assumes given timestamp is in given timezone and converts to UTC. +#' Given a timestamp, which corresponds to a certain time of day in the given timezone, returns +#' another timestamp that corresponds to the same time of day in UTC. #' #' @param y Column to compute on #' @param x timezone to use @@ -2539,7 +2541,7 @@ setMethod("shiftLeft", signature(y = "Column", x = "numeric"), #' shiftRight #' -#' Shift the given value numBits right. If the given value is a long value, it will return +#' (Signed) shift the given value numBits right. If the given value is a long value, it will return #' a long value else it will return an integer value. #' #' @param y column to compute on. @@ -2777,7 +2779,8 @@ setMethod("window", signature(x = "Column"), #' locate #' #' Locate the position of the first occurrence of substr. -#' NOTE: The position is not zero based, but 1 based index, returns 0 if substr +#' +#' Note: The position is not zero based, but 1 based index. Returns 0 if substr #' could not be found in str. #' #' @param substr a character string to be matched. @@ -2823,7 +2826,8 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), #' rand #' -#' Generate a random column with i.i.d. samples from U[0.0, 1.0]. +#' Generate a random column with independent and identically distributed (i.i.d.) samples +#' from U[0.0, 1.0]. #' #' @param seed a random seed. Can be missing. #' @family normal_funcs @@ -2852,7 +2856,8 @@ setMethod("rand", signature(seed = "numeric"), #' randn #' -#' Generate a column with i.i.d. samples from the standard normal distribution. +#' Generate a column with independent and identically distributed (i.i.d.) samples from +#' the standard normal distribution. #' #' @param seed a random seed. Can be missing. #' @family normal_funcs @@ -3442,8 +3447,8 @@ setMethod("size", #' sort_array #' -#' Sorts the input array for the given column in ascending order, -#' according to the natural ordering of the array elements. +#' Sorts the input array in ascending or descending order according +#' to the natural ordering of the array elements. #' #' @param x A Column to sort #' @param asc A logical flag indicating the sorting order. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 0271b26a10a90..499c7b279ea9d 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1343,6 +1343,10 @@ setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") }) setGeneric("spark.gaussianMixture", function(data, formula, ...) { standardGeneric("spark.gaussianMixture") }) +#' @rdname spark.gbt +#' @export +setGeneric("spark.gbt", function(data, formula, ...) { standardGeneric("spark.gbt") }) + #' @rdname spark.glm #' @export setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") }) @@ -1369,7 +1373,7 @@ setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark. #' @rdname spark.mlp #' @export -setGeneric("spark.mlp", function(data, ...) { standardGeneric("spark.mlp") }) +setGeneric("spark.mlp", function(data, formula, ...) { standardGeneric("spark.mlp") }) #' @rdname spark.naiveBayes #' @export diff --git a/R/pkg/R/install.R b/R/pkg/R/install.R index 69b0a523b84e4..097b7ad4bea08 100644 --- a/R/pkg/R/install.R +++ b/R/pkg/R/install.R @@ -79,19 +79,28 @@ install.spark <- function(hadoopVersion = "2.7", mirrorUrl = NULL, dir.create(localDir, recursive = TRUE) } - packageLocalDir <- file.path(localDir, packageName) - if (overwrite) { message(paste0("Overwrite = TRUE: download and overwrite the tar file", "and Spark package directory if they exist.")) } + releaseUrl <- Sys.getenv("SPARKR_RELEASE_DOWNLOAD_URL") + if (releaseUrl != "") { + packageName <- basenameSansExtFromUrl(releaseUrl) + } + + packageLocalDir <- file.path(localDir, packageName) + # can use dir.exists(packageLocalDir) under R 3.2.0 or later if (!is.na(file.info(packageLocalDir)$isdir) && !overwrite) { - fmt <- "%s for Hadoop %s found, with SPARK_HOME set to %s" - msg <- sprintf(fmt, version, ifelse(hadoopVersion == "without", "Free build", hadoopVersion), - packageLocalDir) - message(msg) + if (releaseUrl != "") { + message(paste(packageName, "found, setting SPARK_HOME to", packageLocalDir)) + } else { + fmt <- "%s for Hadoop %s found, setting SPARK_HOME to %s" + msg <- sprintf(fmt, version, ifelse(hadoopVersion == "without", "Free build", hadoopVersion), + packageLocalDir) + message(msg) + } Sys.setenv(SPARK_HOME = packageLocalDir) return(invisible(packageLocalDir)) } else { @@ -104,7 +113,12 @@ install.spark <- function(hadoopVersion = "2.7", mirrorUrl = NULL, if (tarExists && !overwrite) { message("tar file found.") } else { - robustDownloadTar(mirrorUrl, version, hadoopVersion, packageName, packageLocalPath) + if (releaseUrl != "") { + message("Downloading from alternate URL:\n- ", releaseUrl) + downloadUrl(releaseUrl, packageLocalPath, paste0("Fetch failed from ", releaseUrl)) + } else { + robustDownloadTar(mirrorUrl, version, hadoopVersion, packageName, packageLocalPath) + } } message(sprintf("Installing to %s", localDir)) @@ -182,16 +196,18 @@ getPreferredMirror <- function(version, packageName) { } directDownloadTar <- function(mirrorUrl, version, hadoopVersion, packageName, packageLocalPath) { - packageRemotePath <- paste0( - file.path(mirrorUrl, version, packageName), ".tgz") + packageRemotePath <- paste0(file.path(mirrorUrl, version, packageName), ".tgz") fmt <- "Downloading %s for Hadoop %s from:\n- %s" msg <- sprintf(fmt, version, ifelse(hadoopVersion == "without", "Free build", hadoopVersion), packageRemotePath) message(msg) + downloadUrl(packageRemotePath, packageLocalPath, paste0("Fetch failed from ", mirrorUrl)) +} - isFail <- tryCatch(download.file(packageRemotePath, packageLocalPath), +downloadUrl <- function(remotePath, localPath, errorMessage) { + isFail <- tryCatch(download.file(remotePath, localPath), error = function(e) { - message(sprintf("Fetch failed from %s", mirrorUrl)) + message(errorMessage) print(e) TRUE }) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 7a220b8d53a2f..d736bbb5e9113 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -116,6 +116,20 @@ setClass("RandomForestRegressionModel", representation(jobj = "jobj")) #' @note RandomForestClassificationModel since 2.1.0 setClass("RandomForestClassificationModel", representation(jobj = "jobj")) +#' S4 class that represents a GBTRegressionModel +#' +#' @param jobj a Java object reference to the backing Scala GBTRegressionModel +#' @export +#' @note GBTRegressionModel since 2.1.0 +setClass("GBTRegressionModel", representation(jobj = "jobj")) + +#' S4 class that represents a GBTClassificationModel +#' +#' @param jobj a Java object reference to the backing Scala GBTClassificationModel +#' @export +#' @note GBTClassificationModel since 2.1.0 +setClass("GBTClassificationModel", representation(jobj = "jobj")) + #' Saves the MLlib model to the input path #' #' Saves the MLlib model to the input path. For more information, see the specific @@ -124,7 +138,8 @@ setClass("RandomForestClassificationModel", representation(jobj = "jobj")) #' @name write.ml #' @export #' @seealso \link{spark.glm}, \link{glm}, -#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, +#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.gbt}, \link{spark.isoreg}, +#' @seealso \link{spark.kmeans}, #' @seealso \link{spark.lda}, \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, #' @seealso \link{spark.randomForest}, \link{spark.survreg}, #' @seealso \link{read.ml} @@ -138,7 +153,8 @@ NULL #' @name predict #' @export #' @seealso \link{spark.glm}, \link{glm}, -#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, +#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.gbt}, \link{spark.isoreg}, +#' @seealso \link{spark.kmeans}, #' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, #' @seealso \link{spark.randomForest}, \link{spark.survreg} NULL @@ -175,7 +191,7 @@ predict_internal <- function(object, newData) { #' @param regParam regularization parameter for L2 regularization. #' @param ... additional arguments passed to the method. #' @aliases spark.glm,SparkDataFrame,formula-method -#' @return \code{spark.glm} returns a fitted generalized linear model +#' @return \code{spark.glm} returns a fitted generalized linear model. #' @rdname spark.glm #' @name spark.glm #' @export @@ -261,10 +277,12 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDat # Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary(). #' @param object a fitted generalized linear model. -#' @return \code{summary} returns a summary object of the fitted model, a list of components -#' including at least the coefficients, null/residual deviance, null/residual degrees -#' of freedom, AIC and number of iterations IRLS takes. -#' +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list of components includes at least the \code{coefficients} (coefficients matrix, which includes +#' coefficients, standard error of coefficients, t value and p value), +#' \code{null.deviance} (null/residual degrees of freedom), \code{aic} (AIC) +#' and \code{iter} (number of iterations IRLS takes). If there are collinear columns in the data, +#' the coefficients matrix only provides coefficients. #' @rdname spark.glm #' @export #' @note summary(GeneralizedLinearRegressionModel) since 2.0.0 @@ -287,9 +305,18 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), } else { dataFrame(callJMethod(jobj, "rDevianceResiduals")) } - coefficients <- matrix(coefficients, ncol = 4) - colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)") - rownames(coefficients) <- unlist(features) + # If the underlying WeightedLeastSquares using "normal" solver, we can provide + # coefficients, standard error of coefficients, t value and p value. Otherwise, + # it will be fitted by local "l-bfgs", we can only provide coefficients. + if (length(features) == length(coefficients)) { + coefficients <- matrix(coefficients, ncol = 1) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) + } else { + coefficients <- matrix(coefficients, ncol = 4) + colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)") + rownames(coefficients) <- unlist(features) + } ans <- list(deviance.resid = deviance.resid, coefficients = coefficients, dispersion = dispersion, null.deviance = null.deviance, deviance = deviance, df.null = df.null, df.residual = df.residual, @@ -301,7 +328,7 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), # Prints the summary of GeneralizedLinearRegressionModel #' @rdname spark.glm -#' @param x summary object of fitted generalized linear model returned by \code{summary} function +#' @param x summary object of fitted generalized linear model returned by \code{summary} function. #' @export #' @note print.summary.GeneralizedLinearRegressionModel since 2.0.0 print.summary.GeneralizedLinearRegressionModel <- function(x, ...) { @@ -334,7 +361,7 @@ print.summary.GeneralizedLinearRegressionModel <- function(x, ...) { #' @param newData a SparkDataFrame for testing. #' @return \code{predict} returns a SparkDataFrame containing predicted labels in a column named -#' "prediction" +#' "prediction". #' @rdname spark.glm #' @export #' @note predict(GeneralizedLinearRegressionModel) since 1.5.0 @@ -348,7 +375,7 @@ setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"), #' @param newData a SparkDataFrame for testing. #' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named -#' "prediction" +#' "prediction". #' @rdname spark.naiveBayes #' @export #' @note predict(NaiveBayesModel) since 2.0.0 @@ -360,8 +387,9 @@ setMethod("predict", signature(object = "NaiveBayesModel"), # Returns the summary of a naive Bayes model produced by \code{spark.naiveBayes} #' @param object a naive Bayes model fitted by \code{spark.naiveBayes}. -#' @return \code{summary} returns a list containing \code{apriori}, the label distribution, and -#' \code{tables}, conditional probabilities given the target label. +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes \code{apriori} (the label distribution) and +#' \code{tables} (conditional probabilities given the target label). #' @rdname spark.naiveBayes #' @export #' @note summary(NaiveBayesModel) since 2.0.0 @@ -382,9 +410,9 @@ setMethod("summary", signature(object = "NaiveBayesModel"), # Returns posterior probabilities from a Latent Dirichlet Allocation model produced by spark.lda() -#' @param newData A SparkDataFrame for testing +#' @param newData A SparkDataFrame for testing. #' @return \code{spark.posterior} returns a SparkDataFrame containing posterior probabilities -#' vectors named "topicDistribution" +#' vectors named "topicDistribution". #' @rdname spark.lda #' @aliases spark.posterior,LDAModel,SparkDataFrame-method #' @export @@ -398,7 +426,8 @@ setMethod("spark.posterior", signature(object = "LDAModel", newData = "SparkData #' @param object A Latent Dirichlet Allocation model fitted by \code{spark.lda}. #' @param maxTermsPerTopic Maximum number of terms to collect for each topic. Default value of 10. -#' @return \code{summary} returns a list containing +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes #' \item{\code{docConcentration}}{concentration parameter commonly named \code{alpha} for #' the prior placed on documents distributions over topics \code{theta}} #' \item{\code{topicConcentration}}{concentration parameter commonly named \code{beta} or @@ -449,7 +478,7 @@ setMethod("spark.perplexity", signature(object = "LDAModel", data = "SparkDataFr # Saves the Latent Dirichlet Allocation model to the input path. -#' @param path The directory where the model is saved +#' @param path The directory where the model is saved. #' @param overwrite Overwrites or not if the output path already exists. Default is FALSE #' which means throw exception if the output path exists. #' @@ -468,16 +497,16 @@ setMethod("write.ml", signature(object = "LDAModel", path = "character"), #' Fits an Isotonic Regression model against a Spark DataFrame, similarly to R's isoreg(). #' Users can print, make predictions on the produced model and save the model to the input path. #' -#' @param data SparkDataFrame for training +#' @param data SparkDataFrame for training. #' @param formula A symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. #' @param isotonic Whether the output sequence should be isotonic/increasing (TRUE) or -#' antitonic/decreasing (FALSE) +#' antitonic/decreasing (FALSE). #' @param featureIndex The index of the feature if \code{featuresCol} is a vector column -#' (default: 0), no effect otherwise +#' (default: 0), no effect otherwise. #' @param weightCol The weight column name. #' @param ... additional arguments passed to the method. -#' @return \code{spark.isoreg} returns a fitted Isotonic Regression model +#' @return \code{spark.isoreg} returns a fitted Isotonic Regression model. #' @rdname spark.isoreg #' @aliases spark.isoreg,SparkDataFrame,formula-method #' @name spark.isoreg @@ -509,7 +538,7 @@ setMethod("write.ml", signature(object = "LDAModel", path = "character"), #' @note spark.isoreg since 2.1.0 setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, isotonic = TRUE, featureIndex = 0, weightCol = NULL) { - formula <- paste0(deparse(formula), collapse = "") + formula <- paste(deparse(formula), collapse = "") if (is.null(weightCol)) { weightCol <- "" @@ -523,9 +552,9 @@ setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula" # Predicted values based on an isotonicRegression model -#' @param object a fitted IsotonicRegressionModel -#' @param newData SparkDataFrame for testing -#' @return \code{predict} returns a SparkDataFrame containing predicted values +#' @param object a fitted IsotonicRegressionModel. +#' @param newData SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted values. #' @rdname spark.isoreg #' @aliases predict,IsotonicRegressionModel,SparkDataFrame-method #' @export @@ -537,7 +566,9 @@ setMethod("predict", signature(object = "IsotonicRegressionModel"), # Get the summary of an IsotonicRegressionModel model -#' @return \code{summary} returns the model's boundaries and prediction as lists +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes model's \code{boundaries} (boundaries in increasing order) +#' and \code{predictions} (predictions associated with the boundaries at the same index). #' @rdname spark.isoreg #' @aliases summary,IsotonicRegressionModel-method #' @export @@ -634,7 +665,11 @@ setMethod("fitted", signature(object = "KMeansModel"), # Get the summary of a k-means model #' @param object a fitted k-means model. -#' @return \code{summary} returns the model's coefficients, size and cluster. +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes the model's \code{k} (number of cluster centers), +#' \code{coefficients} (model cluster centers), +#' \code{size} (number of data points in each cluster), and \code{cluster} +#' (cluster centers of the transformed data). #' @rdname spark.kmeans #' @export #' @note summary(KMeansModel) since 2.0.0 @@ -654,7 +689,7 @@ setMethod("summary", signature(object = "KMeansModel"), } else { dataFrame(callJMethod(jobj, "cluster")) } - list(coefficients = coefficients, size = size, + list(k = k, coefficients = coefficients, size = size, cluster = cluster, is.loaded = is.loaded) }) @@ -676,18 +711,17 @@ setMethod("predict", signature(object = "KMeansModel"), #' with pivoting; "multinomial": Multinomial logistic (softmax) regression without pivoting, similar to glmnet. #' Users can print, make predictions on the produced model and save the model to the input path. #' -#' @param data SparkDataFrame for training +#' @param data SparkDataFrame for training. #' @param formula A symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param regParam the regularization parameter. Default is 0.0. +#' @param regParam the regularization parameter. #' @param elasticNetParam the ElasticNet mixing parameter. For alpha = 0.0, the penalty is an L2 penalty. #' For alpha = 1.0, it is an L1 penalty. For 0.0 < alpha < 1.0, the penalty is a combination #' of L1 and L2. Default is 0.0 which is an L2 penalty. #' @param maxIter maximum iteration number. #' @param tol convergence tolerance of iterations. -#' @param fitIntercept whether to fit an intercept term. Default is TRUE. #' @param family the name of family which is a description of the label distribution to be used in the model. -#' Supported options: Default is "auto". +#' Supported options: #' \itemize{ #' \item{"auto": Automatically select the family based on the number of classes: #' If number of classes == 1 || number of classes == 2, set to "binomial". @@ -705,13 +739,10 @@ setMethod("predict", signature(object = "KMeansModel"), #' threshold p is equivalent to setting thresholds c(1-p, p). In multiclass (or binary) classification to adjust the probability of #' predicting each class. Array must have length equal to the number of classes, with values > 0, #' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p -#' is the original probability of that class and t is the class's threshold. Default is 0.5. +#' is the original probability of that class and t is the class's threshold. #' @param weightCol The weight column name. -#' @param aggregationDepth depth for treeAggregate (>= 2). If the dimensions of features or the number of partitions -#' are large, this param could be adjusted to a larger size. Default is 2. -#' @param probabilityCol column name for predicted class conditional probabilities. Default is "probability". #' @param ... additional arguments passed to the method. -#' @return \code{spark.logit} returns a fitted logistic regression model +#' @return \code{spark.logit} returns a fitted logistic regression model. #' @rdname spark.logit #' @aliases spark.logit,SparkDataFrame,formula-method #' @name spark.logit @@ -720,46 +751,36 @@ setMethod("predict", signature(object = "KMeansModel"), #' \dontrun{ #' sparkR.session() #' # binary logistic regression -#' label <- c(1.0, 1.0, 1.0, 0.0, 0.0) -#' feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776) -#' binary_data <- as.data.frame(cbind(label, feature)) -#' binary_df <- createDataFrame(binary_data) -#' blr_model <- spark.logit(binary_df, label ~ feature, thresholds = 1.0) -#' blr_predict <- collect(select(predict(blr_model, binary_df), "prediction")) -#' -#' # summary of binary logistic regression -#' blr_summary <- summary(blr_model) -#' blr_fmeasure <- collect(select(blr_summary$fMeasureByThreshold, "threshold", "F-Measure")) +#' df <- createDataFrame(iris) +#' training <- df[df$Species %in% c("versicolor", "virginica"), ] +#' model <- spark.logit(training, Species ~ ., regParam = 0.5) +#' summary <- summary(model) +#' +#' # fitted values on training data +#' fitted <- predict(model, training) +#' #' # save fitted model to input path #' path <- "path/to/model" -#' write.ml(blr_model, path) +#' write.ml(model, path) #' #' # can also read back the saved model and predict #' # Note that summary deos not work on loaded model #' savedModel <- read.ml(path) -#' blr_predict2 <- collect(select(predict(savedModel, binary_df), "prediction")) +#' summary(savedModel) #' #' # multinomial logistic regression #' -#' label <- c(0.0, 1.0, 2.0, 0.0, 0.0) -#' feature1 <- c(4.845940, 5.64480, 7.430381, 6.464263, 5.555667) -#' feature2 <- c(2.941319, 2.614812, 2.162451, 3.339474, 2.970987) -#' feature3 <- c(1.322733, 1.348044, 3.861237, 9.686976, 3.447130) -#' feature4 <- c(1.3246388, 0.5510444, 0.9225810, 1.2147881, 1.6020842) -#' data <- as.data.frame(cbind(label, feature1, feature2, feature3, feature4)) -#' df <- createDataFrame(data) +#' df <- createDataFrame(iris) +#' model <- spark.logit(df, Species ~ ., regParam = 0.5) +#' summary <- summary(model) #' -#' # Note that summary of multinomial logistic regression is not implemented yet -#' model <- spark.logit(df, label ~ ., family = "multinomial", thresholds = c(0, 1, 1)) -#' predict1 <- collect(select(predict(model, df), "prediction")) #' } #' @note spark.logit since 2.1.0 setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100, - tol = 1E-6, fitIntercept = TRUE, family = "auto", standardization = TRUE, - thresholds = 0.5, weightCol = NULL, aggregationDepth = 2, - probabilityCol = "probability") { - formula <- paste0(deparse(formula), collapse = "") + tol = 1E-6, family = "auto", standardization = TRUE, + thresholds = 0.5, weightCol = NULL) { + formula <- paste(deparse(formula), collapse = "") if (is.null(weightCol)) { weightCol <- "" @@ -768,10 +789,9 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula") jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit", data@sdf, formula, as.numeric(regParam), as.numeric(elasticNetParam), as.integer(maxIter), - as.numeric(tol), as.logical(fitIntercept), - as.character(family), as.logical(standardization), - as.array(thresholds), as.character(weightCol), - as.integer(aggregationDepth), as.character(probabilityCol)) + as.numeric(tol), as.character(family), + as.logical(standardization), as.array(thresholds), + as.character(weightCol)) new("LogisticRegressionModel", jobj = jobj) }) @@ -790,9 +810,9 @@ setMethod("predict", signature(object = "LogisticRegressionModel"), # Get the summary of an LogisticRegressionModel -#' @param object an LogisticRegressionModel fitted by \code{spark.logit} -#' @return \code{summary} returns the Binary Logistic regression results of a given model as lists. Note that -#' Multinomial logistic regression summary is not available now. +#' @param object an LogisticRegressionModel fitted by \code{spark.logit}. +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes \code{coefficients} (coefficients matrix of the fitted model). #' @rdname spark.logit #' @aliases summary,LogisticRegressionModel-method #' @export @@ -800,33 +820,21 @@ setMethod("predict", signature(object = "LogisticRegressionModel"), setMethod("summary", signature(object = "LogisticRegressionModel"), function(object) { jobj <- object@jobj - is.loaded <- callJMethod(jobj, "isLoaded") - - if (is.loaded) { - stop("Loaded model doesn't have training summary.") + features <- callJMethod(jobj, "rFeatures") + labels <- callJMethod(jobj, "labels") + coefficients <- callJMethod(jobj, "rCoefficients") + nCol <- length(coefficients) / length(features) + coefficients <- matrix(coefficients, ncol = nCol) + # If nCol == 1, means this is a binomial logistic regression model with pivoting. + # Otherwise, it's a multinomial logistic regression model without pivoting. + if (nCol == 1) { + colnames(coefficients) <- c("Estimate") + } else { + colnames(coefficients) <- unlist(labels) } + rownames(coefficients) <- unlist(features) - roc <- dataFrame(callJMethod(jobj, "roc")) - - areaUnderROC <- callJMethod(jobj, "areaUnderROC") - - pr <- dataFrame(callJMethod(jobj, "pr")) - - fMeasureByThreshold <- dataFrame(callJMethod(jobj, "fMeasureByThreshold")) - - precisionByThreshold <- dataFrame(callJMethod(jobj, "precisionByThreshold")) - - recallByThreshold <- dataFrame(callJMethod(jobj, "recallByThreshold")) - - totalIterations <- callJMethod(jobj, "totalIterations") - - objectiveHistory <- callJMethod(jobj, "objectiveHistory") - - list(roc = roc, areaUnderROC = areaUnderROC, pr = pr, - fMeasureByThreshold = fMeasureByThreshold, - precisionByThreshold = precisionByThreshold, - recallByThreshold = recallByThreshold, - totalIterations = totalIterations, objectiveHistory = objectiveHistory) + list(coefficients = coefficients) }) #' Multilayer Perceptron Classification Model @@ -840,8 +848,10 @@ setMethod("summary", signature(object = "LogisticRegressionModel"), #' Multilayer Perceptron} #' #' @param data a \code{SparkDataFrame} of observations and labels for model fitting. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', '.', ':', '+', and '-'. #' @param blockSize blockSize parameter. -#' @param layers integer vector containing the number of nodes for each layer +#' @param layers integer vector containing the number of nodes for each layer. #' @param solver solver parameter, supported options: "gd" (minibatch gradient descent) or "l-bfgs". #' @param maxIter maximum iteration number. #' @param tol convergence tolerance of iterations. @@ -852,7 +862,7 @@ setMethod("summary", signature(object = "LogisticRegressionModel"), #' @param ... additional arguments passed to the method. #' @return \code{spark.mlp} returns a fitted Multilayer Perceptron Classification Model. #' @rdname spark.mlp -#' @aliases spark.mlp,SparkDataFrame-method +#' @aliases spark.mlp,SparkDataFrame,formula-method #' @name spark.mlp #' @seealso \link{read.ml} #' @export @@ -861,7 +871,7 @@ setMethod("summary", signature(object = "LogisticRegressionModel"), #' df <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") #' #' # fit a Multilayer Perceptron Classification Model -#' model <- spark.mlp(df, blockSize = 128, layers = c(4, 3), solver = "l-bfgs", +#' model <- spark.mlp(df, label ~ features, blockSize = 128, layers = c(4, 3), solver = "l-bfgs", #' maxIter = 100, tol = 0.5, stepSize = 1, seed = 1, #' initialWeights = c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) #' @@ -878,9 +888,10 @@ setMethod("summary", signature(object = "LogisticRegressionModel"), #' summary(savedModel) #' } #' @note spark.mlp since 2.1.0 -setMethod("spark.mlp", signature(data = "SparkDataFrame"), - function(data, layers, blockSize = 128, solver = "l-bfgs", maxIter = 100, +setMethod("spark.mlp", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, layers, blockSize = 128, solver = "l-bfgs", maxIter = 100, tol = 1E-6, stepSize = 0.03, seed = NULL, initialWeights = NULL) { + formula <- paste(deparse(formula), collapse = "") if (is.null(layers)) { stop ("layers must be a integer vector with length > 1.") } @@ -895,7 +906,7 @@ setMethod("spark.mlp", signature(data = "SparkDataFrame"), initialWeights <- as.array(as.numeric(na.omit(initialWeights))) } jobj <- callJStatic("org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper", - "fit", data@sdf, as.integer(blockSize), as.array(layers), + "fit", data@sdf, formula, as.integer(blockSize), as.array(layers), as.character(solver), as.integer(maxIter), as.numeric(tol), as.numeric(stepSize), seed, initialWeights) new("MultilayerPerceptronClassificationModel", jobj = jobj) @@ -918,9 +929,12 @@ setMethod("predict", signature(object = "MultilayerPerceptronClassificationModel # Returns the summary of a Multilayer Perceptron Classification Model produced by \code{spark.mlp} #' @param object a Multilayer Perceptron Classification Model fitted by \code{spark.mlp} -#' @return \code{summary} returns a list containing \code{labelCount}, \code{layers}, and -#' \code{weights}. For \code{weights}, it is a numeric vector with length equal to -#' the expected given the architecture (i.e., for 8-10-2 network, 100 connection weights). +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes \code{numOfInputs} (number of inputs), \code{numOfOutputs} +#' (number of outputs), \code{layers} (array of layer sizes including input +#' and output layers), and \code{weights} (the weights of layers). +#' For \code{weights}, it is a numeric vector with length equal to the expected +#' given the architecture (i.e., for 8-10-2 network, 112 connection weights). #' @rdname spark.mlp #' @export #' @aliases summary,MultilayerPerceptronClassificationModel-method @@ -928,10 +942,12 @@ setMethod("predict", signature(object = "MultilayerPerceptronClassificationModel setMethod("summary", signature(object = "MultilayerPerceptronClassificationModel"), function(object) { jobj <- object@jobj - labelCount <- callJMethod(jobj, "labelCount") layers <- unlist(callJMethod(jobj, "layers")) + numOfInputs <- head(layers, n = 1) + numOfOutputs <- tail(layers, n = 1) weights <- callJMethod(jobj, "weights") - list(labelCount = labelCount, layers = layers, weights = weights) + list(numOfInputs = numOfInputs, numOfOutputs = numOfOutputs, + layers = layers, weights = weights) }) #' Naive Bayes Models @@ -983,7 +999,7 @@ setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "form # Saves the Bernoulli naive Bayes model to the input path. -#' @param path the directory where the model is saved +#' @param path the directory where the model is saved. #' @param overwrite overwrites or not if the output path already exists. Default is FALSE #' which means throw exception if the output path exists. #' @@ -1057,7 +1073,7 @@ setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationMode # Save fitted IsotonicRegressionModel to the input path -#' @param path The directory where the model is saved +#' @param path The directory where the model is saved. #' @param overwrite Overwrites or not if the output path already exists. Default is FALSE #' which means throw exception if the output path exists. #' @@ -1072,7 +1088,7 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char # Save fitted LogisticRegressionModel to the input path -#' @param path The directory where the model is saved +#' @param path The directory where the model is saved. #' @param overwrite Overwrites or not if the output path already exists. Default is FALSE #' which means throw exception if the output path exists. #' @@ -1141,6 +1157,10 @@ read.ml <- function(path) { new("RandomForestRegressionModel", jobj = jobj) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestClassifierWrapper")) { new("RandomForestClassificationModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTRegressorWrapper")) { + new("GBTRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTClassifierWrapper")) { + new("GBTClassificationModel", jobj = jobj) } else { stop("Unsupported model: ", jobj) } @@ -1195,14 +1215,14 @@ setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula #' posterior probabilities on new data, \code{spark.perplexity} to compute log perplexity on new #' data and \code{write.ml}/\code{read.ml} to save/load fitted models. #' -#' @param data A SparkDataFrame for training -#' @param features Features column name, default "features". Either libSVM-format column or -#' character-format column is valid. -#' @param k Number of topics, default 10 -#' @param maxIter Maximum iterations, default 20 -#' @param optimizer Optimizer to train an LDA model, "online" or "em", default "online" +#' @param data A SparkDataFrame for training. +#' @param features Features column name. Either libSVM-format column or character-format column is +#' valid. +#' @param k Number of topics. +#' @param maxIter Maximum iterations. +#' @param optimizer Optimizer to train an LDA model, "online" or "em", default is "online". #' @param subsamplingRate (For online optimizer) Fraction of the corpus to be sampled and used in -#' each iteration of mini-batch gradient descent, in range (0, 1], default 0.05 +#' each iteration of mini-batch gradient descent, in range (0, 1]. #' @param topicConcentration concentration parameter (commonly named \code{beta} or \code{eta}) for #' the prior placed on topic distributions over terms, default -1 to set automatically on the #' Spark side. Use \code{summary} to retrieve the effective topicConcentration. Only 1-size @@ -1215,7 +1235,7 @@ setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula #' parameter if libSVM-format column is used as the features column. #' @param maxVocabSize maximum vocabulary size, default 1 << 18 #' @param ... additional argument(s) passed to the method. -#' @return \code{spark.lda} returns a fitted Latent Dirichlet Allocation model +#' @return \code{spark.lda} returns a fitted Latent Dirichlet Allocation model. #' @rdname spark.lda #' @aliases spark.lda,SparkDataFrame-method #' @seealso topicmodels: \url{https://cran.r-project.org/package=topicmodels} @@ -1263,8 +1283,9 @@ setMethod("spark.lda", signature(data = "SparkDataFrame"), # similarly to R's summary(). #' @param object a fitted AFT survival regression model. -#' @return \code{summary} returns a list containing the model's coefficients, -#' intercept and log(scale) +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes the model's \code{coefficients} (features, coefficients, +#' intercept and log(scale)). #' @rdname spark.survreg #' @export #' @note summary(AFTSurvivalRegressionModel) since 2.0.0 @@ -1284,7 +1305,7 @@ setMethod("summary", signature(object = "AFTSurvivalRegressionModel"), #' @param newData a SparkDataFrame for testing. #' @return \code{predict} returns a SparkDataFrame containing predicted values -#' on the original scale of the data (mean predicted value at scale = 1.0). +#' on the original scale of the data (mean predicted value at scale = 1.0). #' @rdname spark.survreg #' @export #' @note predict(AFTSurvivalRegressionModel) since 2.0.0 @@ -1351,7 +1372,9 @@ setMethod("spark.gaussianMixture", signature(data = "SparkDataFrame", formula = # Get the summary of a multivariate gaussian mixture model #' @param object a fitted gaussian mixture model. -#' @return \code{summary} returns the model's lambda, mu, sigma and posterior. +#' @return \code{summary} returns summary of the fitted model, which is a list. +#' The list includes the model's \code{lambda} (lambda), \code{mu} (mu), +#' \code{sigma} (sigma), and \code{posterior} (posterior). #' @aliases spark.gaussianMixture,SparkDataFrame,formula-method #' @rdname spark.gaussianMixture #' @export @@ -1415,7 +1438,7 @@ setMethod("predict", signature(object = "GaussianMixtureModel"), #' @param userCol column name for user ids. Ids must be (or can be coerced into) integers. #' @param itemCol column name for item ids. Ids must be (or can be coerced into) integers. #' @param rank rank of the matrix factorization (> 0). -#' @param reg regularization parameter (>= 0). +#' @param regParam regularization parameter (>= 0). #' @param maxIter maximum number of iterations (>= 0). #' @param nonnegative logical value indicating whether to apply nonnegativity constraints. #' @param implicitPrefs logical value indicating whether to use implicit preference. @@ -1425,7 +1448,7 @@ setMethod("predict", signature(object = "GaussianMixtureModel"), #' @param numItemBlocks number of item blocks used to parallelize computation (> 0). #' @param checkpointInterval number of checkpoint intervals (>= 1) or disable checkpoint (-1). #' @param ... additional argument(s) passed to the method. -#' @return \code{spark.als} returns a fitted ALS model +#' @return \code{spark.als} returns a fitted ALS model. #' @rdname spark.als #' @aliases spark.als,SparkDataFrame-method #' @name spark.als @@ -1454,21 +1477,21 @@ setMethod("predict", signature(object = "GaussianMixtureModel"), #' #' # set other arguments #' modelS <- spark.als(df, "rating", "user", "item", rank = 20, -#' reg = 0.1, nonnegative = TRUE) +#' regParam = 0.1, nonnegative = TRUE) #' statsS <- summary(modelS) #' } #' @note spark.als since 2.1.0 setMethod("spark.als", signature(data = "SparkDataFrame"), function(data, ratingCol = "rating", userCol = "user", itemCol = "item", - rank = 10, reg = 0.1, maxIter = 10, nonnegative = FALSE, + rank = 10, regParam = 0.1, maxIter = 10, nonnegative = FALSE, implicitPrefs = FALSE, alpha = 1.0, numUserBlocks = 10, numItemBlocks = 10, checkpointInterval = 10, seed = 0) { if (!is.numeric(rank) || rank <= 0) { stop("rank should be a positive number.") } - if (!is.numeric(reg) || reg < 0) { - stop("reg should be a nonnegative number.") + if (!is.numeric(regParam) || regParam < 0) { + stop("regParam should be a nonnegative number.") } if (!is.numeric(maxIter) || maxIter <= 0) { stop("maxIter should be a positive number.") @@ -1476,7 +1499,7 @@ setMethod("spark.als", signature(data = "SparkDataFrame"), jobj <- callJStatic("org.apache.spark.ml.r.ALSWrapper", "fit", data@sdf, ratingCol, userCol, itemCol, as.integer(rank), - reg, as.integer(maxIter), implicitPrefs, alpha, nonnegative, + regParam, as.integer(maxIter), implicitPrefs, alpha, nonnegative, as.integer(numUserBlocks), as.integer(numItemBlocks), as.integer(checkpointInterval), as.integer(seed)) new("ALSModel", jobj = jobj) @@ -1485,9 +1508,11 @@ setMethod("spark.als", signature(data = "SparkDataFrame"), # Returns a summary of the ALS model produced by spark.als. #' @param object a fitted ALS model. -#' @return \code{summary} returns a list containing the names of the user column, -#' the item column and the rating column, the estimated user and item factors, -#' rank, regularization parameter and maximum number of iterations used in training. +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes \code{user} (the names of the user column), +#' \code{item} (the item column), \code{rating} (the rating column), \code{userFactors} +#' (the estimated user factors), \code{itemFactors} (the estimated item factors), +#' and \code{rank} (rank of the matrix factorization model). #' @rdname spark.als #' @aliases summary,ALSModel-method #' @export @@ -1570,14 +1595,14 @@ setMethod("write.ml", signature(object = "ALSModel", path = "character"), #' \dontrun{ #' data <- data.frame(test = c(0.1, 0.15, 0.2, 0.3, 0.25)) #' df <- createDataFrame(data) -#' test <- spark.ktest(df, "test", "norm", c(0, 1)) +#' test <- spark.kstest(df, "test", "norm", c(0, 1)) #' #' # get a summary of the test result #' testSummary <- summary(test) #' testSummary #' #' # print out the summary in an organized way -#' print.summary.KSTest(test) +#' print.summary.KSTest(testSummary) #' } #' @note spark.kstest since 2.1.0 setMethod("spark.kstest", signature(data = "SparkDataFrame"), @@ -1600,9 +1625,10 @@ setMethod("spark.kstest", signature(data = "SparkDataFrame"), # Get the summary of Kolmogorov-Smirnov (KS) Test. #' @param object test result object of KSTest by \code{spark.kstest}. -#' @return \code{summary} returns a list containing the p-value, test statistic computed for the -#' test, the null hypothesis with its parameters tested against -#' and degrees of freedom of the test. +#' @return \code{summary} returns summary information of KSTest object, which is a list. +#' The list includes the \code{p.value} (p-value), \code{statistic} (test statistic +#' computed for the test), \code{nullHypothesis} (the null hypothesis with its +#' parameters tested against) and \code{degreesOfFreedom} (degrees of freedom of the test). #' @rdname spark.kstest #' @aliases summary,KSTest-method #' @export @@ -1644,33 +1670,36 @@ print.summary.KSTest <- function(x, ...) { #' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to #' save/load fitted models. #' For more details, see -#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html}{Random Forest} +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-regression}{ +#' Random Forest Regression} and +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-classifier}{ +#' Random Forest Classification} #' #' @param data a SparkDataFrame for training. #' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', ':', '+', and '-'. #' @param type type of model, one of "regression" or "classification", to fit -#' @param maxDepth Maximum depth of the tree (>= 0). (default = 5) +#' @param maxDepth Maximum depth of the tree (>= 0). #' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing #' how to split on features at each node. More bins give higher granularity. Must be -#' >= 2 and >= number of categories in any categorical feature. (default = 32) +#' >= 2 and >= number of categories in any categorical feature. #' @param numTrees Number of trees to train (>= 1). #' @param impurity Criterion used for information gain calculation. #' For regression, must be "variance". For classification, must be one of -#' "entropy" and "gini". (default = gini) -#' @param minInstancesPerNode Minimum number of instances each child must have after split. -#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. -#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' "entropy" and "gini", default is "gini". #' @param featureSubsetStrategy The number of features to consider for splits at each tree node. #' Supported options: "auto", "all", "onethird", "sqrt", "log2", (0.0-1.0], [1-n]. #' @param seed integer seed for random number generation. #' @param subsamplingRate Fraction of the training data used for learning each decision tree, in -#' range (0, 1]. (default = 1.0) -#' @param probabilityCol column name for predicted class conditional probabilities, only for -#' classification. (default = "probability") +#' range (0, 1]. +#' @param minInstancesPerNode Minimum number of instances each child must have after split. +#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. +#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). #' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. #' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with -#' nodes. +#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching +#' can speed up training of deeper trees. Users can set how often should the +#' cache be checkpointed or disable it by setting checkpointInterval. #' @param ... additional arguments passed to the method. #' @aliases spark.randomForest,SparkDataFrame,formula-method #' @return \code{spark.randomForest} returns a fitted Random Forest model. @@ -1703,9 +1732,9 @@ print.summary.KSTest <- function(x, ...) { setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, type = c("regression", "classification"), maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL, - minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0, - probabilityCol = "probability", maxMemoryInMB = 256, cacheNodeIds = FALSE) { + minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, + maxMemoryInMB = 256, cacheNodeIds = FALSE) { type <- match.arg(type) formula <- paste(deparse(formula), collapse = "") if (!is.null(seed)) { @@ -1734,7 +1763,7 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo impurity, as.integer(minInstancesPerNode), as.numeric(minInfoGain), as.integer(checkpointInterval), as.character(featureSubsetStrategy), seed, - as.numeric(subsamplingRate), as.character(probabilityCol), + as.numeric(subsamplingRate), as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) new("RandomForestClassificationModel", jobj = jobj) } @@ -1745,11 +1774,11 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo #' @param newData a SparkDataFrame for testing. #' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named -#' "prediction" +#' "prediction". #' @rdname spark.randomForest #' @aliases predict,RandomForestRegressionModel-method #' @export -#' @note predict(randomForestRegressionModel) since 2.1.0 +#' @note predict(RandomForestRegressionModel) since 2.1.0 setMethod("predict", signature(object = "RandomForestRegressionModel"), function(object, newData) { predict_internal(object, newData) @@ -1758,7 +1787,7 @@ setMethod("predict", signature(object = "RandomForestRegressionModel"), #' @rdname spark.randomForest #' @aliases predict,RandomForestClassificationModel-method #' @export -#' @note predict(randomForestClassificationModel) since 2.1.0 +#' @note predict(RandomForestClassificationModel) since 2.1.0 setMethod("predict", signature(object = "RandomForestClassificationModel"), function(object, newData) { predict_internal(object, newData) @@ -1766,8 +1795,8 @@ setMethod("predict", signature(object = "RandomForestClassificationModel"), # Save the Random Forest Regression or Classification model to the input path. -#' @param object A fitted Random Forest regression model or classification model -#' @param path The directory where the model is saved +#' @param object A fitted Random Forest regression model or classification model. +#' @param path The directory where the model is saved. #' @param overwrite Overwrites or not if the output path already exists. Default is FALSE #' which means throw exception if the output path exists. #' @@ -1789,8 +1818,8 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path write_internal(object, path, overwrite) }) -# Get the summary of an RandomForestRegressionModel model -summary.randomForest <- function(model) { +# Create the summary of a tree ensemble model (eg. Random Forest, GBT) +summary.treeEnsemble <- function(model) { jobj <- model@jobj formula <- callJMethod(jobj, "formula") numFeatures <- callJMethod(jobj, "numFeatures") @@ -1807,20 +1836,25 @@ summary.randomForest <- function(model) { jobj = jobj) } -#' @return \code{summary} returns the model's features as lists, depth and number of nodes -#' or number of classes. +# Get the summary of a Random Forest Regression Model + +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list of components includes \code{formula} (formula), +#' \code{numFeatures} (number of features), \code{features} (list of features), +#' \code{featureImportances} (feature importances), \code{numTrees} (number of trees), +#' and \code{treeWeights} (tree weights). #' @rdname spark.randomForest #' @aliases summary,RandomForestRegressionModel-method #' @export #' @note summary(RandomForestRegressionModel) since 2.1.0 setMethod("summary", signature(object = "RandomForestRegressionModel"), function(object) { - ans <- summary.randomForest(object) + ans <- summary.treeEnsemble(object) class(ans) <- "summary.RandomForestRegressionModel" ans }) -# Get the summary of an RandomForestClassificationModel model +# Get the summary of a Random Forest Classification Model #' @rdname spark.randomForest #' @aliases summary,RandomForestClassificationModel-method @@ -1828,13 +1862,13 @@ setMethod("summary", signature(object = "RandomForestRegressionModel"), #' @note summary(RandomForestClassificationModel) since 2.1.0 setMethod("summary", signature(object = "RandomForestClassificationModel"), function(object) { - ans <- summary.randomForest(object) + ans <- summary.treeEnsemble(object) class(ans) <- "summary.RandomForestClassificationModel" ans }) -# Prints the summary of Random Forest Regression Model -print.summary.randomForest <- function(x) { +# Prints the summary of tree ensemble models (eg. Random Forest, GBT) +print.summary.treeEnsemble <- function(x) { jobj <- x$jobj cat("Formula: ", x$formula) cat("\nNumber of features: ", x$numFeatures) @@ -1848,13 +1882,15 @@ print.summary.randomForest <- function(x) { invisible(x) } +# Prints the summary of Random Forest Regression Model + #' @param x summary object of Random Forest regression model or classification model #' returned by \code{summary}. #' @rdname spark.randomForest #' @export #' @note print.summary.RandomForestRegressionModel since 2.1.0 print.summary.RandomForestRegressionModel <- function(x, ...) { - print.summary.randomForest(x) + print.summary.treeEnsemble(x) } # Prints the summary of Random Forest Classification Model @@ -1863,5 +1899,216 @@ print.summary.RandomForestRegressionModel <- function(x, ...) { #' @export #' @note print.summary.RandomForestClassificationModel since 2.1.0 print.summary.RandomForestClassificationModel <- function(x, ...) { - print.summary.randomForest(x) + print.summary.treeEnsemble(x) +} + +#' Gradient Boosted Tree Model for Regression and Classification +#' +#' \code{spark.gbt} fits a Gradient Boosted Tree Regression model or Classification model on a +#' SparkDataFrame. Users can call \code{summary} to get a summary of the fitted +#' Gradient Boosted Tree model, \code{predict} to make predictions on new data, and +#' \code{write.ml}/\code{read.ml} to save/load fitted models. +#' For more details, see +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-regression}{ +#' GBT Regression} and +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-classifier}{ +#' GBT Classification} +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', ':', '+', and '-'. +#' @param type type of model, one of "regression" or "classification", to fit +#' @param maxDepth Maximum depth of the tree (>= 0). +#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing +#' how to split on features at each node. More bins give higher granularity. Must be +#' >= 2 and >= number of categories in any categorical feature. +#' @param maxIter Param for maximum number of iterations (>= 0). +#' @param stepSize Param for Step size to be used for each iteration of optimization. +#' @param lossType Loss function which GBT tries to minimize. +#' For classification, must be "logistic". For regression, must be one of +#' "squared" (L2) and "absolute" (L1), default is "squared". +#' @param seed integer seed for random number generation. +#' @param subsamplingRate Fraction of the training data used for learning each decision tree, in +#' range (0, 1]. +#' @param minInstancesPerNode Minimum number of instances each child must have after split. If a +#' split causes the left or right child to have fewer than +#' minInstancesPerNode, the split will be discarded as invalid. Should be +#' >= 1. +#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. +#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. +#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with +#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching +#' can speed up training of deeper trees. Users can set how often should the +#' cache be checkpointed or disable it by setting checkpointInterval. +#' @param ... additional arguments passed to the method. +#' @aliases spark.gbt,SparkDataFrame,formula-method +#' @return \code{spark.gbt} returns a fitted Gradient Boosted Tree model. +#' @rdname spark.gbt +#' @name spark.gbt +#' @export +#' @examples +#' \dontrun{ +#' # fit a Gradient Boosted Tree Regression Model +#' df <- createDataFrame(longley) +#' model <- spark.gbt(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16) +#' +#' # get the summary of the model +#' summary(model) +#' +#' # make predictions +#' predictions <- predict(model, df) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' +#' # fit a Gradient Boosted Tree Classification Model +#' # label must be binary - Only binary classification is supported for GBT. +#' df <- createDataFrame(iris[iris$Species != "virginica", ]) +#' model <- spark.gbt(df, Species ~ Petal_Length + Petal_Width, "classification") +#' +#' # numeric label is also supported +#' iris2 <- iris[iris$Species != "virginica", ] +#' iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1) +#' df <- createDataFrame(iris2) +#' model <- spark.gbt(df, NumericSpecies ~ ., type = "classification") +#' } +#' @note spark.gbt since 2.1.0 +setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, type = c("regression", "classification"), + maxDepth = 5, maxBins = 32, maxIter = 20, stepSize = 0.1, lossType = NULL, + seed = NULL, subsamplingRate = 1.0, minInstancesPerNode = 1, minInfoGain = 0.0, + checkpointInterval = 10, maxMemoryInMB = 256, cacheNodeIds = FALSE) { + type <- match.arg(type) + formula <- paste(deparse(formula), collapse = "") + if (!is.null(seed)) { + seed <- as.character(as.integer(seed)) + } + switch(type, + regression = { + if (is.null(lossType)) lossType <- "squared" + lossType <- match.arg(lossType, c("squared", "absolute")) + jobj <- callJStatic("org.apache.spark.ml.r.GBTRegressorWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(maxIter), + as.numeric(stepSize), as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + lossType, seed, as.numeric(subsamplingRate), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("GBTRegressionModel", jobj = jobj) + }, + classification = { + if (is.null(lossType)) lossType <- "logistic" + lossType <- match.arg(lossType, "logistic") + jobj <- callJStatic("org.apache.spark.ml.r.GBTClassifierWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(maxIter), + as.numeric(stepSize), as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + lossType, seed, as.numeric(subsamplingRate), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("GBTClassificationModel", jobj = jobj) + } + ) + }) + +# Makes predictions from a Gradient Boosted Tree Regression model or Classification model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named +#' "prediction". +#' @rdname spark.gbt +#' @aliases predict,GBTRegressionModel-method +#' @export +#' @note predict(GBTRegressionModel) since 2.1.0 +setMethod("predict", signature(object = "GBTRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +#' @rdname spark.gbt +#' @aliases predict,GBTClassificationModel-method +#' @export +#' @note predict(GBTClassificationModel) since 2.1.0 +setMethod("predict", signature(object = "GBTClassificationModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Save the Gradient Boosted Tree Regression or Classification model to the input path. + +#' @param object A fitted Gradient Boosted Tree regression model or classification model. +#' @param path The directory where the model is saved. +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' @aliases write.ml,GBTRegressionModel,character-method +#' @rdname spark.gbt +#' @export +#' @note write.ml(GBTRegressionModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "GBTRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' @aliases write.ml,GBTClassificationModel,character-method +#' @rdname spark.gbt +#' @export +#' @note write.ml(GBTClassificationModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "GBTClassificationModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +# Get the summary of a Gradient Boosted Tree Regression Model + +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list of components includes \code{formula} (formula), +#' \code{numFeatures} (number of features), \code{features} (list of features), +#' \code{featureImportances} (feature importances), \code{numTrees} (number of trees), +#' and \code{treeWeights} (tree weights). +#' @rdname spark.gbt +#' @aliases summary,GBTRegressionModel-method +#' @export +#' @note summary(GBTRegressionModel) since 2.1.0 +setMethod("summary", signature(object = "GBTRegressionModel"), + function(object) { + ans <- summary.treeEnsemble(object) + class(ans) <- "summary.GBTRegressionModel" + ans + }) + +# Get the summary of a Gradient Boosted Tree Classification Model + +#' @rdname spark.gbt +#' @aliases summary,GBTClassificationModel-method +#' @export +#' @note summary(GBTClassificationModel) since 2.1.0 +setMethod("summary", signature(object = "GBTClassificationModel"), + function(object) { + ans <- summary.treeEnsemble(object) + class(ans) <- "summary.GBTClassificationModel" + ans + }) + +# Prints the summary of Gradient Boosted Tree Regression Model + +#' @param x summary object of Gradient Boosted Tree regression model or classification model +#' returned by \code{summary}. +#' @rdname spark.gbt +#' @export +#' @note print.summary.GBTRegressionModel since 2.1.0 +print.summary.GBTRegressionModel <- function(x, ...) { + print.summary.treeEnsemble(x) +} + +# Prints the summary of Gradient Boosted Tree Classification Model + +#' @rdname spark.gbt +#' @export +#' @note print.summary.GBTClassificationModel since 2.1.0 +print.summary.GBTClassificationModel <- function(x, ...) { + print.summary.treeEnsemble(x) } diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 6b4a2f2fdc85c..c57cc8f285613 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -322,6 +322,9 @@ sparkRHive.init <- function(jsc = NULL) { #' SparkSession or initializes a new SparkSession. #' Additional Spark properties can be set in \code{...}, and these named parameters take priority #' over values in \code{master}, \code{appName}, named lists of \code{sparkConfig}. +#' When called in an interactive session, this checks for the Spark installation, and, if not +#' found, it will be downloaded and cached automatically. Alternatively, \code{install.spark} can +#' be called manually. #' #' For details on how to initialize and use SparkR, refer to SparkR programming guide at #' \url{http://spark.apache.org/docs/latest/sparkr.html#starting-up-sparksession}. @@ -373,8 +376,13 @@ sparkR.session <- function( overrideEnvs(sparkConfigMap, paramMap) } + deployMode <- "" + if (exists("spark.submit.deployMode", envir = sparkConfigMap)) { + deployMode <- sparkConfigMap[["spark.submit.deployMode"]] + } + if (!exists(".sparkRjsc", envir = .sparkREnv)) { - retHome <- sparkCheckInstall(sparkHome, master) + retHome <- sparkCheckInstall(sparkHome, master, deployMode) if (!is.null(retHome)) sparkHome <- retHome sparkExecutorEnvMap <- new.env() sparkR.sparkContext(master, appName, sparkHome, sparkConfigMap, sparkExecutorEnvMap, @@ -419,7 +427,7 @@ sparkR.session <- function( #' @method setJobGroup default setJobGroup.default <- function(groupId, description, interruptOnCancel) { sc <- getSparkContext() - callJMethod(sc, "setJobGroup", groupId, description, interruptOnCancel) + invisible(callJMethod(sc, "setJobGroup", groupId, description, interruptOnCancel)) } setJobGroup <- function(sc, groupId, description, interruptOnCancel) { @@ -449,7 +457,7 @@ setJobGroup <- function(sc, groupId, description, interruptOnCancel) { #' @method clearJobGroup default clearJobGroup.default <- function() { sc <- getSparkContext() - callJMethod(sc, "clearJobGroup") + invisible(callJMethod(sc, "clearJobGroup")) } clearJobGroup <- function(sc) { @@ -476,7 +484,7 @@ clearJobGroup <- function(sc) { #' @method cancelJobGroup default cancelJobGroup.default <- function(groupId) { sc <- getSparkContext() - callJMethod(sc, "cancelJobGroup", groupId) + invisible(callJMethod(sc, "cancelJobGroup", groupId)) } cancelJobGroup <- function(sc, groupId) { @@ -550,24 +558,27 @@ processSparkPackages <- function(packages) { # # @param sparkHome directory to find Spark package. # @param master the Spark master URL, used to check local or remote mode. +# @param deployMode whether to deploy your driver on the worker nodes (cluster) +# or locally as an external client (client). # @return NULL if no need to update sparkHome, and new sparkHome otherwise. -sparkCheckInstall <- function(sparkHome, master) { +sparkCheckInstall <- function(sparkHome, master, deployMode) { if (!isSparkRShell()) { if (!is.na(file.info(sparkHome)$isdir)) { msg <- paste0("Spark package found in SPARK_HOME: ", sparkHome) message(msg) NULL } else { - if (!nzchar(master) || isMasterLocal(master)) { - msg <- paste0("Spark not found in SPARK_HOME: ", - sparkHome) + if (interactive() || isMasterLocal(master)) { + msg <- paste0("Spark not found in SPARK_HOME: ", sparkHome) message(msg) packageLocalDir <- install.spark() packageLocalDir - } else { + } else if (isClientMode(master) || deployMode == "client") { msg <- paste0("Spark not found in SPARK_HOME: ", sparkHome, "\n", installInstruction("remote")) stop(msg) + } else { + NULL } } } else { diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 20004549cc037..1283449f3592a 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -777,6 +777,10 @@ isMasterLocal <- function(master) { grepl("^local(\\[([0-9]+|\\*)\\])?$", master, perl = TRUE) } +isClientMode <- function(master) { + grepl("([a-z]+)-client$", master, perl = TRUE) +} + isSparkRShell <- function() { grepl(".*shell\\.R$", Sys.getenv("R_PROFILE_USER"), perl = TRUE) } @@ -837,7 +841,7 @@ captureJVMException <- function(e, method) { # # @param inputData a list of rows, with each row a list # @return data.frame with raw columns as lists -rbindRaws <- function(inputData){ +rbindRaws <- function(inputData) { row1 <- inputData[[1]] rawcolumns <- ("raw" == sapply(row1, class)) @@ -847,3 +851,15 @@ rbindRaws <- function(inputData){ out[!rawcolumns] <- lapply(out[!rawcolumns], unlist) out } + +# Get basename without extension from URL +basenameSansExtFromUrl <- function(url) { + # split by '/' + splits <- unlist(strsplit(url, "^.+/")) + last <- tail(splits, 1) + # this is from file_path_sans_ext + # first, remove any compression extension + filename <- sub("[.](gz|bz2|xz)$", "", last) + # then, strip extension by the last '.' + sub("([^.]+)\\.[[:alnum:]]+$", "\\1", filename) +} diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index db98d0e45547e..40c0446740277 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -159,6 +159,15 @@ test_that("spark.glm summary", { df <- suppressWarnings(createDataFrame(data)) regStats <- summary(spark.glm(df, b ~ a1 + a2, regParam = 1.0)) expect_equal(regStats$aic, 13.32836, tolerance = 1e-4) # 13.32836 is from summary() result + + # Test spark.glm works on collinear data + A <- matrix(c(1, 2, 3, 4, 2, 4, 6, 8), 4, 2) + b <- c(1, 2, 3, 4) + data <- as.data.frame(cbind(A, b)) + df <- createDataFrame(data) + stats <- summary(spark.glm(df, b ~ . - 1)) + coefs <- unlist(stats$coefficients) + expect_true(all(abs(c(0.5, 0.25) - coefs) < 1e-4)) }) test_that("spark.glm save/load", { @@ -341,6 +350,8 @@ test_that("spark.kmeans", { # Test summary works on KMeans summary.model <- summary(model) cluster <- summary.model$cluster + k <- summary.model$k + expect_equal(k, 2) expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1)) # Test model save/load @@ -361,12 +372,13 @@ test_that("spark.kmeans", { test_that("spark.mlp", { df <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), source = "libsvm") - model <- spark.mlp(df, blockSize = 128, layers = c(4, 5, 4, 3), solver = "l-bfgs", maxIter = 100, - tol = 0.5, stepSize = 1, seed = 1) + model <- spark.mlp(df, label ~ features, blockSize = 128, layers = c(4, 5, 4, 3), + solver = "l-bfgs", maxIter = 100, tol = 0.5, stepSize = 1, seed = 1) # Test summary method summary <- summary(model) - expect_equal(summary$labelCount, 3) + expect_equal(summary$numOfInputs, 4) + expect_equal(summary$numOfOutputs, 3) expect_equal(summary$layers, c(4, 5, 4, 3)) expect_equal(length(summary$weights), 64) expect_equal(head(summary$weights, 5), list(-0.878743, 0.2154151, -1.16304, -0.6583214, 1.009825), @@ -375,7 +387,7 @@ test_that("spark.mlp", { # Test predict method mlpTestDF <- df mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 6), c(0, 1, 1, 1, 1, 1)) + expect_equal(head(mlpPredictions$prediction, 6), c("1.0", "0.0", "0.0", "0.0", "0.0", "0.0")) # Test model save/load modelPath <- tempfile(pattern = "spark-mlp", fileext = ".tmp") @@ -385,46 +397,68 @@ test_that("spark.mlp", { model2 <- read.ml(modelPath) summary2 <- summary(model2) - expect_equal(summary2$labelCount, 3) + expect_equal(summary2$numOfInputs, 4) + expect_equal(summary2$numOfOutputs, 3) expect_equal(summary2$layers, c(4, 5, 4, 3)) expect_equal(length(summary2$weights), 64) unlink(modelPath) # Test default parameter - model <- spark.mlp(df, layers = c(4, 5, 4, 3)) + model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3)) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 10), c(1, 1, 1, 1, 0, 1, 2, 2, 1, 0)) + expect_equal(head(mlpPredictions$prediction, 10), + c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) # Test illegal parameter - expect_error(spark.mlp(df, layers = NULL), "layers must be a integer vector with length > 1.") - expect_error(spark.mlp(df, layers = c()), "layers must be a integer vector with length > 1.") - expect_error(spark.mlp(df, layers = c(3)), "layers must be a integer vector with length > 1.") + expect_error(spark.mlp(df, label ~ features, layers = NULL), + "layers must be a integer vector with length > 1.") + expect_error(spark.mlp(df, label ~ features, layers = c()), + "layers must be a integer vector with length > 1.") + expect_error(spark.mlp(df, label ~ features, layers = c(3)), + "layers must be a integer vector with length > 1.") # Test random seed # default seed - model <- spark.mlp(df, layers = c(4, 5, 4, 3), maxIter = 10) + model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3), maxIter = 10) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 0, 1, 2, 2, 1, 2, 0, 1)) + expect_equal(head(mlpPredictions$prediction, 10), + c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) # seed equals 10 - model <- spark.mlp(df, layers = c(4, 5, 4, 3), maxIter = 10, seed = 10) + model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3), maxIter = 10, seed = 10) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 2, 1, 2, 2, 1, 0, 0, 1)) + expect_equal(head(mlpPredictions$prediction, 10), + c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) # test initialWeights - model <- spark.mlp(df, layers = c(4, 3), maxIter = 2, initialWeights = + model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2, initialWeights = c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1)) + expect_equal(head(mlpPredictions$prediction, 10), + c("1.0", "1.0", "1.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) - model <- spark.mlp(df, layers = c(4, 3), maxIter = 2, initialWeights = + model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2, initialWeights = c(0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 9.0, 9.0, 9.0, 9.0, 9.0)) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1)) + expect_equal(head(mlpPredictions$prediction, 10), + c("1.0", "1.0", "1.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) - model <- spark.mlp(df, layers = c(4, 3), maxIter = 2) + model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 0, 1, 0, 2, 1, 0, 0, 1)) + expect_equal(head(mlpPredictions$prediction, 10), + c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "0.0", "2.0", "1.0", "0.0")) + + # Test formula works well + df <- suppressWarnings(createDataFrame(iris)) + model <- spark.mlp(df, Species ~ Sepal_Length + Sepal_Width + Petal_Length + Petal_Width, + layers = c(4, 3)) + summary <- summary(model) + expect_equal(summary$numOfInputs, 4) + expect_equal(summary$numOfOutputs, 3) + expect_equal(summary$layers, c(4, 3)) + expect_equal(length(summary$weights), 15) + expect_equal(head(summary$weights, 5), list(-1.1957257, -5.2693685, 7.4489734, -6.3751413, + -10.2376130), tolerance = 1e-6) }) test_that("spark.naiveBayes", { @@ -603,58 +637,141 @@ test_that("spark.isotonicRegression", { }) test_that("spark.logit", { - # test binary logistic regression - label <- c(1.0, 1.0, 1.0, 0.0, 0.0) - feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776) - binary_data <- as.data.frame(cbind(label, feature)) - binary_df <- createDataFrame(binary_data) - - blr_model <- spark.logit(binary_df, label ~ feature, thresholds = 1.0) - blr_predict <- collect(select(predict(blr_model, binary_df), "prediction")) - expect_equal(blr_predict$prediction, c(0, 0, 0, 0, 0)) - blr_model1 <- spark.logit(binary_df, label ~ feature, thresholds = 0.0) - blr_predict1 <- collect(select(predict(blr_model1, binary_df), "prediction")) - expect_equal(blr_predict1$prediction, c(1, 1, 1, 1, 1)) - - # test summary of binary logistic regression - blr_summary <- summary(blr_model) - blr_fmeasure <- collect(select(blr_summary$fMeasureByThreshold, "threshold", "F-Measure")) - expect_equal(blr_fmeasure$threshold, c(0.8221347, 0.7884005, 0.6674709, 0.3785437, 0.3434487), - tolerance = 1e-4) - expect_equal(blr_fmeasure$"F-Measure", c(0.5000000, 0.8000000, 0.6666667, 0.8571429, 0.7500000), - tolerance = 1e-4) - blr_precision <- collect(select(blr_summary$precisionByThreshold, "threshold", "precision")) - expect_equal(blr_precision$precision, c(1.0000000, 1.0000000, 0.6666667, 0.7500000, 0.6000000), - tolerance = 1e-4) - blr_recall <- collect(select(blr_summary$recallByThreshold, "threshold", "recall")) - expect_equal(blr_recall$recall, c(0.3333333, 0.6666667, 0.6666667, 1.0000000, 1.0000000), - tolerance = 1e-4) + # R code to reproduce the result. + # nolint start + #' library(glmnet) + #' iris.x = as.matrix(iris[, 1:4]) + #' iris.y = as.factor(as.character(iris[, 5])) + #' logit = glmnet(iris.x, iris.y, family="multinomial", alpha=0, lambda=0.5) + #' coef(logit) + # + # $setosa + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # 1.0981324 + # Sepal.Length -0.2909860 + # Sepal.Width 0.5510907 + # Petal.Length -0.1915217 + # Petal.Width -0.4211946 + # + # $versicolor + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # 1.520061e+00 + # Sepal.Length 2.524501e-02 + # Sepal.Width -5.310313e-01 + # Petal.Length 3.656543e-02 + # Petal.Width -3.144464e-05 + # + # $virginica + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # -2.61819385 + # Sepal.Length 0.26574097 + # Sepal.Width -0.02005932 + # Petal.Length 0.15495629 + # Petal.Width 0.42122607 + # nolint end - # test model save and read - modelPath <- tempfile(pattern = "spark-logisticRegression", fileext = ".tmp") - write.ml(blr_model, modelPath) - expect_error(write.ml(blr_model, modelPath)) - write.ml(blr_model, modelPath, overwrite = TRUE) - blr_model2 <- read.ml(modelPath) - blr_predict2 <- collect(select(predict(blr_model2, binary_df), "prediction")) - expect_equal(blr_predict$prediction, blr_predict2$prediction) - expect_error(summary(blr_model2)) + # Test multinomial logistic regression againt three classes + df <- suppressWarnings(createDataFrame(iris)) + model <- spark.logit(df, Species ~ ., regParam = 0.5) + summary <- summary(model) + versicolorCoefsR <- c(1.52, 0.03, -0.53, 0.04, 0.00) + virginicaCoefsR <- c(-2.62, 0.27, -0.02, 0.16, 0.42) + setosaCoefsR <- c(1.10, -0.29, 0.55, -0.19, -0.42) + versicolorCoefs <- unlist(summary$coefficients[, "versicolor"]) + virginicaCoefs <- unlist(summary$coefficients[, "virginica"]) + setosaCoefs <- unlist(summary$coefficients[, "setosa"]) + expect_true(all(abs(versicolorCoefsR - versicolorCoefs) < 0.1)) + expect_true(all(abs(virginicaCoefsR - virginicaCoefs) < 0.1)) + expect_true(all(abs(setosaCoefs - setosaCoefs) < 0.1)) + + # Test model save and load + modelPath <- tempfile(pattern = "spark-logit", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + coefs <- summary(model)$coefficients + coefs2 <- summary(model2)$coefficients + expect_equal(coefs, coefs2) unlink(modelPath) - # test multinomial logistic regression - label <- c(0.0, 1.0, 2.0, 0.0, 0.0) - feature1 <- c(4.845940, 5.64480, 7.430381, 6.464263, 5.555667) - feature2 <- c(2.941319, 2.614812, 2.162451, 3.339474, 2.970987) - feature3 <- c(1.322733, 1.348044, 3.861237, 9.686976, 3.447130) - feature4 <- c(1.3246388, 0.5510444, 0.9225810, 1.2147881, 1.6020842) - data <- as.data.frame(cbind(label, feature1, feature2, feature3, feature4)) - df <- createDataFrame(data) + # R code to reproduce the result. + # nolint start + #' library(glmnet) + #' iris2 <- iris[iris$Species %in% c("versicolor", "virginica"), ] + #' iris.x = as.matrix(iris2[, 1:4]) + #' iris.y = as.factor(as.character(iris2[, 5])) + #' logit = glmnet(iris.x, iris.y, family="multinomial", alpha=0, lambda=0.5) + #' coef(logit) + # + # $versicolor + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # 3.93844796 + # Sepal.Length -0.13538675 + # Sepal.Width -0.02386443 + # Petal.Length -0.35076451 + # Petal.Width -0.77971954 + # + # $virginica + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # -3.93844796 + # Sepal.Length 0.13538675 + # Sepal.Width 0.02386443 + # Petal.Length 0.35076451 + # Petal.Width 0.77971954 + # + #' logit = glmnet(iris.x, iris.y, family="binomial", alpha=0, lambda=0.5) + #' coef(logit) + # + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # (Intercept) -6.0824412 + # Sepal.Length 0.2458260 + # Sepal.Width 0.1642093 + # Petal.Length 0.4759487 + # Petal.Width 1.0383948 + # + # nolint end + + # Test multinomial logistic regression againt two classes + df <- suppressWarnings(createDataFrame(iris)) + training <- df[df$Species %in% c("versicolor", "virginica"), ] + model <- spark.logit(training, Species ~ ., regParam = 0.5, family = "multinomial") + summary <- summary(model) + versicolorCoefsR <- c(3.94, -0.16, -0.02, -0.35, -0.78) + virginicaCoefsR <- c(-3.94, 0.16, -0.02, 0.35, 0.78) + versicolorCoefs <- unlist(summary$coefficients[, "versicolor"]) + virginicaCoefs <- unlist(summary$coefficients[, "virginica"]) + expect_true(all(abs(versicolorCoefsR - versicolorCoefs) < 0.1)) + expect_true(all(abs(virginicaCoefsR - virginicaCoefs) < 0.1)) + + # Test binomial logistic regression againt two classes + model <- spark.logit(training, Species ~ ., regParam = 0.5) + summary <- summary(model) + coefsR <- c(-6.08, 0.25, 0.16, 0.48, 1.04) + coefs <- unlist(summary$coefficients[, "Estimate"]) + expect_true(all(abs(coefsR - coefs) < 0.1)) - model <- spark.logit(df, label ~., family = "multinomial", thresholds = c(0, 1, 1)) - predict1 <- collect(select(predict(model, df), "prediction")) - expect_equal(predict1$prediction, c(0, 0, 0, 0, 0)) - # Summary of multinomial logistic regression is not implemented yet - expect_error(summary(model)) + # Test prediction with string label + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "character") + expected <- c("versicolor", "versicolor", "virginica", "versicolor", "versicolor", + "versicolor", "versicolor", "versicolor", "versicolor", "versicolor") + expect_equal(as.list(take(select(prediction, "prediction"), 10))[[1]], expected) + + # Test prediction with numeric label + label <- c(0.0, 0.0, 0.0, 1.0, 1.0) + feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776) + data <- as.data.frame(cbind(label, feature)) + df <- createDataFrame(data) + model <- spark.logit(df, label ~ feature) + prediction <- collect(select(predict(model, df), "prediction")) + expect_equal(prediction$prediction, c("0.0", "0.0", "1.0", "1.0", "0.0")) }) test_that("spark.gaussianMixture", { @@ -811,10 +928,10 @@ test_that("spark.posterior and spark.perplexity", { test_that("spark.als", { data <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), - list(2, 1, 1.0), list(2, 2, 5.0)) + list(2, 1, 1.0), list(2, 2, 5.0)) df <- createDataFrame(data, c("user", "item", "score")) model <- spark.als(df, ratingCol = "score", userCol = "user", itemCol = "item", - rank = 10, maxIter = 5, seed = 0, reg = 0.1) + rank = 10, maxIter = 5, seed = 0, regParam = 0.1) stats <- summary(model) expect_equal(stats$rank, 10) test <- createDataFrame(list(list(0, 2), list(1, 0), list(2, 0)), c("user", "item")) @@ -869,9 +986,16 @@ test_that("spark.kstest", { expect_equal(stats$p.value, rStats$p.value, tolerance = 1e-4) expect_equal(stats$statistic, unname(rStats$statistic), tolerance = 1e-4) expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:") + + # Test print.summary.KSTest + printStats <- capture.output(print.summary.KSTest(stats)) + expect_match(printStats[1], "Kolmogorov-Smirnov test summary:") + expect_match(printStats[5], + "Low presumption against null hypothesis: Sample follows theoretical distribution. ") }) -test_that("spark.randomForest Regression", { +test_that("spark.randomForest", { + # regression data <- suppressWarnings(createDataFrame(longley)) model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, numTrees = 1) @@ -891,10 +1015,11 @@ test_that("spark.randomForest Regression", { model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, numTrees = 20, seed = 123) predictions <- collect(predict(model, data)) - expect_equal(predictions$prediction, c(60.379, 61.096, 60.636, 62.258, - 63.736, 64.296, 64.868, 64.300, - 66.709, 67.697, 67.966, 67.252, - 68.866, 69.593, 69.195, 69.658), + expect_equal(predictions$prediction, c(60.32820, 61.22315, 60.69025, 62.11070, + 63.53160, 64.05470, 65.12710, 64.30450, + 66.70910, 67.86125, 68.08700, 67.21865, + 68.89275, 69.53180, 69.39640, 69.68250), + tolerance = 1e-4) stats <- summary(model) expect_equal(stats$numTrees, 20) @@ -913,9 +1038,8 @@ test_that("spark.randomForest Regression", { expect_equal(stats$treeWeights, stats2$treeWeights) unlink(modelPath) -}) -test_that("spark.randomForest Classification", { + # classification data <- suppressWarnings(createDataFrame(iris)) model <- spark.randomForest(data, Species ~ Petal_Length + Petal_Width, "classification", maxDepth = 5, maxBins = 16) @@ -925,6 +1049,10 @@ test_that("spark.randomForest Classification", { expect_equal(stats$numTrees, 20) expect_error(capture.output(stats), NA) expect_true(length(capture.output(stats)) > 6) + # Test string prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("setosa", predictions)), 50) + expect_equal(length(grep("versicolor", predictions)), 50) modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") write.ml(model, modelPath) @@ -937,6 +1065,106 @@ test_that("spark.randomForest Classification", { expect_equal(stats$numClasses, stats2$numClasses) unlink(modelPath) + + # Test numeric response variable + labelToIndex <- function(species) { + switch(as.character(species), + setosa = 0.0, + versicolor = 1.0, + virginica = 2.0 + ) + } + iris$NumericSpecies <- lapply(iris$Species, labelToIndex) + data <- suppressWarnings(createDataFrame(iris[-5])) + model <- spark.randomForest(data, NumericSpecies ~ Petal_Length + Petal_Width, "classification", + maxDepth = 5, maxBins = 16) + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + # Test numeric prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("1.0", predictions)), 50) + expect_equal(length(grep("2.0", predictions)), 50) + + # spark.randomForest classification can work on libsvm data + data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), + source = "libsvm") + model <- spark.randomForest(data, label ~ features, "classification") + expect_equal(summary(model)$numFeatures, 4) +}) + +test_that("spark.gbt", { + # regression + data <- suppressWarnings(createDataFrame(longley)) + model <- spark.gbt(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, seed = 123) + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187, + 63.221, 63.639, 64.989, 63.761, + 66.019, 67.857, 68.169, 66.513, + 68.655, 69.564, 69.331, 70.551), + tolerance = 1e-4) + stats <- summary(model) + expect_equal(stats$numTrees, 20) + expect_equal(stats$formula, "Employed ~ .") + expect_equal(stats$numFeatures, 6) + expect_equal(length(stats$treeWeights), 20) + + modelPath <- tempfile(pattern = "spark-gbtRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$formula, stats2$formula) + expect_equal(stats$numFeatures, stats2$numFeatures) + expect_equal(stats$features, stats2$features) + expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$treeWeights, stats2$treeWeights) + + unlink(modelPath) + + # classification + # label must be binary - GBTClassifier currently only supports binary classification. + iris2 <- iris[iris$Species != "virginica", ] + data <- suppressWarnings(createDataFrame(iris2)) + model <- spark.gbt(data, Species ~ Petal_Length + Petal_Width, "classification") + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + predictions <- collect(predict(model, data))$prediction + # test string prediction values + expect_equal(length(grep("setosa", predictions)), 50) + expect_equal(length(grep("versicolor", predictions)), 50) + + modelPath <- tempfile(pattern = "spark-gbtClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) + + unlink(modelPath) + + iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1) + df <- suppressWarnings(createDataFrame(iris2)) + m <- spark.gbt(df, NumericSpecies ~ ., type = "classification") + s <- summary(m) + # test numeric prediction values + expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction)) + expect_equal(s$numFeatures, 5) + expect_equal(s$numTrees, 20) + + # spark.gbt classification can work on libsvm data + data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"), + source = "libsvm") + model <- spark.gbt(data, label ~ features, "classification") + expect_equal(summary(model)$numFeatures, 692) }) sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_sparkR.R b/R/pkg/inst/tests/testthat/test_sparkR.R new file mode 100644 index 0000000000000..f73fc6baeccef --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_sparkR.R @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("functions in sparkR.R") + +test_that("sparkCheckInstall", { + # "local, yarn-client, mesos-client" mode, SPARK_HOME was set correctly, + # and the SparkR job was submitted by "spark-submit" + sparkHome <- paste0(tempdir(), "/", "sparkHome") + dir.create(sparkHome) + master <- "" + deployMode <- "" + expect_true(is.null(sparkCheckInstall(sparkHome, master, deployMode))) + unlink(sparkHome, recursive = TRUE) + + # "yarn-cluster, mesos-cluster" mode, SPARK_HOME was not set, + # and the SparkR job was submitted by "spark-submit" + sparkHome <- "" + master <- "" + deployMode <- "" + expect_true(is.null(sparkCheckInstall(sparkHome, master, deployMode))) + + # "yarn-client, mesos-client" mode, SPARK_HOME was not set + sparkHome <- "" + master <- "yarn-client" + deployMode <- "" + expect_error(sparkCheckInstall(sparkHome, master, deployMode)) + sparkHome <- "" + master <- "" + deployMode <- "client" + expect_error(sparkCheckInstall(sparkHome, master, deployMode)) +}) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 806019d7524ff..e8ccff81222d0 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -576,7 +576,7 @@ test_that("test tableNames and tables", { tables <- tables() expect_equal(count(tables), 2) suppressWarnings(dropTempTable("table1")) - dropTempView("table2") + expect_true(dropTempView("table2")) tables <- tables() expect_equal(count(tables), 0) @@ -589,7 +589,7 @@ test_that( newdf <- sql("SELECT * FROM table1 where name = 'Michael'") expect_is(newdf, "SparkDataFrame") expect_equal(count(newdf), 1) - dropTempView("table1") + expect_true(dropTempView("table1")) createOrReplaceTempView(df, "dfView") sqlCast <- collect(sql("select cast('2' as decimal) as x from dfView limit 1")) @@ -600,7 +600,7 @@ test_that( expect_equal(ncol(sqlCast), 1) expect_equal(out[1], " x") expect_equal(out[2], "1 2") - dropTempView("dfView") + expect_true(dropTempView("dfView")) }) test_that("test cache, uncache and clearCache", { @@ -609,7 +609,7 @@ test_that("test cache, uncache and clearCache", { cacheTable("table1") uncacheTable("table1") clearCache() - dropTempView("table1") + expect_true(dropTempView("table1")) }) test_that("insertInto() on a registered table", { @@ -630,13 +630,13 @@ test_that("insertInto() on a registered table", { insertInto(dfParquet2, "table1") expect_equal(count(sql("select * from table1")), 5) expect_equal(first(sql("select * from table1 order by age"))$name, "Michael") - dropTempView("table1") + expect_true(dropTempView("table1")) createOrReplaceTempView(dfParquet, "table1") insertInto(dfParquet2, "table1", overwrite = TRUE) expect_equal(count(sql("select * from table1")), 2) expect_equal(first(sql("select * from table1 order by age"))$name, "Bob") - dropTempView("table1") + expect_true(dropTempView("table1")) unlink(jsonPath2) unlink(parquetPath2) @@ -650,7 +650,7 @@ test_that("tableToDF() returns a new DataFrame", { expect_equal(count(tabledf), 3) tabledf2 <- tableToDF("table1") expect_equal(count(tabledf2), 3) - dropTempView("table1") + expect_true(dropTempView("table1")) }) test_that("toRDD() returns an RRDD", { @@ -1222,16 +1222,16 @@ test_that("column functions", { # Test struct() df <- createDataFrame(list(list(1L, 2L, 3L), list(4L, 5L, 6L)), schema = c("a", "b", "c")) - result <- collect(select(df, struct("a", "c"))) + result <- collect(select(df, alias(struct("a", "c"), "d"))) expected <- data.frame(row.names = 1:2) - expected$"struct(a, c)" <- list(listToStruct(list(a = 1L, c = 3L)), - listToStruct(list(a = 4L, c = 6L))) + expected$"d" <- list(listToStruct(list(a = 1L, c = 3L)), + listToStruct(list(a = 4L, c = 6L))) expect_equal(result, expected) - result <- collect(select(df, struct(df$a, df$b))) + result <- collect(select(df, alias(struct(df$a, df$b), "d"))) expected <- data.frame(row.names = 1:2) - expected$"struct(a, b)" <- list(listToStruct(list(a = 1L, b = 2L)), - listToStruct(list(a = 4L, b = 5L))) + expected$"d" <- list(listToStruct(list(a = 1L, b = 2L)), + listToStruct(list(a = 4L, b = 5L))) expect_equal(result, expected) # Test encode(), decode() @@ -2659,7 +2659,7 @@ test_that("Call DataFrameWriter.save() API in Java without path and check argume # It makes sure that we can omit path argument in write.df API and then it calls # DataFrameWriter.save() without path. expect_error(write.df(df, source = "csv"), - "Error in save : illegal argument - 'path' is not specified") + "Error in save : illegal argument - Expected exactly one path to be specified") expect_error(write.json(df, jsonPath), "Error in json : analysis error - path file:.*already exists") expect_error(write.text(df, jsonPath), @@ -2667,7 +2667,7 @@ test_that("Call DataFrameWriter.save() API in Java without path and check argume expect_error(write.orc(df, jsonPath), "Error in orc : analysis error - path file:.*already exists") expect_error(write.parquet(df, jsonPath), - "Error in parquet : analysis error - path file:.*already exists") + "Error in parquet : analysis error - path file:.*already exists") # Arguments checking in R side. expect_error(write.df(df, "data.tmp", source = c(1, 2)), @@ -2684,7 +2684,7 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume # It makes sure that we can omit path argument in read.df API and then it calls # DataFrameWriter.load() without path. expect_error(read.df(source = "json"), - paste("Error in loadDF : analysis error - Unable to infer schema for JSON at .", + paste("Error in loadDF : analysis error - Unable to infer schema for JSON.", "It must be specified manually")) expect_error(read.df("arbitrary_path"), "Error in loadDF : analysis error - Path does not exist") expect_error(read.json("arbitrary_path"), "Error in json : analysis error - Path does not exist") diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 607c407f04f97..c87524842876e 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -228,4 +228,15 @@ test_that("varargsToStrEnv", { expect_warning(varargsToStrEnv(1, 2, 3, 4), "Unnamed arguments ignored: 1, 2, 3, 4.") }) +test_that("basenameSansExtFromUrl", { + x <- paste0("http://people.apache.org/~pwendell/spark-nightly/spark-branch-2.1-bin/spark-2.1.1-", + "SNAPSHOT-2016_12_09_11_08-eb2d9bf-bin/spark-2.1.1-SNAPSHOT-bin-hadoop2.7.tgz") + y <- paste0("http://people.apache.org/~pwendell/spark-releases/spark-2.1.0-rc2-bin/spark-2.1.0-", + "bin-hadoop2.4-without-hive.tgz") + expect_equal(basenameSansExtFromUrl(x), "spark-2.1.1-SNAPSHOT-bin-hadoop2.7") + expect_equal(basenameSansExtFromUrl(y), "spark-2.1.0-bin-hadoop2.4-without-hive") + z <- "http://people.apache.org/~pwendell/spark-releases/spark-2.1.0--hive.tar.gz" + expect_equal(basenameSansExtFromUrl(z), "spark-2.1.0--hive") +}) + sparkR.session.stop() diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 80e876027bddb..fa2656c008660 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -1,12 +1,13 @@ --- title: "SparkR - Practical Guide" output: - html_document: - theme: united + rmarkdown::html_vignette: toc: true toc_depth: 4 - toc_float: true - highlight: textmate +vignette: > + %\VignetteIndexEntry{SparkR - Practical Guide} + %\VignetteEngine{knitr::rmarkdown} + \usepackage[utf8]{inputenc} --- ## Overview @@ -93,13 +94,13 @@ sparkR.session.stop() Different from many other R packages, to use SparkR, you need an additional installation of Apache Spark. The Spark installation will be used to run a backend process that will compile and execute SparkR programs. -If you don't have Spark installed on the computer, you may download it from [Apache Spark Website](http://spark.apache.org/downloads.html). Alternatively, we provide an easy-to-use function `install.spark` to complete this process. You don't have to call it explicitly. We will check the installation when `sparkR.session` is called and `install.spark` function will be triggered automatically if no installation is found. +After installing the SparkR package, you can call `sparkR.session` as explained in the previous section to start and it will check for the Spark installation. If you are working with SparkR from an interactive shell (eg. R, RStudio) then Spark is downloaded and cached automatically if it is not found. Alternatively, we provide an easy-to-use function `install.spark` for running this manually. If you don't have Spark installed on the computer, you may download it from [Apache Spark Website](http://spark.apache.org/downloads.html). ```{r, eval=FALSE} install.spark() ``` -If you already have Spark installed, you don't have to install again and can pass the `sparkHome` argument to `sparkR.session` to let SparkR know where the Spark installation is. +If you already have Spark installed, you don't have to install again and can pass the `sparkHome` argument to `sparkR.session` to let SparkR know where the existing Spark installation is. ```{r, eval=FALSE} sparkR.session(sparkHome = "/HOME/spark") @@ -446,25 +447,31 @@ head(teenagers) SparkR supports the following machine learning models and algorithms. +* Accelerated Failure Time (AFT) Survival Model + +* Collaborative Filtering with Alternating Least Squares (ALS) + +* Gaussian Mixture Model (GMM) + * Generalized Linear Model (GLM) -* Naive Bayes Model +* Gradient-Boosted Trees (GBT) + +* Isotonic Regression Model * $k$-means Clustering -* Accelerated Failure Time (AFT) Survival Model - -* Gaussian Mixture Model (GMM) +* Kolmogorov-Smirnov Test * Latent Dirichlet Allocation (LDA) -* Multilayer Perceptron Model +* Logistic Regression Model -* Collaborative Filtering with Alternating Least Squares (ALS) +* Multilayer Perceptron Model -* Isotonic Regression Model +* Naive Bayes Model -More will be added in the future. +* Random Forest ### R Formula @@ -525,6 +532,34 @@ gaussianFitted <- predict(gaussianGLM, carsDF) head(select(gaussianFitted, "model", "prediction", "mpg", "wt", "hp")) ``` +#### Random Forest + +`spark.randomForest` fits a [random forest](https://en.wikipedia.org/wiki/Random_forest) classification or regression model on a `SparkDataFrame`. +Users can call `summary` to get a summary of the fitted model, `predict` to make predictions, and `write.ml`/`read.ml` to save/load fitted models. + +In the following example, we use the `longley` dataset to train a random forest and make predictions: + +```{r, warning=FALSE} +df <- createDataFrame(longley) +rfModel <- spark.randomForest(df, Employed ~ ., type = "regression", maxDepth = 2, numTrees = 2) +summary(rfModel) +predictions <- predict(rfModel, df) +``` + +#### Gradient-Boosted Trees + +`spark.gbt` fits a [gradient-boosted tree](https://en.wikipedia.org/wiki/Gradient_boosting) classification or regression model on a `SparkDataFrame`. +Users can call `summary` to get a summary of the fitted model, `predict` to make predictions, and `write.ml`/`read.ml` to save/load fitted models. + +Similar to the random forest example above, we use the `longley` dataset to train a gradient-boosted tree and make predictions: + +```{r, warning=FALSE} +df <- createDataFrame(longley) +gbtModel <- spark.gbt(df, Employed ~ ., type = "regression", maxDepth = 2, maxIter = 2) +summary(gbtModel) +predictions <- predict(gbtModel, df) +``` + #### Naive Bayes Model Naive Bayes model assumes independence among the features. `spark.naiveBayes` fits a [Bernoulli naive Bayes model](https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Bernoulli_naive_Bayes) against a SparkDataFrame. The data should be all categorical. These models are often used for document classification. @@ -564,8 +599,6 @@ head(aftPredictions) #### Gaussian Mixture Model -(Coming in 2.1.0) - `spark.gaussianMixture` fits multivariate [Gaussian Mixture Model](https://en.wikipedia.org/wiki/Mixture_model#Multivariate_Gaussian_mixture_model) (GMM) against a `SparkDataFrame`. [Expectation-Maximization](https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) (EM) is used to approximate the maximum likelihood estimator (MLE) of the model. We use a simulated example to demostrate the usage. @@ -583,8 +616,6 @@ head(select(gmmFitted, "V1", "V2", "prediction")) #### Latent Dirichlet Allocation -(Coming in 2.1.0) - `spark.lda` fits a [Latent Dirichlet Allocation](https://en.wikipedia.org/wiki/Latent_Dirichlet_allocation) model on a `SparkDataFrame`. It is often used in topic modeling in which topics are inferred from a collection of text documents. LDA can be thought of as a clustering algorithm as follows: * Topics correspond to cluster centers, and documents correspond to examples (rows) in a dataset. @@ -599,22 +630,6 @@ To use LDA, we need to specify a `features` column in `data` where each entry re * libSVM: Each entry is a collection of words and will be processed directly. -There are several parameters LDA takes for fitting the model. - -* `k`: number of topics (default 10). - -* `maxIter`: maximum iterations (default 20). - -* `optimizer`: optimizer to train an LDA model, "online" (default) uses [online variational inference](https://www.cs.princeton.edu/~blei/papers/HoffmanBleiBach2010b.pdf). "em" uses [expectation-maximization](https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm). - -* `subsamplingRate`: For `optimizer = "online"`. Fraction of the corpus to be sampled and used in each iteration of mini-batch gradient descent, in range (0, 1] (default 0.05). - -* `topicConcentration`: concentration parameter (commonly named beta or eta) for the prior placed on topic distributions over terms, default -1 to set automatically on the Spark side. Use `summary` to retrieve the effective topicConcentration. Only 1-size numeric is accepted. - -* `docConcentration`: concentration parameter (commonly named alpha) for the prior placed on documents distributions over topics (theta), default -1 to set automatically on the Spark side. Use `summary` to retrieve the effective docConcentration. Only 1-size or k-size numeric is accepted. - -* `maxVocabSize`: maximum vocabulary size, default 1 << 18. - Two more functions are provided for the fitted model. * `spark.posterior` returns a `SparkDataFrame` containing a column of posterior probabilities vectors named "topicDistribution". @@ -653,11 +668,8 @@ perplexity <- spark.perplexity(model, corpusDF) perplexity ``` - #### Multilayer Perceptron -(Coming in 2.1.0) - Multilayer perceptron classifier (MLPC) is a classifier based on the [feedforward artificial neural network](https://en.wikipedia.org/wiki/Feedforward_neural_network). MLPC consists of multiple layers of nodes. Each layer is fully connected to the next layer in the network. Nodes in the input layer represent the input data. All other nodes map inputs to outputs by a linear combination of the inputs with the node’s weights $w$ and bias $b$ and applying an activation function. This can be written in matrix form for MLPC with $K+1$ layers as follows: $$ y(x)=f_K(\ldots f_2(w_2^T f_1(w_1^T x + b_1) + b_2) \ldots + b_K). @@ -677,24 +689,35 @@ The number of nodes $N$ in the output layer corresponds to the number of classes MLPC employs backpropagation for learning the model. We use the logistic loss function for optimization and L-BFGS as an optimization routine. -`spark.mlp` requires at least two columns in `data`: one named `"label"` and the other one `"features"`. The `"features"` column should be in libSVM-format. According to the description above, there are several additional parameters that can be set: +`spark.mlp` requires at least two columns in `data`: one named `"label"` and the other one `"features"`. The `"features"` column should be in libSVM-format. -* `layers`: integer vector containing the number of nodes for each layer. - -* `solver`: solver parameter, supported options: `"gd"` (minibatch gradient descent) or `"l-bfgs"`. - -* `maxIter`: maximum iteration number. - -* `tol`: convergence tolerance of iterations. - -* `stepSize`: step size for `"gd"`. +We use iris data set to show how to use `spark.mlp` in classification. +```{r, warning=FALSE} +df <- createDataFrame(iris) +# fit a Multilayer Perceptron Classification Model +model <- spark.mlp(df, Species ~ ., blockSize = 128, layers = c(4, 3), solver = "l-bfgs", maxIter = 100, tol = 0.5, stepSize = 1, seed = 1, initialWeights = c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) +``` -* `seed`: seed parameter for weights initialization. +To avoid lengthy display, we only present partial results of the model summary. You can check the full result from your sparkR shell. +```{r, include=FALSE} +ops <- options() +options(max.print=5) +``` +```{r} +# check the summary of the fitted model +summary(model) +``` +```{r, include=FALSE} +options(ops) +``` +```{r} +# make predictions use the fitted model +predictions <- predict(model, df) +head(select(predictions, predictions$prediction)) +``` #### Collaborative Filtering -(Coming in 2.1.0) - `spark.als` learns latent factors in [collaborative filtering](https://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) via [alternating least squares](http://dl.acm.org/citation.cfm?id=1608614). There are multiple options that can be configured in `spark.als`, including `rank`, `reg`, `nonnegative`. For a complete list, refer to the help file. @@ -724,8 +747,6 @@ head(predicted) #### Isotonic Regression Model -(Coming in 2.1.0) - `spark.isoreg` fits an [Isotonic Regression](https://en.wikipedia.org/wiki/Isotonic_regression) model against a `SparkDataFrame`. It solves a weighted univariate a regression problem under a complete order constraint. Specifically, given a set of real observed responses $y_1, \ldots, y_n$, corresponding real features $x_1, \ldots, x_n$, and optionally positive weights $w_1, \ldots, w_n$, we want to find a monotone (piecewise linear) function $f$ to minimize $$ \ell(f) = \sum_{i=1}^n w_i (y_i - f(x_i))^2. @@ -767,8 +788,60 @@ newDF <- createDataFrame(data.frame(x = c(1.5, 3.2))) head(predict(isoregModel, newDF)) ``` -#### What's More? -We also expect Decision Tree, Random Forest, Kolmogorov-Smirnov Test coming in the next version 2.1.0. +#### Logistic Regression Model + +[Logistic regression](https://en.wikipedia.org/wiki/Logistic_regression) is a widely-used model when the response is categorical. It can be seen as a special case of the [Generalized Linear Predictive Model](https://en.wikipedia.org/wiki/Generalized_linear_model). +We provide `spark.logit` on top of `spark.glm` to support logistic regression with advanced hyper-parameters. +It supports both binary and multiclass classification with elastic-net regularization and feature standardization, similar to `glmnet`. + +We use a simple example to demonstrate `spark.logit` usage. In general, there are three steps of using `spark.logit`: +1). Create a dataframe from a proper data source; 2). Fit a logistic regression model using `spark.logit` with a proper parameter setting; +and 3). Obtain the coefficient matrix of the fitted model using `summary` and use the model for prediction with `predict`. + +Binomial logistic regression +```{r, warning=FALSE} +df <- createDataFrame(iris) +# Create a DataFrame containing two classes +training <- df[df$Species %in% c("versicolor", "virginica"), ] +model <- spark.logit(training, Species ~ ., regParam = 0.00042) +summary(model) +``` + +Predict values on training data +```{r} +fitted <- predict(model, training) +``` + +Multinomial logistic regression against three classes +```{r, warning=FALSE} +df <- createDataFrame(iris) +# Note in this case, Spark infers it is multinomial logistic regression, so family = "multinomial" is optional. +model <- spark.logit(df, Species ~ ., regParam = 0.056) +summary(model) +``` + +#### Kolmogorov-Smirnov Test + +`spark.kstest` runs a two-sided, one-sample [Kolmogorov-Smirnov (KS) test](https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test). +Given a `SparkDataFrame`, the test compares continuous data in a given column `testCol` with the theoretical distribution +specified by parameter `nullHypothesis`. +Users can call `summary` to get a summary of the test results. + +In the following example, we test whether the `longley` dataset's `Armed_Forces` column +follows a normal distribution. We set the parameters of the normal distribution using +the mean and standard deviation of the sample. + +```{r, warning=FALSE} +df <- createDataFrame(longley) +afStats <- head(select(df, mean(df$Armed_Forces), sd(df$Armed_Forces))) +afMean <- afStats[1] +afStd <- afStats[2] + +test <- spark.kstest(df, "Armed_Forces", "norm", c(afMean, afStd)) +testSummary <- summary(test) +testSummary +``` + ### Model Persistence The following example shows how to save/load an ML model by SparkR. diff --git a/README.md b/README.md index dd7d0e22495b3..cb747225a11d4 100644 --- a/README.md +++ b/README.md @@ -1,104 +1,34 @@ -# Apache Spark +# Apache Spark On Kubernetes -Spark is a fast and general cluster computing system for Big Data. It provides -high-level APIs in Scala, Java, Python, and R, and an optimized engine that -supports general computation graphs for data analysis. It also supports a -rich set of higher-level tools including Spark SQL for SQL and DataFrames, -MLlib for machine learning, GraphX for graph processing, -and Spark Streaming for stream processing. +This repository, located at https://github.com/apache-spark-on-k8s/spark, contains a fork of Apache Spark that enables running Spark jobs natively on a Kubernetes cluster. - +## What is this? +This is a collaboratively maintained project working on [SPARK-18278](https://issues.apache.org/jira/browse/SPARK-18278). The goal is to bring native support for Spark to use Kubernetes as a cluster manager, in a fully supported way on par with the Spark Standalone, Mesos, and Apache YARN cluster managers. -## Online Documentation +## Getting Started -You can find the latest Spark documentation, including a programming -guide, on the [project web page](http://spark.apache.org/documentation.html) -and [project wiki](https://cwiki.apache.org/confluence/display/SPARK). -This README file only contains basic setup instructions. +- [Usage guide](https://apache-spark-on-k8s.github.io/userdocs/) shows how to run the code +- [Development docs](resource-managers/kubernetes/README.md) shows how to get set up for development +- Code is primarily located in the [resource-managers/kubernetes](resource-managers/kubernetes) folder -## Building Spark +## Why does this fork exist? -Spark is built using [Apache Maven](http://maven.apache.org/). -To build Spark and its example programs, run: +Adding native integration for a new cluster manager is a large undertaking. If poorly executed, it could introduce bugs into Spark when run on other cluster managers, cause release blockers slowing down the overall Spark project, or require hotfixes which divert attention away from development towards managing additional releases. Any work this deep inside Spark needs to be done carefully to minimize the risk of those negative externalities. - build/mvn -DskipTests clean package +At the same time, an increasing number of people from various companies and organizations desire to work together to natively run Spark on Kubernetes. The group needs a code repository, communication forum, issue tracking, and continuous integration, all in order to work together effectively on an open source product. -(You do not need to do this if you downloaded a pre-built package.) +We've been asked by an Apache Spark Committer to work outside of the Apache infrastructure for a short period of time to allow this feature to be hardened and improved without creating risk for Apache Spark. The aim is to rapidly bring it to the point where it can be brought into the mainline Apache Spark repository for continued development within the Apache umbrella. If all goes well, this should be a short-lived fork rather than a long-lived one. -You can build Spark using more than one thread by using the -T option with Maven, see ["Parallel builds in Maven 3"](https://cwiki.apache.org/confluence/display/MAVEN/Parallel+builds+in+Maven+3). -More detailed documentation is available from the project site, at -["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html). -For developing Spark using an IDE, see [Eclipse](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-Eclipse) -and [IntelliJ](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-IntelliJ). +## Who are we? -## Interactive Scala Shell +This is a collaborative effort by several folks from different companies who are interested in seeing this feature be successful. Companies active in this project include (alphabetically): -The easiest way to start using Spark is through the Scala shell: - - ./bin/spark-shell - -Try the following command, which should return 1000: - - scala> sc.parallelize(1 to 1000).count() - -## Interactive Python Shell - -Alternatively, if you prefer Python, you can use the Python shell: - - ./bin/pyspark - -And run the following command, which should also return 1000: - - >>> sc.parallelize(range(1000)).count() - -## Example Programs - -Spark also comes with several sample programs in the `examples` directory. -To run one of them, use `./bin/run-example [params]`. For example: - - ./bin/run-example SparkPi - -will run the Pi example locally. - -You can set the MASTER environment variable when running examples to submit -examples to a cluster. This can be a mesos:// or spark:// URL, -"yarn" to run on YARN, and "local" to run -locally with one thread, or "local[N]" to run locally with N threads. You -can also use an abbreviated class name if the class is in the `examples` -package. For instance: - - MASTER=spark://host:7077 ./bin/run-example SparkPi - -Many of the example programs print usage help if no params are given. - -## Running Tests - -Testing first requires [building Spark](#building-spark). Once Spark is built, tests -can be run using: - - ./dev/run-tests - -Please see the guidance on how to -[run tests for a module, or individual tests](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools). - -## A Note About Hadoop Versions - -Spark uses the Hadoop core library to talk to HDFS and other Hadoop-supported -storage systems. Because the protocols have changed in different versions of -Hadoop, you must build Spark against the same version that your cluster runs. - -Please refer to the build documentation at -["Specifying the Hadoop Version"](http://spark.apache.org/docs/latest/building-spark.html#specifying-the-hadoop-version) -for detailed guidance on building for a particular distribution of Hadoop, including -building for particular Hive and Hive Thriftserver distributions. - -## Configuration - -Please refer to the [Configuration Guide](http://spark.apache.org/docs/latest/configuration.html) -in the online documentation for an overview on how to configure Spark. - -## Contributing - -Please review the [Contribution to Spark](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) -wiki for information on how to get started contributing to the project. +- Bloomberg +- Google +- Haiwen +- Hyperpilot +- Intel +- Palantir +- Pepperdata +- Red Hat \ No newline at end of file diff --git a/assembly/pom.xml b/assembly/pom.xml index ec243eaebaea7..a4f695e790ce3 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.1.0-k8s-0.2.0-SNAPSHOT ../pom.xml @@ -148,6 +148,16 @@ + + kubernetes + + + org.apache.spark + spark-kubernetes_${scala.binary.version} + ${project.version} + + + hive diff --git a/bin/beeline b/bin/beeline index 1627626941a73..058534699e44b 100755 --- a/bin/beeline +++ b/bin/beeline @@ -25,7 +25,7 @@ set -o posix # Figure out if SPARK_HOME is set if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi CLASS="org.apache.hive.beeline.BeeLine" diff --git a/bin/find-spark-home b/bin/find-spark-home new file mode 100755 index 0000000000000..fa78407d4175a --- /dev/null +++ b/bin/find-spark-home @@ -0,0 +1,41 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Attempts to find a proper value for SPARK_HOME. Should be included using "source" directive. + +FIND_SPARK_HOME_PYTHON_SCRIPT="$(cd "$(dirname "$0")"; pwd)/find_spark_home.py" + +# Short cirtuit if the user already has this set. +if [ ! -z "${SPARK_HOME}" ]; then + exit 0 +elif [ ! -f "$FIND_SPARK_HOME_PYTHON_SCRIPT" ]; then + # If we are not in the same directory as find_spark_home.py we are not pip installed so we don't + # need to search the different Python directories for a Spark installation. + # Note only that, if the user has pip installed PySpark but is directly calling pyspark-shell or + # spark-submit in another directory we want to use that version of PySpark rather than the + # pip installed version of PySpark. + export SPARK_HOME="$(cd "$(dirname "$0")"/..; pwd)" +else + # We are pip installed, use the Python script to resolve a reasonable SPARK_HOME + # Default to standard python interpreter unless told otherwise + if [[ -z "$PYSPARK_DRIVER_PYTHON" ]]; then + PYSPARK_DRIVER_PYTHON="${PYSPARK_PYTHON:-"python"}" + fi + export SPARK_HOME=$($PYSPARK_DRIVER_PYTHON "$FIND_SPARK_HOME_PYTHON_SCRIPT") +fi diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index eaea964ed5b3d..8a2f709960a25 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -23,7 +23,7 @@ # Figure out where Spark is installed if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi if [ -z "$SPARK_ENV_LOADED" ]; then diff --git a/bin/pyspark b/bin/pyspark index d6b3ab0a44321..98387c2ec5b8a 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi source "${SPARK_HOME}"/bin/load-spark-env.sh @@ -46,7 +46,7 @@ WORKS_WITH_IPYTHON=$(python -c 'import sys; print(sys.version_info >= (2, 7, 0)) # Determine the Python executable to use for the executors: if [[ -z "$PYSPARK_PYTHON" ]]; then - if [[ $PYSPARK_DRIVER_PYTHON == *ipython* && ! WORKS_WITH_IPYTHON ]]; then + if [[ $PYSPARK_DRIVER_PYTHON == *ipython* && ! $WORKS_WITH_IPYTHON ]]; then echo "IPython requires Python 2.7+; please install python2.7 or set PYSPARK_PYTHON" 1>&2 exit 1 else @@ -68,7 +68,7 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR export PYTHONHASHSEED=0 - exec "$PYSPARK_DRIVER_PYTHON" -m $1 + exec "$PYSPARK_DRIVER_PYTHON" -m "$1" exit fi diff --git a/bin/run-example b/bin/run-example index dd0e3c4120260..4ba5399311d33 100755 --- a/bin/run-example +++ b/bin/run-example @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi export _SPARK_CMD_USAGE="Usage: ./bin/run-example [options] example-class [example args]" diff --git a/bin/spark-class b/bin/spark-class index 377c8d1add3f6..77ea40cc37946 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi . "${SPARK_HOME}"/bin/load-spark-env.sh @@ -27,7 +27,7 @@ fi if [ -n "${JAVA_HOME}" ]; then RUNNER="${JAVA_HOME}/bin/java" else - if [ `command -v java` ]; then + if [ "$(command -v java)" ]; then RUNNER="java" else echo "JAVA_HOME is not set" >&2 @@ -36,7 +36,7 @@ else fi # Find Spark jars. -if [ -f "${SPARK_HOME}/RELEASE" ]; then +if [ -d "${SPARK_HOME}/jars" ]; then SPARK_JARS_DIR="${SPARK_HOME}/jars" else SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars" diff --git a/bin/spark-shell b/bin/spark-shell index 6583b5bd880ee..421f36cac3d47 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -21,7 +21,7 @@ # Shell script for starting the Spark Shell REPL cygwin=false -case "`uname`" in +case "$(uname)" in CYGWIN*) cygwin=true;; esac @@ -29,7 +29,7 @@ esac set -o posix if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi export _SPARK_CMD_USAGE="Usage: ./bin/spark-shell [options]" diff --git a/bin/spark-sql b/bin/spark-sql index 970d12cbf51dd..b08b944ebd319 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi export _SPARK_CMD_USAGE="Usage: ./bin/spark-sql [options] [cli option]" diff --git a/bin/spark-submit b/bin/spark-submit index 023f9c162f4b8..4e9d3614e6370 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi # disable randomized hash for string in Python 3.3+ diff --git a/bin/sparkR b/bin/sparkR index 2c07a82e2173b..29ab10df8ab6d 100755 --- a/bin/sparkR +++ b/bin/sparkR @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi source "${SPARK_HOME}"/bin/load-spark-env.sh diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index fcefe64d59c91..58889a55cf651 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.1.0-k8s-0.2.0-SNAPSHOT ../../pom.xml diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 511e1f29de368..2daacc14d42b5 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.1.0-k8s-0.2.0-SNAPSHOT ../../pom.xml diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/kubernetes/KubernetesExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/kubernetes/KubernetesExternalShuffleClient.java new file mode 100644 index 0000000000000..49cb5243e32dc --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/kubernetes/KubernetesExternalShuffleClient.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.kubernetes; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.shuffle.ExternalShuffleClient; +import org.apache.spark.network.shuffle.protocol.RegisterDriver; +import org.apache.spark.network.util.TransportConf; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.ByteBuffer; + +/** + * A client for talking to the external shuffle service in Kubernetes cluster mode. + * + * This is used by the each Spark executor to register with a corresponding external + * shuffle service on the cluster. The purpose is for cleaning up shuffle files + * reliably if the application exits unexpectedly. + */ +public class KubernetesExternalShuffleClient extends ExternalShuffleClient { + private static final Logger logger = LoggerFactory + .getLogger(KubernetesExternalShuffleClient.class); + + /** + * Creates an Kubernetes external shuffle client that wraps the {@link ExternalShuffleClient}. + * Please refer to docs on {@link ExternalShuffleClient} for more information. + */ + public KubernetesExternalShuffleClient( + TransportConf conf, + SecretKeyHolder secretKeyHolder, + boolean saslEnabled, + boolean saslEncryptionEnabled) { + super(conf, secretKeyHolder, saslEnabled, saslEncryptionEnabled); + } + + public void registerDriverWithShuffleService(String host, int port) throws IOException { + checkInit(); + ByteBuffer registerDriver = new RegisterDriver(appId, 0).toByteBuffer(); + TransportClient client = clientFactory.createClient(host, port); + client.sendRpc(registerDriver, new RegisterDriverCallback()); + } + + private class RegisterDriverCallback implements RpcResponseCallback { + @Override + public void onSuccess(ByteBuffer response) { + logger.info("Successfully registered app " + appId + " with external shuffle service."); + } + + @Override + public void onFailure(Throwable e) { + logger.warn("Unable to register app " + appId + " with external shuffle service. " + + "Please manually remove shuffle data after driver exit. Error: " + e); + } + } + + @Override + public void close() { + super.close(); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java index 42cedd9943150..e36cfd165db30 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java @@ -32,7 +32,7 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.shuffle.ExternalShuffleClient; -import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver; +import org.apache.spark.network.shuffle.protocol.RegisterDriver; import org.apache.spark.network.util.TransportConf; /** diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index 9af6759f5d5f3..6012a84599368 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -23,7 +23,6 @@ import io.netty.buffer.Unpooled; import org.apache.spark.network.protocol.Encodable; -import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver; import org.apache.spark.network.shuffle.protocol.mesos.ShuffleServiceHeartbeat; /** diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterDriver.java similarity index 91% rename from common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterDriver.java index d5f53ccb7f741..ac606e6539f3e 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterDriver.java @@ -15,19 +15,18 @@ * limitations under the License. */ -package org.apache.spark.network.shuffle.protocol.mesos; +package org.apache.spark.network.shuffle.protocol; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; import org.apache.spark.network.protocol.Encoders; -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; // Needed by ScalaDoc. See SPARK-7726 import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; /** - * A message sent from the driver to register with the MesosExternalShuffleService. + * A message sent from the driver to register with an ExternalShuffleService. */ public class RegisterDriver extends BlockTransferMessage { private final String appId; diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index 606ad15739617..e14b4748efca9 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.1.0-k8s-0.2.0-SNAPSHOT ../../pom.xml diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index 626f023a5b99c..24fd97315ef4e 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.1.0-k8s-0.2.0-SNAPSHOT ../../pom.xml diff --git a/common/tags/pom.xml b/common/tags/pom.xml index 1c60d510e5703..e07e51c34ec93 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.1.0-k8s-0.2.0-SNAPSHOT ../../pom.xml diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index 45af98d94ef91..0bf7005b32eeb 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.1.0-k8s-0.2.0-SNAPSHOT ../../pom.xml diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java index 518ed6470a753..fd6e95c3e0a38 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java @@ -252,6 +252,10 @@ public static long parseSecondNano(String secondNano) throws IllegalArgumentExce public final int months; public final long microseconds; + public long milliseconds() { + return this.microseconds / MICROS_PER_MILLI; + } + public CalendarInterval(int months, long microseconds) { this.months = months; this.microseconds = microseconds; diff --git a/conf/kubernetes-resource-staging-server.yaml b/conf/kubernetes-resource-staging-server.yaml new file mode 100644 index 0000000000000..025b9b125d9e0 --- /dev/null +++ b/conf/kubernetes-resource-staging-server.yaml @@ -0,0 +1,79 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +--- +apiVersion: extensions/v1beta1 +kind: Deployment +metadata: + name: spark-resource-staging-server +spec: + replicas: 1 + template: + metadata: + labels: + resource-staging-server-instance: default + spec: + volumes: + - name: resource-staging-server-properties + configMap: + name: spark-resource-staging-server-config + containers: + - name: spark-resource-staging-server + image: kubespark/spark-resource-staging-server:v2.1.0-kubernetes-0.2.0 + resources: + requests: + cpu: 100m + memory: 256Mi + limits: + cpu: 100m + memory: 256Mi + volumeMounts: + - name: resource-staging-server-properties + mountPath: '/etc/spark-resource-staging-server' + args: + - '/etc/spark-resource-staging-server/resource-staging-server.properties' +--- +apiVersion: v1 +kind: ConfigMap +metadata: + name: spark-resource-staging-server-config +data: + resource-staging-server.properties: | + spark.kubernetes.resourceStagingServer.port=10000 + spark.ssl.kubernetes.resourceStagingServer.enabled=false +# Other possible properties are listed below, primarily for setting up TLS. The paths given by KeyStore, password, and PEM files here should correspond to +# files that are securely mounted into the resource staging server container, via e.g. secret volumes. +# spark.ssl.kubernetes.resourceStagingServer.keyStore=/mnt/secrets/resource-staging-server/keyStore.jks +# spark.ssl.kubernetes.resourceStagingServer.keyStorePassword=changeit +# spark.ssl.kubernetes.resourceStagingServer.keyPassword=changeit +# spark.ssl.kubernetes.resourceStagingServer.keyStorePasswordFile=/mnt/secrets/resource-staging-server/keystore-password.txt +# spark.ssl.kubernetes.resourceStagingServer.keyPasswordFile=/mnt/secrets/resource-staging-server/keystore-key-password.txt +# spark.ssl.kubernetes.resourceStagingServer.keyPem=/mnt/secrets/resource-staging-server/key.pem +# spark.ssl.kubernetes.resourceStagingServer.serverCertPem=/mnt/secrets/resource-staging-server/cert.pem +--- +apiVersion: v1 +kind: Service +metadata: + name: spark-resource-staging-service +spec: + type: NodePort + selector: + resource-staging-server-instance: default + ports: + - protocol: TCP + port: 10000 + targetPort: 10000 + nodePort: 31000 diff --git a/conf/kubernetes-shuffle-service.yaml b/conf/kubernetes-shuffle-service.yaml new file mode 100644 index 0000000000000..55c170b01a4f5 --- /dev/null +++ b/conf/kubernetes-shuffle-service.yaml @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +apiVersion: extensions/v1beta1 +kind: DaemonSet +metadata: + labels: + app: spark-shuffle-service + spark-version: 2.1.0 + name: shuffle +spec: + template: + metadata: + labels: + app: spark-shuffle-service + spark-version: 2.1.0 + spec: + volumes: + - name: temp-volume + hostPath: + path: '/var/tmp' # change this path according to your cluster configuration. + containers: + - name: shuffle + # This is an official image that is built + # from the dockerfiles/shuffle directory + # in the spark distribution. + image: kubespark/spark-shuffle:v2.1.0-kubernetes-0.2.0 + imagePullPolicy: IfNotPresent + volumeMounts: + - mountPath: '/tmp' + name: temp-volume + # more volumes can be mounted here. + # The spark job must be configured to use these + # mounts using the configuration: + # spark.kubernetes.shuffle.dir=,,... + resources: + requests: + cpu: "1" + limits: + cpu: "1" \ No newline at end of file diff --git a/core/pom.xml b/core/pom.xml index eac99ab82a2e4..9cac063dc62e7 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.1.0-k8s-0.2.0-SNAPSHOT ../pom.xml diff --git a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java index f6d1288cb263d..ea5f1a9abf69b 100644 --- a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java +++ b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java @@ -130,8 +130,10 @@ public synchronized void close() throws IOException { StorageUtils.dispose(byteBuffer); } + //checkstyle.off: NoFinalizer @Override protected void finalize() throws IOException { close(); } + //checkstyle.on: NoFinalizer } diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 1a700aa37554e..c40974b54cb47 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -378,14 +378,14 @@ public long cleanUpAllAllocatedMemory() { for (MemoryConsumer c: consumers) { if (c != null && c.getUsed() > 0) { // In case of failed task, it's normal to see leaked memory - logger.warn("leak " + Utils.bytesToString(c.getUsed()) + " memory from " + c); + logger.debug("unreleased " + Utils.bytesToString(c.getUsed()) + " memory from " + c); } } consumers.clear(); for (MemoryBlock page : pageTable) { if (page != null) { - logger.warn("leak a page: " + page + " in task " + taskAttemptId); + logger.debug("unreleased page: " + page + " in task " + taskAttemptId); memoryManager.tungstenMemoryAllocator().free(page); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index f235c434be7b1..8a1771848dee6 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -40,6 +40,8 @@ import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; +import org.apache.commons.io.output.CloseShieldOutputStream; +import org.apache.commons.io.output.CountingOutputStream; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; @@ -264,6 +266,7 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true); final boolean fastMergeIsSupported = !compressionEnabled || CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); + final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled(); try { if (spills.length == 0) { new FileOutputStream(outputFile).close(); // Create an empty file @@ -289,7 +292,7 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti // Compression is disabled or we are using an IO compression codec that supports // decompression of concatenated compressed streams, so we can perform a fast spill merge // that doesn't need to interpret the spilled bytes. - if (transferToEnabled) { + if (transferToEnabled && !encryptionEnabled) { logger.debug("Using transferTo-based fast merge"); partitionLengths = mergeSpillsWithTransferTo(spills, outputFile); } else { @@ -320,9 +323,9 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti /** * Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge, * {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in - * cases where the IO compression codec does not support concatenation of compressed data, or in - * cases where users have explicitly disabled use of {@code transferTo} in order to work around - * kernel bugs. + * cases where the IO compression codec does not support concatenation of compressed data, when + * encryption is enabled, or when users have explicitly disabled use of {@code transferTo} in + * order to work around kernel bugs. * * @param spills the spills to merge. * @param outputFile the file to write the merged data to. @@ -337,7 +340,11 @@ private long[] mergeSpillsWithFileStream( final int numPartitions = partitioner.numPartitions(); final long[] partitionLengths = new long[numPartitions]; final InputStream[] spillInputStreams = new FileInputStream[spills.length]; - OutputStream mergedFileOutputStream = null; + + // Use a counting output stream to avoid having to close the underlying file and ask + // the file system for its size after each partition is written. + final CountingOutputStream mergedFileOutputStream = new CountingOutputStream( + new FileOutputStream(outputFile)); boolean threwException = true; try { @@ -345,34 +352,35 @@ private long[] mergeSpillsWithFileStream( spillInputStreams[i] = new FileInputStream(spills[i].file); } for (int partition = 0; partition < numPartitions; partition++) { - final long initialFileLength = outputFile.length(); - mergedFileOutputStream = - new TimeTrackingOutputStream(writeMetrics, new FileOutputStream(outputFile, true)); + final long initialFileLength = mergedFileOutputStream.getByteCount(); + // Shield the underlying output stream from close() calls, so that we can close the higher + // level streams to make sure all data is really flushed and internal state is cleaned. + OutputStream partitionOutput = new CloseShieldOutputStream( + new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream)); + partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); if (compressionCodec != null) { - mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream); + partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); } - for (int i = 0; i < spills.length; i++) { final long partitionLengthInSpill = spills[i].partitionLengths[partition]; if (partitionLengthInSpill > 0) { - InputStream partitionInputStream = null; - boolean innerThrewException = true; + InputStream partitionInputStream = new LimitedInputStream(spillInputStreams[i], + partitionLengthInSpill, false); try { - partitionInputStream = - new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill, false); + partitionInputStream = blockManager.serializerManager().wrapForEncryption( + partitionInputStream); if (compressionCodec != null) { partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); } - ByteStreams.copy(partitionInputStream, mergedFileOutputStream); - innerThrewException = false; + ByteStreams.copy(partitionInputStream, partitionOutput); } finally { - Closeables.close(partitionInputStream, innerThrewException); + partitionInputStream.close(); } } } - mergedFileOutputStream.flush(); - mergedFileOutputStream.close(); - partitionLengths[partition] = (outputFile.length() - initialFileLength); + partitionOutput.flush(); + partitionOutput.close(); + partitionLengths[partition] = (mergedFileOutputStream.getByteCount() - initialFileLength); } threwException = false; } finally { diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index d2fcdea4f2cee..44120e591f2fb 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -170,6 +170,8 @@ public final class BytesToBytesMap extends MemoryConsumer { private long peakMemoryUsedBytes = 0L; + private final int initialCapacity; + private final BlockManager blockManager; private final SerializerManager serializerManager; private volatile MapIterator destructiveIterator = null; @@ -202,6 +204,7 @@ public BytesToBytesMap( throw new IllegalArgumentException("Page size " + pageSizeBytes + " cannot exceed " + TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES); } + this.initialCapacity = initialCapacity; allocate(initialCapacity); } @@ -902,12 +905,12 @@ public LongArray getArray() { public void reset() { numKeys = 0; numValues = 0; - longArray.zeroOut(); - + freeArray(longArray); while (dataPages.size() > 0) { MemoryBlock dataPage = dataPages.removeLast(); freePage(dataPage); } + allocate(initialCapacity); currentPage = null; pageCursor = 0; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java index 404361734a55b..3dd318471008b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java @@ -17,6 +17,8 @@ package org.apache.spark.util.collection.unsafe.sort; +import com.google.common.primitives.Ints; + import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; @@ -40,14 +42,14 @@ public class RadixSort { * of always copying the data back to position zero for efficiency. */ public static int sort( - LongArray array, int numRecords, int startByteIndex, int endByteIndex, + LongArray array, long numRecords, int startByteIndex, int endByteIndex, boolean desc, boolean signed) { assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0"; assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7"; assert endByteIndex > startByteIndex; assert numRecords * 2 <= array.size(); - int inIndex = 0; - int outIndex = numRecords; + long inIndex = 0; + long outIndex = numRecords; if (numRecords > 0) { long[][] counts = getCounts(array, numRecords, startByteIndex, endByteIndex); for (int i = startByteIndex; i <= endByteIndex; i++) { @@ -55,13 +57,13 @@ public static int sort( sortAtByte( array, numRecords, counts[i], i, inIndex, outIndex, desc, signed && i == endByteIndex); - int tmp = inIndex; + long tmp = inIndex; inIndex = outIndex; outIndex = tmp; } } } - return inIndex; + return Ints.checkedCast(inIndex); } /** @@ -78,14 +80,14 @@ public static int sort( * @param signed whether this is a signed (two's complement) sort (only applies to last byte). */ private static void sortAtByte( - LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex, + LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex, boolean desc, boolean signed) { assert counts.length == 256; long[] offsets = transformCountsToOffsets( - counts, numRecords, array.getBaseOffset() + outIndex * 8, 8, desc, signed); + counts, numRecords, array.getBaseOffset() + outIndex * 8L, 8, desc, signed); Object baseObject = array.getBaseObject(); - long baseOffset = array.getBaseOffset() + inIndex * 8; - long maxOffset = baseOffset + numRecords * 8; + long baseOffset = array.getBaseOffset() + inIndex * 8L; + long maxOffset = baseOffset + numRecords * 8L; for (long offset = baseOffset; offset < maxOffset; offset += 8) { long value = Platform.getLong(baseObject, offset); int bucket = (int)((value >>> (byteIdx * 8)) & 0xff); @@ -106,13 +108,13 @@ private static void sortAtByte( * significant byte. If the byte does not need sorting the array will be null. */ private static long[][] getCounts( - LongArray array, int numRecords, int startByteIndex, int endByteIndex) { + LongArray array, long numRecords, int startByteIndex, int endByteIndex) { long[][] counts = new long[8][]; // Optimization: do a fast pre-pass to determine which byte indices we can skip for sorting. // If all the byte values at a particular index are the same we don't need to count it. long bitwiseMax = 0; long bitwiseMin = -1L; - long maxOffset = array.getBaseOffset() + numRecords * 8; + long maxOffset = array.getBaseOffset() + numRecords * 8L; Object baseObject = array.getBaseObject(); for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) { long value = Platform.getLong(baseObject, offset); @@ -146,18 +148,18 @@ private static long[][] getCounts( * @return the input counts array. */ private static long[] transformCountsToOffsets( - long[] counts, int numRecords, long outputOffset, int bytesPerRecord, + long[] counts, long numRecords, long outputOffset, long bytesPerRecord, boolean desc, boolean signed) { assert counts.length == 256; int start = signed ? 128 : 0; // output the negative records first (values 129-255). if (desc) { - int pos = numRecords; + long pos = numRecords; for (int i = start; i < start + 256; i++) { pos -= counts[i & 0xff]; counts[i & 0xff] = outputOffset + pos * bytesPerRecord; } } else { - int pos = 0; + long pos = 0; for (int i = start; i < start + 256; i++) { long tmp = counts[i & 0xff]; counts[i & 0xff] = outputOffset + pos * bytesPerRecord; @@ -176,8 +178,8 @@ private static long[] transformCountsToOffsets( */ public static int sortKeyPrefixArray( LongArray array, - int startIndex, - int numRecords, + long startIndex, + long numRecords, int startByteIndex, int endByteIndex, boolean desc, @@ -186,8 +188,8 @@ public static int sortKeyPrefixArray( assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7"; assert endByteIndex > startByteIndex; assert numRecords * 4 <= array.size(); - int inIndex = startIndex; - int outIndex = startIndex + numRecords * 2; + long inIndex = startIndex; + long outIndex = startIndex + numRecords * 2L; if (numRecords > 0) { long[][] counts = getKeyPrefixArrayCounts( array, startIndex, numRecords, startByteIndex, endByteIndex); @@ -196,13 +198,13 @@ public static int sortKeyPrefixArray( sortKeyPrefixArrayAtByte( array, numRecords, counts[i], i, inIndex, outIndex, desc, signed && i == endByteIndex); - int tmp = inIndex; + long tmp = inIndex; inIndex = outIndex; outIndex = tmp; } } } - return inIndex; + return Ints.checkedCast(inIndex); } /** @@ -210,7 +212,7 @@ public static int sortKeyPrefixArray( * getCounts with some added parameters but that seems to hurt in benchmarks. */ private static long[][] getKeyPrefixArrayCounts( - LongArray array, int startIndex, int numRecords, int startByteIndex, int endByteIndex) { + LongArray array, long startIndex, long numRecords, int startByteIndex, int endByteIndex) { long[][] counts = new long[8][]; long bitwiseMax = 0; long bitwiseMin = -1L; @@ -238,11 +240,11 @@ private static long[][] getKeyPrefixArrayCounts( * Specialization of sortAtByte() for key-prefix arrays. */ private static void sortKeyPrefixArrayAtByte( - LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex, + LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex, boolean desc, boolean signed) { assert counts.length == 256; long[] offsets = transformCountsToOffsets( - counts, numRecords, array.getBaseOffset() + outIndex * 8, 16, desc, signed); + counts, numRecords, array.getBaseOffset() + outIndex * 8L, 16, desc, signed); Object baseObject = array.getBaseObject(); long baseOffset = array.getBaseOffset() + inIndex * 8L; long maxOffset = baseOffset + numRecords * 16L; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 2a71e68adafad..252a35ec6bdf5 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -322,7 +322,7 @@ public UnsafeSorterIterator getSortedIterator() { if (sortComparator != null) { if (this.radixSortSupport != null) { offset = RadixSort.sortKeyPrefixArray( - array, nullBoundaryPos, (pos - nullBoundaryPos) / 2, 0, 7, + array, nullBoundaryPos, (pos - nullBoundaryPos) / 2L, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned()); } else { MemoryBlock unused = new MemoryBlock( diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js index 1df67337ea031..fa0282678d1f4 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js @@ -54,7 +54,28 @@ $(document).ajaxStart(function () { $.blockUI({message: '

Loading Executors Page...

'}); }); +function findKubernetesServiceBaseURI() { + var k8sProxyPattern = '/api/v1/proxy/namespaces/'; + var k8sProxyPatternPos = document.baseURI.indexOf(k8sProxyPattern); + if (k8sProxyPatternPos > 0) { + // Spark is running in a kubernetes cluster, and the web ui is served + // through the kubectl proxy. + var remaining = document.baseURI.substr(k8sProxyPatternPos + k8sProxyPattern.length); + var urlSlashesCount = remaining.split('/').length - 3; + var words = document.baseURI.split('/'); + var baseURI = words.slice(0, words.length - urlSlashesCount).join('/'); + return baseURI; + } + + return null; +} + function createTemplateURI(appId) { + var kubernetesBaseURI = findKubernetesServiceBaseURI(); + if (kubernetesBaseURI) { + return kubernetesBaseURI + '/static/executorspage-template.html'; + } + var words = document.baseURI.split('/'); var ind = words.indexOf("proxy"); if (ind > 0) { @@ -70,6 +91,14 @@ function createTemplateURI(appId) { } function getStandAloneppId(cb) { + var kubernetesBaseURI = findKubernetesServiceBaseURI(); + if (kubernetesBaseURI) { + var appIdAndPort = kubernetesBaseURI.split('/').slice(-1)[0]; + var appId = appIdAndPort.split(':')[0]; + cb(appId); + return; + } + var words = document.baseURI.split('/'); var ind = words.indexOf("proxy"); if (ind > 0) { @@ -95,6 +124,11 @@ function getStandAloneppId(cb) { } function createRESTEndPoint(appId) { + var kubernetesBaseURI = findKubernetesServiceBaseURI(); + if (kubernetesBaseURI) { + return kubernetesBaseURI + "/api/v1/applications/" + appId + "/allexecutors"; + } + var words = document.baseURI.split('/'); var ind = words.indexOf("proxy"); if (ind > 0) { @@ -411,10 +445,6 @@ $(document).ready(function () { } ], "columnDefs": [ - { - "targets": [ 15 ], - "visible": logsExist(response) - }, { "targets": [ 16 ], "visible": getThreadDumpEnabled() @@ -423,7 +453,8 @@ $(document).ready(function () { "order": [[0, "asc"]] }; - $(selector).DataTable(conf); + var dt = $(selector).DataTable(conf); + dt.column(15).visible(logsExist(response)); $('#active-executors [data-toggle="tooltip"]').tooltip(); var sumSelector = "#summary-execs-table"; diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-common.js b/core/src/main/resources/org/apache/spark/ui/static/historypage-common.js new file mode 100644 index 0000000000000..55d540d8317a0 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-common.js @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +$(document).ready(function() { + if ($('#last-updated').length) { + var lastUpdatedMillis = Number($('#last-updated').text()); + var updatedDate = new Date(lastUpdatedMillis); + $('#last-updated').text(updatedDate.toLocaleDateString()+", "+updatedDate.toLocaleTimeString()) + } +}); diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html index 1fd6ef4a71253..42e2d9abdeb5e 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html @@ -68,16 +68,16 @@ {{#applications}} - {{id}} + {{id}} {{name}} {{#attempts}} - {{attemptId}} + {{attemptId}} {{startTime}} {{endTime}} {{duration}} {{sparkUser}} {{lastUpdated}} - Download + Download {{/attempts}} {{/applications}} diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index 2a32e18672a22..8fd91865b0429 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -119,7 +119,11 @@ $(document).ready(function() { } } - var data = {"applications": array} + var data = { + "uiroot": uiRoot, + "applications": array + } + $.get("static/historypage-template.html", function(template) { historySummary.append(Mustache.render($(template).filter("#history-summary-template").html(),data)); var selector = "#history-summary-table"; @@ -135,6 +139,9 @@ $(document).ready(function() { {name: 'eighth'}, {name: 'ninth'}, ], + "columnDefs": [ + {"searchable": false, "targets": [5]} + ], "autoWidth": false, "order": [[ 4, "desc" ]] }; diff --git a/core/src/main/resources/org/apache/spark/ui/static/table.js b/core/src/main/resources/org/apache/spark/ui/static/table.js index 14b06bfe860ed..0315ebf5c48a9 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/table.js +++ b/core/src/main/resources/org/apache/spark/ui/static/table.js @@ -36,7 +36,7 @@ function toggleThreadStackTrace(threadId, forceAdd) { if (stackTrace.length == 0) { var stackTraceText = $('#' + threadId + "_td_stacktrace").html() var threadCell = $("#thread_" + threadId + "_tr") - threadCell.after("
" +
+        threadCell.after("
" +
             stackTraceText +  "
") } else { if (!forceAdd) { @@ -73,6 +73,7 @@ function onMouseOverAndOut(threadId) { $("#" + threadId + "_td_id").toggleClass("threaddump-td-mouseover"); $("#" + threadId + "_td_name").toggleClass("threaddump-td-mouseover"); $("#" + threadId + "_td_state").toggleClass("threaddump-td-mouseover"); + $("#" + threadId + "_td_locking").toggleClass("threaddump-td-mouseover"); } function onSearchStringChange() { diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.js b/core/src/main/resources/org/apache/spark/ui/static/webui.js index e37307aa1f705..0fa1fcf25f8b9 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.js +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.js @@ -15,6 +15,12 @@ * limitations under the License. */ +var uiRoot = ""; + +function setUIRoot(val) { + uiRoot = val; +} + function collapseTablePageLoad(name, table){ if (window.localStorage.getItem(name) == "true") { // Set it to false so that the click function can revert it diff --git a/core/src/main/scala/org/apache/spark/Accumulator.scala b/core/src/main/scala/org/apache/spark/Accumulator.scala index 9d1f1d59dbce1..7bea636c94aa0 100644 --- a/core/src/main/scala/org/apache/spark/Accumulator.scala +++ b/core/src/main/scala/org/apache/spark/Accumulator.scala @@ -26,7 +26,7 @@ package org.apache.spark * * An accumulator is created from an initial value `v` by calling * [[SparkContext#accumulator SparkContext.accumulator]]. - * Tasks running on the cluster can then add to it using the [[Accumulable#+= +=]] operator. + * Tasks running on the cluster can then add to it using the `+=` operator. * However, they cannot read its value. Only the driver program can read the accumulator's value, * using its [[#value]] method. * diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 5678d790e9e76..af913454fce69 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -139,7 +139,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { periodicGCService.shutdown() } - /** Register a RDD for cleanup when it is garbage collected. */ + /** Register an RDD for cleanup when it is garbage collected. */ def registerRDDForCleanup(rdd: RDD[_]): Unit = { registerForCleanup(rdd, CleanRDD(rdd.id)) } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 7f8f0f513134f..6f5c31d7ab71c 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -322,7 +322,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, if (minSizeForBroadcast > maxRpcMessageSize) { val msg = s"spark.shuffle.mapOutput.minSizeForBroadcast ($minSizeForBroadcast bytes) must " + s"be <= spark.rpc.message.maxSize ($maxRpcMessageSize bytes) to prevent sending an rpc " + - "message that is to large." + "message that is too large." logError(msg) throw new IllegalArgumentException(msg) } diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 93dfbc0e6ed65..f83f5278e8b8f 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -101,7 +101,7 @@ class HashPartitioner(partitions: Int) extends Partitioner { * A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly * equal ranges. The ranges are determined by sampling the content of the RDD passed in. * - * Note that the actual number of partitions created by the RangePartitioner might not be the same + * @note The actual number of partitions created by the RangePartitioner might not be the same * as the `partitions` parameter, in the case where the number of sampled records is less than * the value of `partitions`. */ diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index be19179b00a49..5f14102c3c366 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -150,8 +150,8 @@ private[spark] object SSLOptions extends Logging { * $ - `[ns].enabledAlgorithms` - a comma separated list of ciphers * * For a list of protocols and ciphers supported by particular Java versions, you may go to - * [[https://blogs.oracle.com/java-platform-group/entry/diagnosing_tls_ssl_and_https Oracle - * blog page]]. + * + * Oracle blog page. * * You can optionally specify the default configuration. If you do, for each setting which is * missing in SparkConf, the corresponding setting is used from the default configuration. diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 199365ad925a3..87fe56315203e 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -21,7 +21,6 @@ import java.lang.{Byte => JByte} import java.net.{Authenticator, PasswordAuthentication} import java.security.{KeyStore, SecureRandom} import java.security.cert.X509Certificate -import javax.crypto.KeyGenerator import javax.net.ssl._ import com.google.common.hash.HashCodes @@ -33,7 +32,6 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.sasl.SecretKeyHolder -import org.apache.spark.security.CryptoStreamUtils._ import org.apache.spark.util.Utils /** @@ -185,7 +183,9 @@ import org.apache.spark.util.Utils * setting `spark.ssl.useNodeLocalConf` to `true`. */ -private[spark] class SecurityManager(sparkConf: SparkConf) +private[spark] class SecurityManager( + sparkConf: SparkConf, + ioEncryptionKey: Option[Array[Byte]] = None) extends Logging with SecretKeyHolder { import SecurityManager._ @@ -415,6 +415,8 @@ private[spark] class SecurityManager(sparkConf: SparkConf) logInfo("Changing acls enabled to: " + aclsOn) } + def getIOEncryptionKey(): Option[Array[Byte]] = ioEncryptionKey + /** * Generates or looks up the secret key. * @@ -559,19 +561,4 @@ private[spark] object SecurityManager { // key used to store the spark secret in the Hadoop UGI val SECRET_LOOKUP_KEY = "sparkCookie" - /** - * Setup the cryptographic key used by IO encryption in credentials. The key is generated using - * [[KeyGenerator]]. The algorithm and key length is specified by the [[SparkConf]]. - */ - def initIOEncryptionKey(conf: SparkConf, credentials: Credentials): Unit = { - if (credentials.getSecretKey(SPARK_IO_TOKEN) == null) { - val keyLen = conf.get(IO_ENCRYPTION_KEY_SIZE_BITS) - val ioKeyGenAlgorithm = conf.get(IO_ENCRYPTION_KEYGEN_ALGORITHM) - val keyGen = KeyGenerator.getInstance(ioKeyGenAlgorithm) - keyGen.init(keyLen) - - val ioKey = keyGen.generateKey() - credentials.addSecretKey(SPARK_IO_TOKEN, ioKey.getEncoded) - } - } } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index c9c342df82c97..d78b9f1b29685 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -42,10 +42,10 @@ import org.apache.spark.util.Utils * All setter methods in this class support chaining. For example, you can write * `new SparkConf().setMaster("local").setAppName("My app")`. * - * Note that once a SparkConf object is passed to Spark, it is cloned and can no longer be modified - * by the user. Spark does not support modifying the configuration at runtime. - * * @param loadDefaults whether to also load values from Java system properties + * + * @note Once a SparkConf object is passed to Spark, it is cloned and can no longer be modified + * by the user. Spark does not support modifying the configuration at runtime. */ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Serializable { @@ -262,7 +262,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria /** * Get a time parameter as seconds; throws a NoSuchElementException if it's not set. If no * suffix is provided then seconds are assumed. - * @throws NoSuchElementException + * @throws java.util.NoSuchElementException */ def getTimeAsSeconds(key: String): Long = { Utils.timeStringAsSeconds(get(key)) @@ -279,7 +279,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria /** * Get a time parameter as milliseconds; throws a NoSuchElementException if it's not set. If no * suffix is provided then milliseconds are assumed. - * @throws NoSuchElementException + * @throws java.util.NoSuchElementException */ def getTimeAsMs(key: String): Long = { Utils.timeStringAsMs(get(key)) @@ -296,7 +296,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria /** * Get a size parameter as bytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then bytes are assumed. - * @throws NoSuchElementException + * @throws java.util.NoSuchElementException */ def getSizeAsBytes(key: String): Long = { Utils.byteStringAsBytes(get(key)) @@ -320,7 +320,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria /** * Get a size parameter as Kibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Kibibytes are assumed. - * @throws NoSuchElementException + * @throws java.util.NoSuchElementException */ def getSizeAsKb(key: String): Long = { Utils.byteStringAsKb(get(key)) @@ -337,7 +337,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria /** * Get a size parameter as Mebibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Mebibytes are assumed. - * @throws NoSuchElementException + * @throws java.util.NoSuchElementException */ def getSizeAsMb(key: String): Long = { Utils.byteStringAsMb(get(key)) @@ -354,7 +354,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria /** * Get a size parameter as Gibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Gibibytes are assumed. - * @throws NoSuchElementException + * @throws java.util.NoSuchElementException */ def getSizeAsGb(key: String): Long = { Utils.byteStringAsGb(get(key)) @@ -378,7 +378,9 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria settings.entrySet().asScala.map(x => (x.getKey, x.getValue)).toArray } - /** Get all parameters that start with `prefix` */ + /** + * Get all parameters that start with `prefix` + */ def getAllWithPrefix(prefix: String): Array[(String, String)] = { getAll.filter { case (k, v) => k.startsWith(prefix) } .map { case (k, v) => (k.substring(prefix.length), v) } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 4694790c72cd8..b6aeeb9559ec8 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -19,7 +19,7 @@ package org.apache.spark import java.io._ import java.lang.reflect.Constructor -import java.net.{MalformedURLException, URI} +import java.net.{URI} import java.util.{Arrays, Locale, Properties, ServiceLoader, UUID} import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap} import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicReference} @@ -183,6 +183,8 @@ class SparkContext(config: SparkConf) extends Logging { // log out Spark Version in Spark driver log logInfo(s"Running Spark version $SPARK_VERSION") + warnDeprecatedVersions() + /* ------------------------------------------------------------------------------------- * | Private variables. These variables keep the internal state of the context, and are | | not accessible by the outside world. They're mutable since we want to initialize all | @@ -279,7 +281,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. * - * '''Note:''' As it will be reused in all Hadoop RDDs, it's better not to modify it unless you + * @note As it will be reused in all Hadoop RDDs, it's better not to modify it unless you * plan to set some global configurations for all Hadoop RDDs. */ def hadoopConfiguration: Configuration = _hadoopConfiguration @@ -346,6 +348,16 @@ class SparkContext(config: SparkConf) extends Logging { value } + private def warnDeprecatedVersions(): Unit = { + val javaVersion = System.getProperty("java.version").split("[+.\\-]+", 3) + if (javaVersion.length >= 2 && javaVersion(1).toInt == 7) { + logWarning("Support for Java 7 is deprecated as of Spark 2.0.0") + } + if (scala.util.Properties.releaseVersion.exists(_.startsWith("2.10"))) { + logWarning("Support for Scala 2.10 is deprecated as of Spark 2.1.0") + } + } + /** Control our logLevel. This overrides any user-defined log settings. * @param logLevel The desired log level as a string. * Valid log levels include: ALL, DEBUG, ERROR, FATAL, INFO, OFF, TRACE, WARN @@ -410,10 +422,6 @@ class SparkContext(config: SparkConf) extends Logging { } if (master == "yarn" && deployMode == "client") System.setProperty("SPARK_YARN_MODE", "true") - if (_conf.get(IO_ENCRYPTION_ENABLED) && !SparkHadoopUtil.get.isYarnMode()) { - throw new SparkException("IO encryption is only supported in YARN mode, please disable it " + - s"by setting ${IO_ENCRYPTION_ENABLED.key} to false") - } // "_jobProgressListener" should be set up before creating SparkEnv because when creating // "SparkEnv", some messages will be posted to "listenerBus" and we should not miss them. @@ -633,7 +641,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * Get a local property set in this thread, or null if it is missing. See - * [[org.apache.spark.SparkContext.setLocalProperty]]. + * `org.apache.spark.SparkContext.setLocalProperty`. */ def getLocalProperty(key: String): String = Option(localProperties.get).map(_.getProperty(key)).orNull @@ -651,7 +659,7 @@ class SparkContext(config: SparkConf) extends Logging { * Application programmers can use this method to group all those jobs together and give a * group description. Once set, the Spark web UI will associate such jobs with this group. * - * The application can also use [[org.apache.spark.SparkContext.cancelJobGroup]] to cancel all + * The application can also use `org.apache.spark.SparkContext.cancelJobGroup` to cancel all * running jobs in this group. For example, * {{{ * // In the main thread: @@ -688,7 +696,7 @@ class SparkContext(config: SparkConf) extends Logging { * Execute a block of code in a scope such that all new RDDs created in this body will * be part of the same scope. For more detail, see {{org.apache.spark.rdd.RDDOperationScope}}. * - * Note: Return statements are NOT allowed in the given body. + * @note Return statements are NOT allowed in the given body. */ private[spark] def withScope[U](body: => U): U = RDDOperationScope.withScope[U](this)(body) @@ -915,7 +923,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * Load data from a flat binary file, assuming the length of each record is constant. * - * '''Note:''' We ensure that the byte array for each record in the resulting RDD + * @note We ensure that the byte array for each record in the resulting RDD * has the provided record length. * * @param path Directory to the input data files, the path can be comma separated paths as the @@ -958,7 +966,7 @@ class SparkContext(config: SparkConf) extends Logging { * @param valueClass Class of the values * @param minPartitions Minimum number of Hadoop Splits to generate. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -983,7 +991,7 @@ class SparkContext(config: SparkConf) extends Logging { /** Get an RDD for a Hadoop file with an arbitrary InputFormat * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1022,7 +1030,7 @@ class SparkContext(config: SparkConf) extends Logging { * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path, minPartitions) * }}} * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1046,7 +1054,7 @@ class SparkContext(config: SparkConf) extends Logging { * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path) * }}} * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1072,7 +1080,7 @@ class SparkContext(config: SparkConf) extends Logging { * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat * and extra configuration options to pass to the input format. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1112,7 +1120,7 @@ class SparkContext(config: SparkConf) extends Logging { * @param kClass Class of the keys * @param vClass Class of the values * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1138,7 +1146,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * Get an RDD for a Hadoop SequenceFile with given key and value types. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1157,7 +1165,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * Get an RDD for a Hadoop SequenceFile with given key and value types. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1187,7 +1195,7 @@ class SparkContext(config: SparkConf) extends Logging { * for the appropriate type. In addition, we pass the converter a ClassTag of its type to * allow it to figure out the Writable class to use in the subclass case. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1318,16 +1326,18 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Register the given accumulator. Note that accumulators must be registered before use, or it - * will throw exception. + * Register the given accumulator. + * + * @note Accumulators must be registered before use, or it will throw exception. */ def register(acc: AccumulatorV2[_, _]): Unit = { acc.register(this) } /** - * Register the given accumulator with given name. Note that accumulators must be registered - * before use, or it will throw exception. + * Register the given accumulator with given name. + * + * @note Accumulators must be registered before use, or it will throw exception. */ def register(acc: AccumulatorV2[_, _], name: String): Unit = { acc.register(this, name = Some(name)) @@ -1370,7 +1380,7 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Create and register a [[CollectionAccumulator]], which starts with empty list and accumulates + * Create and register a `CollectionAccumulator`, which starts with empty list and accumulates * inputs by adding them into the list. */ def collectionAccumulator[T]: CollectionAccumulator[T] = { @@ -1380,7 +1390,7 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Create and register a [[CollectionAccumulator]], which starts with empty list and accumulates + * Create and register a `CollectionAccumulator`, which starts with empty list and accumulates * inputs by adding them into the list. */ def collectionAccumulator[T](name: String): CollectionAccumulator[T] = { @@ -1538,7 +1548,7 @@ class SparkContext(config: SparkConf) extends Logging { * :: DeveloperApi :: * Request that the cluster manager kill the specified executors. * - * Note: This is an indication to the cluster manager that the application wishes to adjust + * @note This is an indication to the cluster manager that the application wishes to adjust * its resource usage downwards. If the application wishes to replace the executors it kills * through this method with new ones, it should follow up explicitly with a call to * {{SparkContext#requestExecutors}}. @@ -1560,7 +1570,7 @@ class SparkContext(config: SparkConf) extends Logging { * :: DeveloperApi :: * Request that the cluster manager kill the specified executor. * - * Note: This is an indication to the cluster manager that the application wishes to adjust + * @note This is an indication to the cluster manager that the application wishes to adjust * its resource usage downwards. If the application wishes to replace the executor it kills * through this method with a new one, it should follow up explicitly with a call to * {{SparkContext#requestExecutors}}. @@ -1578,7 +1588,7 @@ class SparkContext(config: SparkConf) extends Logging { * this request. This assumes the cluster manager will automatically and eventually * fulfill all missing application resource requests. * - * Note: The replace is by no means guaranteed; another application on the same cluster + * @note The replace is by no means guaranteed; another application on the same cluster * can steal the window of opportunity and acquire this application's resources in the * mean time. * @@ -1627,7 +1637,8 @@ class SparkContext(config: SparkConf) extends Logging { /** * Returns an immutable map of RDDs that have marked themselves as persistent via cache() call. - * Note that this does not necessarily mean the caching or computation was successful. + * + * @note This does not necessarily mean the caching or computation was successful. */ def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap @@ -1716,29 +1727,12 @@ class SparkContext(config: SparkConf) extends Logging { key = uri.getScheme match { // A JAR file which exists only on the driver node case null | "file" => - if (master == "yarn" && deployMode == "cluster") { - // In order for this to work in yarn cluster mode the user must specify the - // --addJars option to the client to upload the file into the distributed cache - // of the AM to make it show up in the current working directory. - val fileName = new Path(uri.getPath).getName() - try { - env.rpcEnv.fileServer.addJar(new File(fileName)) - } catch { - case e: Exception => - // For now just log an error but allow to go through so spark examples work. - // The spark examples don't really need the jar distributed since its also - // the app jar. - logError("Error adding jar (" + e + "), was the --addJars option used?") - null - } - } else { - try { - env.rpcEnv.fileServer.addJar(new File(uri.getPath)) - } catch { - case exc: FileNotFoundException => - logError(s"Jar not found at $path") - null - } + try { + env.rpcEnv.fileServer.addJar(new File(uri.getPath)) + } catch { + case exc: FileNotFoundException => + logError(s"Jar not found at $path") + null } // A JAR file which exists locally on every worker node case "local" => @@ -1762,8 +1756,31 @@ class SparkContext(config: SparkConf) extends Logging { */ def listJars(): Seq[String] = addedJars.keySet.toSeq - // Shut down the SparkContext. - def stop() { + /** + * When stopping SparkContext inside Spark components, it's easy to cause dead-lock since Spark + * may wait for some internal threads to finish. It's better to use this method to stop + * SparkContext instead. + */ + private[spark] def stopInNewThread(): Unit = { + new Thread("stop-spark-context") { + setDaemon(true) + + override def run(): Unit = { + try { + SparkContext.this.stop() + } catch { + case e: Throwable => + logError(e.getMessage, e) + throw e + } + } + }.start() + } + + /** + * Shut down the SparkContext. + */ + def stop(): Unit = { if (LiveListenerBus.withinListenerThread.value) { throw new SparkException( s"Cannot stop SparkContext within listener thread of ${LiveListenerBus.name}") @@ -2027,7 +2044,7 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Cancel active jobs for the specified group. See [[org.apache.spark.SparkContext.setJobGroup]] + * Cancel active jobs for the specified group. See `org.apache.spark.SparkContext.setJobGroup` * for more information. */ def cancelJobGroup(groupId: String) { @@ -2045,7 +2062,7 @@ class SparkContext(config: SparkConf) extends Logging { * Cancel a given job if it's scheduled or running. * * @param jobId the job ID to cancel - * @throws InterruptedException if the cancel message cannot be sent + * @note Throws `InterruptedException` if the cancel message cannot be sent */ def cancelJob(jobId: Int) { dagScheduler.cancelJob(jobId) @@ -2055,7 +2072,7 @@ class SparkContext(config: SparkConf) extends Logging { * Cancel a given stage and all jobs associated with it. * * @param stageId the stage ID to cancel - * @throws InterruptedException if the cancel message cannot be sent + * @note Throws `InterruptedException` if the cancel message cannot be sent */ def cancelStage(stageId: Int) { dagScheduler.cancelStage(stageId) @@ -2285,7 +2302,7 @@ object SparkContext extends Logging { * singleton object. Because we can only have one active SparkContext per JVM, * this is useful when applications may wish to share a SparkContext. * - * Note: This function cannot be used to create multiple SparkContext instances + * @note This function cannot be used to create multiple SparkContext instances * even if multiple contexts are allowed. */ def getOrCreate(config: SparkConf): SparkContext = { @@ -2310,7 +2327,7 @@ object SparkContext extends Logging { * * This method allows not passing a SparkConf (useful if just retrieving). * - * Note: This function cannot be used to create multiple SparkContext instances + * @note This function cannot be used to create multiple SparkContext instances * even if multiple contexts are allowed. */ def getOrCreate(): SparkContext = { @@ -2550,8 +2567,8 @@ object SparkContext extends Logging { val serviceLoaders = ServiceLoader.load(classOf[ExternalClusterManager], loader).asScala.filter(_.canCreate(url)) if (serviceLoaders.size > 1) { - throw new SparkException(s"Multiple Cluster Managers ($serviceLoaders) registered " + - s"for the url $url:") + throw new SparkException( + s"Multiple external cluster managers registered for the url $url: $serviceLoaders") } serviceLoaders.headOption } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 1ffeb129880f9..1296386ac9bd3 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -36,6 +36,7 @@ import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator} import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint +import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager} import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage._ @@ -165,15 +166,20 @@ object SparkEnv extends Logging { val bindAddress = conf.get(DRIVER_BIND_ADDRESS) val advertiseAddress = conf.get(DRIVER_HOST_ADDRESS) val port = conf.get("spark.driver.port").toInt + val ioEncryptionKey = if (conf.get(IO_ENCRYPTION_ENABLED)) { + Some(CryptoStreamUtils.createKey(conf)) + } else { + None + } create( conf, SparkContext.DRIVER_IDENTIFIER, bindAddress, advertiseAddress, port, - isDriver = true, - isLocal = isLocal, - numUsableCores = numCores, + isLocal, + numCores, + ioEncryptionKey, listenerBus = listenerBus, mockOutputCommitCoordinator = mockOutputCommitCoordinator ) @@ -189,6 +195,7 @@ object SparkEnv extends Logging { hostname: String, port: Int, numCores: Int, + ioEncryptionKey: Option[Array[Byte]], isLocal: Boolean): SparkEnv = { val env = create( conf, @@ -196,9 +203,9 @@ object SparkEnv extends Logging { hostname, hostname, port, - isDriver = false, - isLocal = isLocal, - numUsableCores = numCores + isLocal, + numCores, + ioEncryptionKey ) SparkEnv.set(env) env @@ -213,18 +220,26 @@ object SparkEnv extends Logging { bindAddress: String, advertiseAddress: String, port: Int, - isDriver: Boolean, isLocal: Boolean, numUsableCores: Int, + ioEncryptionKey: Option[Array[Byte]], listenerBus: LiveListenerBus = null, mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { + val isDriver = executorId == SparkContext.DRIVER_IDENTIFIER + // Listener bus is only used on the driver if (isDriver) { assert(listenerBus != null, "Attempted to create driver SparkEnv with null listener bus!") } - val securityManager = new SecurityManager(conf) + val securityManager = new SecurityManager(conf, ioEncryptionKey) + ioEncryptionKey.foreach { _ => + if (!securityManager.isSaslEncryptionEnabled()) { + logWarning("I/O encryption enabled without RPC encryption: keys will be visible on the " + + "wire.") + } + } val systemName = if (isDriver) driverSystemName else executorSystemName val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port, conf, @@ -270,7 +285,7 @@ object SparkEnv extends Logging { "spark.serializer", "org.apache.spark.serializer.JavaSerializer") logDebug(s"Using serializer: ${serializer.getClass}") - val serializerManager = new SerializerManager(serializer, conf) + val serializerManager = new SerializerManager(serializer, conf, ioEncryptionKey) val closureSerializer = new JavaSerializer(conf) diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 6550d703bc860..7f75a393bf8ff 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -20,7 +20,7 @@ package org.apache.spark import java.io.IOException import java.text.NumberFormat import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.fs.Path @@ -67,12 +67,12 @@ class SparkHadoopWriter(jobConf: JobConf) extends Logging with Serializable { def setup(jobid: Int, splitid: Int, attemptid: Int) { setIDs(jobid, splitid, attemptid) - HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss").format(now), + HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(now), jobid, splitID, attemptID, conf.value) } def open() { - val numfmt = NumberFormat.getInstance() + val numfmt = NumberFormat.getInstance(Locale.US) numfmt.setMinimumIntegerDigits(5) numfmt.setGroupingUsed(false) @@ -162,7 +162,7 @@ class SparkHadoopWriter(jobConf: JobConf) extends Logging with Serializable { private[spark] object SparkHadoopWriter { def createJobID(time: Date, id: Int): JobID = { - val formatter = new SimpleDateFormat("yyyyMMddHHmmss") + val formatter = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) val jobtrackerID = formatter.format(time) new JobID(jobtrackerID, id) } diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala index 52c4656c271bc..22a553e68439a 100644 --- a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala +++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala @@ -112,7 +112,7 @@ class SparkStatusTracker private[spark] (sc: SparkContext) { */ def getExecutorInfos: Array[SparkExecutorInfo] = { val executorIdToRunningTasks: Map[String, Int] = - sc.taskScheduler.asInstanceOf[TaskSchedulerImpl].runningTasksByExecutors() + sc.taskScheduler.asInstanceOf[TaskSchedulerImpl].runningTasksByExecutors sc.getExecutorStorageStatus.map { status => val bmId = status.blockManagerId diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 27abccf5ac2a9..0fd777ed12829 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -164,7 +164,7 @@ abstract class TaskContext extends Serializable { /** * Get a local property set upstream in the driver, or null if it is missing. See also - * [[org.apache.spark.SparkContext.setLocalProperty]]. + * `org.apache.spark.SparkContext.setLocalProperty`. */ def getLocalProperty(key: String): String @@ -174,7 +174,7 @@ abstract class TaskContext extends Serializable { /** * ::DeveloperApi:: * Returns all metrics sources with the given name which are associated with the instance - * which runs the task. For more information see [[org.apache.spark.metrics.MetricsSystem!]]. + * which runs the task. For more information see `org.apache.spark.metrics.MetricsSystem`. */ @DeveloperApi def getMetricsSources(sourceName: String): Seq[Source] diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 7ca3c103dbf5b..7745387dbceba 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -65,7 +65,7 @@ sealed trait TaskFailedReason extends TaskEndReason { /** * :: DeveloperApi :: - * A [[org.apache.spark.scheduler.ShuffleMapTask]] that completed successfully earlier, but we + * A `org.apache.spark.scheduler.ShuffleMapTask` that completed successfully earlier, but we * lost the executor before the stage completed. This means Spark needs to reschedule the task * to be re-executed on a different executor. */ diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 871b9d1ad575b..2909191bd6f14 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -186,7 +186,7 @@ private[spark] object TestUtils { /** - * A [[SparkListener]] that detects whether spills have occurred in Spark jobs. + * A `SparkListener` that detects whether spills have occurred in Spark jobs. */ private class SpillListener extends SparkListener { private val stageIdToTaskMetrics = new mutable.HashMap[Int, ArrayBuffer[TaskMetrics]] diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 0026fc9dad517..b71af0d42cdb0 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -45,7 +45,9 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) import JavaDoubleRDD.fromRDD - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ + /** + * Persist this RDD with the default storage level (`MEMORY_ONLY`). + */ def cache(): JavaDoubleRDD = fromRDD(srdd.cache()) /** @@ -153,7 +155,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. */ def intersection(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.intersection(other.srdd)) @@ -256,7 +258,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) * e.g 1<=x<10 , 10<=x<20, 20<=x<50 * And on the input of 1 and 50 we would have a histogram of 1,0,0 * - * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched + * @note If your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched * from an O(log n) insertion to O(1) per element. (where n = # buckets) if you set evenBuckets * to true. * buckets must be sorted and not contain any duplicates. diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 1c95bc4bfcaaf..766aea213a972 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -54,7 +54,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) // Common RDD functions - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ + /** + * Persist this RDD with the default storage level (`MEMORY_ONLY`). + */ def cache(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.cache()) /** @@ -206,7 +208,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. */ def intersection(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.intersection(other.rdd)) @@ -223,9 +225,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a - * "combined type" C. Note that V and C can be different -- for example, one might group an - * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three - * functions: + * "combined type" C. + * + * Users provide three functions: * * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) @@ -234,6 +236,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * In addition, users can control the partitioning of the output RDD, the serializer that is use * for the shuffle, and whether to perform map-side aggregation (if a mapper can produce multiple * items with the same key). + * + * @note V and C can be different -- for example, one might group an RDD of type (Int, Int) into + * an RDD of type (Int, List[Int]). */ def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], @@ -255,9 +260,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a - * "combined type" C. Note that V and C can be different -- for example, one might group an - * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three - * functions: + * "combined type" C. + * + * Users provide three functions: * * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) @@ -265,6 +270,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * * In addition, users can control the partitioning of the output RDD. This method automatically * uses map-side aggregation in shuffling the RDD. + * + * @note V and C can be different -- for example, one might group an RDD of type (Int, Int) into + * an RDD of type (Int, List[Int]). */ def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], @@ -398,8 +406,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Allows controlling the * partitioning of the resulting key-value pair RDD by passing a Partitioner. * - * Note: If you are grouping in order to perform an aggregation (such as a sum or average) over - * each key, using [[JavaPairRDD.reduceByKey]] or [[JavaPairRDD.combineByKey]] + * @note If you are grouping in order to perform an aggregation (such as a sum or average) over + * each key, using `JavaPairRDD.reduceByKey` or `JavaPairRDD.combineByKey` * will provide much better performance. */ def groupByKey(partitioner: Partitioner): JavaPairRDD[K, JIterable[V]] = @@ -409,8 +417,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Hash-partitions the * resulting RDD with into `numPartitions` partitions. * - * Note: If you are grouping in order to perform an aggregation (such as a sum or average) over - * each key, using [[JavaPairRDD.reduceByKey]] or [[JavaPairRDD.combineByKey]] + * @note If you are grouping in order to perform an aggregation (such as a sum or average) over + * each key, using `JavaPairRDD.reduceByKey` or `JavaPairRDD.combineByKey` * will provide much better performance. */ def groupByKey(numPartitions: Int): JavaPairRDD[K, JIterable[V]] = @@ -448,13 +456,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) fromRDD(rdd.subtractByKey(other)) } - /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ + /** + * Return an RDD with the pairs from `this` whose keys are not in `other`. + */ def subtractByKey[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, V] = { implicit val ctag: ClassTag[W] = fakeClassTag fromRDD(rdd.subtractByKey(other, numPartitions)) } - /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ + /** + * Return an RDD with the pairs from `this` whose keys are not in `other`. + */ def subtractByKey[W](other: JavaPairRDD[K, W], p: Partitioner): JavaPairRDD[K, V] = { implicit val ctag: ClassTag[W] = fakeClassTag fromRDD(rdd.subtractByKey(other, p)) @@ -539,8 +551,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Hash-partitions the * resulting RDD with the existing partitioner/parallelism level. * - * Note: If you are grouping in order to perform an aggregation (such as a sum or average) over - * each key, using [[JavaPairRDD.reduceByKey]] or [[JavaPairRDD.combineByKey]] + * @note If you are grouping in order to perform an aggregation (such as a sum or average) over + * each key, using `JavaPairRDD.reduceByKey` or `JavaPairRDD.combineByKey` * will provide much better performance. */ def groupByKey(): JavaPairRDD[K, JIterable[V]] = diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 20d6c9341bf7a..41b5cab601c36 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -34,7 +34,9 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) // Common RDD functions - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ + /** + * Persist this RDD with the default storage level (`MEMORY_ONLY`). + */ def cache(): JavaRDD[T] = wrapRDD(rdd.cache()) /** @@ -98,24 +100,32 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) def repartition(numPartitions: Int): JavaRDD[T] = rdd.repartition(numPartitions) /** - * Return a sampled subset of this RDD. + * Return a sampled subset of this RDD with a random seed. * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] - * with replacement: expected number of times each element is chosen; fraction must be >= 0 + * with replacement: expected number of times each element is chosen; fraction must be greater + * than or equal to 0 + * + * @note This is NOT guaranteed to provide exactly the fraction of the count + * of the given `RDD`. */ def sample(withReplacement: Boolean, fraction: Double): JavaRDD[T] = sample(withReplacement, fraction, Utils.random.nextLong) /** - * Return a sampled subset of this RDD. + * Return a sampled subset of this RDD, with a user-supplied seed. * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] - * with replacement: expected number of times each element is chosen; fraction must be >= 0 + * with replacement: expected number of times each element is chosen; fraction must be greater + * than or equal to 0 * @param seed seed for the random number generator + * + * @note This is NOT guaranteed to provide exactly the fraction of the count + * of the given `RDD`. */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaRDD[T] = wrapRDD(rdd.sample(withReplacement, fraction, seed)) @@ -153,7 +163,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. */ def intersection(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.intersection(other.rdd)) @@ -161,7 +171,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) * Return an RDD with the elements from `this` that are not in `other`. * * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. + * RDD will be less than or equal to us. */ def subtract(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.subtract(other)) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index a37c52cbaf210..eda16d957cc58 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -47,7 +47,8 @@ private[spark] abstract class AbstractJavaRDDLike[T, This <: JavaRDDLike[T, This /** * Defines operations common to several Java RDD implementations. - * Note that this trait is not intended to be implemented by user code. + * + * @note This trait is not intended to be implemented by user code. */ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def wrapRDD(rdd: RDD[T]): This diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 4e50c2686dd53..9481156bc93a5 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -238,7 +238,9 @@ class JavaSparkContext(val sc: SparkContext) * }}} * * Do - * `JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`, + * {{{ + * JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path") + * }}} * * then `rdd` contains * {{{ @@ -270,7 +272,9 @@ class JavaSparkContext(val sc: SparkContext) * }}} * * Do - * `JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`, + * {{{ + * JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path") + * }}}, * * then `rdd` contains * {{{ @@ -298,7 +302,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get an RDD for a Hadoop SequenceFile with given key and value types. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -316,7 +320,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get an RDD for a Hadoop SequenceFile. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -366,7 +370,7 @@ class JavaSparkContext(val sc: SparkContext) * @param valueClass Class of the values * @param minPartitions Minimum number of Hadoop Splits to generate. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -396,7 +400,7 @@ class JavaSparkContext(val sc: SparkContext) * @param keyClass Class of the keys * @param valueClass Class of the values * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -416,7 +420,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get an RDD for a Hadoop file with an arbitrary InputFormat. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -437,7 +441,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get an RDD for a Hadoop file with an arbitrary InputFormat * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -458,7 +462,7 @@ class JavaSparkContext(val sc: SparkContext) * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat * and extra configuration options to pass to the input format. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -487,7 +491,7 @@ class JavaSparkContext(val sc: SparkContext) * @param kClass Class of the keys * @param vClass Class of the values * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -694,7 +698,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse. * - * '''Note:''' As it will be reused in all Hadoop RDDs, it's better not to modify it unless you + * @note As it will be reused in all Hadoop RDDs, it's better not to modify it unless you * plan to set some global configurations for all Hadoop RDDs. */ def hadoopConfiguration(): Configuration = { @@ -749,7 +753,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get a local property set in this thread, or null if it is missing. See - * [[org.apache.spark.api.java.JavaSparkContext.setLocalProperty]]. + * `org.apache.spark.api.java.JavaSparkContext.setLocalProperty`. */ def getLocalProperty(key: String): String = sc.getLocalProperty(key) @@ -769,7 +773,7 @@ class JavaSparkContext(val sc: SparkContext) * Application programmers can use this method to group all those jobs together and give a * group description. Once set, the Spark web UI will associate such jobs with this group. * - * The application can also use [[org.apache.spark.api.java.JavaSparkContext.cancelJobGroup]] + * The application can also use `org.apache.spark.api.java.JavaSparkContext.cancelJobGroup` * to cancel all running jobs in this group. For example, * {{{ * // In the main thread: @@ -802,7 +806,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Cancel active jobs for the specified group. See - * [[org.apache.spark.api.java.JavaSparkContext.setJobGroup]] for more information. + * `org.apache.spark.api.java.JavaSparkContext.setJobGroup` for more information. */ def cancelJobGroup(groupId: String): Unit = sc.cancelJobGroup(groupId) @@ -811,7 +815,8 @@ class JavaSparkContext(val sc: SparkContext) /** * Returns a Java map of JavaRDDs that have marked themselves as persistent via cache() call. - * Note that this does not necessarily mean the caching or computation was successful. + * + * @note This does not necessarily mean the caching or computation was successful. */ def getPersistentRDDs: JMap[java.lang.Integer, JavaRDD[_]] = { sc.getPersistentRDDs.mapValues(s => JavaRDD.fromRDD(s)) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala index 99ca3c77cced0..6aa290ecd7bb5 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala @@ -31,7 +31,7 @@ import org.apache.spark.{SparkContext, SparkJobInfo, SparkStageInfo} * will provide information for the last `spark.ui.retainedStages` stages and * `spark.ui.retainedJobs` jobs. * - * NOTE: this class's constructor should be considered private and may be subject to change. + * @note This class's constructor should be considered private and may be subject to change. */ class JavaSparkStatusTracker private[spark] (sc: SparkContext) { diff --git a/core/src/main/scala/org/apache/spark/api/r/JVMObjectTracker.scala b/core/src/main/scala/org/apache/spark/api/r/JVMObjectTracker.scala new file mode 100644 index 0000000000000..3432700f11602 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/JVMObjectTracker.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.ConcurrentHashMap + +/** JVM object ID wrapper */ +private[r] case class JVMObjectId(id: String) { + require(id != null, "Object ID cannot be null.") +} + +/** + * Counter that tracks JVM objects returned to R. + * This is useful for referencing these objects in RPC calls. + */ +private[r] class JVMObjectTracker { + + private[this] val objMap = new ConcurrentHashMap[JVMObjectId, Object]() + private[this] val objCounter = new AtomicInteger() + + /** + * Returns the JVM object associated with the input key or None if not found. + */ + final def get(id: JVMObjectId): Option[Object] = this.synchronized { + if (objMap.containsKey(id)) { + Some(objMap.get(id)) + } else { + None + } + } + + /** + * Returns the JVM object associated with the input key or throws an exception if not found. + */ + @throws[NoSuchElementException]("if key does not exist.") + final def apply(id: JVMObjectId): Object = { + get(id).getOrElse( + throw new NoSuchElementException(s"$id does not exist.") + ) + } + + /** + * Adds a JVM object to track and returns assigned ID, which is unique within this tracker. + */ + final def addAndGetId(obj: Object): JVMObjectId = { + val id = JVMObjectId(objCounter.getAndIncrement().toString) + objMap.put(id, obj) + id + } + + /** + * Removes and returns a JVM object with the specific ID from the tracker, or None if not found. + */ + final def remove(id: JVMObjectId): Option[Object] = this.synchronized { + if (objMap.containsKey(id)) { + Some(objMap.remove(id)) + } else { + None + } + } + + /** + * Number of JVM objects being tracked. + */ + final def size: Int = objMap.size() + + /** + * Clears the tracker. + */ + final def clear(): Unit = objMap.clear() +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 550746c552d02..2d1152a036449 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -22,7 +22,7 @@ import java.net.{InetAddress, InetSocketAddress, ServerSocket} import java.util.concurrent.TimeUnit import io.netty.bootstrap.ServerBootstrap -import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption, EventLoopGroup} +import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup} import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioServerSocketChannel @@ -42,6 +42,9 @@ private[spark] class RBackend { private[this] var bootstrap: ServerBootstrap = null private[this] var bossGroup: EventLoopGroup = null + /** Tracks JVM objects returned to R for this RBackend instance. */ + private[r] val jvmObjectTracker = new JVMObjectTracker + def init(): Int = { val conf = new SparkConf() val backendConnectionTimeout = conf.getInt( @@ -94,6 +97,7 @@ private[spark] class RBackend { bootstrap.childGroup().shutdownGracefully() } bootstrap = null + jvmObjectTracker.clear() } } diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 9f5afa29d6d22..cfd37ac54ba23 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -20,7 +20,6 @@ package org.apache.spark.api.r import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import java.util.concurrent.TimeUnit -import scala.collection.mutable.HashMap import scala.language.existentials import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} @@ -62,7 +61,7 @@ private[r] class RBackendHandler(server: RBackend) assert(numArgs == 1) writeInt(dos, 0) - writeObject(dos, args(0)) + writeObject(dos, args(0), server.jvmObjectTracker) case "stopBackend" => writeInt(dos, 0) writeType(dos, "void") @@ -72,9 +71,9 @@ private[r] class RBackendHandler(server: RBackend) val t = readObjectType(dis) assert(t == 'c') val objToRemove = readString(dis) - JVMObjectTracker.remove(objToRemove) + server.jvmObjectTracker.remove(JVMObjectId(objToRemove)) writeInt(dos, 0) - writeObject(dos, null) + writeObject(dos, null, server.jvmObjectTracker) } catch { case e: Exception => logError(s"Removing $objId failed", e) @@ -143,12 +142,8 @@ private[r] class RBackendHandler(server: RBackend) val cls = if (isStatic) { Utils.classForName(objId) } else { - JVMObjectTracker.get(objId) match { - case None => throw new IllegalArgumentException("Object not found " + objId) - case Some(o) => - obj = o - o.getClass - } + obj = server.jvmObjectTracker(JVMObjectId(objId)) + obj.getClass } val args = readArgs(numArgs, dis) @@ -173,7 +168,7 @@ private[r] class RBackendHandler(server: RBackend) // Write status bit writeInt(dos, 0) - writeObject(dos, ret.asInstanceOf[AnyRef]) + writeObject(dos, ret.asInstanceOf[AnyRef], server.jvmObjectTracker) } else if (methodName == "") { // methodName should be "" for constructor val ctors = cls.getConstructors @@ -193,7 +188,7 @@ private[r] class RBackendHandler(server: RBackend) val obj = ctors(index.get).newInstance(args : _*) writeInt(dos, 0) - writeObject(dos, obj.asInstanceOf[AnyRef]) + writeObject(dos, obj.asInstanceOf[AnyRef], server.jvmObjectTracker) } else { throw new IllegalArgumentException("invalid method " + methodName + " for object " + objId) } @@ -210,7 +205,7 @@ private[r] class RBackendHandler(server: RBackend) // Read a number of arguments from the data input stream def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { (0 until numArgs).map { _ => - readObject(dis) + readObject(dis, server.jvmObjectTracker) }.toArray } @@ -286,37 +281,4 @@ private[r] class RBackendHandler(server: RBackend) } } -/** - * Helper singleton that tracks JVM objects returned to R. - * This is useful for referencing these objects in RPC calls. - */ -private[r] object JVMObjectTracker { - - // TODO: This map should be thread-safe if we want to support multiple - // connections at the same time - private[this] val objMap = new HashMap[String, Object] - - // TODO: We support only one connection now, so an integer is fine. - // Investigate using use atomic integer in the future. - private[this] var objCounter: Int = 0 - - def getObject(id: String): Object = { - objMap(id) - } - - def get(id: String): Option[Object] = { - objMap.get(id) - } - - def put(obj: Object): String = { - val objId = objCounter.toString - objCounter = objCounter + 1 - objMap.put(objId, obj) - objId - } - def remove(id: String): Option[Object] = { - objMap.remove(id) - } - -} diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala index 7ef64723d9593..29e21b3b1aa8a 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala @@ -152,7 +152,7 @@ private[spark] class RRunner[U]( dataOut.writeInt(mode) if (isDataFrame) { - SerDe.writeObject(dataOut, colNames) + SerDe.writeObject(dataOut, colNames, jvmObjectTracker = null) } if (!iter.hasNext) { diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala index 77825e75e5136..fdd8cf62f0e5f 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala @@ -84,7 +84,6 @@ private[spark] object RUtils { } } else { // Otherwise, assume the package is local - // TODO: support this for Mesos val sparkRPkgPath = localSparkRPackagePath.getOrElse { throw new SparkException("SPARK_HOME not set. Can't locate SparkR package.") } diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 550e075a95129..dad928cdcfd0f 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -28,13 +28,20 @@ import scala.collection.mutable.WrappedArray * Utility functions to serialize, deserialize objects to / from R */ private[spark] object SerDe { - type ReadObject = (DataInputStream, Char) => Object - type WriteObject = (DataOutputStream, Object) => Boolean + type SQLReadObject = (DataInputStream, Char) => Object + type SQLWriteObject = (DataOutputStream, Object) => Boolean - var sqlSerDe: (ReadObject, WriteObject) = _ + private[this] var sqlReadObject: SQLReadObject = _ + private[this] var sqlWriteObject: SQLWriteObject = _ - def registerSqlSerDe(sqlSerDe: (ReadObject, WriteObject)): Unit = { - this.sqlSerDe = sqlSerDe + def setSQLReadObject(value: SQLReadObject): this.type = { + sqlReadObject = value + this + } + + def setSQLWriteObject(value: SQLWriteObject): this.type = { + sqlWriteObject = value + this } // Type mapping from R to Java @@ -56,32 +63,33 @@ private[spark] object SerDe { dis.readByte().toChar } - def readObject(dis: DataInputStream): Object = { + def readObject(dis: DataInputStream, jvmObjectTracker: JVMObjectTracker): Object = { val dataType = readObjectType(dis) - readTypedObject(dis, dataType) + readTypedObject(dis, dataType, jvmObjectTracker) } def readTypedObject( dis: DataInputStream, - dataType: Char): Object = { + dataType: Char, + jvmObjectTracker: JVMObjectTracker): Object = { dataType match { case 'n' => null case 'i' => new java.lang.Integer(readInt(dis)) case 'd' => new java.lang.Double(readDouble(dis)) case 'b' => new java.lang.Boolean(readBoolean(dis)) case 'c' => readString(dis) - case 'e' => readMap(dis) + case 'e' => readMap(dis, jvmObjectTracker) case 'r' => readBytes(dis) - case 'a' => readArray(dis) - case 'l' => readList(dis) + case 'a' => readArray(dis, jvmObjectTracker) + case 'l' => readList(dis, jvmObjectTracker) case 'D' => readDate(dis) case 't' => readTime(dis) - case 'j' => JVMObjectTracker.getObject(readString(dis)) + case 'j' => jvmObjectTracker(JVMObjectId(readString(dis))) case _ => - if (sqlSerDe == null || sqlSerDe._1 == null) { + if (sqlReadObject == null) { throw new IllegalArgumentException (s"Invalid type $dataType") } else { - val obj = (sqlSerDe._1)(dis, dataType) + val obj = sqlReadObject(dis, dataType) if (obj == null) { throw new IllegalArgumentException (s"Invalid type $dataType") } else { @@ -181,28 +189,28 @@ private[spark] object SerDe { } // All elements of an array must be of the same type - def readArray(dis: DataInputStream): Array[_] = { + def readArray(dis: DataInputStream, jvmObjectTracker: JVMObjectTracker): Array[_] = { val arrType = readObjectType(dis) arrType match { case 'i' => readIntArr(dis) case 'c' => readStringArr(dis) case 'd' => readDoubleArr(dis) case 'b' => readBooleanArr(dis) - case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x)) + case 'j' => readStringArr(dis).map(x => jvmObjectTracker(JVMObjectId(x))) case 'r' => readBytesArr(dis) case 'a' => val len = readInt(dis) - (0 until len).map(_ => readArray(dis)).toArray + (0 until len).map(_ => readArray(dis, jvmObjectTracker)).toArray case 'l' => val len = readInt(dis) - (0 until len).map(_ => readList(dis)).toArray + (0 until len).map(_ => readList(dis, jvmObjectTracker)).toArray case _ => - if (sqlSerDe == null || sqlSerDe._1 == null) { + if (sqlReadObject == null) { throw new IllegalArgumentException (s"Invalid array type $arrType") } else { val len = readInt(dis) (0 until len).map { _ => - val obj = (sqlSerDe._1)(dis, arrType) + val obj = sqlReadObject(dis, arrType) if (obj == null) { throw new IllegalArgumentException (s"Invalid array type $arrType") } else { @@ -215,17 +223,19 @@ private[spark] object SerDe { // Each element of a list can be of different type. They are all represented // as Object on JVM side - def readList(dis: DataInputStream): Array[Object] = { + def readList(dis: DataInputStream, jvmObjectTracker: JVMObjectTracker): Array[Object] = { val len = readInt(dis) - (0 until len).map(_ => readObject(dis)).toArray + (0 until len).map(_ => readObject(dis, jvmObjectTracker)).toArray } - def readMap(in: DataInputStream): java.util.Map[Object, Object] = { + def readMap( + in: DataInputStream, + jvmObjectTracker: JVMObjectTracker): java.util.Map[Object, Object] = { val len = readInt(in) if (len > 0) { // Keys is an array of String - val keys = readArray(in).asInstanceOf[Array[Object]] - val values = readList(in) + val keys = readArray(in, jvmObjectTracker).asInstanceOf[Array[Object]] + val values = readList(in, jvmObjectTracker) keys.zip(values).toMap.asJava } else { @@ -272,7 +282,11 @@ private[spark] object SerDe { } } - private def writeKeyValue(dos: DataOutputStream, key: Object, value: Object): Unit = { + private def writeKeyValue( + dos: DataOutputStream, + key: Object, + value: Object, + jvmObjectTracker: JVMObjectTracker): Unit = { if (key == null) { throw new IllegalArgumentException("Key in map can't be null.") } else if (!key.isInstanceOf[String]) { @@ -280,10 +294,10 @@ private[spark] object SerDe { } writeString(dos, key.asInstanceOf[String]) - writeObject(dos, value) + writeObject(dos, value, jvmObjectTracker) } - def writeObject(dos: DataOutputStream, obj: Object): Unit = { + def writeObject(dos: DataOutputStream, obj: Object, jvmObjectTracker: JVMObjectTracker): Unit = { if (obj == null) { writeType(dos, "void") } else { @@ -373,14 +387,14 @@ private[spark] object SerDe { case v: Array[Object] => writeType(dos, "list") writeInt(dos, v.length) - v.foreach(elem => writeObject(dos, elem)) + v.foreach(elem => writeObject(dos, elem, jvmObjectTracker)) // Handle Properties // This must be above the case java.util.Map below. // (Properties implements Map and will be serialized as map otherwise) case v: java.util.Properties => writeType(dos, "jobj") - writeJObj(dos, value) + writeJObj(dos, value, jvmObjectTracker) // Handle map case v: java.util.Map[_, _] => @@ -392,19 +406,21 @@ private[spark] object SerDe { val key = entry.getKey val value = entry.getValue - writeKeyValue(dos, key.asInstanceOf[Object], value.asInstanceOf[Object]) + writeKeyValue( + dos, key.asInstanceOf[Object], value.asInstanceOf[Object], jvmObjectTracker) } case v: scala.collection.Map[_, _] => writeType(dos, "map") writeInt(dos, v.size) - v.foreach { case (key, value) => - writeKeyValue(dos, key.asInstanceOf[Object], value.asInstanceOf[Object]) + v.foreach { case (k1, v1) => + writeKeyValue(dos, k1.asInstanceOf[Object], v1.asInstanceOf[Object], jvmObjectTracker) } case _ => - if (sqlSerDe == null || sqlSerDe._2 == null || !(sqlSerDe._2)(dos, value)) { + val sqlWriteSucceeded = sqlWriteObject != null && sqlWriteObject(dos, value) + if (!sqlWriteSucceeded) { writeType(dos, "jobj") - writeJObj(dos, value) + writeJObj(dos, value, jvmObjectTracker) } } } @@ -447,9 +463,9 @@ private[spark] object SerDe { out.write(value) } - def writeJObj(out: DataOutputStream, value: Object): Unit = { - val objId = JVMObjectTracker.put(value) - writeString(out, objId) + def writeJObj(out: DataOutputStream, value: Object, jvmObjectTracker: JVMObjectTracker): Unit = { + val JVMObjectId(id) = jvmObjectTracker.addAndGetId(value) + writeString(out, id) } def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index e8d6d587b4824..f350784378795 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -19,6 +19,7 @@ package org.apache.spark.broadcast import java.io._ import java.nio.ByteBuffer +import java.util.zip.Adler32 import scala.collection.JavaConverters._ import scala.reflect.ClassTag @@ -77,6 +78,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } // Note: use getSizeAsKb (not bytes) to maintain compatibility if no units are provided blockSize = conf.getSizeAsKb("spark.broadcast.blockSize", "4m").toInt * 1024 + checksumEnabled = conf.getBoolean("spark.broadcast.checksum", true) } setConf(SparkEnv.get.conf) @@ -85,10 +87,27 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) /** Total number of blocks this broadcast variable contains. */ private val numBlocks: Int = writeBlocks(obj) + /** Whether to generate checksum for blocks or not. */ + private var checksumEnabled: Boolean = false + /** The checksum for all the blocks. */ + private var checksums: Array[Int] = _ + override protected def getValue() = { _value } + private def calcChecksum(block: ByteBuffer): Int = { + val adler = new Adler32() + if (block.hasArray) { + adler.update(block.array, block.arrayOffset + block.position, block.limit - block.position) + } else { + val bytes = new Array[Byte](block.remaining()) + block.duplicate.get(bytes) + adler.update(bytes) + } + adler.getValue.toInt + } + /** * Divide the object into multiple blocks and put those blocks in the block manager. * @@ -105,7 +124,13 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } val blocks = TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec) + if (checksumEnabled) { + checksums = new Array[Int](blocks.length) + } blocks.zipWithIndex.foreach { case (block, i) => + if (checksumEnabled) { + checksums(i) = calcChecksum(block) + } val pieceId = BroadcastBlockId(id, "piece" + i) val bytes = new ChunkedByteBuffer(block.duplicate()) if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) { @@ -135,6 +160,13 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) case None => bm.getRemoteBytes(pieceId) match { case Some(b) => + if (checksumEnabled) { + val sum = calcChecksum(b.chunks(0)) + if (sum != checksums(pid)) { + throw new SparkException(s"corrupt remote block $pieceId of $broadcastId:" + + s" $sum != ${checksums(pid)}") + } + } // We found the block from remote executors/driver's BlockManager, so put the block // in this executor's BlockManager. if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) { diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index ee276e1b71138..a4de3d7eaf458 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -221,7 +221,9 @@ object Client { val conf = new SparkConf() val driverArgs = new ClientArguments(args) - conf.set("spark.rpc.askTimeout", "10") + if (!conf.contains("spark.rpc.askTimeout")) { + conf.set("spark.rpc.askTimeout", "10s") + } Logger.getRootLogger.setLevel(driverArgs.logLevel) val rpcEnv = diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 3f54ecc17ac33..23156072c3ebe 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -21,7 +21,7 @@ import java.io.IOException import java.lang.reflect.Method import java.security.PrivilegedExceptionAction import java.text.DateFormat -import java.util.{Arrays, Comparator, Date} +import java.util.{Arrays, Comparator, Date, Locale} import scala.collection.JavaConverters._ import scala.util.control.NonFatal @@ -357,7 +357,7 @@ class SparkHadoopUtil extends Logging { * @return a printable string value. */ private[spark] def tokenToString(token: Token[_ <: TokenIdentifier]): String = { - val df = DateFormat.getDateTimeInstance(DateFormat.SHORT, DateFormat.SHORT) + val df = DateFormat.getDateTimeInstance(DateFormat.SHORT, DateFormat.SHORT, Locale.US) val buffer = new StringBuilder(128) buffer.append(token.toString) try { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 5c052286099f5..67e0a13e6d0b5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -70,7 +70,8 @@ object SparkSubmit { private val STANDALONE = 2 private val MESOS = 4 private val LOCAL = 8 - private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL + private val KUBERNETES = 16 + private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | KUBERNETES | LOCAL // Deploy modes private val CLIENT = 1 @@ -239,9 +240,10 @@ object SparkSubmit { YARN case m if m.startsWith("spark") => STANDALONE case m if m.startsWith("mesos") => MESOS + case m if m.startsWith("k8s") => KUBERNETES case m if m.startsWith("local") => LOCAL case _ => - printErrorAndExit("Master must either be yarn or start with spark, mesos, local") + printErrorAndExit("Master must either be yarn or start with spark, mesos, k8s, or local") -1 } @@ -284,6 +286,7 @@ object SparkSubmit { } val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER val isMesosCluster = clusterManager == MESOS && deployMode == CLUSTER + val isKubernetesCluster = clusterManager == KUBERNETES && deployMode == CLUSTER // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files // too for packages that include Python code @@ -322,7 +325,7 @@ object SparkSubmit { } // Require all R files to be local - if (args.isR && !isYarnCluster) { + if (args.isR && !isYarnCluster && !isMesosCluster) { if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) { printErrorAndExit(s"Only local R files are supported: ${args.primaryResource}") } @@ -330,9 +333,10 @@ object SparkSubmit { // The following modes are not supported or applicable (clusterManager, deployMode) match { - case (MESOS, CLUSTER) if args.isR => - printErrorAndExit("Cluster deploy mode is currently not supported for R " + - "applications on Mesos clusters.") + case (KUBERNETES, CLIENT) => + printErrorAndExit("Client mode is currently not supported for Kubernetes.") + case (KUBERNETES, CLUSTER) if args.isR => + printErrorAndExit("Kubernetes does not currently support R applications.") case (STANDALONE, CLUSTER) if args.isPython => printErrorAndExit("Cluster deploy mode is currently not supported for python " + "applications on standalone clusters.") @@ -410,9 +414,9 @@ object SparkSubmit { printErrorAndExit("Distributing R packages with standalone cluster is not supported.") } - // TODO: Support SparkR with mesos cluster - if (args.isR && clusterManager == MESOS) { - printErrorAndExit("SparkR is not supported for Mesos cluster.") + // TODO: Support distributing R packages with mesos cluster + if (args.isR && clusterManager == MESOS && !RUtils.rPackages.isEmpty) { + printErrorAndExit("Distributing R packages with mesos cluster is not supported.") } // If we're running an R app, set the main class to our specific R runner @@ -466,17 +470,21 @@ object SparkSubmit { OptionAssigner(args.principal, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.principal"), OptionAssigner(args.keytab, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.keytab"), - // Other options + OptionAssigner(args.kubernetesNamespace, KUBERNETES, ALL_DEPLOY_MODES, + sysProp = "spark.kubernetes.namespace"), + + // Other options OptionAssigner(args.executorCores, STANDALONE | YARN, ALL_DEPLOY_MODES, sysProp = "spark.executor.cores"), OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, ALL_DEPLOY_MODES, sysProp = "spark.executor.memory"), OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, ALL_DEPLOY_MODES, sysProp = "spark.cores.max"), - OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, ALL_DEPLOY_MODES, + OptionAssigner(args.files, LOCAL | STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, sysProp = "spark.files"), OptionAssigner(args.jars, LOCAL, CLIENT, sysProp = "spark.jars"), - OptionAssigner(args.jars, STANDALONE | MESOS, ALL_DEPLOY_MODES, sysProp = "spark.jars"), + OptionAssigner(args.jars, STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, + sysProp = "spark.jars"), OptionAssigner(args.driverMemory, STANDALONE | MESOS | YARN, CLUSTER, sysProp = "spark.driver.memory"), OptionAssigner(args.driverCores, STANDALONE | MESOS | YARN, CLUSTER, @@ -509,8 +517,9 @@ object SparkSubmit { // Add the application jar automatically so the user doesn't have to call sc.addJar // For YARN cluster mode, the jar is already distributed on each node as "app.jar" + // In Kubernetes cluster mode, the jar will be uploaded by the client separately. // For python and R files, the primary resource is already distributed as a regular file - if (!isYarnCluster && !args.isPython && !args.isR) { + if (!isYarnCluster && !isKubernetesCluster && !args.isPython && !args.isR) { var jars = sysProps.get("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq.empty) if (isUserJar(args.primaryResource)) { jars = jars ++ Seq(args.primaryResource) @@ -550,7 +559,7 @@ object SparkSubmit { } // assure a keytab is available from any place in a JVM - if (clusterManager == YARN || clusterManager == LOCAL) { + if (clusterManager == YARN || clusterManager == KUBERNETES || clusterManager == LOCAL) { if (args.principal != null) { require(args.keytab != null, "Keytab must be specified when principal is specified") if (!new File(args.keytab).exists()) { @@ -598,6 +607,9 @@ object SparkSubmit { if (args.pyFiles != null) { sysProps("spark.submit.pyFiles") = args.pyFiles } + } else if (args.isR) { + // Second argument is main class + childArgs += (args.primaryResource, "") } else { childArgs += (args.primaryResource, args.mainClass) } @@ -606,6 +618,22 @@ object SparkSubmit { } } + if (isKubernetesCluster) { + childMainClass = "org.apache.spark.deploy.kubernetes.submit.Client" + if (args.isPython) { + childArgs ++= Array("--primary-py-file", args.primaryResource) + childArgs ++= Array("--main-class", "org.apache.spark.deploy.PythonRunner") + childArgs ++= Array("--other-py-files", args.pyFiles) + } else { + childArgs ++= Array("--primary-java-resource", args.primaryResource) + childArgs ++= Array("--main-class", args.mainClass) + } + args.childArgs.foreach { arg => + childArgs += "--arg" + childArgs += arg + } + } + // Load any properties specified through --conf and the default properties file for ((k, v) <- args.sparkProperties) { sysProps.getOrElseUpdate(k, v) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index f1761e7c1ec92..4e297fe3b0e3b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -71,6 +71,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S var principal: String = null var keytab: String = null + // Kubernetes only + var kubernetesNamespace: String = null + // Standalone cluster mode only var supervise: Boolean = false var driverCores: String = null @@ -186,6 +189,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S .getOrElse(sparkProperties.get("spark.executor.instances").orNull) keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull principal = Option(principal).orElse(sparkProperties.get("spark.yarn.principal")).orNull + kubernetesNamespace = Option(kubernetesNamespace) + .orElse(sparkProperties.get("spark.kubernetes.namespace")) + .orNull // Try to set main class from JAR if no --class argument is given if (mainClass == null && !isPython && !isR && primaryResource != null) { @@ -426,6 +432,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S case KEYTAB => keytab = value + case KUBERNETES_NAMESPACE => + kubernetesNamespace = value + case HELP => printUsageAndExit(0) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index 06530ff836466..d7d82800b8b55 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -74,6 +74,30 @@ private[history] case class LoadedAppUI( private[history] abstract class ApplicationHistoryProvider { + /** + * Returns the count of application event logs that the provider is currently still processing. + * History Server UI can use this to indicate to a user that the application listing on the UI + * can be expected to list additional known applications once the processing of these + * application event logs completes. + * + * A History Provider that does not have a notion of count of event logs that may be pending + * for processing need not override this method. + * + * @return Count of application event logs that are currently under process + */ + def getEventLogsUnderProcess(): Int = { + return 0; + } + + /** + * Returns the time the history provider last updated the application history information + * + * @return 0 if this is undefined or unsupported, otherwise the last updated time in millis + */ + def getLastUpdatedTime(): Long = { + return 0; + } + /** * Returns a list of applications available for the history server to show. * diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index dfc1aad64c818..8ef69b142cd15 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.history import java.io.{FileNotFoundException, IOException, OutputStream} import java.util.UUID -import java.util.concurrent.{Executors, ExecutorService, TimeUnit} +import java.util.concurrent.{Executors, ExecutorService, Future, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.mutable @@ -108,7 +108,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // The modification time of the newest log detected during the last scan. Currently only // used for logging msgs (logs are re-scanned based on file size, rather than modtime) - private var lastScanTime = -1L + private val lastScanTime = new java.util.concurrent.atomic.AtomicLong(-1) // Mapping of application IDs to their metadata, in descending end time order. Apps are inserted // into the map in order, so the LinkedHashMap maintains the correct ordering. @@ -120,6 +120,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // List of application logs to be deleted by event log cleaner. private var attemptsToClean = new mutable.ListBuffer[FsApplicationAttemptInfo] + private val pendingReplayTasksCount = new java.util.concurrent.atomic.AtomicInteger(0) + /** * Return a runnable that performs the given operation on the event logs. * This operation is expected to be executed periodically. @@ -226,6 +228,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) applications.get(appId) } + override def getEventLogsUnderProcess(): Int = pendingReplayTasksCount.get() + + override def getLastUpdatedTime(): Long = lastScanTime.get() + override def getAppUI(appId: String, attemptId: Option[String]): Option[LoadedAppUI] = { try { applications.get(appId).flatMap { appInfo => @@ -329,26 +335,43 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) if (logInfos.nonEmpty) { logDebug(s"New/updated attempts found: ${logInfos.size} ${logInfos.map(_.getPath)}") } - logInfos.map { file => - replayExecutor.submit(new Runnable { + + var tasks = mutable.ListBuffer[Future[_]]() + + try { + for (file <- logInfos) { + tasks += replayExecutor.submit(new Runnable { override def run(): Unit = mergeApplicationListing(file) }) } - .foreach { task => - try { - // Wait for all tasks to finish. This makes sure that checkForLogs - // is not scheduled again while some tasks are already running in - // the replayExecutor. - task.get() - } catch { - case e: InterruptedException => - throw e - case e: Exception => - logError("Exception while merging application listings", e) - } + } catch { + // let the iteration over logInfos break, since an exception on + // replayExecutor.submit (..) indicates the ExecutorService is unable + // to take any more submissions at this time + + case e: Exception => + logError(s"Exception while submitting event log for replay", e) + } + + pendingReplayTasksCount.addAndGet(tasks.size) + + tasks.foreach { task => + try { + // Wait for all tasks to finish. This makes sure that checkForLogs + // is not scheduled again while some tasks are already running in + // the replayExecutor. + task.get() + } catch { + case e: InterruptedException => + throw e + case e: Exception => + logError("Exception while merging application listings", e) + } finally { + pendingReplayTasksCount.decrementAndGet() } + } - lastScanTime = newLastScanTime + lastScanTime.set(newLastScanTime) } catch { case e: Exception => logError("Exception in checking for event log updates", e) } @@ -365,7 +388,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } catch { case e: Exception => logError("Exception encountered when attempting to update last scan time", e) - lastScanTime + lastScanTime.get() } finally { if (!fs.delete(path, true)) { logWarning(s"Error deleting ${path}") @@ -640,9 +663,9 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) false } - // For testing. private[history] def isFsInSafeMode(dfs: DistributedFileSystem): Boolean = { - dfs.setSafeMode(HdfsConstants.SafeModeAction.SAFEMODE_GET) + /* true to check only for Active NNs status */ + dfs.setSafeMode(HdfsConstants.SafeModeAction.SAFEMODE_GET, true) } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 96b9ecf43b14c..0e7a6c24d4fa5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -30,13 +30,30 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") Option(request.getParameter("showIncomplete")).getOrElse("false").toBoolean val allAppsSize = parent.getApplicationList().count(_.completed != requestedIncomplete) + val eventLogsUnderProcessCount = parent.getEventLogsUnderProcess() + val lastUpdatedTime = parent.getLastUpdatedTime() val providerConfig = parent.getProviderConfig() val content = +
    {providerConfig.map { case (k, v) =>
  • {k}: {v}
  • }}
+ { + if (eventLogsUnderProcessCount > 0) { +

There are {eventLogsUnderProcessCount} event log(s) currently being + processed which may result in additional applications getting listed on this page. + Refresh the page to view updates.

+ } + } + + { + if (lastUpdatedTime > 0) { +

Last updated: {lastUpdatedTime}

+ } + } + { if (allAppsSize > 0) { ++ @@ -46,6 +63,8 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") } else if (requestedIncomplete) {

No incomplete applications found!

+ } else if (eventLogsUnderProcessCount > 0) { +

No completed applications found!

} else {

No completed applications found!

++ parent.emptyListingHtml } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 3175b36b3e56f..7e21fa681aa1e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -179,6 +179,14 @@ class HistoryServer( provider.getListing() } + def getEventLogsUnderProcess(): Int = { + provider.getEventLogsUnderProcess() + } + + def getLastUpdatedTime(): Long = { + provider.getLastUpdatedTime() + } + def getApplicationInfoList: Iterator[ApplicationInfo] = { getApplicationList().map(ApplicationsListResource.appHistoryInfoToPublicAppInfo) } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index 2eddb5ff54479..080ba12c2f0d1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -24,7 +24,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.util.Utils /** - * Command-line parser for the master. + * Command-line parser for the [[HistoryServer]]. */ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[String]) extends Logging { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 8c91aa15167c4..4618e6117a4fb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.master import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import java.util.concurrent.{ScheduledFuture, TimeUnit} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -51,7 +51,8 @@ private[deploy] class Master( private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs + // For application IDs + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) private val WORKER_TIMEOUT_MS = conf.getLong("spark.worker.timeout", 60) * 1000 private val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala index b30c980e95a9a..524726c2ccf92 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala @@ -19,16 +19,16 @@ package org.apache.spark.deploy.rest import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} -import scala.io.Source - import com.fasterxml.jackson.core.JsonProcessingException -import org.eclipse.jetty.server.{HttpConnectionFactory, Server, ServerConnector} +import org.eclipse.jetty.http.HttpVersion +import org.eclipse.jetty.server.{HttpConfiguration, HttpConnectionFactory, Server, ServerConnector, SslConnectionFactory} import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.eclipse.jetty.util.thread.{QueuedThreadPool, ScheduledExecutorScheduler} import org.json4s._ import org.json4s.jackson.JsonMethods._ +import scala.io.Source -import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} +import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf, SSLOptions} import org.apache.spark.internal.Logging import org.apache.spark.util.Utils @@ -50,7 +50,8 @@ import org.apache.spark.util.Utils private[spark] abstract class RestSubmissionServer( val host: String, val requestedPort: Int, - val masterConf: SparkConf) extends Logging { + val masterConf: SparkConf, + val sslOptions: SSLOptions = SSLOptions()) extends Logging { protected val submitRequestServlet: SubmitRequestServlet protected val killRequestServlet: KillRequestServlet protected val statusRequestServlet: StatusRequestServlet @@ -79,19 +80,32 @@ private[spark] abstract class RestSubmissionServer( * Return a 2-tuple of the started server and the bound port. */ private def doStart(startPort: Int): (Server, Int) = { + // TODO consider using JettyUtils#startServer to do this instead val threadPool = new QueuedThreadPool threadPool.setDaemon(true) val server = new Server(threadPool) + val resolvedConnectionFactories = sslOptions + .createJettySslContextFactory() + .map(sslFactory => { + val sslConnectionFactory = new SslConnectionFactory( + sslFactory, HttpVersion.HTTP_1_1.asString()) + val rawHttpConfiguration = new HttpConfiguration() + rawHttpConfiguration.setSecureScheme("https") + rawHttpConfiguration.setSecurePort(startPort) + val rawHttpConnectionFactory = new HttpConnectionFactory(rawHttpConfiguration) + Array(sslConnectionFactory, rawHttpConnectionFactory) + }).getOrElse(Array(new HttpConnectionFactory())) + val connector = new ServerConnector( - server, - null, - // Call this full constructor to set this, which forces daemon threads: - new ScheduledExecutorScheduler("RestSubmissionServer-JettyScheduler", true), - null, - -1, - -1, - new HttpConnectionFactory()) + server, + null, + // Call this full constructor to set this, which forces daemon threads: + new ScheduledExecutorScheduler("RestSubmissionServer-JettyScheduler", true), + null, + -1, + -1, + resolvedConnectionFactories: _*) connector.setHost(host) connector.setPort(startPort) server.addConnector(connector) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 0bedd9a20a969..0940f3c55844c 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -20,7 +20,7 @@ package org.apache.spark.deploy.worker import java.io.File import java.io.IOException import java.text.SimpleDateFormat -import java.util.{Date, UUID} +import java.util.{Date, Locale, UUID} import java.util.concurrent._ import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} @@ -68,7 +68,7 @@ private[deploy] class Worker( ThreadUtils.newDaemonSingleThreadExecutor("worker-cleanup-thread")) // For worker and executor IDs - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) // Send a heartbeat every (heartbeat timeout) / 4 milliseconds private val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 @@ -187,8 +187,7 @@ private[deploy] class Worker( webUi = new WorkerWebUI(this, workDir, webUiPort) webUi.bind() - val scheme = if (webUi.sslOptions.enabled) "https" else "http" - workerWebUiUrl = s"$scheme://$publicAddress:${webUi.boundPort}" + workerWebUiUrl = s"http://$publicAddress:${webUi.boundPort}" registerWithMaster() metricsSystem.registerSource(workerSource) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 7eec4ae64f296..f0e13aa6bf109 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -200,8 +200,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { new SecurityManager(executorConf), clientMode = true) val driver = fetcher.setupEndpointRefByURI(driverUrl) - val props = driver.askWithRetry[Seq[(String, String)]](RetrieveSparkProps) ++ - Seq[(String, String)](("spark.app.id", appId)) + val cfg = driver.askWithRetry[SparkAppConfig](RetrieveSparkAppConfig(executorId)) + val props = cfg.sparkProperties ++ Seq[(String, String)](("spark.app.id", appId)) fetcher.shutdown() // Create SparkEnv using properties we fetched from the driver. @@ -221,7 +221,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { } val env = SparkEnv.createExecutorEnv( - driverConf, executorId, hostname, port, cores, isLocal = false) + driverConf, executorId, hostname, port, cores, cfg.ioEncryptionKey, isLocal = false) env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend( env.rpcEnv, driverUrl, executorId, hostname, cores, userClassPath, env)) diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index f66510b6f977f..59404e08895a3 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -27,6 +27,9 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFileRecordReader, CombineFileSplit} +import org.apache.spark.internal.config +import org.apache.spark.SparkContext + /** * A general format for reading whole files in as streams, byte arrays, * or other functions to be added @@ -40,9 +43,14 @@ private[spark] abstract class StreamFileInputFormat[T] * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API * which is set through setMaxSplitSize */ - def setMinPartitions(context: JobContext, minPartitions: Int) { - val totalLen = listStatus(context).asScala.filterNot(_.isDirectory).map(_.getLen).sum - val maxSplitSize = math.ceil(totalLen / math.max(minPartitions, 1.0)).toLong + def setMinPartitions(sc: SparkContext, context: JobContext, minPartitions: Int) { + val defaultMaxSplitBytes = sc.getConf.get(config.FILES_MAX_PARTITION_BYTES) + val openCostInBytes = sc.getConf.get(config.FILES_OPEN_COST_IN_BYTES) + val defaultParallelism = sc.defaultParallelism + val files = listStatus(context).asScala + val totalBytes = files.filterNot(_.isDirectory).map(_.getLen + openCostInBytes).sum + val bytesPerCore = totalBytes / defaultParallelism + val maxSplitSize = Math.min(defaultMaxSplitBytes, Math.max(openCostInBytes, bytesPerCore)) super.setMaxSplitSize(maxSplitSize) } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 497ca92c7bc60..f4844dee62ef4 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -198,12 +198,26 @@ package object config { .createWithDefault(0) private[spark] val DRIVER_BLOCK_MANAGER_PORT = ConfigBuilder("spark.driver.blockManager.port") - .doc("Port to use for the block managed on the driver.") + .doc("Port to use for the block manager on the driver.") .fallbackConf(BLOCK_MANAGER_PORT) private[spark] val IGNORE_CORRUPT_FILES = ConfigBuilder("spark.files.ignoreCorruptFiles") .doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " + - "encountering corrupt files and contents that have been read will still be returned.") + "encountering corrupted or non-existing files and contents that have been read will still " + + "be returned.") .booleanConf .createWithDefault(false) + + private[spark] val FILES_MAX_PARTITION_BYTES = ConfigBuilder("spark.files.maxPartitionBytes") + .doc("The maximum number of bytes to pack into a single partition when reading files.") + .longConf + .createWithDefault(128 * 1024 * 1024) + + private[spark] val FILES_OPEN_COST_IN_BYTES = ConfigBuilder("spark.files.openCostInBytes") + .doc("The estimated cost to open a file, measured by the number of bytes could be scanned in" + + " the same time. This is used when putting multiple files into a partition. It's better to" + + " over estimate, then the partitions with small files will be faster than partitions with" + + " bigger files.") + .longConf + .createWithDefault(4 * 1024 * 1024) } diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala new file mode 100644 index 0000000000000..afd2250c93a8a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import org.apache.hadoop.mapreduce._ + +import org.apache.spark.util.Utils + + +/** + * An interface to define how a single Spark job commits its outputs. Two notes: + * + * 1. Implementations must be serializable, as the committer instance instantiated on the driver + * will be used for tasks on executors. + * 2. Implementations should have a constructor with either 2 or 3 arguments: + * (jobId: String, path: String) or (jobId: String, path: String, isAppend: Boolean). + * 3. A committer should not be reused across multiple Spark jobs. + * + * The proper call sequence is: + * + * 1. Driver calls setupJob. + * 2. As part of each task's execution, executor calls setupTask and then commitTask + * (or abortTask if task failed). + * 3. When all necessary tasks completed successfully, the driver calls commitJob. If the job + * failed to execute (e.g. too many failed tasks), the job should call abortJob. + */ +abstract class FileCommitProtocol { + import FileCommitProtocol._ + + /** + * Setups up a job. Must be called on the driver before any other methods can be invoked. + */ + def setupJob(jobContext: JobContext): Unit + + /** + * Commits a job after the writes succeed. Must be called on the driver. + */ + def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit + + /** + * Aborts a job after the writes fail. Must be called on the driver. + * + * Calling this function is a best-effort attempt, because it is possible that the driver + * just crashes (or killed) before it can call abort. + */ + def abortJob(jobContext: JobContext): Unit + + /** + * Sets up a task within a job. + * Must be called before any other task related methods can be invoked. + */ + def setupTask(taskContext: TaskAttemptContext): Unit + + /** + * Notifies the commit protocol to add a new file, and gets back the full path that should be + * used. Must be called on the executors when running tasks. + * + * Note that the returned temp file may have an arbitrary path. The commit protocol only + * promises that the file will be at the location specified by the arguments after job commit. + * + * A full file path consists of the following parts: + * 1. the base path + * 2. some sub-directory within the base path, used to specify partitioning + * 3. file prefix, usually some unique job id with the task id + * 4. bucket id + * 5. source specific file extension, e.g. ".snappy.parquet" + * + * The "dir" parameter specifies 2, and "ext" parameter specifies both 4 and 5, and the rest + * are left to the commit protocol implementation to decide. + * + * Important: it is the caller's responsibility to add uniquely identifying content to "ext" + * if a task is going to write out multiple files to the same dir. The file commit protocol only + * guarantees that files written by different tasks will not conflict. + */ + def newTaskTempFile(taskContext: TaskAttemptContext, dir: Option[String], ext: String): String + + /** + * Similar to newTaskTempFile(), but allows files to committed to an absolute output location. + * Depending on the implementation, there may be weaker guarantees around adding files this way. + * + * Important: it is the caller's responsibility to add uniquely identifying content to "ext" + * if a task is going to write out multiple files to the same dir. The file commit protocol only + * guarantees that files written by different tasks will not conflict. + */ + def newTaskTempFileAbsPath( + taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String + + /** + * Commits a task after the writes succeed. Must be called on the executors when running tasks. + */ + def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage + + /** + * Aborts a task after the writes have failed. Must be called on the executors when running tasks. + * + * Calling this function is a best-effort attempt, because it is possible that the executor + * just crashes (or killed) before it can call abort. + */ + def abortTask(taskContext: TaskAttemptContext): Unit +} + + +object FileCommitProtocol { + class TaskCommitMessage(val obj: Any) extends Serializable + + object EmptyTaskCommitMessage extends TaskCommitMessage(null) + + /** + * Instantiates a FileCommitProtocol using the given className. + */ + def instantiate(className: String, jobId: String, outputPath: String, isAppend: Boolean) + : FileCommitProtocol = { + val clazz = Utils.classForName(className).asInstanceOf[Class[FileCommitProtocol]] + + // First try the one with argument (jobId: String, outputPath: String, isAppend: Boolean). + // If that doesn't exist, try the one with (jobId: string, outputPath: String). + try { + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String], classOf[Boolean]) + ctor.newInstance(jobId, outputPath, isAppend.asInstanceOf[java.lang.Boolean]) + } catch { + case _: NoSuchMethodException => + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String]) + ctor.newInstance(jobId, outputPath) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala new file mode 100644 index 0000000000000..c99b75e52325e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import java.util.{Date, UUID} + +import scala.collection.mutable + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl + +import org.apache.spark.SparkHadoopWriter +import org.apache.spark.internal.Logging +import org.apache.spark.mapred.SparkHadoopMapRedUtil + +/** + * An [[FileCommitProtocol]] implementation backed by an underlying Hadoop OutputCommitter + * (from the newer mapreduce API, not the old mapred API). + * + * Unlike Hadoop's OutputCommitter, this implementation is serializable. + */ +class HadoopMapReduceCommitProtocol(jobId: String, path: String) + extends FileCommitProtocol with Serializable with Logging { + + import FileCommitProtocol._ + + /** OutputCommitter from Hadoop is not serializable so marking it transient. */ + @transient private var committer: OutputCommitter = _ + + /** + * Tracks files staged by this task for absolute output paths. These outputs are not managed by + * the Hadoop OutputCommitter, so we must move these to their final locations on job commit. + * + * The mapping is from the temp output path to the final desired output path of the file. + */ + @transient private var addedAbsPathFiles: mutable.Map[String, String] = null + + /** + * The staging directory for all files committed with absolute output paths. + */ + private def absPathStagingDir: Path = new Path(path, "_temporary-" + jobId) + + protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { + context.getOutputFormatClass.newInstance().getOutputCommitter(context) + } + + override def newTaskTempFile( + taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = { + val filename = getFilename(taskContext, ext) + + val stagingDir: String = committer match { + // For FileOutputCommitter it has its own staging path called "work path". + case f: FileOutputCommitter => Option(f.getWorkPath.toString).getOrElse(path) + case _ => path + } + + dir.map { d => + new Path(new Path(stagingDir, d), filename).toString + }.getOrElse { + new Path(stagingDir, filename).toString + } + } + + override def newTaskTempFileAbsPath( + taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = { + val filename = getFilename(taskContext, ext) + val absOutputPath = new Path(absoluteDir, filename).toString + + // Include a UUID here to prevent file collisions for one task writing to different dirs. + // In principle we could include hash(absoluteDir) instead but this is simpler. + val tmpOutputPath = new Path( + absPathStagingDir, UUID.randomUUID().toString() + "-" + filename).toString + + addedAbsPathFiles(tmpOutputPath) = absOutputPath + tmpOutputPath + } + + private def getFilename(taskContext: TaskAttemptContext, ext: String): String = { + // The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet + // Note that %05d does not truncate the split number, so if we have more than 100000 tasks, + // the file name is fine and won't overflow. + val split = taskContext.getTaskAttemptID.getTaskID.getId + f"part-$split%05d-$jobId$ext" + } + + override def setupJob(jobContext: JobContext): Unit = { + // Setup IDs + val jobId = SparkHadoopWriter.createJobID(new Date, 0) + val taskId = new TaskID(jobId, TaskType.MAP, 0) + val taskAttemptId = new TaskAttemptID(taskId, 0) + + // Set up the configuration object + jobContext.getConfiguration.set("mapred.job.id", jobId.toString) + jobContext.getConfiguration.set("mapred.tip.id", taskAttemptId.getTaskID.toString) + jobContext.getConfiguration.set("mapred.task.id", taskAttemptId.toString) + jobContext.getConfiguration.setBoolean("mapred.task.is.map", true) + jobContext.getConfiguration.setInt("mapred.task.partition", 0) + + val taskAttemptContext = new TaskAttemptContextImpl(jobContext.getConfiguration, taskAttemptId) + committer = setupCommitter(taskAttemptContext) + committer.setupJob(jobContext) + } + + override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = { + committer.commitJob(jobContext) + val filesToMove = taskCommits.map(_.obj.asInstanceOf[Map[String, String]]) + .foldLeft(Map[String, String]())(_ ++ _) + logDebug(s"Committing files staged for absolute locations $filesToMove") + val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + for ((src, dst) <- filesToMove) { + fs.rename(new Path(src), new Path(dst)) + } + fs.delete(absPathStagingDir, true) + } + + override def abortJob(jobContext: JobContext): Unit = { + committer.abortJob(jobContext, JobStatus.State.FAILED) + val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + fs.delete(absPathStagingDir, true) + } + + override def setupTask(taskContext: TaskAttemptContext): Unit = { + committer = setupCommitter(taskContext) + committer.setupTask(taskContext) + addedAbsPathFiles = mutable.Map[String, String]() + } + + override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = { + val attemptId = taskContext.getTaskAttemptID + SparkHadoopMapRedUtil.commitTask( + committer, taskContext, attemptId.getJobID.getId, attemptId.getTaskID.getId) + new TaskCommitMessage(addedAbsPathFiles.toMap) + } + + override def abortTask(taskContext: TaskAttemptContext): Unit = { + committer.abortTask(taskContext) + // best effort cleanup of other staged files + for ((src, _) <- addedAbsPathFiles) { + val tmp = new Path(src) + tmp.getFileSystem(taskContext.getConfiguration).delete(tmp, false) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index ae014becef755..2e991ce394c42 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -32,9 +32,8 @@ import org.apache.spark.util.Utils * CompressionCodec allows the customization of choosing different compression implementations * to be used in block storage. * - * Note: The wire protocol for a codec is not guaranteed compatible across versions of Spark. - * This is intended for use as an internal compression utility within a single - * Spark application. + * @note The wire protocol for a codec is not guaranteed compatible across versions of Spark. + * This is intended for use as an internal compression utility within a single Spark application. */ @DeveloperApi trait CompressionCodec { @@ -103,9 +102,9 @@ private[spark] object CompressionCodec { * LZ4 implementation of [[org.apache.spark.io.CompressionCodec]]. * Block size can be configured by `spark.io.compression.lz4.blockSize`. * - * Note: The wire protocol for this codec is not guaranteed to be compatible across versions - * of Spark. This is intended for use as an internal compression utility within a single Spark - * application. + * @note The wire protocol for this codec is not guaranteed to be compatible across versions + * of Spark. This is intended for use as an internal compression utility within a single Spark + * application. */ @DeveloperApi class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec { @@ -123,9 +122,9 @@ class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec { * :: DeveloperApi :: * LZF implementation of [[org.apache.spark.io.CompressionCodec]]. * - * Note: The wire protocol for this codec is not guaranteed to be compatible across versions - * of Spark. This is intended for use as an internal compression utility within a single Spark - * application. + * @note The wire protocol for this codec is not guaranteed to be compatible across versions + * of Spark. This is intended for use as an internal compression utility within a single Spark + * application. */ @DeveloperApi class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { @@ -143,9 +142,9 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { * Snappy implementation of [[org.apache.spark.io.CompressionCodec]]. * Block size can be configured by `spark.io.compression.snappy.blockSize`. * - * Note: The wire protocol for this codec is not guaranteed to be compatible across versions - * of Spark. This is intended for use as an internal compression utility within a single Spark - * application. + * @note The wire protocol for this codec is not guaranteed to be compatible across versions + * of Spark. This is intended for use as an internal compression utility within a single Spark + * application. */ @DeveloperApi class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { @@ -173,7 +172,7 @@ private final object SnappyCompressionCodec { } /** - * Wrapper over [[SnappyOutputStream]] which guards against write-after-close and double-close + * Wrapper over `SnappyOutputStream` which guards against write-after-close and double-close * issues. See SPARK-7660 for more details. This wrapping can be removed if we upgrade to a version * of snappy-java that contains the fix for https://github.com/xerial/snappy-java/issues/107. */ diff --git a/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala b/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala index 3f7cfd9d2c11f..99ec78633ab75 100644 --- a/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala +++ b/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala @@ -85,6 +85,17 @@ object HiveCatalogMetrics extends Source { */ val METRIC_FILE_CACHE_HITS = metricRegistry.counter(MetricRegistry.name("fileCacheHits")) + /** + * Tracks the total number of Hive client calls (e.g. to lookup a table). + */ + val METRIC_HIVE_CLIENT_CALLS = metricRegistry.counter(MetricRegistry.name("hiveClientCalls")) + + /** + * Tracks the total number of Spark jobs launched for parallel file listing. + */ + val METRIC_PARALLEL_LISTING_JOB_COUNT = metricRegistry.counter( + MetricRegistry.name("parallelListingJobCount")) + /** * Resets the values of all metrics to zero. This is useful in tests. */ @@ -92,10 +103,14 @@ object HiveCatalogMetrics extends Source { METRIC_PARTITIONS_FETCHED.dec(METRIC_PARTITIONS_FETCHED.getCount()) METRIC_FILES_DISCOVERED.dec(METRIC_FILES_DISCOVERED.getCount()) METRIC_FILE_CACHE_HITS.dec(METRIC_FILE_CACHE_HITS.getCount()) + METRIC_HIVE_CLIENT_CALLS.dec(METRIC_HIVE_CLIENT_CALLS.getCount()) + METRIC_PARALLEL_LISTING_JOB_COUNT.dec(METRIC_PARALLEL_LISTING_JOB_COUNT.getCount()) } // clients can use these to avoid classloader issues with the codahale classes def incrementFetchedPartitions(n: Int): Unit = METRIC_PARTITIONS_FETCHED.inc(n) def incrementFilesDiscovered(n: Int): Unit = METRIC_FILES_DISCOVERED.inc(n) def incrementFileCacheHits(n: Int): Unit = METRIC_FILE_CACHE_HITS.inc(n) + def incrementHiveClientCalls(n: Int): Unit = METRIC_HIVE_CLIENT_CALLS.inc(n) + def incrementParallelListingJobCount(n: Int): Unit = METRIC_PARALLEL_LISTING_JOB_COUNT.inc(n) } diff --git a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala index ab6aba6fc7d6a..8f579c5a3033c 100644 --- a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala +++ b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala @@ -28,7 +28,7 @@ class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, v this.mean.hashCode ^ this.confidence.hashCode ^ this.low.hashCode ^ this.high.hashCode /** - * Note that consistent with Double, any NaN value will make equality false + * @note Consistent with Double, any NaN value will make equality false */ override def equals(that: Any): Boolean = that match { diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala index 41832e8354741..50d977a92da51 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -26,7 +26,7 @@ import org.apache.spark.{Partition, SparkContext} import org.apache.spark.input.StreamFileInputFormat private[spark] class BinaryFileRDD[T]( - sc: SparkContext, + @transient private val sc: SparkContext, inputFormatClass: Class[_ <: StreamFileInputFormat[T]], keyClass: Class[String], valueClass: Class[T], @@ -43,7 +43,7 @@ private[spark] class BinaryFileRDD[T]( case _ => } val jobContext = new JobContextImpl(conf, jobId) - inputFormat.setMinPartitions(jobContext, minPartitions) + inputFormat.setMinPartitions(sc, jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) for (i <- 0 until rawSplits.size) { diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 2381f54ee3f06..a091f06b4ed7c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -66,14 +66,14 @@ private[spark] class CoGroupPartition( /** * :: DeveloperApi :: - * A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a + * An RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a * tuple with the list of values for that key. * - * Note: This is an internal API. We recommend users use RDD.cogroup(...) instead of - * instantiating this directly. - * * @param rdds parent RDDs. * @param part partitioner used to partition the shuffle output + * + * @note This is an internal API. We recommend users use RDD.cogroup(...) instead of + * instantiating this directly. */ @DeveloperApi class CoGroupedRDD[K: ClassTag]( diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index a05a770b40c57..14331dfd0c987 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -152,13 +152,13 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { /** * Compute a histogram using the provided buckets. The buckets are all open - * to the right except for the last which is closed + * to the right except for the last which is closed. * e.g. for the array * [1, 10, 20, 50] the buckets are [1, 10) [10, 20) [20, 50] - * e.g 1<=x<10 , 10<=x<20, 20<=x<=50 + * e.g {@code <=x<10, 10<=x<20, 20<=x<=50} * And on the input of 1 and 50 we would have a histogram of 1, 0, 1 * - * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched + * @note If your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched * from an O(log n) insertion to O(1) per element. (where n = # buckets) if you set evenBuckets * to true. * buckets must be sorted and not contain any duplicates. diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index e1cf3938de098..b56ebf4df06e9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.io.IOException import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import scala.collection.immutable.Map import scala.reflect.ClassTag @@ -84,9 +84,6 @@ private[spark] class HadoopPartition(rddId: Int, override val index: Int, s: Inp * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, * sources in HBase, or S3), using the older MapReduce API (`org.apache.hadoop.mapred`). * - * Note: Instantiating this class directly is not recommended, please use - * [[org.apache.spark.SparkContext.hadoopRDD()]] - * * @param sc The SparkContext to associate the RDD with. * @param broadcastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed * variable references an instance of JobConf, then that JobConf will be used for the Hadoop job. @@ -97,6 +94,9 @@ private[spark] class HadoopPartition(rddId: Int, override val index: Int, s: Inp * @param keyClass Class of the key associated with the inputFormatClass. * @param valueClass Class of the value associated with the inputFormatClass. * @param minPartitions Minimum number of HadoopRDD partitions (Hadoop Splits) to generate. + * + * @note Instantiating this class directly is not recommended, please use + * `org.apache.spark.SparkContext.hadoopRDD()` */ @DeveloperApi class HadoopRDD[K, V]( @@ -210,12 +210,12 @@ class HadoopRDD[K, V]( override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = { val iter = new NextIterator[(K, V)] { - val split = theSplit.asInstanceOf[HadoopPartition] + private val split = theSplit.asInstanceOf[HadoopPartition] logInfo("Input split: " + split.inputSplit) - val jobConf = getJobConf() + private val jobConf = getJobConf() - val inputMetrics = context.taskMetrics().inputMetrics - val existingBytesRead = inputMetrics.bytesRead + private val inputMetrics = context.taskMetrics().inputMetrics + private val existingBytesRead = inputMetrics.bytesRead // Sets the thread local variable for the file's name split.inputSplit.value match { @@ -225,7 +225,7 @@ class HadoopRDD[K, V]( // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match { + private val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match { case _: FileSplit | _: CombineFileSplit => SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() case _ => None @@ -235,28 +235,39 @@ class HadoopRDD[K, V]( // If we do a coalesce, however, we are likely to compute multiple partitions in the same // task and in the same thread, in which case we need to avoid override values written by // previous partitions (SPARK-13071). - def updateBytesRead(): Unit = { + private def updateBytesRead(): Unit = { getBytesReadCallback.foreach { getBytesRead => inputMetrics.setBytesRead(existingBytesRead + getBytesRead()) } } - var reader: RecordReader[K, V] = null - val inputFormat = getInputFormat(jobConf) - HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss").format(createTime), + private var reader: RecordReader[K, V] = null + private val inputFormat = getInputFormat(jobConf) + HadoopRDD.addLocalConfiguration( + new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(createTime), context.stageId, theSplit.index, context.attemptNumber, jobConf) - reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) + reader = + try { + inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) + } catch { + case e: IOException if ignoreCorruptFiles => + logWarning(s"Skipped the rest content in the corrupted file: ${split.inputSplit}", e) + finished = true + null + } // Register an on-task-completion callback to close the input stream. context.addTaskCompletionListener{ context => closeIfNeeded() } - val key: K = reader.createKey() - val value: V = reader.createValue() + private val key: K = if (reader == null) null.asInstanceOf[K] else reader.createKey() + private val value: V = if (reader == null) null.asInstanceOf[V] else reader.createValue() override def getNext(): (K, V) = { try { finished = !reader.next(key, value) } catch { - case e: IOException if ignoreCorruptFiles => finished = true + case e: IOException if ignoreCorruptFiles => + logWarning(s"Skipped the rest content in the corrupted file: ${split.inputSplit}", e) + finished = true } if (!finished) { inputMetrics.incRecordsRead(1) diff --git a/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala b/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala index f40d4c8e0a4d0..960c91a154db1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala +++ b/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala @@ -22,6 +22,8 @@ import org.apache.spark.unsafe.types.UTF8String /** * This holds file names of the current Spark task. This is used in HadoopRDD, * FileScanRDD, NewHadoopRDD and InputFileName function in Spark SQL. + * + * The returned value should never be null but empty string if it is unknown. */ private[spark] object InputFileNameHolder { /** @@ -32,9 +34,15 @@ private[spark] object InputFileNameHolder { override protected def initialValue(): UTF8String = UTF8String.fromString("") } + /** + * Returns the holding file name or empty string if it is unknown. + */ def getInputFileName(): UTF8String = inputFileName.get() - private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file)) + private[spark] def setInputFileName(file: String) = { + require(file != null, "The input file name cannot be null") + inputFileName.set(UTF8String.fromString(file)) + } private[spark] def unsetInputFileName(): Unit = inputFileName.remove() diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index 0970b98071675..aab46b8954bf7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -41,7 +41,10 @@ private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) e * The RDD takes care of closing the connection. * @param sql the text of the query. * The query must contain two ? placeholders for parameters used to partition the results. - * E.g. "select title, author from books where ? <= id and id <= ?" + * For example, + * {{{ + * select title, author from books where ? <= id and id <= ? + * }}} * @param lowerBound the minimum value of the first placeholder * @param upperBound the maximum value of the second placeholder * The lower and upper bounds are inclusive. @@ -151,7 +154,10 @@ object JdbcRDD { * The RDD takes care of closing the connection. * @param sql the text of the query. * The query must contain two ? placeholders for parameters used to partition the results. - * E.g. "select title, author from books where ? <= id and id <= ?" + * For example, + * {{{ + * select title, author from books where ? <= id and id <= ? + * }}} * @param lowerBound the minimum value of the first placeholder * @param upperBound the maximum value of the second placeholder * The lower and upper bounds are inclusive. @@ -191,7 +197,10 @@ object JdbcRDD { * The RDD takes care of closing the connection. * @param sql the text of the query. * The query must contain two ? placeholders for parameters used to partition the results. - * E.g. "select title, author from books where ? <= id and id <= ?" + * For example, + * {{{ + * select title, author from books where ? <= id and id <= ? + * }}} * @param lowerBound the minimum value of the first placeholder * @param upperBound the maximum value of the second placeholder * The lower and upper bounds are inclusive. diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index baf31fb658870..6168d979032aa 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.io.IOException import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import scala.reflect.ClassTag @@ -57,13 +57,13 @@ private[spark] class NewHadoopPartition( * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, * sources in HBase, or S3), using the new MapReduce API (`org.apache.hadoop.mapreduce`). * - * Note: Instantiating this class directly is not recommended, please use - * [[org.apache.spark.SparkContext.newAPIHadoopRDD()]] - * * @param sc The SparkContext to associate the RDD with. * @param inputFormatClass Storage format of the data to be read. * @param keyClass Class of the key associated with the inputFormatClass. * @param valueClass Class of the value associated with the inputFormatClass. + * + * @note Instantiating this class directly is not recommended, please use + * `org.apache.spark.SparkContext.newAPIHadoopRDD()` */ @DeveloperApi class NewHadoopRDD[K, V]( @@ -79,7 +79,7 @@ class NewHadoopRDD[K, V]( // private val serializableConf = new SerializableWritable(_conf) private val jobTrackerId: String = { - val formatter = new SimpleDateFormat("yyyyMMddHHmmss") + val formatter = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) formatter.format(new Date()) } @@ -132,12 +132,12 @@ class NewHadoopRDD[K, V]( override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = { val iter = new Iterator[(K, V)] { - val split = theSplit.asInstanceOf[NewHadoopPartition] + private val split = theSplit.asInstanceOf[NewHadoopPartition] logInfo("Input split: " + split.serializableHadoopSplit) - val conf = getConf + private val conf = getConf - val inputMetrics = context.taskMetrics().inputMetrics - val existingBytesRead = inputMetrics.bytesRead + private val inputMetrics = context.taskMetrics().inputMetrics + private val existingBytesRead = inputMetrics.bytesRead // Sets the thread local variable for the file's name split.serializableHadoopSplit.value match { @@ -147,46 +147,62 @@ class NewHadoopRDD[K, V]( // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val getBytesReadCallback: Option[() => Long] = split.serializableHadoopSplit.value match { - case _: FileSplit | _: CombineFileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() - case _ => None - } + private val getBytesReadCallback: Option[() => Long] = + split.serializableHadoopSplit.value match { + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + case _ => None + } // For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics. // If we do a coalesce, however, we are likely to compute multiple partitions in the same // task and in the same thread, in which case we need to avoid override values written by // previous partitions (SPARK-13071). - def updateBytesRead(): Unit = { + private def updateBytesRead(): Unit = { getBytesReadCallback.foreach { getBytesRead => inputMetrics.setBytesRead(existingBytesRead + getBytesRead()) } } - val format = inputFormatClass.newInstance + private val format = inputFormatClass.newInstance format match { case configurable: Configurable => configurable.setConf(conf) case _ => } - val attemptId = new TaskAttemptID(jobTrackerId, id, TaskType.MAP, split.index, 0) - val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) - private var reader = format.createRecordReader( - split.serializableHadoopSplit.value, hadoopAttemptContext) - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + private val attemptId = new TaskAttemptID(jobTrackerId, id, TaskType.MAP, split.index, 0) + private val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + private var finished = false + private var reader = + try { + val _reader = format.createRecordReader( + split.serializableHadoopSplit.value, hadoopAttemptContext) + _reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + _reader + } catch { + case e: IOException if ignoreCorruptFiles => + logWarning( + s"Skipped the rest content in the corrupted file: ${split.serializableHadoopSplit}", + e) + finished = true + null + } // Register an on-task-completion callback to close the input stream. context.addTaskCompletionListener(context => close()) - var havePair = false - var finished = false - var recordsSinceMetricsUpdate = 0 + private var havePair = false + private var recordsSinceMetricsUpdate = 0 override def hasNext: Boolean = { if (!finished && !havePair) { try { finished = !reader.nextKeyValue } catch { - case e: IOException if ignoreCorruptFiles => finished = true + case e: IOException if ignoreCorruptFiles => + logWarning( + s"Skipped the rest content in the corrupted file: ${split.serializableHadoopSplit}", + e) + finished = true } if (finished) { // Close and release the reader here; close() will also be called when the task diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 068f4ed8ad745..dc123e23b781c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.nio.ByteBuffer import java.text.SimpleDateFormat -import java.util.{Date, HashMap => JHashMap} +import java.util.{Date, HashMap => JHashMap, Locale} import scala.collection.{mutable, Map} import scala.collection.JavaConverters._ @@ -59,8 +59,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * :: Experimental :: * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined type" C - * Note that V and C can be different -- for example, one might group an RDD of type - * (Int, Int) into an RDD of type (Int, Seq[Int]). Users provide three functions: + * + * Users provide three functions: * * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) @@ -68,6 +68,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * * In addition, users can control the partitioning of the output RDD, and whether to perform * map-side aggregation (if a mapper can produce multiple items with the same key). + * + * @note V and C can be different -- for example, one might group an RDD of type + * (Int, Int) into an RDD of type (Int, Seq[Int]). */ @Experimental def combineByKeyWithClassTag[C]( @@ -363,7 +366,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) /** * Count the number of elements for each key, collecting the results to a local Map. * - * Note that this method should only be used if the resulting map is expected to be small, as + * @note This method should only be used if the resulting map is expected to be small, as * the whole thing is loaded into the driver's memory. * To handle very large results, consider using rdd.mapValues(_ => 1L).reduceByKey(_ + _), which * returns an RDD[T, Long] instead of a map. @@ -398,9 +401,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available * here. * - * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero `sp > p` - * would trigger sparse representation of registers, which may reduce the memory consumption - * and increase accuracy when the cardinality is small. + * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero (`sp` is + * greater than `p`) would trigger sparse representation of registers, which may reduce the + * memory consumption and increase accuracy when the cardinality is small. * * @param p The precision value for the normal set. * `p` must be a value between 4 and `sp` if `sp` is not zero (32 max). @@ -490,11 +493,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * The ordering of elements within each group is not guaranteed, and may even differ * each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an - * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] - * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. + * @note This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey` + * or `PairRDDFunctions.reduceByKey` will provide much better performance. * - * Note: As currently implemented, groupByKey must be able to hold all the key-value pairs for any + * @note As currently implemented, groupByKey must be able to hold all the key-value pairs for any * key in memory. If a key has too many values, it can result in an [[OutOfMemoryError]]. */ def groupByKey(partitioner: Partitioner): RDD[(K, Iterable[V])] = self.withScope { @@ -514,11 +517,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * resulting RDD with into `numPartitions` partitions. The ordering of elements within * each group is not guaranteed, and may even differ each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an - * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] - * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. + * @note This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey` + * or `PairRDDFunctions.reduceByKey` will provide much better performance. * - * Note: As currently implemented, groupByKey must be able to hold all the key-value pairs for any + * @note As currently implemented, groupByKey must be able to hold all the key-value pairs for any * key in memory. If a key has too many values, it can result in an [[OutOfMemoryError]]. */ def groupByKey(numPartitions: Int): RDD[(K, Iterable[V])] = self.withScope { @@ -635,9 +638,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * within each group is not guaranteed, and may even differ each time the resulting RDD is * evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an - * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] - * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. + * @note This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey` + * or `PairRDDFunctions.reduceByKey` will provide much better performance. */ def groupByKey(): RDD[(K, Iterable[V])] = self.withScope { groupByKey(defaultPartitioner(self)) @@ -907,20 +910,24 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Return an RDD with the pairs from `this` whose keys are not in `other`. * * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. + * RDD will be less than or equal to us. */ def subtractByKey[W: ClassTag](other: RDD[(K, W)]): RDD[(K, V)] = self.withScope { subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.length))) } - /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ + /** + * Return an RDD with the pairs from `this` whose keys are not in `other`. + */ def subtractByKey[W: ClassTag]( other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] = self.withScope { subtractByKey(other, new HashPartitioner(numPartitions)) } - /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ + /** + * Return an RDD with the pairs from `this` whose keys are not in `other`. + */ def subtractByKey[W: ClassTag](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] = self.withScope { new SubtractedRDD[K, V, W](self, other, p) } @@ -1016,7 +1023,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class * supporting the key and value types K and V in this RDD. * - * Note that, we should make sure our tasks are idempotent when speculation is enabled, i.e. do + * @note We should make sure our tasks are idempotent when speculation is enabled, i.e. do * not use output committer that writes data directly. * There is an example in https://issues.apache.org/jira/browse/SPARK-10063 to show the bad * result of using direct output committer with speculation enabled. @@ -1070,7 +1077,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * output paths required (e.g. a table name to write to) in the same way as it would be * configured for a Hadoop MapReduce job. * - * Note that, we should make sure our tasks are idempotent when speculation is enabled, i.e. do + * @note We should make sure our tasks are idempotent when speculation is enabled, i.e. do * not use output committer that writes data directly. * There is an example in https://issues.apache.org/jira/browse/SPARK-10063 to show the bad * result of using direct output committer with speculation enabled. @@ -1079,7 +1086,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf val job = NewAPIHadoopJob.getInstance(hadoopConf) - val formatter = new SimpleDateFormat("yyyyMMddHHmmss") + val formatter = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) val jobtrackerID = formatter.format(new Date()) val stageId = self.id val jobConfiguration = job.getConfiguration diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index 0c6ddda52cee9..ce75a16031a3f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -48,7 +48,7 @@ private[spark] class PruneDependency[T](rdd: RDD[T], partitionFilterFunc: Int => /** * :: DeveloperApi :: - * A RDD used to prune RDD partitions/partitions so we can avoid launching tasks on + * An RDD used to prune RDD partitions/partitions so we can avoid launching tasks on * all partitions. An example use case: If we know the RDD is partitioned by range, * and the execution DAG has a filter on the key, we can avoid launching tasks * on partitions that don't have the range covering the key. diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala index 3b1acacf409b9..6a89ea8786464 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala @@ -32,7 +32,7 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long) } /** - * A RDD sampled from its parent RDD partition-wise. For each partition of the parent RDD, + * An RDD sampled from its parent RDD partition-wise. For each partition of the parent RDD, * a user-specified [[org.apache.spark.util.random.RandomSampler]] instance is used to obtain * a random sample of the records in the partition. The random seeds assigned to the samplers * are guaranteed to have different values. diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index db535de9e9bb3..374abccf6ad55 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -70,8 +70,8 @@ import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, Poi * All of the scheduling and execution in Spark is done based on these methods, allowing each RDD * to implement its own way of computing itself. Indeed, users can implement custom RDDs (e.g. for * reading data from a new storage system) by overriding these functions. Please refer to the - * [[http://people.csail.mit.edu/matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details - * on RDD internals. + * Spark paper + * for more details on RDD internals. */ abstract class RDD[T: ClassTag]( @transient private var _sc: SparkContext, @@ -195,10 +195,14 @@ abstract class RDD[T: ClassTag]( } } - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ + /** + * Persist this RDD with the default storage level (`MEMORY_ONLY`). + */ def persist(): this.type = persist(StorageLevel.MEMORY_ONLY) - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ + /** + * Persist this RDD with the default storage level (`MEMORY_ONLY`). + */ def cache(): this.type = persist() /** @@ -428,7 +432,7 @@ abstract class RDD[T: ClassTag]( * current upstream partitions will be executed in parallel (per whatever * the current partitioning is). * - * Note: With shuffle = true, you can actually coalesce to a larger number + * @note With shuffle = true, you can actually coalesce to a larger number * of partitions. This is useful if you have a small number of partitions, * say 100, potentially with a few partitions being abnormally large. Calling * coalesce(1000, shuffle = true) will result in 1000 partitions with the @@ -469,8 +473,12 @@ abstract class RDD[T: ClassTag]( * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] - * with replacement: expected number of times each element is chosen; fraction must be >= 0 + * with replacement: expected number of times each element is chosen; fraction must be greater + * than or equal to 0 * @param seed seed for the random number generator + * + * @note This is NOT guaranteed to provide exactly the fraction of the count + * of the given [[RDD]]. */ def sample( withReplacement: Boolean, @@ -534,13 +542,13 @@ abstract class RDD[T: ClassTag]( /** * Return a fixed-size sampled subset of this RDD in an array * - * @note this method should only be used if the resulting array is expected to be small, as - * all the data is loaded into the driver's memory. - * * @param withReplacement whether sampling is done with replacement * @param num size of the returned sample * @param seed seed for the random number generator * @return sample of specified size in an array + * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. */ def takeSample( withReplacement: Boolean, @@ -615,7 +623,7 @@ abstract class RDD[T: ClassTag]( * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. */ def intersection(other: RDD[T]): RDD[T] = withScope { this.map(v => (v, null)).cogroup(other.map(v => (v, null))) @@ -627,7 +635,7 @@ abstract class RDD[T: ClassTag]( * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. * * @param partitioner Partitioner to use for the resulting RDD */ @@ -643,7 +651,7 @@ abstract class RDD[T: ClassTag]( * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. Performs a hash partition across the cluster * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. * * @param numPartitions How many partitions to use in the resulting RDD */ @@ -671,9 +679,9 @@ abstract class RDD[T: ClassTag]( * mapping to that key. The ordering of elements within each group is not guaranteed, and * may even differ each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an - * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] - * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. + * @note This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey` + * or `PairRDDFunctions.reduceByKey` will provide much better performance. */ def groupBy[K](f: T => K)(implicit kt: ClassTag[K]): RDD[(K, Iterable[T])] = withScope { groupBy[K](f, defaultPartitioner(this)) @@ -684,9 +692,9 @@ abstract class RDD[T: ClassTag]( * mapping to that key. The ordering of elements within each group is not guaranteed, and * may even differ each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an - * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] - * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. + * @note This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey` + * or `PairRDDFunctions.reduceByKey` will provide much better performance. */ def groupBy[K]( f: T => K, @@ -699,9 +707,9 @@ abstract class RDD[T: ClassTag]( * mapping to that key. The ordering of elements within each group is not guaranteed, and * may even differ each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an - * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] - * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. + * @note This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey` + * or `PairRDDFunctions.reduceByKey` will provide much better performance. */ def groupBy[K](f: T => K, p: Partitioner)(implicit kt: ClassTag[K], ord: Ordering[K] = null) : RDD[(K, Iterable[T])] = withScope { @@ -747,8 +755,10 @@ abstract class RDD[T: ClassTag]( * print line function (like out.println()) as the 2nd parameter. * An example of pipe the RDD data of groupBy() in a streaming way, * instead of constructing a huge String to concat all the elements: - * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) = - * for (e <- record._2) {f(e)} + * {{{ + * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) = + * for (e <- record._2) {f(e)} + * }}} * @param separateWorkingDir Use separate working directories for each task. * @param bufferSize Buffer size for the stdin writer for the piped process. * @param encoding Char encoding used for interacting (via stdin, stdout and stderr) with @@ -788,14 +798,26 @@ abstract class RDD[T: ClassTag]( } /** - * [performance] Spark's internal mapPartitions method which skips closure cleaning. It is a - * performance API to be used carefully only if we are sure that the RDD elements are + * [performance] Spark's internal mapPartitionsWithIndex method that skips closure cleaning. + * It is a performance API to be used carefully only if we are sure that the RDD elements are * serializable and don't require closure cleaning. * * @param preservesPartitioning indicates whether the input function preserves the partitioner, * which should be `false` unless this is a pair RDD and the input function doesn't modify * the keys. */ + private[spark] def mapPartitionsWithIndexInternal[U: ClassTag]( + f: (Int, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = withScope { + new MapPartitionsRDD( + this, + (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter), + preservesPartitioning) + } + + /** + * [performance] Spark's internal mapPartitions method that skips closure cleaning. + */ private[spark] def mapPartitionsInternal[U: ClassTag]( f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = withScope { @@ -906,7 +928,7 @@ abstract class RDD[T: ClassTag]( /** * Return an array that contains all of the elements in this RDD. * - * @note this method should only be used if the resulting array is expected to be small, as + * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. */ def collect(): Array[T] = withScope { @@ -919,7 +941,7 @@ abstract class RDD[T: ClassTag]( * * The iterator will consume as much memory as the largest partition in this RDD. * - * Note: this results in multiple Spark jobs, and if the input RDD is the result + * @note This results in multiple Spark jobs, and if the input RDD is the result * of a wide transformation (e.g. join with different partitioners), to avoid * recomputing the input RDD should be cached first. */ @@ -1167,10 +1189,15 @@ abstract class RDD[T: ClassTag]( /** * Return the count of each unique value in this RDD as a local map of (value, count) pairs. * - * Note that this method should only be used if the resulting map is expected to be small, as + * @note This method should only be used if the resulting map is expected to be small, as * the whole thing is loaded into the driver's memory. - * To handle very large results, consider using rdd.map(x => (x, 1L)).reduceByKey(_ + _), which - * returns an RDD[T, Long] instead of a map. + * To handle very large results, consider using + * + * {{{ + * rdd.map(x => (x, 1L)).reduceByKey(_ + _) + * }}} + * + * , which returns an RDD[T, Long] instead of a map. */ def countByValue()(implicit ord: Ordering[T] = null): Map[T, Long] = withScope { map(value => (value, null)).countByKey() @@ -1208,9 +1235,9 @@ abstract class RDD[T: ClassTag]( * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available * here. * - * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero `sp > p` - * would trigger sparse representation of registers, which may reduce the memory consumption - * and increase accuracy when the cardinality is small. + * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero (`sp` is greater + * than `p`) would trigger sparse representation of registers, which may reduce the memory + * consumption and increase accuracy when the cardinality is small. * * @param p The precision value for the normal set. * `p` must be a value between 4 and `sp` if `sp` is not zero (32 max). @@ -1257,7 +1284,7 @@ abstract class RDD[T: ClassTag]( * This is similar to Scala's zipWithIndex but it uses Long instead of Int as the index type. * This method needs to trigger a spark job when this RDD contains more than one partitions. * - * Note that some RDDs, such as those returned by groupBy(), do not guarantee order of + * @note Some RDDs, such as those returned by groupBy(), do not guarantee order of * elements in a partition. The index assigned to each element is therefore not guaranteed, * and may even change if the RDD is reevaluated. If a fixed ordering is required to guarantee * the same index assignments, you should sort the RDD with sortByKey() or save it to a file. @@ -1271,7 +1298,7 @@ abstract class RDD[T: ClassTag]( * 2*n+k, ..., where n is the number of partitions. So there may exist gaps, but this method * won't trigger a spark job, which is different from [[org.apache.spark.rdd.RDD#zipWithIndex]]. * - * Note that some RDDs, such as those returned by groupBy(), do not guarantee order of + * @note Some RDDs, such as those returned by groupBy(), do not guarantee order of * elements in a partition. The unique ID assigned to each element is therefore not guaranteed, * and may even change if the RDD is reevaluated. If a fixed ordering is required to guarantee * the same index assignments, you should sort the RDD with sortByKey() or save it to a file. @@ -1290,10 +1317,10 @@ abstract class RDD[T: ClassTag]( * results from that partition to estimate the number of additional partitions needed to satisfy * the limit. * - * @note this method should only be used if the resulting array is expected to be small, as + * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. * - * @note due to complications in the internal implementation, this method will raise + * @note Due to complications in the internal implementation, this method will raise * an exception if called on an RDD of `Nothing` or `Null`. */ def take(num: Int): Array[T] = withScope { @@ -1355,7 +1382,7 @@ abstract class RDD[T: ClassTag]( * // returns Array(6, 5) * }}} * - * @note this method should only be used if the resulting array is expected to be small, as + * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. * * @param num k, the number of top elements to return @@ -1378,7 +1405,7 @@ abstract class RDD[T: ClassTag]( * // returns Array(2, 3) * }}} * - * @note this method should only be used if the resulting array is expected to be small, as + * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. * * @param num k, the number of elements to return @@ -1423,7 +1450,7 @@ abstract class RDD[T: ClassTag]( } /** - * @note due to complications in the internal implementation, this method will raise an + * @note Due to complications in the internal implementation, this method will raise an * exception if called on an RDD of `Nothing` or `Null`. This may be come up in practice * because, for example, the type of `parallelize(Seq())` is `RDD[Nothing]`. * (`parallelize(Seq())` should be avoided anyway in favor of `parallelize(Seq[T]())`.) @@ -1719,7 +1746,7 @@ abstract class RDD[T: ClassTag]( /** * Clears the dependencies of this RDD. This method must ensure that all references - * to the original parent RDDs is removed to enable the parent RDDs to be garbage + * to the original parent RDDs are removed to enable the parent RDDs to be garbage * collected. Subclasses of RDD may override this method for implementing their own cleaning * logic. See [[org.apache.spark.rdd.UnionRDD]] for an example. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index 429514b4f6bee..6c552d4d12515 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -23,7 +23,8 @@ import org.apache.spark.Partition /** * Enumeration to manage state transitions of an RDD through checkpointing - * [ Initialized --> checkpointing in progress --> checkpointed ]. + * + * [ Initialized --{@literal >} checkpointing in progress --{@literal >} checkpointed ] */ private[spark] object CheckpointState extends Enumeration { type CheckpointState = Value @@ -32,7 +33,7 @@ private[spark] object CheckpointState extends Enumeration { /** * This class contains all the information related to RDD checkpointing. Each instance of this - * class is associated with a RDD. It manages process of checkpointing of the associated RDD, + * class is associated with an RDD. It manages process of checkpointing of the associated RDD, * as well as, manages the post-checkpoint state by providing the updated partitions, * iterator and preferred locations of the checkpointed RDD. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index eac901d10067c..7f399ecf81a08 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -151,7 +151,7 @@ private[spark] object ReliableCheckpointRDD extends Logging { } /** - * Write a RDD partition's data to a checkpoint file. + * Write an RDD partition's data to a checkpoint file. */ def writePartitionToCheckpointFile[T: ClassTag]( path: String, diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala index 1311b481c7c71..86a332790fb00 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala @@ -27,9 +27,10 @@ import org.apache.spark.internal.Logging /** * Extra functions available on RDDs of (key, value) pairs to create a Hadoop SequenceFile, - * through an implicit conversion. Note that this can't be part of PairRDDFunctions because - * we need more implicit parameters to convert our keys and values to Writable. + * through an implicit conversion. * + * @note This can't be part of PairRDDFunctions because we need more implicit parameters to + * convert our keys and values to Writable. */ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag]( self: RDD[(K, V)], diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala index b0e5ba0865c63..8425b211d6ecf 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -29,7 +29,7 @@ class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long) } /** - * Represents a RDD zipped with its element indices. The ordering is first based on the partition + * Represents an RDD zipped with its element indices. The ordering is first based on the partition * index and then the ordering of items within each partition. So the first item in the first * partition gets index 0, and the last item in the last partition receives the largest index. * diff --git a/core/src/main/scala/org/apache/spark/rdd/coalesce-public.scala b/core/src/main/scala/org/apache/spark/rdd/coalesce-public.scala index d8a80aa5aeb15..e00bc22aba44d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/coalesce-public.scala +++ b/core/src/main/scala/org/apache/spark/rdd/coalesce-public.scala @@ -35,14 +35,14 @@ trait PartitionCoalescer { * @param maxPartitions the maximum number of partitions to have after coalescing * @param parent the parent RDD whose partitions to coalesce * @return an array of [[PartitionGroup]]s, where each element is itself an array of - * [[Partition]]s and represents a partition after coalescing is performed. + * `Partition`s and represents a partition after coalescing is performed. */ def coalesce(maxPartitions: Int, parent: RDD[_]): Array[PartitionGroup] } /** * ::DeveloperApi:: - * A group of [[Partition]]s + * A group of `Partition`s * @param prefLoc preferred location for the partition group */ @DeveloperApi diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala index f527ec86ab7b2..117f51c5b8f2a 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala @@ -18,7 +18,7 @@ package org.apache.spark.rpc /** - * A callback that [[RpcEndpoint]] can use it to send back a message or failure. It's thread-safe + * A callback that [[RpcEndpoint]] can use to send back a message or failure. It's thread-safe * and can be called in any thread. */ private[spark] trait RpcCallContext { diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 579122868afc8..530743c03640b 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -146,7 +146,6 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { * @param uri URI with location of the file. */ def openChannel(uri: String): ReadableByteChannel - } /** diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index e51649a1ecce9..e56943da1303a 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -407,11 +407,9 @@ private[netty] class NettyRpcEnv( } } - } private[netty] object NettyRpcEnv extends Logging { - /** * When deserializing the [[NettyRpcEndpointRef]], it needs a reference to [[NettyRpcEnv]]. * Use `currentEnv` to wrap the deserialization codes. E.g., diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala index 99f20da2d66aa..430dcc50ba711 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala @@ -20,7 +20,7 @@ package org.apache.spark.rpc.netty import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv} /** - * An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if an [[RpcEndpoint]] exists. + * An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if an `RpcEndpoint` exists. * * This is used when setting up a remote endpoint reference. */ @@ -35,6 +35,6 @@ private[netty] class RpcEndpointVerifier(override val rpcEnv: RpcEnv, dispatcher private[netty] object RpcEndpointVerifier { val NAME = "endpoint-verifier" - /** A message used to ask the remote [[RpcEndpointVerifier]] if an [[RpcEndpoint]] exists. */ + /** A message used to ask the remote [[RpcEndpointVerifier]] if an `RpcEndpoint` exists. */ case class CheckExistence(name: String) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala index cedacad44afec..0a5fe5a1d3ee1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -24,11 +24,6 @@ import org.apache.spark.annotation.DeveloperApi * :: DeveloperApi :: * Information about an [[org.apache.spark.Accumulable]] modified during a task or stage. * - * Note: once this is JSON serialized the types of `update` and `value` will be lost and be - * cast to strings. This is because the user can define an accumulator of any type and it will - * be difficult to preserve the type in consumers of the event log. This does not apply to - * internal accumulators that represent task level metrics. - * * @param id accumulator ID * @param name accumulator name * @param update partial value from a task, may be None if used on driver to describe a stage @@ -36,6 +31,11 @@ import org.apache.spark.annotation.DeveloperApi * @param internal whether this accumulator was internal * @param countFailedValues whether to count this accumulator's partial value if the task failed * @param metadata internal metadata associated with this accumulator, if any + * + * @note Once this is JSON serialized the types of `update` and `value` will be lost and be + * cast to strings. This is because the user can define an accumulator of any type and it will + * be difficult to preserve the type in consumers of the event log. This does not apply to + * internal accumulators that represent task level metrics. */ @DeveloperApi case class AccumulableInfo private[spark] ( diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index f2517401cb76b..01a95c06fc69c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1660,7 +1660,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler } catch { case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t) } - dagScheduler.sc.stop() + dagScheduler.sc.stopInNewThread() } override def onStop(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala index a6b032cc0084c..66ab9a52b7781 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala @@ -153,7 +153,7 @@ object InputFormatInfo { a) For each host, count number of splits hosted on that host. b) Decrement the currently allocated containers on that host. - c) Compute rack info for each host and update rack -> count map based on (b). + c) Compute rack info for each host and update rack to count map based on (b). d) Allocate nodes based on (c) e) On the allocation result, ensure that we don't allocate "too many" jobs on a single node (even if data locality on that is very high) : this is to prevent fragility of job if a diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index 0bd5a6bc59a9e..08e05ae0c095b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -22,6 +22,7 @@ import java.io.{InputStream, IOException} import scala.io.Source import com.fasterxml.jackson.core.JsonParseException +import com.fasterxml.jackson.databind.exc.UnrecognizedPropertyException import org.json4s.jackson.JsonMethods._ import org.apache.spark.internal.Logging @@ -87,6 +88,12 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { // Ignore events generated by Structured Streaming in Spark 2.0.0 and 2.0.1. // It's safe since no place uses them. logWarning(s"Dropped incompatible Structured Streaming log: $currentLine") + case e: UnrecognizedPropertyException if e.getMessage != null && e.getMessage.startsWith( + "Unrecognized field \"queryStatus\" " + + "(class org.apache.spark.sql.streaming.StreamingQueryListener$") => + // Ignore events generated by Structured Streaming in Spark 2.0.2 + // It's safe since no place uses them. + logWarning(s"Dropped incompatible Structured Streaming log: $currentLine") case jpe: JsonParseException => // We can only ignore exception from last line of the file that might be truncated // the last entry may not be the very last line in the event log, but we treat it diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 1e7c63af2e797..d19353f2a9930 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -42,7 +42,7 @@ import org.apache.spark.rdd.RDD * @param outputId index of the task in this job (a job can launch tasks on only a subset of the * input RDD's partitions). * @param localProperties copy of thread-local properties set by the user on the driver side. - * @param metrics a [[TaskMetrics]] that is created at driver side and sent to executor side. + * @param metrics a `TaskMetrics` that is created at driver side and sent to executor side. * * The parameters below are optional: * @param jobId id of the job this task belongs to diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 66d6790e168f2..31011de85bf7e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -42,7 +42,7 @@ import org.apache.spark.shuffle.ShuffleWriter * the type should be (RDD[_], ShuffleDependency[_, _, _]). * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling - * @param metrics a [[TaskMetrics]] that is created at driver side and sent to executor side. + * @param metrics a `TaskMetrics` that is created at driver side and sent to executor side. * @param localProperties copy of thread-local properties set by the user on the driver side. * * The parameters below are optional: diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 9385e3c31e1e4..112b08f2c03a9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -45,7 +45,7 @@ import org.apache.spark.util._ * @param stageId id of the stage this task belongs to * @param stageAttemptId attempt id of the stage this task belongs to * @param partitionId index of the number in the RDD - * @param metrics a [[TaskMetrics]] that is created at driver side and sent to executor side. + * @param metrics a `TaskMetrics` that is created at driver side and sent to executor side. * @param localProperties copy of thread-local properties set by the user on the driver side. * * The parameters below are optional: diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala index 1c7c81c488c3a..45c742cbff5e7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala @@ -23,7 +23,7 @@ import org.apache.spark.util.SerializableBuffer /** * Description of a task that gets passed onto executors to be executed, usually created by - * [[TaskSetManager.resourceOffer]]. + * `TaskSetManager.resourceOffer`. */ private[spark] class TaskDescription( val taskId: Long, diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 3e3f1ad031e66..b03cfe4f0dc49 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -93,10 +93,12 @@ private[spark] class TaskSchedulerImpl( // Incrementing task IDs val nextTaskId = new AtomicLong(0) - // Number of tasks running on each executor - private val executorIdToTaskCount = new HashMap[String, Int] + // IDs of the tasks running on each executor + private val executorIdToRunningTaskIds = new HashMap[String, HashSet[Long]] - def runningTasksByExecutors(): Map[String, Int] = executorIdToTaskCount.toMap + def runningTasksByExecutors: Map[String, Int] = synchronized { + executorIdToRunningTaskIds.toMap.mapValues(_.size) + } // The set of executors we have on each host; this is used to compute hostsAlive, which // in turn is used to decide when we can attain data locality on a given host @@ -264,7 +266,7 @@ private[spark] class TaskSchedulerImpl( val tid = task.taskId taskIdToTaskSetManager(tid) = taskSet taskIdToExecutorId(tid) = execId - executorIdToTaskCount(execId) += 1 + executorIdToRunningTaskIds(execId).add(tid) availableCpus(i) -= CPUS_PER_TASK assert(availableCpus(i) >= 0) launchedTask = true @@ -294,11 +296,11 @@ private[spark] class TaskSchedulerImpl( if (!hostToExecutors.contains(o.host)) { hostToExecutors(o.host) = new HashSet[String]() } - if (!executorIdToTaskCount.contains(o.executorId)) { + if (!executorIdToRunningTaskIds.contains(o.executorId)) { hostToExecutors(o.host) += o.executorId executorAdded(o.executorId, o.host) executorIdToHost(o.executorId) = o.host - executorIdToTaskCount(o.executorId) = 0 + executorIdToRunningTaskIds(o.executorId) = HashSet[Long]() newExecAvail = true } for (rack <- getRackForHost(o.host)) { @@ -349,38 +351,34 @@ private[spark] class TaskSchedulerImpl( var reason: Option[ExecutorLossReason] = None synchronized { try { - if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) { - // We lost this entire executor, so remember that it's gone - val execId = taskIdToExecutorId(tid) - - if (executorIdToTaskCount.contains(execId)) { - reason = Some( - SlaveLost(s"Task $tid was lost, so marking the executor as lost as well.")) - removeExecutor(execId, reason.get) - failedExecutor = Some(execId) - } - } taskIdToTaskSetManager.get(tid) match { case Some(taskSet) => - if (TaskState.isFinished(state)) { - taskIdToTaskSetManager.remove(tid) - taskIdToExecutorId.remove(tid).foreach { execId => - if (executorIdToTaskCount.contains(execId)) { - executorIdToTaskCount(execId) -= 1 - } + if (state == TaskState.LOST) { + // TaskState.LOST is only used by the deprecated Mesos fine-grained scheduling mode, + // where each executor corresponds to a single task, so mark the executor as failed. + val execId = taskIdToExecutorId.getOrElse(tid, throw new IllegalStateException( + "taskIdToTaskSetManager.contains(tid) <=> taskIdToExecutorId.contains(tid)")) + if (executorIdToRunningTaskIds.contains(execId)) { + reason = Some( + SlaveLost(s"Task $tid was lost, so marking the executor as lost as well.")) + removeExecutor(execId, reason.get) + failedExecutor = Some(execId) } } - if (state == TaskState.FINISHED) { - taskSet.removeRunningTask(tid) - taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData) - } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) { + if (TaskState.isFinished(state)) { + cleanupTaskState(tid) taskSet.removeRunningTask(tid) - taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData) + if (state == TaskState.FINISHED) { + taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData) + } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) { + taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData) + } } case None => logError( ("Ignoring update with state %s for TID %s because its task set is gone (this is " + - "likely the result of receiving duplicate task finished status updates)") + "likely the result of receiving duplicate task finished status updates) or its " + + "executor has been marked as failed.") .format(state, tid)) } } catch { @@ -491,7 +489,7 @@ private[spark] class TaskSchedulerImpl( var failedExecutor: Option[String] = None synchronized { - if (executorIdToTaskCount.contains(executorId)) { + if (executorIdToRunningTaskIds.contains(executorId)) { val hostPort = executorIdToHost(executorId) logExecutorLoss(executorId, hostPort, reason) removeExecutor(executorId, reason) @@ -533,13 +531,31 @@ private[spark] class TaskSchedulerImpl( logError(s"Lost executor $executorId on $hostPort: $reason") } + /** + * Cleans up the TaskScheduler's state for tracking the given task. + */ + private def cleanupTaskState(tid: Long): Unit = { + taskIdToTaskSetManager.remove(tid) + taskIdToExecutorId.remove(tid).foreach { executorId => + executorIdToRunningTaskIds.get(executorId).foreach { _.remove(tid) } + } + } + /** * Remove an executor from all our data structures and mark it as lost. If the executor's loss * reason is not yet known, do not yet remove its association with its host nor update the status * of any running tasks, since the loss reason defines whether we'll fail those tasks. */ private def removeExecutor(executorId: String, reason: ExecutorLossReason) { - executorIdToTaskCount -= executorId + // The tasks on the lost executor may not send any more status updates (because the executor + // has been lost), so they should be cleaned up here. + executorIdToRunningTaskIds.remove(executorId).foreach { taskIds => + logDebug("Cleaning up TaskScheduler state for tasks " + + s"${taskIds.mkString("[", ",", "]")} on failed executor $executorId") + // We do not notify the TaskSetManager of the task failures because that will + // happen below in the rootPool.executorLost() call. + taskIds.foreach(cleanupTaskState) + } val host = executorIdToHost(executorId) val execs = hostToExecutors.getOrElse(host, new HashSet) @@ -577,11 +593,11 @@ private[spark] class TaskSchedulerImpl( } def isExecutorAlive(execId: String): Boolean = synchronized { - executorIdToTaskCount.contains(execId) + executorIdToRunningTaskIds.contains(execId) } def isExecutorBusy(execId: String): Boolean = synchronized { - executorIdToTaskCount.getOrElse(execId, -1) > 0 + executorIdToRunningTaskIds.get(execId).exists(_.nonEmpty) } // By default, rack is unknown diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index b766e4148e496..30df8862c3589 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -221,7 +221,7 @@ private[spark] class TaskSetManager( * Return the pending tasks list for a given host, or an empty list if * there is no map entry for that host */ - private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = { + protected def getPendingTasksForHost(host: String): ArrayBuffer[Int] = { pendingTasksForHost.getOrElse(host, ArrayBuffer()) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index edc8aac5d1515..2406999f9ee92 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -28,7 +28,12 @@ private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable private[spark] object CoarseGrainedClusterMessages { - case object RetrieveSparkProps extends CoarseGrainedClusterMessage + case class RetrieveSparkAppConfig(executorId: String) extends CoarseGrainedClusterMessage + + case class SparkAppConfig( + sparkProperties: Seq[(String, String)], + ioEncryptionKey: Option[Array[Byte]]) + extends CoarseGrainedClusterMessage case object RetrieveLastAllocatedExecutorId extends CoarseGrainedClusterMessage diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 10d55c87fb8de..89e59353de845 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -206,8 +206,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp removeExecutor(executorId, reason) context.reply(true) - case RetrieveSparkProps => - context.reply(sparkProperties) + case RetrieveSparkAppConfig(executorId) => + val reply = SparkAppConfig(sparkProperties, + SparkEnv.get.securityManager.getIOEncryptionKey()) + context.reply(reply) } // Make fake resource offers on all executors diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 04d40e2907cff..4a9af80f4537b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -139,7 +139,7 @@ private[spark] class StandaloneSchedulerBackend( scheduler.error(reason) } finally { // Ensure the application terminates, as we can no longer run jobs. - sc.stop() + sc.stopInNewThread() } } } diff --git a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala index 8f15f50bee814..8e3436f13480d 100644 --- a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala +++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala @@ -18,14 +18,13 @@ package org.apache.spark.security import java.io.{InputStream, OutputStream} import java.util.Properties +import javax.crypto.KeyGenerator import javax.crypto.spec.{IvParameterSpec, SecretKeySpec} import org.apache.commons.crypto.random._ import org.apache.commons.crypto.stream._ -import org.apache.hadoop.io.Text import org.apache.spark.SparkConf -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ @@ -33,10 +32,6 @@ import org.apache.spark.internal.config._ * A util class for manipulating IO encryption and decryption streams. */ private[spark] object CryptoStreamUtils extends Logging { - /** - * Constants and variables for spark IO encryption - */ - val SPARK_IO_TOKEN = new Text("SPARK_IO_TOKEN") // The initialization vector length in bytes. val IV_LENGTH_IN_BYTES = 16 @@ -46,32 +41,30 @@ private[spark] object CryptoStreamUtils extends Logging { val COMMONS_CRYPTO_CONF_PREFIX = "commons.crypto." /** - * Helper method to wrap [[OutputStream]] with [[CryptoOutputStream]] for encryption. + * Helper method to wrap `OutputStream` with `CryptoOutputStream` for encryption. */ def createCryptoOutputStream( os: OutputStream, - sparkConf: SparkConf): OutputStream = { + sparkConf: SparkConf, + key: Array[Byte]): OutputStream = { val properties = toCryptoConf(sparkConf) val iv = createInitializationVector(properties) os.write(iv) - val credentials = SparkHadoopUtil.get.getCurrentUserCredentials() - val key = credentials.getSecretKey(SPARK_IO_TOKEN) val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) new CryptoOutputStream(transformationStr, properties, os, new SecretKeySpec(key, "AES"), new IvParameterSpec(iv)) } /** - * Helper method to wrap [[InputStream]] with [[CryptoInputStream]] for decryption. + * Helper method to wrap `InputStream` with `CryptoInputStream` for decryption. */ def createCryptoInputStream( is: InputStream, - sparkConf: SparkConf): InputStream = { + sparkConf: SparkConf, + key: Array[Byte]): InputStream = { val properties = toCryptoConf(sparkConf) val iv = new Array[Byte](IV_LENGTH_IN_BYTES) is.read(iv, 0, iv.length) - val credentials = SparkHadoopUtil.get.getCurrentUserCredentials() - val key = credentials.getSecretKey(SPARK_IO_TOKEN) val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) new CryptoInputStream(transformationStr, properties, is, new SecretKeySpec(key, "AES"), new IvParameterSpec(iv)) @@ -91,6 +84,17 @@ private[spark] object CryptoStreamUtils extends Logging { props } + /** + * Creates a new encryption key. + */ + def createKey(conf: SparkConf): Array[Byte] = { + val keyLen = conf.get(IO_ENCRYPTION_KEY_SIZE_BITS) + val ioKeyGenAlgorithm = conf.get(IO_ENCRYPTION_KEYGEN_ALGORITHM) + val keyGen = KeyGenerator.getInstance(ioKeyGenAlgorithm) + keyGen.init(keyLen) + keyGen.generateKey().getEncoded() + } + /** * This method to generate an IV (Initialization Vector) using secure random. */ diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 8b72da2ee01b7..f60dcfddfdc20 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -131,7 +131,7 @@ private[spark] class JavaSerializerInstance( * :: DeveloperApi :: * A Spark serializer that uses Java's built-in serialization. * - * Note that this serializer is not guaranteed to be wire-compatible across different versions of + * @note This serializer is not guaranteed to be wire-compatible across different versions of * Spark. It is intended to be used to serialize/de-serialize data within a single * Spark application. */ diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 0d26281fe1076..7eb2da1c2748c 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -43,9 +43,10 @@ import org.apache.spark.util.{BoundedPriorityQueue, SerializableConfiguration, S import org.apache.spark.util.collection.CompactBuffer /** - * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. + * A Spark serializer that uses the + * Kryo serialization library. * - * Note that this serializer is not guaranteed to be wire-compatible across different versions of + * @note This serializer is not guaranteed to be wire-compatible across different versions of * Spark. It is intended to be used to serialize/de-serialize data within a single * Spark application. */ diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index cb95246d5b0ca..afe6cd86059f0 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -40,7 +40,7 @@ import org.apache.spark.util.NextIterator * * 2. Java serialization interface. * - * Note that serializers are not required to be wire-compatible across different versions of Spark. + * @note Serializers are not required to be wire-compatible across different versions of Spark. * They are intended to be used to serialize/de-serialize data within a single Spark application. */ @DeveloperApi diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 2156d576f1874..686305e9335dc 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -33,7 +33,12 @@ import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStrea * Component which configures serialization, compression and encryption for various Spark * components, including automatic selection of which [[Serializer]] to use for shuffles. */ -private[spark] class SerializerManager(defaultSerializer: Serializer, conf: SparkConf) { +private[spark] class SerializerManager( + defaultSerializer: Serializer, + conf: SparkConf, + encryptionKey: Option[Array[Byte]]) { + + def this(defaultSerializer: Serializer, conf: SparkConf) = this(defaultSerializer, conf, None) private[this] val kryoSerializer = new KryoSerializer(conf) @@ -63,9 +68,6 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar // Whether to compress shuffle output temporarily spilled to disk private[this] val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) - // Whether to enable IO encryption - private[this] val enableIOEncryption = conf.get(IO_ENCRYPTION_ENABLED) - /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay * the initialization of the compression codec until it is first used. The reason is that a Spark * program could be using a user-defined codec in a third party jar, which is loaded in @@ -73,12 +75,17 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar * loaded yet. */ private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf) + def encryptionEnabled: Boolean = encryptionKey.isDefined + def canUseKryo(ct: ClassTag[_]): Boolean = { primitiveAndPrimitiveArrayClassTags.contains(ct) || ct == stringClassTag } - def getSerializer(ct: ClassTag[_]): Serializer = { - if (canUseKryo(ct)) { + // SPARK-18617: As feature in SPARK-13990 can not be applied to Spark Streaming now. The worst + // result is streaming job based on `Receiver` mode can not run on Spark 2.x properly. It may be + // a rational choice to close `kryo auto pick` feature for streaming in the first step. + def getSerializer(ct: ClassTag[_], autoPick: Boolean): Serializer = { + if (autoPick && canUseKryo(ct)) { kryoSerializer } else { defaultSerializer @@ -124,15 +131,19 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar /** * Wrap an input stream for encryption if shuffle encryption is enabled */ - private[this] def wrapForEncryption(s: InputStream): InputStream = { - if (enableIOEncryption) CryptoStreamUtils.createCryptoInputStream(s, conf) else s + def wrapForEncryption(s: InputStream): InputStream = { + encryptionKey + .map { key => CryptoStreamUtils.createCryptoInputStream(s, conf, key) } + .getOrElse(s) } /** * Wrap an output stream for encryption if shuffle encryption is enabled */ - private[this] def wrapForEncryption(s: OutputStream): OutputStream = { - if (enableIOEncryption) CryptoStreamUtils.createCryptoOutputStream(s, conf) else s + def wrapForEncryption(s: OutputStream): OutputStream = { + encryptionKey + .map { key => CryptoStreamUtils.createCryptoOutputStream(s, conf, key) } + .getOrElse(s) } /** @@ -155,7 +166,8 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar outputStream: OutputStream, values: Iterator[T]): Unit = { val byteStream = new BufferedOutputStream(outputStream) - val ser = getSerializer(implicitly[ClassTag[T]]).newInstance() + val autoPick = !blockId.isInstanceOf[StreamBlockId] + val ser = getSerializer(implicitly[ClassTag[T]], autoPick).newInstance() ser.serializeStream(wrapStream(blockId, byteStream)).writeAll(values).close() } @@ -171,7 +183,8 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar classTag: ClassTag[_]): ChunkedByteBuffer = { val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate) val byteStream = new BufferedOutputStream(bbos) - val ser = getSerializer(classTag).newInstance() + val autoPick = !blockId.isInstanceOf[StreamBlockId] + val ser = getSerializer(classTag, autoPick).newInstance() ser.serializeStream(wrapStream(blockId, byteStream)).writeAll(values).close() bbos.toChunkedByteBuffer } @@ -185,7 +198,8 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar inputStream: InputStream) (classTag: ClassTag[T]): Iterator[T] = { val stream = new BufferedInputStream(inputStream) - getSerializer(classTag) + val autoPick = !blockId.isInstanceOf[StreamBlockId] + getSerializer(classTag, autoPick) .newInstance() .deserializeStream(wrapStream(blockId, stream)) .asIterator.asInstanceOf[Iterator[T]] diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala index f6a9f9c5573db..76af33c1a18db 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala @@ -21,7 +21,7 @@ import java.lang.annotation.Annotation import java.lang.reflect.Type import java.nio.charset.StandardCharsets import java.text.SimpleDateFormat -import java.util.{Calendar, SimpleTimeZone} +import java.util.{Calendar, Locale, SimpleTimeZone} import javax.ws.rs.Produces import javax.ws.rs.core.{MediaType, MultivaluedMap} import javax.ws.rs.ext.{MessageBodyWriter, Provider} @@ -86,7 +86,7 @@ private[v1] class JacksonMessageWriter extends MessageBodyWriter[Object]{ private[spark] object JacksonMessageWriter { def makeISODateFormat: SimpleDateFormat = { - val iso8601 = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'GMT'") + val iso8601 = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'GMT'", Locale.US) val cal = Calendar.getInstance(new SimpleTimeZone(0, "GMT")) iso8601.setCalendar(cal) iso8601 diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala index 0c71cd2382225..d8d5e8958b23c 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala @@ -17,7 +17,7 @@ package org.apache.spark.status.api.v1 import java.text.{ParseException, SimpleDateFormat} -import java.util.TimeZone +import java.util.{Locale, TimeZone} import javax.ws.rs.WebApplicationException import javax.ws.rs.core.Response import javax.ws.rs.core.Response.Status @@ -25,12 +25,12 @@ import javax.ws.rs.core.Response.Status private[v1] class SimpleDateParam(val originalValue: String) { val timestamp: Long = { - val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz") + val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz", Locale.US) try { format.parse(originalValue).getTime() } catch { case _: ParseException => - val gmtDay = new SimpleDateFormat("yyyy-MM-dd") + val gmtDay = new SimpleDateFormat("yyyy-MM-dd", Locale.US) gmtDay.setTimeZone(TimeZone.getTimeZone("GMT")) try { gmtDay.parse(originalValue).getTime() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 982b83324e0fc..18f7d135acdd2 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -62,7 +62,7 @@ private[spark] class BlockManager( executorId: String, rpcEnv: RpcEnv, val master: BlockManagerMaster, - serializerManager: SerializerManager, + val serializerManager: SerializerManager, val conf: SparkConf, memoryManager: MemoryManager, mapOutputTracker: MapOutputTracker, @@ -184,8 +184,14 @@ private[spark] class BlockManager( blockManagerId = if (idFromMaster != null) idFromMaster else id shuffleServerId = if (externalShuffleServiceEnabled) { - logInfo(s"external shuffle service port = $externalShuffleServicePort") - BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort) + val shuffleServerHostName = if (blockManagerId.isDriver) { + blockTransferService.hostName + } else { + conf.get("spark.shuffle.service.host", blockTransferService.hostName) + } + logInfo(s"external shuffle service host = $shuffleServerHostName, " + + s"port = $externalShuffleServicePort") + BlockManagerId(executorId, shuffleServerHostName, externalShuffleServicePort) } else { blockManagerId } @@ -745,9 +751,8 @@ private[spark] class BlockManager( serializerInstance: SerializerInstance, bufferSize: Int, writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = { - val wrapStream: OutputStream => OutputStream = serializerManager.wrapStream(blockId, _) val syncWrites = conf.getBoolean("spark.shuffle.sync", false) - new DiskBlockObjectWriter(file, serializerInstance, bufferSize, wrapStream, + new DiskBlockObjectWriter(file, serializerManager, serializerInstance, bufferSize, syncWrites, writeMetrics, blockId) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 6bded92700504..d71acbb4cf771 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -43,7 +43,7 @@ private[spark] object BlockManagerMessages { extends ToBlockManagerSlave /** - * Driver -> Executor message to trigger a thread dump. + * Driver to Executor message to trigger a thread dump. */ case object TriggerThreadDump extends ToBlockManagerSlave diff --git a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala index bf087af16a5b1..bb8a684b4c7a8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala @@ -89,17 +89,18 @@ class RandomBlockReplicationPolicy prioritizedPeers } + // scalastyle:off line.size.limit /** * Uses sampling algorithm by Robert Floyd. Finds a random sample in O(n) while - * minimizing space usage - * [[http://math.stackexchange.com/questions/178690/ - * whats-the-proof-of-correctness-for-robert-floyds-algorithm-for-selecting-a-sin]] + * minimizing space usage. Please see + * here. * * @param n total number of indices * @param m number of samples needed * @param r random number generator * @return list of m random unique indices */ + // scalastyle:on line.size.limit private def getSampleIds(n: Int, m: Int, r: Random): List[Int] = { val indices = (n - m + 1 to n).foldLeft(Set.empty[Int]) {case (set, i) => val t = r.nextInt(i) + 1 diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index a499827ae1598..3cb12fca7dccb 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -22,7 +22,7 @@ import java.nio.channels.FileChannel import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.Logging -import org.apache.spark.serializer.{SerializationStream, SerializerInstance} +import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} import org.apache.spark.util.Utils /** @@ -37,9 +37,9 @@ import org.apache.spark.util.Utils */ private[spark] class DiskBlockObjectWriter( val file: File, + serializerManager: SerializerManager, serializerInstance: SerializerInstance, bufferSize: Int, - wrapStream: OutputStream => OutputStream, syncWrites: Boolean, // These write metrics concurrently shared with other active DiskBlockObjectWriters who // are themselves performing writes. All updates must be relative. @@ -116,7 +116,7 @@ private[spark] class DiskBlockObjectWriter( initialized = true } - bs = wrapStream(mcs) + bs = serializerManager.wrapStream(blockId, mcs) objOut = serializerInstance.serializeStream(bs) streamOpen = true this diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 4dc2f362329a0..269c12d6da444 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -247,7 +247,7 @@ final class ShuffleBlockFetcherIterator( /** * Fetch the local blocks while we are fetching remote blocks. This is ok because - * [[ManagedBuffer]]'s memory is allocated lazily when we create the input stream, so all we + * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we * track in-memory are the ManagedBuffer references themselves. */ private[this] def fetchLocalBlocks() { @@ -423,7 +423,7 @@ object ShuffleBlockFetcherIterator { * @param address BlockManager that the block was fetched from. * @param size estimated size of the block, used to calculate bytesInFlight. * Note that this is NOT the exact bytes. - * @param buf [[ManagedBuffer]] for the content. + * @param buf `ManagedBuffer` for the content. * @param isNetworkReqDone Is this the last network request for this host in this fetch request. */ private[storage] case class SuccessFetchResult( diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index fb9941bbd9e0f..e12f2e6095d5a 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -71,7 +71,7 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return the blocks stored in this block manager. * - * Note that this is somewhat expensive, as it involves cloning the underlying maps and then + * @note This is somewhat expensive, as it involves cloning the underlying maps and then * concatenating them together. Much faster alternatives exist for common operations such as * contains, get, and size. */ @@ -80,7 +80,7 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return the RDD blocks stored in this block manager. * - * Note that this is somewhat expensive, as it involves cloning the underlying maps and then + * @note This is somewhat expensive, as it involves cloning the underlying maps and then * concatenating them together. Much faster alternatives exist for common operations such as * getting the memory, disk, and off-heap memory sizes occupied by this RDD. */ @@ -128,7 +128,8 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return whether the given block is stored in this block manager in O(1) time. - * Note that this is much faster than `this.blocks.contains`, which is O(blocks) time. + * + * @note This is much faster than `this.blocks.contains`, which is O(blocks) time. */ def containsBlock(blockId: BlockId): Boolean = { blockId match { @@ -141,7 +142,8 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return the given block stored in this block manager in O(1) time. - * Note that this is much faster than `this.blocks.get`, which is O(blocks) time. + * + * @note This is much faster than `this.blocks.get`, which is O(blocks) time. */ def getBlock(blockId: BlockId): Option[BlockStatus] = { blockId match { @@ -154,19 +156,22 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return the number of blocks stored in this block manager in O(RDDs) time. - * Note that this is much faster than `this.blocks.size`, which is O(blocks) time. + * + * @note This is much faster than `this.blocks.size`, which is O(blocks) time. */ def numBlocks: Int = _nonRddBlocks.size + numRddBlocks /** * Return the number of RDD blocks stored in this block manager in O(RDDs) time. - * Note that this is much faster than `this.rddBlocks.size`, which is O(RDD blocks) time. + * + * @note This is much faster than `this.rddBlocks.size`, which is O(RDD blocks) time. */ def numRddBlocks: Int = _rddBlocks.values.map(_.size).sum /** * Return the number of blocks that belong to the given RDD in O(1) time. - * Note that this is much faster than `this.rddBlocksById(rddId).size`, which is + * + * @note This is much faster than `this.rddBlocksById(rddId).size`, which is * O(blocks in this RDD) time. */ def numRddBlocksById(rddId: Int): Int = _rddBlocks.get(rddId).map(_.size).getOrElse(0) diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 095d32407f345..fff21218b1769 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -31,7 +31,7 @@ import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.serializer.{SerializationStream, SerializerManager} -import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel} +import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel, StreamBlockId} import org.apache.spark.unsafe.Platform import org.apache.spark.util.{SizeEstimator, Utils} import org.apache.spark.util.collection.SizeTrackingVector @@ -334,7 +334,8 @@ private[spark] class MemoryStore( val bbos = new ChunkedByteBufferOutputStream(initialMemoryThreshold.toInt, allocator) redirectableStream.setOutputStream(bbos) val serializationStream: SerializationStream = { - val ser = serializerManager.getSerializer(classTag).newInstance() + val autoPick = !blockId.isInstanceOf[StreamBlockId] + val ser = serializerManager.getSerializer(classTag, autoPick).newInstance() ser.serializeStream(serializerManager.wrapStream(blockId, redirectableStream)) } diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index c0d1a2220f62a..d161843dd2230 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -36,7 +36,8 @@ private[spark] object UIUtils extends Logging { // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. private val dateFormat = new ThreadLocal[SimpleDateFormat]() { - override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + override def initialValue(): SimpleDateFormat = + new SimpleDateFormat("yyyy/MM/dd HH:mm:ss", Locale.US) } def formatDate(date: Date): String = dateFormat.get.format(date) @@ -170,6 +171,7 @@ private[spark] object UIUtils extends Logging { + } def vizHeaderNodes: Seq[Node] = { @@ -420,8 +422,8 @@ private[spark] object UIUtils extends Logging { * the whole string will rendered as a simple escaped text. * * Note: In terms of security, only anchor tags with root relative links are supported. So any - * attempts to embed links outside Spark UI, or other tags like