Skip to content

Commit 83a26aa

Browse files
authored
feat(storage): Make sure all futures are spawned on global runtime (#10918)
* feat(storage): Make sure all futures are created on global runtime Signed-off-by: Xuanwo <[email protected]> * Allow dirty buffer Signed-off-by: Xuanwo <[email protected]> * format code Signed-off-by: Xuanwo <[email protected]> * Make sure runtime layer applied first Signed-off-by: Xuanwo <[email protected]> * Fix bug Signed-off-by: Xuanwo <[email protected]> * Use copy instead Signed-off-by: Xuanwo <[email protected]> --------- Signed-off-by: Xuanwo <[email protected]>
1 parent 0266e03 commit 83a26aa

File tree

2 files changed

+184
-8
lines changed

2 files changed

+184
-8
lines changed

src/common/storage/src/operator.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ pub fn build_operator<B: Builder>(builder: B) -> Result<Operator> {
8585
let ob = Operator::new(builder)?;
8686

8787
let op = ob
88+
// NOTE
89+
//
90+
// Magic happens here. We will add a layer upon original
91+
// storage operator so that all underlying storage operations
92+
// will send to storage runtime.
93+
.layer(RuntimeLayer::new(GlobalIORuntime::instance().inner()))
8894
// Add retry
8995
.layer(RetryLayer::new().with_jitter())
9096
// Add metrics
@@ -93,12 +99,6 @@ pub fn build_operator<B: Builder>(builder: B) -> Result<Operator> {
9399
.layer(LoggingLayer::default())
94100
// Add tracing
95101
.layer(TracingLayer)
96-
// NOTE
97-
//
98-
// Magic happens here. We will add a layer upon original
99-
// storage operator so that all underlying storage operations
100-
// will send to storage runtime.
101-
.layer(RuntimeLayer::new(GlobalIORuntime::instance().inner()))
102102
.finish();
103103

104104
Ok(op)

src/common/storage/src/runtime_layer.rs

Lines changed: 178 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,28 @@
1313
// limitations under the License.
1414

1515
use std::env;
16+
use std::io::SeekFrom;
17+
use std::mem;
18+
use std::pin::Pin;
1619
use std::sync::Arc;
1720
use std::sync::LazyLock;
21+
use std::task::Context;
22+
use std::task::Poll;
1823
use std::time::Duration;
1924

2025
use async_trait::async_trait;
26+
use bytes::Bytes;
2127
use common_base::base::tokio::pin;
2228
use common_base::base::tokio::runtime::Handle;
2329
use common_base::base::tokio::select;
30+
use common_base::base::tokio::task::JoinHandle;
2431
use common_base::base::tokio::time;
2532
use common_base::runtime::TrackedFuture;
33+
use futures::ready;
34+
use futures::Future;
2635
use opendal::ops::*;
36+
use opendal::raw::oio;
37+
use opendal::raw::oio::ReadExt;
2738
use opendal::raw::Accessor;
2839
use opendal::raw::Layer;
2940
use opendal::raw::LayeredAccessor;
@@ -84,7 +95,7 @@ pub struct RuntimeAccessor<A> {
8495
#[async_trait]
8596
impl<A: Accessor> LayeredAccessor for RuntimeAccessor<A> {
8697
type Inner = A;
87-
type Reader = A::Reader;
98+
type Reader = RuntimeIO<A::Reader>;
8899
type BlockingReader = A::BlockingReader;
89100
type Writer = A::Writer;
90101
type BlockingWriter = A::BlockingWriter;
@@ -127,7 +138,14 @@ impl<A: Accessor> LayeredAccessor for RuntimeAccessor<A> {
127138
};
128139

129140
let future = TrackedFuture::create(future);
130-
self.runtime.spawn(future).await.expect("join must success")
141+
self.runtime
142+
.spawn(future)
143+
.await
144+
.expect("join must success")
145+
.map(|(rp, r)| {
146+
let r = RuntimeIO::new(r, self.runtime.clone());
147+
(rp, r)
148+
})
131149
}
132150

133151
#[async_backtrace::framed]
@@ -191,3 +209,161 @@ impl<A: Accessor> LayeredAccessor for RuntimeAccessor<A> {
191209
self.inner.blocking_scan(path, args)
192210
}
193211
}
212+
213+
pub struct RuntimeIO<R: 'static> {
214+
runtime: Handle,
215+
state: State<R>,
216+
buf: Vec<u8>,
217+
}
218+
219+
impl<R> RuntimeIO<R> {
220+
fn new(inner: R, runtime: Handle) -> Self {
221+
Self {
222+
runtime,
223+
state: State::Idle(Some(inner)),
224+
buf: vec![],
225+
}
226+
}
227+
}
228+
229+
pub enum State<R: 'static> {
230+
Idle(Option<R>),
231+
Read(JoinHandle<(R, Result<Vec<u8>>)>),
232+
Seek(JoinHandle<(R, Result<u64>)>),
233+
Next(JoinHandle<(R, Option<Result<Bytes>>)>),
234+
}
235+
236+
/// Safety: State will only be accessed under &mut.
237+
unsafe impl<R> Sync for State<R> {}
238+
239+
impl<R: oio::Read> oio::Read for RuntimeIO<R> {
240+
/// TODO: the performance of `read` could be affected, we will improve it later.
241+
fn poll_read(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<Result<usize>> {
242+
match &mut self.state {
243+
State::Idle(r) => {
244+
let mut r = r.take().expect("Idle must have a valid reader");
245+
let mut buffer = mem::take(&mut self.buf);
246+
247+
buffer.reserve(buf.len());
248+
// Safety: buffer is reserved with buf.len() bytes.
249+
#[allow(clippy::uninit_vec)]
250+
unsafe {
251+
buffer.set_len(buf.len())
252+
}
253+
254+
let future = async move {
255+
let mut buffer = buffer;
256+
let res = r.read(&mut buffer).await;
257+
match res {
258+
Ok(size) => {
259+
// Safety: we trust our reader, the returning size is correct.
260+
unsafe { buffer.set_len(size) }
261+
(r, Ok(buffer))
262+
}
263+
Err(err) => (r, Err(err)),
264+
}
265+
};
266+
let future = TrackedFuture::create(future);
267+
self.state = State::Read(self.runtime.spawn(future));
268+
269+
self.poll_read(cx, buf)
270+
}
271+
State::Read(future) => {
272+
let (r, res) = ready!(Pin::new(future).poll(cx)).expect("join must success");
273+
self.state = State::Idle(Some(r));
274+
match res {
275+
Ok(mut buffer) => {
276+
let size = buffer.len();
277+
buf[..size].copy_from_slice(&buffer);
278+
// Safety: set length to 0 as we don't care the remaining content.
279+
unsafe { buffer.set_len(0) }
280+
// Always reuse the same buffer
281+
self.buf = buffer;
282+
Poll::Ready(Ok(size))
283+
}
284+
Err(err) => Poll::Ready(Err(err)),
285+
}
286+
}
287+
State::Seek(future) => {
288+
let (r, _) = ready!(Pin::new(future).poll(cx)).expect("join must success");
289+
self.state = State::Idle(Some(r));
290+
291+
self.poll_read(cx, buf)
292+
}
293+
State::Next(future) => {
294+
let (r, _) = ready!(Pin::new(future).poll(cx)).expect("join must success");
295+
self.state = State::Idle(Some(r));
296+
297+
self.poll_read(cx, buf)
298+
}
299+
}
300+
}
301+
302+
fn poll_seek(&mut self, cx: &mut Context<'_>, pos: SeekFrom) -> Poll<Result<u64>> {
303+
match &mut self.state {
304+
State::Idle(r) => {
305+
let mut r = r.take().expect("Idle must have a valid reader");
306+
let future = async move {
307+
let res = r.seek(pos).await;
308+
(r, res)
309+
};
310+
let future = TrackedFuture::create(future);
311+
self.state = State::Seek(self.runtime.spawn(future));
312+
313+
self.poll_seek(cx, pos)
314+
}
315+
State::Read(future) => {
316+
let (r, _) = ready!(Pin::new(future).poll(cx)).expect("join must success");
317+
self.state = State::Idle(Some(r));
318+
319+
self.poll_seek(cx, pos)
320+
}
321+
State::Seek(future) => {
322+
let (r, res) = ready!(Pin::new(future).poll(cx)).expect("join must success");
323+
self.state = State::Idle(Some(r));
324+
325+
Poll::Ready(res)
326+
}
327+
State::Next(future) => {
328+
let (r, _) = ready!(Pin::new(future).poll(cx)).expect("join must success");
329+
self.state = State::Idle(Some(r));
330+
331+
self.poll_seek(cx, pos)
332+
}
333+
}
334+
}
335+
336+
fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<bytes::Bytes>>> {
337+
match &mut self.state {
338+
State::Idle(r) => {
339+
let mut r = r.take().expect("Idle must have a valid reader");
340+
let future = async move {
341+
let res = r.next().await;
342+
(r, res)
343+
};
344+
let future = TrackedFuture::create(future);
345+
self.state = State::Next(self.runtime.spawn(future));
346+
347+
self.poll_next(cx)
348+
}
349+
State::Read(future) => {
350+
let (r, _) = ready!(Pin::new(future).poll(cx)).expect("join must success");
351+
self.state = State::Idle(Some(r));
352+
353+
self.poll_next(cx)
354+
}
355+
State::Seek(future) => {
356+
let (r, _) = ready!(Pin::new(future).poll(cx)).expect("join must success");
357+
self.state = State::Idle(Some(r));
358+
359+
self.poll_next(cx)
360+
}
361+
State::Next(future) => {
362+
let (r, res) = ready!(Pin::new(future).poll(cx)).expect("join must success");
363+
self.state = State::Idle(Some(r));
364+
365+
Poll::Ready(res)
366+
}
367+
}
368+
}
369+
}

0 commit comments

Comments
 (0)