Skip to content

Commit 56c1879

Browse files
authored
Use unsafe to speed up string marshaling (#6433)
1 parent d9cef81 commit 56c1879

File tree

12 files changed

+375
-64
lines changed

12 files changed

+375
-64
lines changed

exporters/common/build.gradle.kts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ dependencies {
1414
api(project(":sdk-extensions:autoconfigure-spi"))
1515

1616
compileOnly(project(":sdk:common"))
17+
compileOnly(project(":exporters:common:compile-stub"))
1718

1819
compileOnly("org.codehaus.mojo:animal-sniffer-annotations")
1920

@@ -22,6 +23,8 @@ dependencies {
2223
// We include helpers shared by gRPC exporters but do not want to impose these
2324
// dependency on all of our consumers.
2425
compileOnly("com.fasterxml.jackson.core:jackson-core")
26+
// sun.misc.Unsafe from the JDK isn't found by the compiler, we provide our own trimmed down
27+
// version that we can compile against.
2528
compileOnly("io.grpc:grpc-stub")
2629

2730
testImplementation(project(":sdk:common"))
@@ -31,6 +34,7 @@ dependencies {
3134
testImplementation("org.skyscreamer:jsonassert")
3235
testImplementation("com.google.api.grpc:proto-google-common-protos")
3336
testImplementation("io.grpc:grpc-testing")
37+
testImplementation("edu.berkeley.cs.jqf:jqf-fuzz")
3438
testRuntimeOnly("io.grpc:grpc-netty-shaded")
3539
}
3640

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
plugins {
2+
id("otel.java-conventions")
3+
}
4+
5+
description = "OpenTelemetry Exporter Compile Stub"
6+
otelJava.moduleName.set("io.opentelemetry.exporter.internal.compile-stub")
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Copyright The OpenTelemetry Authors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package sun.misc;
7+
8+
import java.lang.reflect.Field;
9+
10+
/**
11+
* sun.misc.Unsafe from the JDK isn't found by the compiler, we provide our own trimmed down version
12+
* that we can compile against.
13+
*/
14+
public class Unsafe {
15+
16+
public long objectFieldOffset(Field f) {
17+
return -1;
18+
}
19+
20+
public Object getObject(Object o, long offset) {
21+
return null;
22+
}
23+
24+
public byte getByte(Object o, long offset) {
25+
return 0;
26+
}
27+
28+
public int arrayBaseOffset(Class<?> arrayClass) {
29+
return 0;
30+
}
31+
32+
public long getLong(Object o, long offset) {
33+
return 0;
34+
}
35+
}

exporters/common/src/main/java/io/opentelemetry/exporter/internal/marshal/MarshalerContext.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
*/
2525
public final class MarshalerContext {
2626
private final boolean marshalStringNoAllocation;
27+
private final boolean marshalStringUnsafe;
2728

2829
private int[] sizes = new int[16];
2930
private int sizeReadIndex;
@@ -32,19 +33,23 @@ public final class MarshalerContext {
3233
private int dataReadIndex;
3334
private int dataWriteIndex;
3435

35-
@SuppressWarnings("BooleanParameter")
3636
public MarshalerContext() {
37-
this(true);
37+
this(/* marshalStringNoAllocation= */ true, /* marshalStringUnsafe= */ true);
3838
}
3939

40-
public MarshalerContext(boolean marshalStringNoAllocation) {
40+
public MarshalerContext(boolean marshalStringNoAllocation, boolean marshalStringUnsafe) {
4141
this.marshalStringNoAllocation = marshalStringNoAllocation;
42+
this.marshalStringUnsafe = marshalStringUnsafe;
4243
}
4344

4445
public boolean marshalStringNoAllocation() {
4546
return marshalStringNoAllocation;
4647
}
4748

49+
public boolean marshalStringUnsafe() {
50+
return marshalStringUnsafe;
51+
}
52+
4853
public void addSize(int size) {
4954
growSizeIfNeeded();
5055
sizes[sizeWriteIndex++] = size;

exporters/common/src/main/java/io/opentelemetry/exporter/internal/marshal/StatelessMarshalerUtil.java

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -299,16 +299,63 @@ public static <K, V> int sizeMessageWithContext(
299299
}
300300

301301
/** Returns the size of utf8 encoded string in bytes. */
302-
@SuppressWarnings("UnusedVariable")
303302
private static int getUtf8Size(String string, MarshalerContext context) {
304-
return getUtf8Size(string);
303+
return getUtf8Size(string, context.marshalStringUnsafe());
305304
}
306305

307306
// Visible for testing
308-
static int getUtf8Size(String string) {
307+
static int getUtf8Size(String string, boolean useUnsafe) {
308+
if (useUnsafe && UnsafeString.isAvailable() && UnsafeString.isLatin1(string)) {
309+
byte[] bytes = UnsafeString.getBytes(string);
310+
// latin1 bytes with negative value (most significant bit set) are encoded as 2 bytes in utf8
311+
return string.length() + countNegative(bytes);
312+
}
313+
309314
return encodedUtf8Length(string);
310315
}
311316

317+
// Inner loop can process at most 8 * 255 bytes without overflowing counter. To process more bytes
318+
// inner loop has to be run multiple times.
319+
private static final int MAX_INNER_LOOP_SIZE = 8 * 255;
320+
// mask that selects only the most significant bit in every byte of the long
321+
private static final long MOST_SIGNIFICANT_BIT_MASK = 0x8080808080808080L;
322+
323+
/** Returns the count of bytes with negative value. */
324+
private static int countNegative(byte[] bytes) {
325+
int count = 0;
326+
int offset = 0;
327+
// We are processing one long (8 bytes) at a time. In the inner loop we are keeping counts in a
328+
// long where each byte in the long is a separate counter. Due to this the inner loop can
329+
// process a maximum of 8*255 bytes at a time without overflow.
330+
for (int i = 1; i <= bytes.length / MAX_INNER_LOOP_SIZE + 1; i++) {
331+
long tmp = 0; // each byte in this long is a separate counter
332+
int limit = Math.min(i * MAX_INNER_LOOP_SIZE, bytes.length & ~7);
333+
for (; offset < limit; offset += 8) {
334+
long value = UnsafeString.getLong(bytes, offset);
335+
// Mask the value keeping only the most significant bit in each byte and then shift this bit
336+
// to the position of the least significant bit in each byte. If the input byte was not
337+
// negative then after this transformation it will be zero, if it was negative then it will
338+
// be one.
339+
tmp += (value & MOST_SIGNIFICANT_BIT_MASK) >>> 7;
340+
}
341+
// sum up counts
342+
if (tmp != 0) {
343+
for (int j = 0; j < 8; j++) {
344+
count += (int) (tmp & 0xff);
345+
tmp = tmp >>> 8;
346+
}
347+
}
348+
}
349+
350+
// Handle remaining bytes. Previous loop processes 8 bytes a time, if the input size is not
351+
// divisible with 8 the remaining bytes are handled here.
352+
for (int i = offset; i < bytes.length; i++) {
353+
// same as if (bytes[i] < 0) count++;
354+
count += bytes[i] >>> 31;
355+
}
356+
return count;
357+
}
358+
312359
// adapted from
313360
// https://github.com/protocolbuffers/protobuf/blob/b618f6750aed641a23d5f26fbbaf654668846d24/java/core/src/main/java/com/google/protobuf/Utf8.java#L217
314361
private static int encodedUtf8Length(String string) {
@@ -376,14 +423,24 @@ private static int encodedUtf8LengthGeneral(String string, int start) {
376423
static void writeUtf8(
377424
CodedOutputStream output, String string, int utf8Length, MarshalerContext context)
378425
throws IOException {
379-
writeUtf8(output, string, utf8Length);
426+
writeUtf8(output, string, utf8Length, context.marshalStringUnsafe());
380427
}
381428

382429
// Visible for testing
383430
@SuppressWarnings("UnusedVariable") // utf8Length argument is added for future use
384-
static void writeUtf8(CodedOutputStream output, String string, int utf8Length)
431+
static void writeUtf8(CodedOutputStream output, String string, int utf8Length, boolean useUnsafe)
385432
throws IOException {
386-
encodeUtf8(output, string);
433+
// if the length of the latin1 string and the utf8 output are the same then the string must be
434+
// composed of only 7bit characters and can be directly copied to the output
435+
if (useUnsafe
436+
&& UnsafeString.isAvailable()
437+
&& string.length() == utf8Length
438+
&& UnsafeString.isLatin1(string)) {
439+
byte[] bytes = UnsafeString.getBytes(string);
440+
output.write(bytes, 0, bytes.length);
441+
} else {
442+
encodeUtf8(output, string);
443+
}
387444
}
388445

389446
// encode utf8 the same way as length is computed in encodedUtf8Length
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Copyright The OpenTelemetry Authors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package io.opentelemetry.exporter.internal.marshal;
7+
8+
import io.opentelemetry.api.internal.ConfigUtil;
9+
import java.lang.reflect.Field;
10+
import sun.misc.Unsafe;
11+
12+
class UnsafeAccess {
13+
private static final int MAX_ENABLED_JAVA_VERSION = 22;
14+
private static final boolean available = checkUnsafe();
15+
16+
static boolean isAvailable() {
17+
return available;
18+
}
19+
20+
private static boolean checkUnsafe() {
21+
double javaVersion = getJavaVersion();
22+
boolean unsafeEnabled =
23+
Boolean.parseBoolean(
24+
ConfigUtil.getString(
25+
"otel.java.experimental.exporter.unsafe.enabled",
26+
javaVersion != -1 && javaVersion <= MAX_ENABLED_JAVA_VERSION ? "true" : "false"));
27+
if (!unsafeEnabled) {
28+
return false;
29+
}
30+
31+
try {
32+
Class.forName("sun.misc.Unsafe", false, UnsafeAccess.class.getClassLoader());
33+
return UnsafeHolder.UNSAFE != null;
34+
} catch (ClassNotFoundException e) {
35+
return false;
36+
}
37+
}
38+
39+
private static double getJavaVersion() {
40+
String specVersion = System.getProperty("java.specification.version");
41+
if (specVersion != null) {
42+
try {
43+
return Double.parseDouble(specVersion);
44+
} catch (NumberFormatException exception) {
45+
// ignore
46+
}
47+
}
48+
return -1;
49+
}
50+
51+
static long objectFieldOffset(Field field) {
52+
return UnsafeHolder.UNSAFE.objectFieldOffset(field);
53+
}
54+
55+
static Object getObject(Object object, long offset) {
56+
return UnsafeHolder.UNSAFE.getObject(object, offset);
57+
}
58+
59+
static byte getByte(Object object, long offset) {
60+
return UnsafeHolder.UNSAFE.getByte(object, offset);
61+
}
62+
63+
static int arrayBaseOffset(Class<?> arrayClass) {
64+
return UnsafeHolder.UNSAFE.arrayBaseOffset(arrayClass);
65+
}
66+
67+
static long getLong(Object o, long offset) {
68+
return UnsafeHolder.UNSAFE.getLong(o, offset);
69+
}
70+
71+
private UnsafeAccess() {}
72+
73+
private static class UnsafeHolder {
74+
public static final Unsafe UNSAFE;
75+
76+
static {
77+
UNSAFE = getUnsafe();
78+
}
79+
80+
private UnsafeHolder() {}
81+
82+
@SuppressWarnings("NullAway")
83+
private static Unsafe getUnsafe() {
84+
try {
85+
Field field = Unsafe.class.getDeclaredField("theUnsafe");
86+
field.setAccessible(true);
87+
return (Unsafe) field.get(null);
88+
} catch (Exception ignored) {
89+
return null;
90+
}
91+
}
92+
}
93+
}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright The OpenTelemetry Authors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package io.opentelemetry.exporter.internal.marshal;
7+
8+
import java.lang.reflect.Field;
9+
10+
class UnsafeString {
11+
private static final long valueOffset = getStringFieldOffset("value", byte[].class);
12+
private static final long coderOffset = getStringFieldOffset("coder", byte.class);
13+
private static final int byteArrayBaseOffset = UnsafeAccess.arrayBaseOffset(byte[].class);
14+
private static final boolean available = valueOffset != -1 && coderOffset != -1;
15+
16+
static boolean isAvailable() {
17+
return available;
18+
}
19+
20+
static boolean isLatin1(String string) {
21+
// 0 represents latin1, 1 utf16
22+
return UnsafeAccess.getByte(string, coderOffset) == 0;
23+
}
24+
25+
static byte[] getBytes(String string) {
26+
return (byte[]) UnsafeAccess.getObject(string, valueOffset);
27+
}
28+
29+
static long getLong(byte[] bytes, int index) {
30+
return UnsafeAccess.getLong(bytes, byteArrayBaseOffset + index);
31+
}
32+
33+
private static long getStringFieldOffset(String fieldName, Class<?> expectedType) {
34+
if (!UnsafeAccess.isAvailable()) {
35+
return -1;
36+
}
37+
38+
try {
39+
Field field = String.class.getDeclaredField(fieldName);
40+
if (field.getType() != expectedType) {
41+
return -1;
42+
}
43+
return UnsafeAccess.objectFieldOffset(field);
44+
} catch (Exception exception) {
45+
return -1;
46+
}
47+
}
48+
49+
private UnsafeString() {}
50+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Copyright The OpenTelemetry Authors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package io.opentelemetry.exporter.internal.marshal;
7+
8+
import static io.opentelemetry.exporter.internal.marshal.StatelessMarshalerUtil.getUtf8Size;
9+
import static io.opentelemetry.exporter.internal.marshal.StatelessMarshalerUtilTest.testUtf8;
10+
import static org.assertj.core.api.Assertions.assertThat;
11+
12+
import edu.berkeley.cs.jqf.fuzz.Fuzz;
13+
import edu.berkeley.cs.jqf.fuzz.JQF;
14+
import edu.berkeley.cs.jqf.fuzz.junit.GuidedFuzzing;
15+
import edu.berkeley.cs.jqf.fuzz.random.NoGuidance;
16+
import java.nio.charset.StandardCharsets;
17+
import org.junit.jupiter.api.Test;
18+
import org.junit.runner.Result;
19+
import org.junit.runner.RunWith;
20+
21+
@SuppressWarnings("SystemOut")
22+
class StatelessMarshalerUtilFuzzTest {
23+
24+
@RunWith(JQF.class)
25+
public static class EncodeUf8 {
26+
27+
@Fuzz
28+
public void encodeRandomString(String value) {
29+
int utf8Size = value.getBytes(StandardCharsets.UTF_8).length;
30+
assertThat(getUtf8Size(value, false)).isEqualTo(utf8Size);
31+
assertThat(getUtf8Size(value, true)).isEqualTo(utf8Size);
32+
assertThat(testUtf8(value, utf8Size, /* useUnsafe= */ false)).isEqualTo(value);
33+
assertThat(testUtf8(value, utf8Size, /* useUnsafe= */ true)).isEqualTo(value);
34+
}
35+
}
36+
37+
// driver methods to avoid having to use the vintage junit engine, and to enable increasing the
38+
// number of iterations:
39+
40+
@Test
41+
void encodeUf8WithFuzzing() {
42+
Result result =
43+
GuidedFuzzing.run(
44+
EncodeUf8.class, "encodeRandomString", new NoGuidance(10000, System.out), System.out);
45+
assertThat(result.wasSuccessful()).isTrue();
46+
}
47+
}

0 commit comments

Comments
 (0)