Skip to content

Commit fc65e0f

Browse files
committed
[SPARK-27839][SQL] Change UTF8String.replace() to operate on UTF8 bytes
## What changes were proposed in this pull request? This PR significantly improves the performance of `UTF8String.replace()` by performing direct replacement over UTF8 bytes instead of decoding those bytes into Java Strings. In cases where the search string is not found (i.e. no replacements are performed, a case which I expect to be common) this new implementation performs no object allocation or memory copying. My implementation is modeled after `commons-lang3`'s `StringUtils.replace()` method. As part of my implementation, I needed a StringBuilder / resizable buffer, so I moved `UTF8StringBuilder` from the `catalyst` package to `unsafe`. ## How was this patch tested? Copied tests from `StringExpressionSuite` to `UTF8StringSuite` and added a couple of new cases. To evaluate performance, I did some quick local benchmarking by running the following code in `spark-shell` (with Java 1.8.0_191): ```scala import org.apache.spark.unsafe.types.UTF8String def benchmark(text: String, search: String, replace: String) { val utf8Text = UTF8String.fromString(text) val utf8Search = UTF8String.fromString(search) val utf8Replace = UTF8String.fromString(replace) val start = System.currentTimeMillis var i = 0 while (i < 1000 * 1000 * 100) { utf8Text.replace(utf8Search, utf8Replace) i += 1 } val end = System.currentTimeMillis println(end - start) } benchmark("ABCDEFGH", "DEF", "ZZZZ") // replacement occurs benchmark("ABCDEFGH", "Z", "") // no replacement occurs ``` On my laptop this took ~54 / ~40 seconds seconds before this patch's changes and ~6.5 / ~3.8 seconds afterwards. Closes apache#24707 from JoshRosen/faster-string-replace. Authored-by: Josh Rosen <[email protected]> Signed-off-by: Josh Rosen <[email protected]>
1 parent fe5145e commit fc65e0f

File tree

5 files changed

+86
-7
lines changed

5 files changed

+86
-7
lines changed
Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.sql.catalyst.expressions.codegen;
18+
package org.apache.spark.unsafe;
1919

20-
import org.apache.spark.unsafe.Platform;
2120
import org.apache.spark.unsafe.array.ByteArrayMethods;
2221
import org.apache.spark.unsafe.types.UTF8String;
2322

@@ -34,7 +33,18 @@ public class UTF8StringBuilder {
3433

3534
public UTF8StringBuilder() {
3635
// Since initial buffer size is 16 in `StringBuilder`, we set the same size here
37-
this.buffer = new byte[16];
36+
this(16);
37+
}
38+
39+
public UTF8StringBuilder(int initialSize) {
40+
if (initialSize < 0) {
41+
throw new IllegalArgumentException("Size must be non-negative");
42+
}
43+
if (initialSize > ARRAY_MAX) {
44+
throw new IllegalArgumentException(
45+
"Size " + initialSize + " exceeded maximum size of " + ARRAY_MAX);
46+
}
47+
this.buffer = new byte[initialSize];
3848
}
3949

4050
// Grows the buffer by at least `neededSize`
@@ -72,6 +82,17 @@ public void append(String value) {
7282
append(UTF8String.fromString(value));
7383
}
7484

85+
public void appendBytes(Object base, long offset, int length) {
86+
grow(length);
87+
Platform.copyMemory(
88+
base,
89+
offset,
90+
buffer,
91+
cursor,
92+
length);
93+
cursor += length;
94+
}
95+
7596
public UTF8String build() {
7697
return UTF8String.fromBytes(buffer, 0, totalSize());
7798
}

common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import com.google.common.primitives.Ints;
3333

3434
import org.apache.spark.unsafe.Platform;
35+
import org.apache.spark.unsafe.UTF8StringBuilder;
3536
import org.apache.spark.unsafe.array.ByteArrayMethods;
3637
import org.apache.spark.unsafe.hash.Murmur3_x86_32;
3738

@@ -1002,12 +1003,29 @@ public UTF8String[] split(UTF8String pattern, int limit) {
10021003
}
10031004

10041005
public UTF8String replace(UTF8String search, UTF8String replace) {
1005-
if (EMPTY_UTF8.equals(search)) {
1006+
// This implementation is loosely based on commons-lang3's StringUtils.replace().
1007+
if (numBytes == 0 || search.numBytes == 0) {
10061008
return this;
10071009
}
1008-
String replaced = toString().replace(
1009-
search.toString(), replace.toString());
1010-
return fromString(replaced);
1010+
// Find the first occurrence of the search string.
1011+
int start = 0;
1012+
int end = this.find(search, start);
1013+
if (end == -1) {
1014+
// Search string was not found, so string is unchanged.
1015+
return this;
1016+
}
1017+
// At least one match was found. Estimate space needed for result.
1018+
// The 16x multiplier here is chosen to match commons-lang3's implementation.
1019+
int increase = Math.max(0, replace.numBytes - search.numBytes) * 16;
1020+
final UTF8StringBuilder buf = new UTF8StringBuilder(numBytes + increase);
1021+
while (end != -1) {
1022+
buf.appendBytes(this.base, this.offset + start, end - start);
1023+
buf.append(replace);
1024+
start = end + search.numBytes;
1025+
end = this.find(search, start);
1026+
}
1027+
buf.appendBytes(this.base, this.offset + start, numBytes - start);
1028+
return buf.build();
10111029
}
10121030

10131031
// TODO: Need to use `Code Point` here instead of Char in case the character longer than 2 bytes

common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,44 @@ public void split() {
403403
new UTF8String[]{fromString("ab"), fromString("def,ghi,")}));
404404
}
405405

