Skip to content

Commit b3fb144

Browse files
authored
Merge pull request apache-spark-on-k8s#368 from palantir/pw/upstream
update
2 parents 17f07ed + acb6c4f commit b3fb144

File tree

46 files changed

+1310
-421
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+1310
-421
lines changed

common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import io.netty.channel.ChannelOption;
3333
import io.netty.channel.EventLoopGroup;
3434
import io.netty.channel.socket.SocketChannel;
35+
import org.apache.commons.lang3.SystemUtils;
3536
import org.slf4j.Logger;
3637
import org.slf4j.LoggerFactory;
3738

@@ -98,7 +99,8 @@ private void init(String hostToBind, int portToBind) {
9899
.group(bossGroup, workerGroup)
99100
.channel(NettyUtils.getServerChannelClass(ioMode))
100101
.option(ChannelOption.ALLOCATOR, allocator)
101-
.childOption(ChannelOption.ALLOCATOR, allocator);
102+
.childOption(ChannelOption.ALLOCATOR, allocator)
103+
.childOption(ChannelOption.SO_REUSEADDR, !SystemUtils.IS_OS_WINDOWS);
102104

103105
this.metrics = new NettyMemoryMetrics(
104106
allocator, conf.getModuleName() + "-server", conf);

common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@ public static long nextPowerOf2(long num) {
3333
}
3434

3535
public static int roundNumberOfBytesToNearestWord(int numBytes) {
36-
int remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8`
36+
return (int)roundNumberOfBytesToNearestWord((long)numBytes);
37+
}
38+
39+
public static long roundNumberOfBytesToNearestWord(long numBytes) {
40+
long remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8`
3741
if (remainder == 0) {
3842
return numBytes;
3943
} else {

common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ private void check(MemoryBlock memory, Object obj, long offset, int length) {
120120
} catch (Exception expected) {
121121
Assert.assertThat(expected.getMessage(), containsString("should not be larger than"));
122122
}
123+
124+
memory.setPageNumber(MemoryBlock.NO_PAGE_NUMBER);
123125
}
124126

125127
@Test
@@ -165,11 +167,13 @@ public void testOffHeapArrayMemoryBlock() {
165167
int length = 56;
166168

167169
check(memory, obj, offset, length);
170+
memoryAllocator.free(memory);
168171

169172
long address = Platform.allocateMemory(112);
170173
memory = new OffHeapMemoryBlock(address, length);
171174
obj = memory.getBaseObject();
172175
offset = memory.getBaseOffset();
173176
check(memory, obj, offset, length);
177+
Platform.freeMemory(address);
174178
}
175179
}

core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ private[spark] abstract class RestSubmissionServer(
108108
resolvedConnectionFactories: _*)
109109
connector.setHost(host)
110110
connector.setPort(startPort)
111+
connector.setReuseAddress(!Utils.isWindows)
111112
server.addConnector(connector)
112113

113114
val mainHandler = new ServletContextHandler

core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ private[spark] class TaskSetManager(
287287
None
288288
}
289289

290-
/** Check whether a task is currently running an attempt on a given host */
290+
/** Check whether a task once ran an attempt on a given host */
291291
private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = {
292292
taskAttempts(taskIndex).exists(_.host == host)
293293
}

core/src/main/scala/org/apache/spark/ui/JettyUtils.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ private[spark] object JettyUtils extends Logging {
344344
connectionFactories: _*)
345345
connector.setPort(port)
346346
connector.setHost(hostName)
347+
connector.setReuseAddress(!Utils.isWindows)
347348

348349
// Currently we only use "SelectChannelConnector"
349350
// Limit the max acceptor number to 8 so that we don't waste a lot of threads

core/src/test/scala/org/apache/spark/SparkContextSuite.scala

Lines changed: 26 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark
2020
import java.io.File
2121
import java.net.{MalformedURLException, URI}
2222
import java.nio.charset.StandardCharsets
23-
import java.util.concurrent.{Semaphore, TimeUnit}
23+
import java.util.concurrent.{CountDownLatch, Semaphore, TimeUnit}
2424

2525
import scala.concurrent.duration._
2626

@@ -498,45 +498,36 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
498498

499499
test("Cancelling stages/jobs with custom reasons.") {
500500
sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local"))
501+
sc.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "true")
501502
val REASON = "You shall not pass"
502-
val slices = 10
503503

504-
val listener = new SparkListener {
505-
override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
506-
if (SparkContextSuite.cancelStage) {
507-
eventually(timeout(10.seconds)) {
508-
assert(SparkContextSuite.isTaskStarted)
504+
for (cancelWhat <- Seq("stage", "job")) {
505+
// This countdown latch used to make sure stage or job canceled in listener
506+
val latch = new CountDownLatch(1)
507+
508+
val listener = cancelWhat match {
509+
case "stage" =>
510+
new SparkListener {
511+
override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
512+
sc.cancelStage(taskStart.stageId, REASON)
513+
latch.countDown()
514+
}
509515
}
510-
sc.cancelStage(taskStart.stageId, REASON)
511-
SparkContextSuite.cancelStage = false
512-
SparkContextSuite.semaphore.release(slices)
513-
}
514-
}
515-
516-
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
517-
if (SparkContextSuite.cancelJob) {
518-
eventually(timeout(10.seconds)) {
519-
assert(SparkContextSuite.isTaskStarted)
516+
case "job" =>
517+
new SparkListener {
518+
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
519+
sc.cancelJob(jobStart.jobId, REASON)
520+
latch.countDown()
521+
}
520522
}
521-
sc.cancelJob(jobStart.jobId, REASON)
522-
SparkContextSuite.cancelJob = false
523-
SparkContextSuite.semaphore.release(slices)
524-
}
525523
}
526-
}
527-
sc.addSparkListener(listener)
528-
529-
for (cancelWhat <- Seq("stage", "job")) {
530-
SparkContextSuite.semaphore.drainPermits()
531-
SparkContextSuite.isTaskStarted = false
532-
SparkContextSuite.cancelStage = (cancelWhat == "stage")
533-
SparkContextSuite.cancelJob = (cancelWhat == "job")
524+
sc.addSparkListener(listener)
534525

535526
val ex = intercept[SparkException] {
536-
sc.range(0, 10000L, numSlices = slices).mapPartitions { x =>
537-
SparkContextSuite.isTaskStarted = true
538-
// Block waiting for the listener to cancel the stage or job.
539-
SparkContextSuite.semaphore.acquire()
527+
sc.range(0, 10000L, numSlices = 10).mapPartitions { x =>
528+
x.synchronized {
529+
x.wait()
530+
}
540531
x
541532
}.count()
542533
}
@@ -550,9 +541,11 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
550541
fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.")
551542
}
552543

544+
latch.await(20, TimeUnit.SECONDS)
553545
eventually(timeout(20.seconds)) {
554546
assert(sc.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0)
555547
}
548+
sc.removeSparkListener(listener)
556549
}
557550
}
558551

