1
+ /*
2
+ * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4
+ *
5
+ * This code is free software; you can redistribute it and/or modify it
6
+ * under the terms of the GNU General Public License version 2 only, as
7
+ * published by the Free Software Foundation.
8
+ *
9
+ * This code is distributed in the hope that it will be useful, but WITHOUT
10
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12
+ * version 2 for more details (a copy is included in the LICENSE file that
13
+ * accompanied this code).
14
+ *
15
+ * You should have received a copy of the GNU General Public License version
16
+ * 2 along with this work; if not, write to the Free Software Foundation,
17
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18
+ *
19
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20
+ * or visit www.oracle.com if you need additional information or have any
21
+ * questions.
22
+ */
23
+
24
+ /*
25
+ * @test
26
+ * @bug 8358764
27
+ * @summary Test closing a socket while a thread is blocked in read. The connection
28
+ * should be closed gracefuly so that the peer reads EOF.
29
+ * @run junit PeerReadsAfterAsyncClose
30
+ */
31
+
32
+ import java .io .IOException ;
33
+ import java .net .InetAddress ;
34
+ import java .net .InetSocketAddress ;
35
+ import java .net .ServerSocket ;
36
+ import java .net .Socket ;
37
+ import java .net .SocketException ;
38
+ import java .nio .ByteBuffer ;
39
+ import java .nio .channels .ClosedChannelException ;
40
+ import java .nio .channels .SocketChannel ;
41
+ import java .util .Arrays ;
42
+ import java .util .Objects ;
43
+ import java .util .concurrent .ThreadFactory ;
44
+ import java .util .concurrent .atomic .AtomicBoolean ;
45
+ import java .util .stream .Stream ;
46
+
47
+ import org .junit .jupiter .params .ParameterizedTest ;
48
+ import org .junit .jupiter .params .provider .MethodSource ;
49
+ import static org .junit .jupiter .api .Assertions .*;
50
+
51
+ class PeerReadsAfterAsyncClose {
52
+
53
+ static Stream <ThreadFactory > factories () {
54
+ return Stream .of (Thread .ofPlatform ().factory (), Thread .ofVirtual ().factory ());
55
+ }
56
+
57
+ /**
58
+ * Close SocketChannel while a thread is blocked reading from the channel's socket.
59
+ */
60
+ @ ParameterizedTest
61
+ @ MethodSource ("factories" )
62
+ void testCloseDuringSocketChannelRead (ThreadFactory factory ) throws Exception {
63
+ var loopback = InetAddress .getLoopbackAddress ();
64
+ try (var listener = new ServerSocket ()) {
65
+ listener .bind (new InetSocketAddress (loopback , 0 ));
66
+
67
+ try (SocketChannel sc = SocketChannel .open (listener .getLocalSocketAddress ());
68
+ Socket peer = listener .accept ()) {
69
+
70
+ // start thread to read from channel
71
+ var cceThrown = new AtomicBoolean ();
72
+ Thread thread = factory .newThread (() -> {
73
+ try {
74
+ sc .read (ByteBuffer .allocate (1 ));
75
+ fail ();
76
+ } catch (ClosedChannelException e ) {
77
+ cceThrown .set (true );
78
+ } catch (Throwable e ) {
79
+ e .printStackTrace ();
80
+ }
81
+ });
82
+ thread .start ();
83
+ try {
84
+ // close SocketChannel when thread sampled in implRead
85
+ onReach (thread , "sun.nio.ch.SocketChannelImpl.implRead" , () -> {
86
+ try {
87
+ sc .close ();
88
+ } catch (IOException ignore ) { }
89
+ });
90
+
91
+ // peer should read EOF
92
+ int n = peer .getInputStream ().read ();
93
+ assertEquals (-1 , n );
94
+ } finally {
95
+ thread .join ();
96
+ }
97
+ assertEquals (true , cceThrown .get (), "ClosedChannelException not thrown" );
98
+ }
99
+ }
100
+ }
101
+
102
+ /**
103
+ * Close Socket while a thread is blocked reading from the socket.
104
+ */
105
+ @ ParameterizedTest
106
+ @ MethodSource ("factories" )
107
+ void testCloseDuringSocketUntimedRead (ThreadFactory factory ) throws Exception {
108
+ testCloseDuringSocketRead (factory , 0 );
109
+ }
110
+
111
+ /**
112
+ * Close Socket while a thread is blocked reading from the socket with a timeout.
113
+ */
114
+ @ ParameterizedTest
115
+ @ MethodSource ("factories" )
116
+ void testCloseDuringSockeTimedRead (ThreadFactory factory ) throws Exception {
117
+ testCloseDuringSocketRead (factory , 60_000 );
118
+ }
119
+
120
+ private void testCloseDuringSocketRead (ThreadFactory factory , int timeout ) throws Exception {
121
+ var loopback = InetAddress .getLoopbackAddress ();
122
+ try (var listener = new ServerSocket ()) {
123
+ listener .bind (new InetSocketAddress (loopback , 0 ));
124
+
125
+ try (Socket s = new Socket (loopback , listener .getLocalPort ());
126
+ Socket peer = listener .accept ()) {
127
+
128
+ // start thread to read from socket
129
+ var seThrown = new AtomicBoolean ();
130
+ Thread thread = factory .newThread (() -> {
131
+ try {
132
+ s .setSoTimeout (timeout );
133
+ s .getInputStream ().read ();
134
+ fail ();
135
+ } catch (SocketException e ) {
136
+ seThrown .set (true );
137
+ } catch (Throwable e ) {
138
+ e .printStackTrace ();
139
+ }
140
+ });
141
+ thread .start ();
142
+ try {
143
+ // close Socket when thread sampled in implRead
144
+ onReach (thread , "sun.nio.ch.NioSocketImpl.implRead" , () -> {
145
+ try {
146
+ s .close ();
147
+ } catch (IOException ignore ) { }
148
+ });
149
+
150
+ // peer should read EOF
151
+ int n = peer .getInputStream ().read ();
152
+ assertEquals (-1 , n );
153
+ } finally {
154
+ thread .join ();
155
+ }
156
+ assertEquals (true , seThrown .get (), "SocketException not thrown" );
157
+ }
158
+ }
159
+ }
160
+
161
+ /**
162
+ * Runs the given action when the given target thread is sampled at the given
163
+ * location. The location takes the form "{@code c.m}" where
164
+ * {@code c} is the fully qualified class name and {@code m} is the method name.
165
+ */
166
+ private void onReach (Thread target , String location , Runnable action ) {
167
+ int index = location .lastIndexOf ('.' );
168
+ String className = location .substring (0 , index );
169
+ String methodName = location .substring (index + 1 );
170
+ Thread .ofPlatform ().daemon (true ).start (() -> {
171
+ try {
172
+ boolean found = false ;
173
+ while (!found ) {
174
+ found = contains (target .getStackTrace (), className , methodName );
175
+ if (!found ) {
176
+ Thread .sleep (20 );
177
+ }
178
+ }
179
+ action .run ();
180
+ } catch (Exception e ) {
181
+ e .printStackTrace ();
182
+ }
183
+ });
184
+ }
185
+
186
+ /**
187
+ * Returns true if the given stack trace contains an element for the given class
188
+ * and method name.
189
+ */
190
+ private boolean contains (StackTraceElement [] stack , String className , String methodName ) {
191
+ return Arrays .stream (stack )
192
+ .anyMatch (e -> className .equals (e .getClassName ())
193
+ && methodName .equals (e .getMethodName ()));
194
+ }
195
+ }
0 commit comments