406+
@Test
407+
public void replace() {
408+
assertEquals(
409+
fromString("re123ace"),
410+
fromString("replace").replace(fromString("pl"), fromString("123")));
411+
assertEquals(
412+
fromString("reace"),
413+
fromString("replace").replace(fromString("pl"), fromString("")));
414+
assertEquals(
415+
fromString("replace"),
416+
fromString("replace").replace(fromString(""), fromString("123")));
417+
// tests for multiple replacements
418+
assertEquals(
419+
fromString("a12ca12c"),
420+
fromString("abcabc").replace(fromString("b"), fromString("12")));
421+
assertEquals(
422+
fromString("adad"),
423+
fromString("abcdabcd").replace(fromString("bc"), fromString("")));
424+
// tests for single character search and replacement strings
425+
assertEquals(
426+
fromString("AbcAbc"),
427+
fromString("abcabc").replace(fromString("a"), fromString("A")));
428+
assertEquals(
429+
fromString("abcabc"),
430+
fromString("abcabc").replace(fromString("Z"), fromString("A")));
431+
// Tests with non-ASCII characters
432+
assertEquals(
433+
fromString("花ab界"),
434+
fromString("花花世界").replace(fromString("花世"), fromString("ab")));
435+
assertEquals(
436+
fromString("a水c"),
437+
fromString("a火c").replace(fromString("火"), fromString("水")));
438+
// Tests for a large number of replacements, triggering UTF8StringBuilder resize
439+
assertEquals(
440+
fromString("abcd").repeat(17),
441+
fromString("a").repeat(17).replace(fromString("a"), fromString("abcd")));
442+
}
443+
406444
@Test
407445
public void levenshteinDistance() {
408446
assertEquals(0, EMPTY_UTF8.levenshteinDistance(EMPTY_UTF8));

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2929
import org.apache.spark.sql.catalyst.util._
3030
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
3131
import org.apache.spark.sql.types._
32+
import org.apache.spark.unsafe.UTF8StringBuilder
3233
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
3334
import org.apache.spark.unsafe.types.UTF8String.{IntWrapper, LongWrapper}
3435

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.util._
3030
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
3131
import org.apache.spark.sql.internal.SQLConf
3232
import org.apache.spark.sql.types._
33+
import org.apache.spark.unsafe.UTF8StringBuilder
3334
import org.apache.spark.unsafe.array.ByteArrayMethods
3435
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
3536
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}

0 commit comments

Comments
 (0)