@@ -637,8 +630,6 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
637630
}
638631

639632
object SparkContextSuite {
640-
@volatile var cancelJob = false
641-
@volatile var cancelStage = false
642633
@volatile var isTaskStarted = false
643634
@volatile var taskKilled = false
644635
@volatile var taskSucceeded = false

external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ class KafkaRDD[
191191

192192
private def fetchBatch: Iterator[MessageAndOffset] = {
193193
val req = new FetchRequestBuilder()
194+
.clientId(consumer.clientId)
194195
.addFetch(part.topic, part.partition, requestOffset, kc.config.fetchMessageMaxBytes)
195196
.build()
196197
val resp = consumer.fetch(req)
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
org.apache.spark.ml.regression.InternalLinearRegressionModelWriter
2-
org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter
2+
org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter
3+
org.apache.spark.ml.clustering.InternalKMeansModelWriter
4+
org.apache.spark.ml.clustering.PMMLKMeansModelWriter

mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717

1818
package org.apache.spark.ml.clustering
1919

20+
import scala.collection.mutable
21+
2022
import org.apache.hadoop.fs.Path
2123

2224
import org.apache.spark.SparkException
2325
import org.apache.spark.annotation.{Experimental, Since}
24-
import org.apache.spark.ml.{Estimator, Model}
26+
import org.apache.spark.ml.{Estimator, Model, PipelineStage}
2527
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
2628
import org.apache.spark.ml.param._
2729
import org.apache.spark.ml.param.shared._
@@ -30,7 +32,7 @@ import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans
3032
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
3133
import org.apache.spark.mllib.linalg.VectorImplicits._
3234
import org.apache.spark.rdd.RDD
33-
import org.apache.spark.sql.{DataFrame, Dataset, Row}
35+
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
3436
import org.apache.spark.sql.functions.{col, udf}
3537
import org.apache.spark.sql.types.{IntegerType, StructType}
3638
import org.apache.spark.storage.StorageLevel
@@ -103,8 +105,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
103105
@Since("1.5.0")
104106
class KMeansModel private[ml] (
105107
@Since("1.5.0") override val uid: String,
106-
private val parentModel: MLlibKMeansModel)
107-
extends Model[KMeansModel] with KMeansParams with MLWritable {
108+
private[clustering] val parentModel: MLlibKMeansModel)
109+
extends Model[KMeansModel] with KMeansParams with GeneralMLWritable {
108110

109111
@Since("1.5.0")
110112
override def copy(extra: ParamMap): KMeansModel = {
@@ -152,14 +154,14 @@ class KMeansModel private[ml] (
152154
}
153155

154156
/**
155-
* Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance.
157+
* Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this ML instance.
156158
*
157159
* For [[KMeansModel]], this does NOT currently save the training [[summary]].
158160
* An option to save [[summary]] may be added in the future.
159161
*
160162
*/
161163
@Since("1.6.0")
162-
override def write: MLWriter = new KMeansModel.KMeansModelWriter(this)
164+
override def write: GeneralMLWriter = new GeneralMLWriter(this)
163165

164166
private var trainingSummary: Option[KMeansSummary] = None
165167

@@ -185,6 +187,47 @@ class KMeansModel private[ml] (
185187
}
186188
}
187189

190+
/** Helper class for storing model data */
191+
private case class ClusterData(clusterIdx: Int, clusterCenter: Vector)
192+
193+
194+
/** A writer for KMeans that handles the "internal" (or default) format */
195+
private class InternalKMeansModelWriter extends MLWriterFormat with MLFormatRegister {
196+
197+
override def format(): String = "internal"
198+
override def stageName(): String = "org.apache.spark.ml.clustering.KMeansModel"
199+
200+
override def write(path: String, sparkSession: SparkSession,
201+
optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
202+
val instance = stage.asInstanceOf[KMeansModel]
203+
val sc = sparkSession.sparkContext
204+
// Save metadata and Params
205+
DefaultParamsWriter.saveMetadata(instance, path, sc)
206+
// Save model data: cluster centers
207+
val data: Array[ClusterData] = instance.clusterCenters.zipWithIndex.map {
208+
case (center, idx) =>
209+
ClusterData(idx, center)
210+
}
211+
val dataPath = new Path(path, "data").toString
212+
sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath)
213+
}
214+
}
215+
216+
/** A writer for KMeans that handles the "pmml" format */
217+
private class PMMLKMeansModelWriter extends MLWriterFormat with MLFormatRegister {
218+
219+
override def format(): String = "pmml"
220+
override def stageName(): String = "org.apache.spark.ml.clustering.KMeansModel"
221+
222+
override def write(path: String, sparkSession: SparkSession,
223+
optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
224+
val instance = stage.asInstanceOf[KMeansModel]
225+
val sc = sparkSession.sparkContext
226+
instance.parentModel.toPMML(sc, path)
227+
}
228+
}
229+
230+
188231
@Since("1.6.0")
189232
object KMeansModel extends MLReadable[KMeansModel] {
190233

@@ -194,30 +237,12 @@ object KMeansModel extends MLReadable[KMeansModel] {
194237
@Since("1.6.0")
195238
override def load(path: String): KMeansModel = super.load(path)
196239

197-
/** Helper class for storing model data */
198-
private case class Data(clusterIdx: Int, clusterCenter: Vector)
199-
200240
/**
201241
* We store all cluster centers in a single row and use this class to store model data by
202242
* Spark 1.6 and earlier. A model can be loaded from such older data for backward compatibility.
203243
*/
204244
private case class OldData(clusterCenters: Array[OldVector])
205245

206-
/** [[MLWriter]] instance for [[KMeansModel]] */
207-
private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter {
208-
209-
override protected def saveImpl(path: String): Unit = {
210-
// Save metadata and Params
211-
DefaultParamsWriter.saveMetadata(instance, path, sc)
212-
// Save model data: cluster centers
213-
val data: Array[Data] = instance.clusterCenters.zipWithIndex.map { case (center, idx) =>
214-
Data(idx, center)
215-
}
216-
val dataPath = new Path(path, "data").toString
217-
sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath)
218-
}
219-
}
220-
221246
private class KMeansModelReader extends MLReader[KMeansModel] {
222247

223248
/** Checked against metadata when loading model */
@@ -232,7 +257,7 @@ object KMeansModel extends MLReadable[KMeansModel] {
232257
val dataPath = new Path(path, "data").toString
233258

234259
val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) {
235-
val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data]
260+
val data: Dataset[ClusterData] = sparkSession.read.parquet(dataPath).as[ClusterData]
236261
data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML)
237262
} else {
238263
// Loads KMeansModel stored with the old format used by Spark 1.6 and earlier.

0 commit comments

Comments
 (0)