Skip to content

Commit 628c980

Browse files
committed
Adjust DummyStatsDServers
1 parent ac1935a commit 628c980

File tree

4 files changed

+79
-32
lines changed

4 files changed

+79
-32
lines changed

src/main/java/com/timgroup/statsd/NonBlockingStatsDClientBuilder.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,15 @@ protected static Callable<SocketAddress> staticUnixResolution(
376376
final UnixSocketAddressWithTransport.TransportType transportType) {
377377
return new Callable<SocketAddress>() {
378378
@Override public SocketAddress call() {
379+
if (ClientChannelUtils.hasNativeUdsSupport()) {
380+
try {
381+
Class<?> udsAddressClass = Class.forName("java.net.UnixDomainSocketAddress");
382+
Object udsAddress = udsAddressClass.getMethod("of", String.class).invoke(null, path);
383+
return new UnixSocketAddressWithTransport((SocketAddress) udsAddress, transportType);
384+
} catch (ReflectiveOperationException e) {
385+
// Fall back to JNR implementation
386+
}
387+
}
379388
final UnixSocketAddress socketAddress = new UnixSocketAddress(path);
380389
return new UnixSocketAddressWithTransport(socketAddress, transportType);
381390
}

src/main/java/com/timgroup/statsd/UnixStreamClientChannel.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,7 @@ private void connectJdkSocket(long deadline) throws IOException {
120120

121121
try {
122122
delegate.configureBlocking(false);
123-
// Use reflection to call the UDS specific connect method
124-
if (!delegate.connect(udsAddress)) {
123+
if (!delegate.connect((SocketAddress) udsAddress)) {
125124
if (connectionTimeout > 0 && System.nanoTime() > deadline) {
126125
throw new IOException("Connection timed out");
127126
}

src/test/java/com/timgroup/statsd/UnixDatagramSocketDummyStatsDServer.java

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import java.io.IOException;
44
import java.nio.ByteBuffer;
55
import java.nio.channels.DatagramChannel;
6+
import java.net.SocketAddress;
67
import jnr.unixsocket.UnixDatagramChannel;
78
import jnr.unixsocket.UnixSocketAddress;
89

@@ -13,8 +14,22 @@ public class UnixDatagramSocketDummyStatsDServer extends DummyStatsDServer {
1314
private volatile boolean running = true;
1415

1516
public UnixDatagramSocketDummyStatsDServer(String socketPath) throws IOException {
16-
server = UnixDatagramChannel.open();
17-
server.bind(new UnixSocketAddress(socketPath));
17+
if (ClientChannelUtils.hasNativeUdsSupport()) {
18+
try {
19+
Class<?> udsAddressClass = Class.forName("java.net.UnixDomainSocketAddress");
20+
Object udsAddress = udsAddressClass.getMethod("of", String.class).invoke(null, socketPath);
21+
22+
DatagramChannel nativeServer = DatagramChannel.open();
23+
nativeServer.bind((SocketAddress) udsAddress);
24+
this.server = nativeServer;
25+
} catch (ReflectiveOperationException e) {
26+
throw new IOException(e);
27+
}
28+
} else {
29+
UnixDatagramChannel jnrServer = UnixDatagramChannel.open();
30+
jnrServer.bind(new UnixSocketAddress(socketPath));
31+
this.server = jnrServer;
32+
}
1833
this.listen();
1934
}
2035

src/test/java/com/timgroup/statsd/UnixStreamSocketDummyStatsDServer.java

Lines changed: 52 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import java.nio.Buffer;
55
import java.nio.ByteBuffer;
66
import java.nio.ByteOrder;
7+
import java.nio.channels.ServerSocketChannel;
78
import java.nio.channels.SocketChannel;
9+
import java.net.SocketAddress;
810
import java.util.concurrent.ConcurrentLinkedQueue;
911
import java.util.logging.Logger;
1012
import jnr.unixsocket.UnixServerSocketChannel;
@@ -14,21 +16,37 @@
1416
import static com.timgroup.statsd.NonBlockingStatsDClient.DEFAULT_UDS_MAX_PACKET_SIZE_BYTES;
1517

1618
public class UnixStreamSocketDummyStatsDServer extends DummyStatsDServer {
17-
private final UnixServerSocketChannel server;
18-
private final ConcurrentLinkedQueue<UnixSocketChannel> channels = new ConcurrentLinkedQueue<>();
19+
private final Object server; // Object is either ServerSocketChannel or UnixServerSocketChannel
20+
private final ConcurrentLinkedQueue<SocketChannel> channels = new ConcurrentLinkedQueue<>();
21+
private final boolean useNativeUds;
1922

2023
private final Logger logger = Logger.getLogger(UnixStreamSocketDummyStatsDServer.class.getName());
2124

2225
public UnixStreamSocketDummyStatsDServer(String socketPath) throws IOException {
23-
server = UnixServerSocketChannel.open();
24-
server.configureBlocking(true);
25-
server.socket().bind(new UnixSocketAddress(socketPath));
26+
this.useNativeUds = ClientChannelUtils.hasNativeUdsSupport();
27+
if (useNativeUds) {
28+
try {
29+
Class<?> udsAddressClass = Class.forName("java.net.UnixDomainSocketAddress");
30+
Object udsAddress = udsAddressClass.getMethod("of", String.class).invoke(null, socketPath);
31+
32+
ServerSocketChannel nativeServer = ServerSocketChannel.open();
33+
nativeServer.bind((SocketAddress) udsAddress);
34+
this.server = nativeServer;
35+
} catch (ReflectiveOperationException e) {
36+
throw new IOException(e);
37+
}
38+
} else {
39+
UnixServerSocketChannel jnrServer = UnixServerSocketChannel.open();
40+
jnrServer.configureBlocking(true);
41+
jnrServer.socket().bind(new UnixSocketAddress(socketPath));
42+
this.server = jnrServer;
43+
}
2644
this.listen();
2745
}
2846

2947
@Override
3048
protected boolean isOpen() {
31-
return server.isOpen();
49+
return useNativeUds ? ((ServerSocketChannel)server).isOpen() : ((UnixServerSocketChannel)server).isOpen();
3250
}
3351

3452
@Override
@@ -38,39 +56,43 @@ protected void receive(ByteBuffer packet) throws IOException {
3856

3957
@Override
4058
protected void listen() {
41-
logger.info("Listening on " + server.getLocalSocketAddress());
59+
try {
60+
String localAddressMessage = useNativeUds ? "Listening on " + ((ServerSocketChannel)server).getLocalAddress() : "Listening on " + ((UnixServerSocketChannel)server).getLocalSocketAddress();
61+
logger.info(localAddressMessage);
62+
} catch (Exception e) {
63+
logger.warning("Failed to get local address: " + e);
64+
}
4265
Thread thread = new Thread(new Runnable() {
4366
@Override
4467
public void run() {
4568
while(isOpen()) {
4669
if (sleepIfFrozen()) {
4770
continue;
4871
}
49-
try {
50-
logger.info("Waiting for connection");
51-
UnixSocketChannel clientChannel = server.accept();
52-
if (clientChannel != null) {
53-
clientChannel.configureBlocking(true);
54-
try {
55-
logger.info("Accepted connection from " + clientChannel.getRemoteSocketAddress());
56-
} catch (Exception e) {
57-
logger.warning("Failed to get remote socket address");
58-
}
59-
channels.add(clientChannel);
60-
readChannel(clientChannel);
61-
}
62-
} catch (IOException e) {
72+
try {
73+
logger.info("Waiting for connection");
74+
SocketChannel clientChannel = null;
75+
clientChannel = useNativeUds ? ((ServerSocketChannel)server).accept() : ((UnixServerSocketChannel)server).accept();
76+
if (clientChannel != null) {
77+
clientChannel.configureBlocking(true);
78+
String connectionMessage = useNativeUds ? "Accepted connection from " + clientChannel.getRemoteAddress() : "Accepted connection from " + ((UnixSocketChannel)clientChannel).getRemoteSocketAddress();
79+
logger.info(connectionMessage);
80+
channels.add(clientChannel);
81+
readChannel(clientChannel);
6382
}
83+
} catch (Exception e) {
84+
// ignore
85+
}
6486
}
6587
}
6688
});
6789
thread.setDaemon(true);
6890
thread.start();
6991
}
7092

71-
public void readChannel(final UnixSocketChannel clientChannel) {
72-
logger.info("Reading from " + clientChannel);
73-
Thread thread = new Thread(new Runnable() {
93+
public void readChannel(final SocketChannel clientChannel) {
94+
logger.info("Reading from " + clientChannel);
95+
Thread thread = new Thread(new Runnable() {
7496
@Override
7597
public void run() {
7698
final ByteBuffer packet = ByteBuffer.allocate(DEFAULT_UDS_MAX_PACKET_SIZE_BYTES);
@@ -90,7 +112,6 @@ public void run() {
90112
logger.warning("Failed to close channel: " + e);
91113
}
92114
}
93-
94115
}
95116
logger.info("Disconnected from " + clientChannel);
96117
}
@@ -128,13 +149,16 @@ private boolean readPacket(SocketChannel channel, ByteBuffer packet) {
128149

129150
public void close() throws IOException {
130151
try {
131-
server.close();
132-
for (UnixSocketChannel channel : channels) {
152+
if (useNativeUds) {
153+
((ServerSocketChannel)server).close();
154+
} else {
155+
((UnixServerSocketChannel)server).close();
156+
}
157+
for (SocketChannel channel : channels) {
133158
channel.close();
134159
}
135160
} catch (Exception e) {
136161
//ignore
137162
}
138163
}
139-
140164
}

0 commit comments

Comments
 (0)