44import java .nio .Buffer ;
55import java .nio .ByteBuffer ;
66import java .nio .ByteOrder ;
7+ import java .nio .channels .ServerSocketChannel ;
78import java .nio .channels .SocketChannel ;
9+ import java .net .SocketAddress ;
810import java .util .concurrent .ConcurrentLinkedQueue ;
911import java .util .logging .Logger ;
1012import jnr .unixsocket .UnixServerSocketChannel ;
1416import static com .timgroup .statsd .NonBlockingStatsDClient .DEFAULT_UDS_MAX_PACKET_SIZE_BYTES ;
1517
1618public class UnixStreamSocketDummyStatsDServer extends DummyStatsDServer {
17- private final UnixServerSocketChannel server ;
18- private final ConcurrentLinkedQueue <UnixSocketChannel > channels = new ConcurrentLinkedQueue <>();
19+ private final Object server ; // Object is either ServerSocketChannel or UnixServerSocketChannel
20+ private final ConcurrentLinkedQueue <SocketChannel > channels = new ConcurrentLinkedQueue <>();
21+ private final boolean useNativeUds ;
1922
2023 private final Logger logger = Logger .getLogger (UnixStreamSocketDummyStatsDServer .class .getName ());
2124
2225 public UnixStreamSocketDummyStatsDServer (String socketPath ) throws IOException {
23- server = UnixServerSocketChannel .open ();
24- server .configureBlocking (true );
25- server .socket ().bind (new UnixSocketAddress (socketPath ));
26+ this .useNativeUds = ClientChannelUtils .hasNativeUdsSupport ();
27+ if (useNativeUds ) {
28+ try {
29+ Class <?> udsAddressClass = Class .forName ("java.net.UnixDomainSocketAddress" );
30+ Object udsAddress = udsAddressClass .getMethod ("of" , String .class ).invoke (null , socketPath );
31+
32+ ServerSocketChannel nativeServer = ServerSocketChannel .open ();
33+ nativeServer .bind ((SocketAddress ) udsAddress );
34+ this .server = nativeServer ;
35+ } catch (ReflectiveOperationException e ) {
36+ throw new IOException (e );
37+ }
38+ } else {
39+ UnixServerSocketChannel jnrServer = UnixServerSocketChannel .open ();
40+ jnrServer .configureBlocking (true );
41+ jnrServer .socket ().bind (new UnixSocketAddress (socketPath ));
42+ this .server = jnrServer ;
43+ }
2644 this .listen ();
2745 }
2846
2947 @ Override
3048 protected boolean isOpen () {
31- return server .isOpen ();
49+ return useNativeUds ? (( ServerSocketChannel ) server ). isOpen () : (( UnixServerSocketChannel ) server ) .isOpen ();
3250 }
3351
3452 @ Override
@@ -38,39 +56,43 @@ protected void receive(ByteBuffer packet) throws IOException {
3856
3957 @ Override
4058 protected void listen () {
41- logger .info ("Listening on " + server .getLocalSocketAddress ());
59+ try {
60+ String localAddressMessage = useNativeUds ? "Listening on " + ((ServerSocketChannel )server ).getLocalAddress () : "Listening on " + ((UnixServerSocketChannel )server ).getLocalSocketAddress ();
61+ logger .info (localAddressMessage );
62+ } catch (Exception e ) {
63+ logger .warning ("Failed to get local address: " + e );
64+ }
4265 Thread thread = new Thread (new Runnable () {
4366 @ Override
4467 public void run () {
4568 while (isOpen ()) {
4669 if (sleepIfFrozen ()) {
4770 continue ;
4871 }
49- try {
50- logger .info ("Waiting for connection" );
51- UnixSocketChannel clientChannel = server .accept ();
52- if (clientChannel != null ) {
53- clientChannel .configureBlocking (true );
54- try {
55- logger .info ("Accepted connection from " + clientChannel .getRemoteSocketAddress ());
56- } catch (Exception e ) {
57- logger .warning ("Failed to get remote socket address" );
58- }
59- channels .add (clientChannel );
60- readChannel (clientChannel );
61- }
62- } catch (IOException e ) {
72+ try {
73+ logger .info ("Waiting for connection" );
74+ SocketChannel clientChannel = null ;
75+ clientChannel = useNativeUds ? ((ServerSocketChannel )server ).accept () : ((UnixServerSocketChannel )server ).accept ();
76+ if (clientChannel != null ) {
77+ clientChannel .configureBlocking (true );
78+ String connectionMessage = useNativeUds ? "Accepted connection from " + clientChannel .getRemoteAddress () : "Accepted connection from " + ((UnixSocketChannel )clientChannel ).getRemoteSocketAddress ();
79+ logger .info (connectionMessage );
80+ channels .add (clientChannel );
81+ readChannel (clientChannel );
6382 }
83+ } catch (Exception e ) {
84+ // ignore
85+ }
6486 }
6587 }
6688 });
6789 thread .setDaemon (true );
6890 thread .start ();
6991 }
7092
71- public void readChannel (final UnixSocketChannel clientChannel ) {
72- logger .info ("Reading from " + clientChannel );
73- Thread thread = new Thread (new Runnable () {
93+ public void readChannel (final SocketChannel clientChannel ) {
94+ logger .info ("Reading from " + clientChannel );
95+ Thread thread = new Thread (new Runnable () {
7496 @ Override
7597 public void run () {
7698 final ByteBuffer packet = ByteBuffer .allocate (DEFAULT_UDS_MAX_PACKET_SIZE_BYTES );
@@ -90,7 +112,6 @@ public void run() {
90112 logger .warning ("Failed to close channel: " + e );
91113 }
92114 }
93-
94115 }
95116 logger .info ("Disconnected from " + clientChannel );
96117 }
@@ -128,13 +149,16 @@ private boolean readPacket(SocketChannel channel, ByteBuffer packet) {
128149
129150 public void close () throws IOException {
130151 try {
131- server .close ();
132- for (UnixSocketChannel channel : channels ) {
152+ if (useNativeUds ) {
153+ ((ServerSocketChannel )server ).close ();
154+ } else {
155+ ((UnixServerSocketChannel )server ).close ();
156+ }
157+ for (SocketChannel channel : channels ) {
133158 channel .close ();
134159 }
135160 } catch (Exception e ) {
136161 //ignore
137162 }
138163 }
139-
140164}
0 commit comments