Skip to content

Commit 7fe2f5e

Browse files
committed
[SPARK-53248][CORE] Support checkedCast in JavaUtils
### What changes were proposed in this pull request? This PR aims to support `checkedCast` in `JavaUtils`. In addition, new Scalastyle and Checkstyle rules are added to prevent future regressions. ### Why are the changes needed? To improve Spark's Java utility features. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #51977 from dongjoon-hyun/SPARK-53248. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent b82957c commit 7fe2f5e

File tree

8 files changed

+35
-22
lines changed

8 files changed

+35
-22
lines changed

common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import java.util.Properties;
2323
import java.util.concurrent.TimeUnit;
2424
import com.google.common.base.Preconditions;
25-
import com.google.common.primitives.Ints;
2625
import io.netty.util.NettyRuntime;
2726

2827
/**
@@ -171,7 +170,7 @@ public int ioRetryWaitTimeMs() {
171170
* memory mapping has high overhead for blocks close to or below the page size of the OS.
172171
*/
173172
public int memoryMapBytes() {
174-
return Ints.checkedCast(JavaUtils.byteStringAsBytes(
173+
return JavaUtils.checkedCast(JavaUtils.byteStringAsBytes(
175174
conf.get("spark.storage.memoryMapThreshold", "2m")));
176175
}
177176

@@ -248,7 +247,7 @@ public boolean saslEncryption() {
248247
* Maximum number of bytes to be encrypted at a time when SASL encryption is used.
249248
*/
250249
public int maxSaslEncryptedBlockSize() {
251-
return Ints.checkedCast(JavaUtils.byteStringAsBytes(
250+
return JavaUtils.checkedCast(JavaUtils.byteStringAsBytes(
252251
conf.get("spark.network.sasl.maxEncryptedBlockSize", "64k")));
253252
}
254253

@@ -263,7 +262,7 @@ public boolean saslServerAlwaysEncrypt() {
263262
* When Secure (SSL/TLS) Shuffle is enabled, the Chunk size to use for shuffling files.
264263
*/
265264
public int sslShuffleChunkSize() {
266-
return Ints.checkedCast(JavaUtils.byteStringAsBytes(
265+
return JavaUtils.checkedCast(JavaUtils.byteStringAsBytes(
267266
conf.get("spark.network.ssl.maxEncryptedBlockSize", "64k")));
268267
}
269268

@@ -567,7 +566,7 @@ public String mergedShuffleFileManagerImpl() {
567566
* service unnecessarily.
568567
*/
569568
public int minChunkSizeInMergedShuffleFile() {
570-
return Ints.checkedCast(JavaUtils.byteStringAsBytes(
569+
return JavaUtils.checkedCast(JavaUtils.byteStringAsBytes(
571570
conf.get("spark.shuffle.push.server.minChunkSizeInMergedShuffleFile", "2m")));
572571
}
573572

common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
import java.util.List;
2121

22-
import com.google.common.primitives.Ints;
2322
import io.netty.buffer.Unpooled;
2423
import io.netty.channel.ChannelHandlerContext;
2524
import io.netty.channel.FileRegion;
@@ -44,6 +43,7 @@
4443
import org.apache.spark.network.protocol.StreamRequest;
4544
import org.apache.spark.network.protocol.StreamResponse;
4645
import org.apache.spark.network.util.ByteArrayWritableChannel;
46+
import org.apache.spark.network.util.JavaUtils;
4747
import org.apache.spark.network.util.NettyUtils;
4848

4949
public class ProtocolSuite {
@@ -115,7 +115,8 @@ private static class FileRegionEncoder extends MessageToMessageEncoder<FileRegio
115115
public void encode(ChannelHandlerContext ctx, FileRegion in, List<Object> out)
116116
throws Exception {
117117

118-
ByteArrayWritableChannel channel = new ByteArrayWritableChannel(Ints.checkedCast(in.count()));
118+
ByteArrayWritableChannel channel =
119+
new ByteArrayWritableChannel(JavaUtils.checkedCast(in.count()));
119120
while (in.transferred() < in.count()) {
120121
in.transferTo(channel, in.transferred());
121122
}

common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@
2020
import java.nio.ByteOrder;
2121
import java.util.Arrays;
2222

23-
import com.google.common.primitives.Ints;
24-
2523
import org.apache.spark.unsafe.Platform;
24+
import org.apache.spark.network.util.JavaUtils;
2625

2726
public final class ByteArray {
2827

@@ -169,7 +168,7 @@ public static byte[] concatWS(byte[] delimiter, byte[]... inputs) {
169168
}
170169
if (totalLength > 0) totalLength -= delimiter.length;
171170
// Allocate a new byte array, and copy the inputs one by one into it
172-
final byte[] result = new byte[Ints.checkedCast(totalLength)];
171+
final byte[] result = new byte[JavaUtils.checkedCast(totalLength)];
173172
int offset = 0;
174173
for (int i = 0; i < inputs.length; i++) {
175174
byte[] input = inputs[i];

common/utils/src/main/java/org/apache/spark/network/util/JavaUtils.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,4 +656,11 @@ public static String stackTraceToString(Throwable t) {
656656
}
657657
return out.toString(StandardCharsets.UTF_8);
658658
}
659+
660+
public static int checkedCast(long value) {
661+
if (value > Integer.MAX_VALUE || value < Integer.MIN_VALUE) {
662+
throw new IllegalArgumentException("Cannot cast to integer.");
663+
}
664+
return (int) value;
665+
}
659666
}

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,9 @@
1717

1818
package org.apache.spark.util.collection.unsafe.sort;
1919

20-
import com.google.common.primitives.Ints;
21-
2220
import org.apache.spark.unsafe.Platform;
2321
import org.apache.spark.unsafe.array.LongArray;
22+
import org.apache.spark.network.util.JavaUtils;
2423

2524
public class RadixSort {
2625

@@ -63,7 +62,7 @@ public static int sort(
6362
}
6463
}
6564
}
66-
return Ints.checkedCast(inIndex);
65+
return JavaUtils.checkedCast(inIndex);
6766
}
6867

6968
/**
@@ -204,7 +203,7 @@ public static int sortKeyPrefixArray(
204203
}
205204
}
206205
}
207-
return Ints.checkedCast(inIndex);
206+
return JavaUtils.checkedCast(inIndex);
208207
}
209208

210209
/**

core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ import java.util.{Arrays, Comparator}
2222

2323
import scala.util.Random
2424

25-
import com.google.common.primitives.Ints
26-
2725
import org.apache.spark.SparkFunSuite
26+
import org.apache.spark.network.util.JavaUtils.checkedCast
2827
import org.apache.spark.unsafe.array.LongArray
2928
import org.apache.spark.unsafe.memory.MemoryBlock
3029
import org.apache.spark.util.collection.Sorter
@@ -75,21 +74,21 @@ class RadixSortSuite extends SparkFunSuite {
7574
2, 4, false, false, true))
7675

7776
private def generateTestData(size: Long, rand: => Long): (Array[JLong], LongArray) = {
78-
val ref = Array.tabulate[Long](Ints.checkedCast(size)) { i => rand }
79-
val extended = ref ++ createArray(Ints.checkedCast(size), 0L)
77+
val ref = Array.tabulate[Long](checkedCast(size)) { i => rand }
78+
val extended = ref ++ createArray(checkedCast(size), 0L)
8079
(ref.map(i => JLong.valueOf(i)), new LongArray(MemoryBlock.fromLongArray(extended)))
8180
}
8281

8382
private def generateKeyPrefixTestData(size: Long, rand: => Long): (LongArray, LongArray) = {
84-
val ref = Array.tabulate[Long](Ints.checkedCast(size * 2)) { i => rand }
85-
val extended = ref ++ createArray(Ints.checkedCast(size * 2), 0L)
83+
val ref = Array.tabulate[Long](checkedCast(size * 2)) { i => rand }
84+
val extended = ref ++ createArray(checkedCast(size * 2), 0L)
8685
(new LongArray(MemoryBlock.fromLongArray(ref)),
8786
new LongArray(MemoryBlock.fromLongArray(extended)))
8887
}
8988

9089
private def collectToArray(array: LongArray, offset: Int, length: Long): Array[Long] = {
9190
var i = 0
92-
val out = new Array[Long](Ints.checkedCast(length))
91+
val out = new Array[Long](checkedCast(length))
9392
while (i < length) {
9493
out(i) = array.get(offset + i)
9594
i += 1
@@ -112,7 +111,7 @@ class RadixSortSuite extends SparkFunSuite {
112111
refCmp: PrefixComparator): Unit = {
113112
val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt)))
114113
new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort(
115-
buf, Ints.checkedCast(lo), Ints.checkedCast(hi),
114+
buf, checkedCast(lo), checkedCast(hi),
116115
(r1: RecordPointerAndKeyPrefix, r2: RecordPointerAndKeyPrefix) =>
117116
refCmp.compare(r1.keyPrefix, r2.keyPrefix))
118117
}

dev/checkstyle.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,10 @@
199199
<property name="illegalClasses" value="com.google.common.io.BaseEncoding" />
200200
<property name="illegalClasses" value="com.google.common.io.Files" />
201201
</module>
202+
<module name="RegexpSinglelineJava">
203+
<property name="format" value="Ints\.checkedCast"/>
204+
<property name="message" value="Use JavaUtils.checkedCast instead." />
205+
</module>
202206
<module name="RegexpSinglelineJava">
203207
<property name="format" value="Charset\.defaultCharset"/>
204208
<property name="message" value="Use StandardCharsets.UTF_8 instead." />

scalastyle-config.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -826,4 +826,9 @@ This file is divided into 3 sections:
826826
<parameters><parameter name="regex">\bPreconditions\.checkNotNull\b</parameter></parameters>
827827
<customMessage>Use requireNonNull of java.util.Objects instead.</customMessage>
828828
</check>
829+
830+
<check customId="intscheckedcast" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
831+
<parameters><parameter name="regex">\bInts\.checkedCast\b</parameter></parameters>
832+
<customMessage>Use JavaUtils.checkedCast instead.</customMessage>
833+
</check>
829834
</scalastyle>

0 commit comments

Comments
 (0)