|
18 | 18 | package org.apache.spark.broadcast
|
19 | 19 |
|
20 | 20 | import java.io._
|
| 21 | +import java.lang.ref.SoftReference |
21 | 22 | import java.nio.ByteBuffer
|
22 | 23 | import java.util.zip.Adler32
|
23 | 24 |
|
@@ -61,9 +62,11 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
|
61 | 62 | * Value of the broadcast object on executors. This is reconstructed by [[readBroadcastBlock]],
|
62 | 63 | * which builds this value by reading blocks from the driver and/or other executors.
|
63 | 64 | *
|
64 |
| - * On the driver, if the value is required, it is read lazily from the block manager. |
| 65 | + * On the driver, if the value is required, it is read lazily from the block manager. We hold |
| 66 | + * a soft reference so that it can be garbage collected if required, as we can always reconstruct |
| 67 | + * in the future. |
65 | 68 | */
|
66 |
| - @transient private lazy val _value: T = readBroadcastBlock() |
| 69 | + @transient private var _value: SoftReference[T] = _ |
67 | 70 |
|
68 | 71 | /** The compression codec to use, or None if compression is disabled */
|
69 | 72 | @transient private var compressionCodec: Option[CompressionCodec] = _
|
@@ -92,8 +95,15 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
|
92 | 95 | /** The checksum for all the blocks. */
|
93 | 96 | private var checksums: Array[Int] = _
|
94 | 97 |
|
95 |
| - override protected def getValue() = { |
96 |
| - _value |
| 98 | + override protected def getValue() = synchronized { |
| 99 | + val memoized: T = if (_value == null) null.asInstanceOf[T] else _value.get |
| 100 | + if (memoized != null) { |
| 101 | + memoized |
| 102 | + } else { |
| 103 | + val newlyRead = readBroadcastBlock() |
| 104 | + _value = new SoftReference[T](newlyRead) |
| 105 | + newlyRead |
| 106 | + } |
97 | 107 | }
|
98 | 108 |
|
99 | 109 | private def calcChecksum(block: ByteBuffer): Int = {
|
@@ -205,8 +215,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
|
205 | 215 | }
|
206 | 216 |
|
207 | 217 | private def readBroadcastBlock(): T = Utils.tryOrIOException {
|
208 |
| - TorrentBroadcast.synchronized { |
209 |
| - val broadcastCache = SparkEnv.get.broadcastManager.cachedValues |
| 218 | + val broadcastCache = SparkEnv.get.broadcastManager.cachedValues |
| 219 | + broadcastCache.synchronized { |
210 | 220 |
|
211 | 221 | Option(broadcastCache.get(broadcastId)).map(_.asInstanceOf[T]).getOrElse {
|
212 | 222 | setConf(SparkEnv.get.conf)
|
|
0 commit comments