@@ -28,8 +28,17 @@ class WebTransportStreamVisitorAdapter
28
28
: visitor_(visitor) {}
29
29
void OnCanRead () override { visitor_->OnCanRead (); }
30
30
void OnCanWrite () override { visitor_->OnCanWrite (); }
31
- void OnResetStreamReceived (::quic::WebTransportStreamError error) override {}
32
- void OnStopSendingReceived (::quic::WebTransportStreamError error) override {}
31
+ void OnResetStreamReceived (::quic::WebTransportStreamError error) override {
32
+ LOG (INFO)<<" OnResetStream received." ;
33
+ if (visitor_) {
34
+ visitor_->OnResetStreamReceived (error);
35
+ }
36
+ }
37
+ void OnStopSendingReceived (::quic::WebTransportStreamError error) override {
38
+ if (visitor_) {
39
+ visitor_->OnStopSendingReceived (error);
40
+ }
41
+ }
33
42
void OnWriteSideInDataRecvdState () override {}
34
43
35
44
private:
@@ -45,7 +54,8 @@ WebTransportStreamImpl::WebTransportStreamImpl(
45
54
quic_stream_ (quic_stream),
46
55
io_runner_(io_runner),
47
56
event_runner_(event_runner),
48
- visitor_(nullptr ) {
57
+ visitor_(nullptr ),
58
+ write_side_closed_(false ) {
49
59
CHECK (stream_);
50
60
CHECK (quic_stream_);
51
61
CHECK (io_runner_);
@@ -63,17 +73,24 @@ size_t WebTransportStreamImpl::Write(const uint8_t* data, size_t length) {
63
73
DCHECK_EQ (sizeof (uint8_t ), sizeof (char ));
64
74
CHECK (io_runner_);
65
75
if (io_runner_->BelongsToCurrentThread ()) {
76
+ if (write_side_closed_) {
77
+ return 0 ;
78
+ }
66
79
return stream_->Write (
67
80
absl::string_view (reinterpret_cast <const char *>(data), length));
68
81
}
69
- bool result;
82
+ bool result = false ;
70
83
base::WaitableEvent done (base::WaitableEvent::ResetPolicy::AUTOMATIC,
71
84
base::WaitableEvent::InitialState::NOT_SIGNALED);
72
85
io_runner_->PostTask (
73
86
FROM_HERE,
74
87
base::BindOnce (
75
- [](WebTransportStreamImpl* stream, const uint8_t * data, size_t & length,
76
- bool & result, base::WaitableEvent* event) {
88
+ [](base::WeakPtr<WebTransportStreamImpl> stream, const uint8_t * data,
89
+ size_t & length, bool & result, base::WaitableEvent* event) {
90
+ if (!stream || stream->write_side_closed_ ) {
91
+ event->Signal ();
92
+ return ;
93
+ }
77
94
if (stream->stream_ ->CanWrite ()) {
78
95
result = stream->stream_ ->Write (absl::string_view (
79
96
reinterpret_cast <const char *>(data), length));
@@ -82,7 +99,7 @@ size_t WebTransportStreamImpl::Write(const uint8_t* data, size_t length) {
82
99
}
83
100
event->Signal ();
84
101
},
85
- base::Unretained ( this ), base::Unretained (data), std::ref (length),
102
+ weak_factory_. GetWeakPtr ( ), base::Unretained (data), std::ref (length),
86
103
std::ref (result), base::Unretained (&done)));
87
104
done.Wait ();
88
105
return result ? length : 0 ;
@@ -100,21 +117,25 @@ size_t WebTransportStreamImpl::Read(uint8_t* data, size_t length) {
100
117
// TODO: FIN is not handled.
101
118
return read_result.bytes_read ;
102
119
}
103
- size_t result;
120
+ size_t result = 0 ;
104
121
base::WaitableEvent done (base::WaitableEvent::ResetPolicy::AUTOMATIC,
105
122
base::WaitableEvent::InitialState::NOT_SIGNALED);
106
123
io_runner_->PostTask (
107
124
FROM_HERE,
108
125
base::BindOnce (
109
- [](WebTransportStreamImpl* stream, uint8_t * data, size_t & length,
110
- size_t & result, base::WaitableEvent* event) {
126
+ [](base::WeakPtr<WebTransportStreamImpl> stream, uint8_t * data,
127
+ size_t & length, size_t & result, base::WaitableEvent* event) {
128
+ if (!stream) {
129
+ event->Signal ();
130
+ return ;
131
+ }
111
132
auto read_result =
112
133
stream->stream_ ->Read (reinterpret_cast <char *>(data), length);
113
134
// TODO: FIN is not handled.
114
135
result = read_result.bytes_read ;
115
136
event->Signal ();
116
137
},
117
- base::Unretained ( this ), base::Unretained (data), std::ref (length),
138
+ weak_factory_. GetWeakPtr ( ), base::Unretained (data), std::ref (length),
118
139
std::ref (result), base::Unretained (&done)));
119
140
done.Wait ();
120
141
return result;
@@ -124,18 +145,22 @@ size_t WebTransportStreamImpl::ReadableBytes() const {
124
145
if (io_runner_->BelongsToCurrentThread ()) {
125
146
return stream_->ReadableBytes ();
126
147
}
127
- size_t result;
148
+ size_t result = 0 ;
128
149
base::WaitableEvent done (base::WaitableEvent::ResetPolicy::AUTOMATIC,
129
150
base::WaitableEvent::InitialState::NOT_SIGNALED);
130
- io_runner_->PostTask (
131
- FROM_HERE,
132
- base::BindOnce (
133
- [](WebTransportStreamImpl const * stream, size_t & result,
134
- base::WaitableEvent* event) {
135
- result = stream->stream_ ->ReadableBytes ();
136
- event->Signal ();
137
- },
138
- base::Unretained (this ), std::ref (result), base::Unretained (&done)));
151
+ io_runner_->PostTask (FROM_HERE,
152
+ base::BindOnce (
153
+ [](base::WeakPtr<WebTransportStreamImpl> stream,
154
+ size_t & result, base::WaitableEvent* event) {
155
+ if (!stream) {
156
+ event->Signal ();
157
+ return ;
158
+ }
159
+ result = stream->stream_ ->ReadableBytes ();
160
+ event->Signal ();
161
+ },
162
+ weak_factory_.GetWeakPtr (), std::ref (result),
163
+ base::Unretained (&done)));
139
164
done.Wait ();
140
165
return result;
141
166
}
@@ -150,34 +175,42 @@ void WebTransportStreamImpl::Close() {
150
175
base::WaitableEvent done (base::WaitableEvent::ResetPolicy::AUTOMATIC,
151
176
base::WaitableEvent::InitialState::NOT_SIGNALED);
152
177
io_runner_->PostTask (
153
- FROM_HERE,
154
- base::BindOnce (
155
- [](WebTransportStreamImpl* stream, base::WaitableEvent* event) {
156
- if (!stream->stream_ ->SendFin ()) {
157
- LOG (ERROR) << " Failed to send FIN." ;
158
- }
159
- event->Signal ();
160
- },
161
- base::Unretained (this ), base::Unretained (&done)));
178
+ FROM_HERE, base::BindOnce (
179
+ [](base::WeakPtr<WebTransportStreamImpl> stream,
180
+ base::WaitableEvent* event) {
181
+ if (!stream) {
182
+ event->Signal ();
183
+ return ;
184
+ }
185
+ if (!stream->stream_ ->SendFin ()) {
186
+ LOG (ERROR) << " Failed to send FIN." ;
187
+ }
188
+ event->Signal ();
189
+ },
190
+ weak_factory_.GetWeakPtr (), base::Unretained (&done)));
162
191
done.Wait ();
163
192
}
164
193
165
194
uint64_t WebTransportStreamImpl::BufferedDataBytes () const {
166
195
if (io_runner_->BelongsToCurrentThread ()) {
167
196
return quic_stream_->BufferedDataBytes ();
168
197
}
169
- uint64_t result;
198
+ uint64_t result = 0 ;
170
199
base::WaitableEvent done (base::WaitableEvent::ResetPolicy::AUTOMATIC,
171
200
base::WaitableEvent::InitialState::NOT_SIGNALED);
172
- io_runner_->PostTask (
173
- FROM_HERE,
174
- base::BindOnce (
175
- [](WebTransportStreamImpl const * stream, uint64_t & result,
176
- base::WaitableEvent* event) {
177
- result = stream->quic_stream_ ->BufferedDataBytes ();
178
- event->Signal ();
179
- },
180
- base::Unretained (this ), std::ref (result), base::Unretained (&done)));
201
+ io_runner_->PostTask (FROM_HERE,
202
+ base::BindOnce (
203
+ [](base::WeakPtr<WebTransportStreamImpl> stream,
204
+ uint64_t & result, base::WaitableEvent* event) {
205
+ if (!stream) {
206
+ event->Signal ();
207
+ return ;
208
+ }
209
+ result = stream->quic_stream_ ->BufferedDataBytes ();
210
+ event->Signal ();
211
+ },
212
+ weak_factory_.GetWeakPtr (), std::ref (result),
213
+ base::Unretained (&done)));
181
214
done.Wait ();
182
215
return result;
183
216
}
@@ -186,47 +219,49 @@ bool WebTransportStreamImpl::CanWrite() const {
186
219
if (io_runner_->BelongsToCurrentThread ()) {
187
220
return stream_->CanWrite ();
188
221
}
189
- bool result;
222
+ bool result = false ;
190
223
base::WaitableEvent done (base::WaitableEvent::ResetPolicy::AUTOMATIC,
191
224
base::WaitableEvent::InitialState::NOT_SIGNALED);
192
- io_runner_->PostTask (
193
- FROM_HERE,
194
- base::BindOnce (
195
- [](WebTransportStreamImpl const * stream, bool & result,
196
- base::WaitableEvent* event) {
197
- result = stream->stream_ ->CanWrite ();
198
- event->Signal ();
199
- },
200
- base::Unretained (this ), std::ref (result), base::Unretained (&done)));
225
+ io_runner_->PostTask (FROM_HERE,
226
+ base::BindOnce (
227
+ [](base::WeakPtr<WebTransportStreamImpl> stream,
228
+ bool & result, base::WaitableEvent* event) {
229
+ if (!stream) {
230
+ event->Signal ();
231
+ return ;
232
+ }
233
+ result = stream->stream_ ->CanWrite ();
234
+ event->Signal ();
235
+ },
236
+ weak_factory_.GetWeakPtr (), std::ref (result),
237
+ base::Unretained (&done)));
201
238
done.Wait ();
202
239
return result;
203
240
}
204
241
205
242
void WebTransportStreamImpl::OnCanRead () {
206
- event_runner_->PostTask (
207
- FROM_HERE,
208
- base::BindOnce (&WebTransportStreamImpl::OnCanReadOnCurrentThread,
209
- weak_factory_.GetWeakPtr ()));
210
- }
211
-
212
- void WebTransportStreamImpl::OnCanWrite () {
213
- event_runner_->PostTask (
214
- FROM_HERE,
215
- base::BindOnce (&WebTransportStreamImpl::OnCanWriteOnCurrentThread,
216
- weak_factory_.GetWeakPtr ()));
217
- }
218
-
219
- void WebTransportStreamImpl::OnCanReadOnCurrentThread () {
220
243
if (visitor_) {
221
244
visitor_->OnCanRead ();
222
245
}
223
246
}
224
247
225
- void WebTransportStreamImpl::OnCanWriteOnCurrentThread () {
248
+ void WebTransportStreamImpl::OnCanWrite () {
226
249
if (visitor_) {
227
250
visitor_->OnCanWrite ();
228
251
}
229
252
}
230
253
254
+ void WebTransportStreamImpl::OnResetStreamReceived (
255
+ ::quic::WebTransportStreamError error) {
256
+ write_side_closed_ = true ;
257
+ }
258
+
259
+ void WebTransportStreamImpl::OnStopSendingReceived (
260
+ ::quic::WebTransportStreamError error) {}
261
+
262
+ void WebTransportStreamImpl::OnSessionClosed () {
263
+ write_side_closed_ = true ;
264
+ }
265
+
231
266
} // namespace quic
232
267
} // namespace owt
0 commit comments