Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ enum SpringWebHttpAttributesGetter

@Nullable private static final MethodHandle GET_STATUS_CODE;
@Nullable private static final MethodHandle STATUS_CODE_VALUE;
@Nullable private static final MethodHandle GET_HEADERS;
private static final MethodHandle GET_HEADERS;

static {
MethodHandles.Lookup lookup = MethodHandles.publicLookup();
Expand Down Expand Up @@ -58,21 +58,24 @@ enum SpringWebHttpAttributesGetter

GET_STATUS_CODE = getStatusCode;
STATUS_CODE_VALUE = statusCodeValue;
GET_HEADERS =
isSpring7OrNewer()
? findGetHeadersMethod(MethodType.methodType(List.class, String.class))
: findGetHeadersMethod(MethodType.methodType(List.class, Object.class));
}

// since spring web 7.0
MethodHandle methodHandle =
findGetHeadersMethod(MethodType.methodType(List.class, String.class, List.class));
if (methodHandle == null) {
// up to spring web 7.0
methodHandle =
findGetHeadersMethod(MethodType.methodType(Object.class, Object.class, Object.class));
private static boolean isSpring7OrNewer() {
try {
Class.forName("org.springframework.core.Nullness");
return true;
} catch (ClassNotFoundException e) {
return false;
}
GET_HEADERS = methodHandle;
}

private static MethodHandle findGetHeadersMethod(MethodType methodType) {
try {
return MethodHandles.lookup().findVirtual(HttpHeaders.class, "getOrDefault", methodType);
return MethodHandles.lookup().findVirtual(HttpHeaders.class, "get", methodType);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be interesting to know why getOrDefault failed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but I don't have an answer

} catch (Throwable t) {
return null;
}
Expand All @@ -95,15 +98,16 @@ public List<String> getHttpRequestHeader(HttpRequest httpRequest, String name) {
}

@SuppressWarnings("unchecked") // casting MethodHandle.invoke result
private static List<String> getHeader(HttpHeaders headers, String name) {
if (GET_HEADERS != null) {
try {
return (List<String>) GET_HEADERS.invoke(headers, name, emptyList());
} catch (Throwable t) {
// ignore
static List<String> getHeader(HttpHeaders headers, String name) {
try {
List<String> result = (List<String>) GET_HEADERS.invoke(headers, name);
if (result == null) {
return emptyList();
}
return result;
} catch (Throwable t) {
throw new IllegalStateException(t);
}
return emptyList();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,42 +11,46 @@
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.util.List;
import javax.annotation.Nullable;
import org.springframework.http.HttpHeaders;

class HeaderUtil {
@Nullable private static final MethodHandle GET_HEADERS;
private static final MethodHandle GET_HEADERS;

static {
// since spring web 7.0
MethodHandle methodHandle =
findGetHeadersMethod(MethodType.methodType(List.class, String.class, List.class));
if (methodHandle == null) {
// up to spring web 7.0
methodHandle =
findGetHeadersMethod(MethodType.methodType(Object.class, Object.class, Object.class));
GET_HEADERS =
isSpring7OrNewer()
? findGetHeadersMethod(MethodType.methodType(List.class, String.class))
: findGetHeadersMethod(MethodType.methodType(List.class, Object.class));
}

private static boolean isSpring7OrNewer() {
try {
Class.forName("org.springframework.core.Nullness");
return true;
} catch (ClassNotFoundException e) {
return false;
}
GET_HEADERS = methodHandle;
}

private static MethodHandle findGetHeadersMethod(MethodType methodType) {
try {
return MethodHandles.lookup().findVirtual(HttpHeaders.class, "getOrDefault", methodType);
return MethodHandles.lookup().findVirtual(HttpHeaders.class, "get", methodType);
} catch (Throwable t) {
return null;
throw new IllegalStateException(t);
}
}

@SuppressWarnings("unchecked") // casting MethodHandle.invoke result
static List<String> getHeader(HttpHeaders headers, String name) {
if (GET_HEADERS != null) {
try {
return (List<String>) GET_HEADERS.invoke(headers, name, emptyList());
} catch (Throwable t) {
// ignore
try {
List<String> result = (List<String>) GET_HEADERS.invoke(headers, name);
if (result == null) {
return emptyList();
}
return result;
} catch (Throwable t) {
throw new IllegalStateException(t);
}
return emptyList();
}

private HeaderUtil() {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,26 @@ public enum WebClientHttpAttributesGetter
private static final MethodHandle GET_HEADERS;

static {
// since webflux 7.0
MethodHandle methodHandle =
findGetHeadersMethod(MethodType.methodType(List.class, String.class, List.class));
if (methodHandle == null) {
// up to webflux 7.0
methodHandle =
findGetHeadersMethod(MethodType.methodType(Object.class, Object.class, Object.class));
GET_HEADERS =
isSpring7OrNewer()
? findGetHeadersMethod(MethodType.methodType(List.class, String.class))
: findGetHeadersMethod(MethodType.methodType(List.class, Object.class));
}

private static boolean isSpring7OrNewer() {
try {
Class.forName("org.springframework.core.Nullness");
return true;
} catch (ClassNotFoundException e) {
return false;
}
GET_HEADERS = methodHandle;
}

private static MethodHandle findGetHeadersMethod(MethodType methodType) {
try {
return MethodHandles.lookup().findVirtual(HttpHeaders.class, "getOrDefault", methodType);
return MethodHandles.lookup().findVirtual(HttpHeaders.class, "get", methodType);
} catch (Throwable t) {
return null;
throw new IllegalStateException(t);
}
}

Expand All @@ -60,14 +64,15 @@ public String getHttpRequestMethod(ClientRequest request) {
@Override
@SuppressWarnings("unchecked") // casting MethodHandle.invoke result
public List<String> getHttpRequestHeader(ClientRequest request, String name) {
if (GET_HEADERS != null) {
try {
return (List<String>) GET_HEADERS.invoke(request.headers(), name, emptyList());
} catch (Throwable t) {
// ignore
try {
List<String> result = (List<String>) GET_HEADERS.invoke(request.headers(), name);
if (result == null) {
return emptyList();
}
return result;
} catch (Throwable t) {
throw new IllegalStateException(t);
}
return emptyList();
}

@Override
Expand Down
Loading