Skip to content

Commit cd64baa

Browse files
authored
separating output providers, sequential output providers (#528)
This PR does a refactor of how we pass in the catch all "OutputProvider" to the download mechanism. It separates the download system to supporting "Sequential" and "Seeking" operations: - Seeking e.g. opening the file multiple times and seeking to location -- this is the standard writing mechanism hf_xet uses today. - Sequential e.g. opening a file once and writing data in order -- this is to be used in a set of upcoming PR's/features to use the parallel-download/sequential-write mechanism to support writing to Stdout and to a channel buffer in memory. To support an in memory channel with backpressure the Channel{Writer, Stream, Reader} are introduced (re-introduced?) in utils. This particularly could be useful in the mount functionality.
1 parent 2fc772e commit cd64baa

File tree

18 files changed

+484
-218
lines changed

18 files changed

+484
-218
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

cas_client/src/download_utils.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use utils::singleflight::Group;
2121

2222
use crate::error::{CasClientError, Result};
2323
use crate::http_client::Api;
24-
use crate::output_provider::OutputProvider;
24+
use crate::output_provider::SeekingOutputProvider;
2525
use crate::remote_client::{PREFIX_DEFAULT, get_reconstruction_with_endpoint_and_client};
2626
use crate::retry_wrapper::{RetryWrapper, RetryableReqwestError};
2727

@@ -296,7 +296,7 @@ pub(crate) struct ChunkRangeWrite {
296296
pub(crate) struct FetchTermDownloadOnceAndWriteEverywhereUsed {
297297
pub download: FetchTermDownload,
298298
// pub write_offset: u64, // start position of the writer to write to
299-
pub output: OutputProvider,
299+
pub output: SeekingOutputProvider,
300300
pub writes: Vec<ChunkRangeWrite>,
301301
}
302302

cas_client/src/interface.rs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use std::collections::HashMap;
21
use std::sync::Arc;
32

43
use bytes::Bytes;
@@ -9,9 +8,9 @@ use merklehash::MerkleHash;
98
use progress_tracking::item_tracking::SingleItemProgressUpdater;
109
use progress_tracking::upload_tracking::CompletionTracker;
1110

12-
#[cfg(not(target_family = "wasm"))]
13-
use crate::OutputProvider;
1411
use crate::error::Result;
12+
#[cfg(not(target_family = "wasm"))]
13+
use crate::{SeekingOutputProvider, SequentialOutput};
1514

1615
/// A Client to the Shard service. The shard service
1716
/// provides for
@@ -25,24 +24,25 @@ pub trait Client {
2524
///
2625
/// The http_client passed in is a non-authenticated client. This is used to directly communicate
2726
/// with the backing store (S3) to retrieve xorbs.
27+
///
28+
/// Content is written in-order to the provided SequentialOutput
2829
#[cfg(not(target_family = "wasm"))]
29-
async fn get_file(
30+
async fn get_file_with_sequential_writer(
3031
&self,
3132
hash: &MerkleHash,
3233
byte_range: Option<FileRange>,
33-
output_provider: &OutputProvider,
34+
output_provider: SequentialOutput,
3435
progress_updater: Option<Arc<SingleItemProgressUpdater>>,
3536
) -> Result<u64>;
3637

3738
#[cfg(not(target_family = "wasm"))]
38-
async fn batch_get_file(&self, files: HashMap<MerkleHash, &OutputProvider>) -> Result<u64> {
39-
let mut n_bytes = 0;
40-
// Provide the basic naive implementation as a default.
41-
for (h, w) in files {
42-
n_bytes += self.get_file(&h, None, w, None).await?;
43-
}
44-
Ok(n_bytes)
45-
}
39+
async fn get_file_with_parallel_writer(
40+
&self,
41+
hash: &MerkleHash,
42+
byte_range: Option<FileRange>,
43+
output_provider: SeekingOutputProvider,
44+
progress_updater: Option<Arc<SingleItemProgressUpdater>>,
45+
) -> Result<u64>;
4646

4747
async fn get_file_reconstruction_info(
4848
&self,

cas_client/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ pub use interface::Client;
66
#[cfg(not(target_family = "wasm"))]
77
pub use local_client::LocalClient;
88
#[cfg(not(target_family = "wasm"))]
9-
pub use output_provider::{FileProvider, OutputProvider};
9+
pub use output_provider::*;
1010
pub use remote_client::RemoteClient;
1111

1212
pub use crate::error::CasClientError;

cas_client/src/local_client.rs

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ use merklehash::MerkleHash;
1818
use progress_tracking::item_tracking::SingleItemProgressUpdater;
1919
use progress_tracking::upload_tracking::CompletionTracker;
2020
use tempfile::TempDir;
21+
use tokio::io::AsyncWriteExt;
2122
use tokio::runtime::Handle;
2223
use tracing::{debug, error, info, warn};
2324

24-
use crate::Client;
2525
use crate::error::{CasClientError, Result};
26-
use crate::output_provider::OutputProvider;
26+
use crate::{Client, SeekingOutputProvider, SequentialOutput};
2727

2828
pub struct LocalClient {
2929
tmp_dir: Option<TempDir>, // To hold directory to use for local testing
@@ -232,14 +232,14 @@ impl LocalClient {
232232
}
233233
}
234234

235-
/// LocalClient is responsible for writing/reading Xorbs on local disk.
235+
/// LocalClient is responsible for writing/reading Xorbs on the local disk.
236236
#[async_trait]
237237
impl Client for LocalClient {
238-
async fn get_file(
238+
async fn get_file_with_sequential_writer(
239239
&self,
240240
hash: &MerkleHash,
241241
byte_range: Option<FileRange>,
242-
output_provider: &OutputProvider,
242+
mut output_provider: SequentialOutput,
243243
_progress_updater: Option<Arc<SingleItemProgressUpdater>>,
244244
) -> Result<u64> {
245245
let Some((file_info, _)) = self
@@ -250,7 +250,6 @@ impl Client for LocalClient {
250250
else {
251251
return Err(CasClientError::FileNotFound(*hash));
252252
};
253-
let mut writer = output_provider.get_writer_at(0)?;
254253

255254
// This is just used for testing, so inefficient is fine.
256255
let mut file_vec = Vec::new();
@@ -269,11 +268,23 @@ impl Client for LocalClient {
269268
.unwrap_or(file_vec.len())
270269
.min(file_vec.len());
271270

272-
writer.write_all(&file_vec[start..end])?;
271+
output_provider.write_all(&file_vec[start..end]).await?;
273272

274273
Ok((end - start) as u64)
275274
}
276275

276+
async fn get_file_with_parallel_writer(
277+
&self,
278+
hash: &MerkleHash,
279+
byte_range: Option<FileRange>,
280+
output_provider: SeekingOutputProvider,
281+
progress_updater: Option<Arc<SingleItemProgressUpdater>>,
282+
) -> Result<u64> {
283+
let sequential = output_provider.try_into()?;
284+
self.get_file_with_sequential_writer(hash, byte_range, sequential, progress_updater)
285+
.await
286+
}
287+
277288
/// Query the shard server for the file reconstruction info.
278289
/// Returns the FileInfo for reconstructing the file and the shard ID that
279290
/// defines the file info.

cas_client/src/output_provider.rs

Lines changed: 125 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,87 @@
1-
use std::io::{Cursor, Seek, SeekFrom, Write};
2-
use std::path::PathBuf;
3-
use std::sync::{Arc, Mutex};
1+
use std::io::{Seek, SeekFrom, Write};
2+
use std::path::{Path, PathBuf};
3+
use std::pin::Pin;
4+
use std::task::{Context, Poll};
45

6+
use tokio::io::AsyncWrite;
7+
8+
use crate::CasClientError;
59
use crate::error::Result;
610

7-
/// Enum of different output formats to write reconstructed files.
11+
/// type that represents all acceptable sequential output mechanisms
12+
/// To convert something that is Write rather than AsyncWrite uses the AsyncWriteFromWrite adapter
13+
pub type SequentialOutput = Box<dyn AsyncWrite + Send + Unpin>;
14+
15+
pub fn sequential_output_from_filepath(filename: impl AsRef<Path>) -> Result<SequentialOutput> {
16+
let file = std::fs::OpenOptions::new()
17+
.write(true)
18+
.truncate(false)
19+
.create(true)
20+
.open(&filename)?;
21+
Ok(Box::new(AsyncWriteFromWrite(Some(Box::new(file)))))
22+
}
23+
24+
/// Enum of different output formats to write reconstructed files
25+
/// where the result writer can be set at a specific position and new handles can be created
826
#[derive(Debug, Clone)]
9-
pub enum OutputProvider {
27+
pub enum SeekingOutputProvider {
1028
File(FileProvider),
1129
#[cfg(test)]
12-
Buffer(BufferProvider),
30+
Buffer(buffer_provider::BufferProvider),
1331
}
1432

15-
impl OutputProvider {
33+
impl SeekingOutputProvider {
34+
// shortcut to create a new FileProvider variant from filename
35+
pub fn new_file_provider(filename: PathBuf) -> Self {
36+
Self::File(FileProvider::new(filename))
37+
}
38+
1639
/// Create a new writer to start writing at the indicated start location.
1740
pub(crate) fn get_writer_at(&self, start: u64) -> Result<Box<dyn Write + Send>> {
1841
match self {
19-
OutputProvider::File(fp) => fp.get_writer_at(start),
42+
SeekingOutputProvider::File(fp) => fp.get_writer_at(start),
2043
#[cfg(test)]
21-
OutputProvider::Buffer(bp) => bp.get_writer_at(start),
44+
SeekingOutputProvider::Buffer(bp) => bp.get_writer_at(start),
2245
}
2346
}
2447
}
2548

49+
// Adapter used to create an AsyncWrite from a Writer.
50+
struct AsyncWriteFromWrite(Option<Box<dyn Write + Send>>);
51+
52+
impl AsyncWrite for AsyncWriteFromWrite {
53+
fn poll_write(mut self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
54+
let Some(inner) = self.0.as_mut() else {
55+
return Poll::Ready(Ok(0));
56+
};
57+
Poll::Ready(inner.write(buf))
58+
}
59+
60+
fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
61+
let Some(inner) = self.0.as_mut() else {
62+
return Poll::Ready(Err(std::io::Error::new(
63+
std::io::ErrorKind::BrokenPipe,
64+
"writer closed, already dropped",
65+
)));
66+
};
67+
Poll::Ready(inner.flush())
68+
}
69+
70+
fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
71+
let _ = self.0.take();
72+
Poll::Ready(Ok(()))
73+
}
74+
}
75+
76+
impl TryFrom<SeekingOutputProvider> for SequentialOutput {
77+
type Error = CasClientError;
78+
79+
fn try_from(value: SeekingOutputProvider) -> std::result::Result<Self, Self::Error> {
80+
let w = value.get_writer_at(0)?;
81+
Ok(Box::new(AsyncWriteFromWrite(Some(w))))
82+
}
83+
}
84+
2685
/// Provides new Writers to a file located at a particular location
2786
#[derive(Debug, Clone)]
2887
pub struct FileProvider {
@@ -45,44 +104,69 @@ impl FileProvider {
45104
}
46105
}
47106

48-
#[derive(Debug, Default, Clone)]
49-
pub struct BufferProvider {
50-
pub buf: ThreadSafeBuffer,
51-
}
107+
#[cfg(test)]
108+
pub(crate) mod buffer_provider {
109+
use std::io::{Cursor, Write};
110+
use std::sync::{Arc, Mutex};
111+
112+
use crate::error::Result;
113+
use crate::output_provider::AsyncWriteFromWrite;
114+
use crate::{SeekingOutputProvider, SequentialOutput};
52115

53-
impl BufferProvider {
54-
pub fn get_writer_at(&self, start: u64) -> crate::error::Result<Box<dyn std::io::Write + Send>> {
55-
let mut buffer = self.buf.clone();
56-
buffer.idx = start;
57-
Ok(Box::new(buffer))
116+
/// BufferProvider may be Seeking or Sequential
117+
/// only used in testing
118+
#[derive(Debug, Clone)]
119+
pub struct BufferProvider {
120+
pub buf: ThreadSafeBuffer,
58121
}
59-
}
60122

61-
#[derive(Debug, Default, Clone)]
62-
/// Thread-safe in-memory buffer that implements [Write](Write) trait at some position
63-
/// within an underlying buffer and allows access to inner buffer.
64-
/// Thread-safe in-memory buffer that implements [Write](Write) trait and allows
65-
/// access to inner buffer
66-
pub struct ThreadSafeBuffer {
67-
idx: u64,
68-
inner: Arc<Mutex<Cursor<Vec<u8>>>>,
69-
}
70-
impl ThreadSafeBuffer {
71-
pub fn value(&self) -> Vec<u8> {
72-
self.inner.lock().unwrap().get_ref().clone()
123+
impl BufferProvider {
124+
pub fn get_writer_at(&self, start: u64) -> Result<Box<dyn Write + Send>> {
125+
let mut buffer = self.buf.clone();
126+
buffer.idx = start;
127+
Ok(Box::new(buffer))
128+
}
129+
}
130+
131+
#[derive(Debug, Default, Clone)]
132+
/// Thread-safe in-memory buffer that implements [Write](Write) trait at some position
133+
/// within an underlying buffer and allows access to the inner buffer.
134+
/// Thread-safe in-memory buffer that implements [Write](Write) trait and allows
135+
/// access to the inner buffer
136+
pub struct ThreadSafeBuffer {
137+
idx: u64,
138+
inner: Arc<Mutex<Cursor<Vec<u8>>>>,
139+
}
140+
141+
impl ThreadSafeBuffer {
142+
pub fn value(&self) -> Vec<u8> {
143+
self.inner.lock().unwrap().get_ref().clone()
144+
}
73145
}
74-
}
75146

76-
impl std::io::Write for ThreadSafeBuffer {
77-
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
78-
let mut guard = self.inner.lock().map_err(|e| std::io::Error::other(format!("{e}")))?;
79-
guard.set_position(self.idx);
80-
let num_written = guard.write(buf)?;
81-
self.idx = guard.position();
82-
Ok(num_written)
147+
impl Write for ThreadSafeBuffer {
148+
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
149+
let mut guard = self.inner.lock().map_err(|e| std::io::Error::other(format!("{e}")))?;
150+
guard.set_position(self.idx);
151+
let num_written = Write::write(guard.get_mut(), buf)?;
152+
self.idx = guard.position();
153+
Ok(num_written)
154+
}
155+
156+
fn flush(&mut self) -> std::io::Result<()> {
157+
Ok(())
158+
}
83159
}
84160

85-
fn flush(&mut self) -> std::io::Result<()> {
86-
Ok(())
161+
impl From<ThreadSafeBuffer> for SequentialOutput {
162+
fn from(value: ThreadSafeBuffer) -> Self {
163+
Box::new(AsyncWriteFromWrite(Some(Box::new(value))))
164+
}
165+
}
166+
167+
impl From<ThreadSafeBuffer> for SeekingOutputProvider {
168+
fn from(value: ThreadSafeBuffer) -> Self {
169+
SeekingOutputProvider::Buffer(BufferProvider { buf: value })
170+
}
87171
}
88172
}

0 commit comments

Comments
 (0)