diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7369caf6..38cdaf08 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -30,8 +30,8 @@ jobs: steps: - name: Checkout repository uses: actions/checkout@v4 - - name: Install Rust 1.89 - uses: dtolnay/rust-toolchain@1.89.0 + - name: Install Rust 1.90 + uses: dtolnay/rust-toolchain@1.90.0 with: components: clippy - uses: ./.github/actions/cache-rust-build @@ -56,8 +56,8 @@ jobs: steps: - name: Checkout repository uses: actions/checkout@v4 - - name: Install Rust 1.89 - uses: dtolnay/rust-toolchain@1.89.0 + - name: Install Rust 1.90 + uses: dtolnay/rust-toolchain@1.90.0 with: components: clippy - uses: ./.github/actions/cache-rust-build @@ -69,8 +69,8 @@ jobs: steps: - name: Checkout repository uses: actions/checkout@v4 - - name: Install Rust 1.89 - uses: dtolnay/rust-toolchain@1.89.0 + - name: Install Rust 1.90 + uses: dtolnay/rust-toolchain@1.90.0 with: components: clippy - name: Set up Git LFS diff --git a/.github/workflows/git-xet-release.yml b/.github/workflows/git-xet-release.yml index 6f1da57d..dc84a02f 100644 --- a/.github/workflows/git-xet-release.yml +++ b/.github/workflows/git-xet-release.yml @@ -27,8 +27,8 @@ jobs: target: aarch64 steps: - uses: actions/checkout@v4 - - name: Install Rust 1.89 - uses: dtolnay/rust-toolchain@1.89.0 + - name: Install Rust 1.90 + uses: dtolnay/rust-toolchain@1.90.0 - uses: ./.github/actions/cache-rust-build - name: Build run: | @@ -50,8 +50,8 @@ jobs: target: aarch64 steps: - uses: actions/checkout@v4 - - name: Install Rust 1.89 - uses: dtolnay/rust-toolchain@1.89.0 + - name: Install Rust 1.90 + uses: dtolnay/rust-toolchain@1.90.0 - uses: ./.github/actions/cache-rust-build - name: Build run: | @@ -85,8 +85,8 @@ jobs: target: x86_64 steps: - uses: actions/checkout@v4 - - name: Install Rust 1.89 - uses: dtolnay/rust-toolchain@1.89.0 + - name: Install Rust 1.90 + uses: dtolnay/rust-toolchain@1.90.0 - uses: ./.github/actions/cache-rust-build - name: Install WiX run: | diff --git a/Cargo.lock b/Cargo.lock index bad66f98..f4f4ac2c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1085,6 +1085,18 @@ dependencies = [ "winapi", ] +[[package]] +name = "filetime" +version = "0.2.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35c0522e981e68cbfa8c3f978441a5f34b30b96e146b33cd3359176b50fe8586" +dependencies = [ + "cfg-if 1.0.0", + "libc", + "libredox", + "windows-sys 0.59.0", +] + [[package]] name = "fixedbitset" version = "0.4.2" @@ -1553,8 +1565,11 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "bytes", "cas_client", + "cas_types", "http 1.3.1", + "regex", "reqwest", "reqwest-middleware", "serde", @@ -1995,6 +2010,7 @@ checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ "bitflags 2.9.1", "libc", + "redox_syscall 0.5.12", ] [[package]] @@ -2216,6 +2232,21 @@ version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" +[[package]] +name = "nfsserve" +version = "0.10.2" +source = "git+https://github.com/huggingface/nfsserve.git#9ba758b0490e52c57a4103f82455fc95fac748df" +dependencies = [ + "anyhow", + "async-trait", + "byteorder", + "filetime", + "num-derive", + "num-traits", + "tokio", + "tracing", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -2242,6 +2273,17 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" +[[package]] +name = "num-derive" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "num-integer" version = "0.1.46" @@ -4007,6 +4049,7 @@ dependencies = [ "tempfile", "thiserror 2.0.12", "tokio", + "tokio-util", "tracing", "web-time", "xet_runtime", @@ -4605,6 +4648,24 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" +[[package]] +name = "xet-mount" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-trait", + "cas_client", + "cas_types", + "clap", + "data", + "hub_client", + "merklehash", + "nfsserve", + "tokio", + "utils", + "uuid", +] + [[package]] name = "xet_runtime" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 2dcaaa42..7831e1de 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ members = [ "merklehash", "progress_tracking", "utils", + "xet-mount", "xet_runtime", ] diff --git a/cas_client/src/download_utils.rs b/cas_client/src/download_utils.rs index 3cceba95..385b6256 100644 --- a/cas_client/src/download_utils.rs +++ b/cas_client/src/download_utils.rs @@ -21,7 +21,7 @@ use utils::singleflight::Group; use crate::error::{CasClientError, Result}; use crate::http_client::Api; -use crate::output_provider::OutputProvider; +use crate::output_provider::SeekingOutputProvider; use crate::remote_client::{PREFIX_DEFAULT, get_reconstruction_with_endpoint_and_client}; use crate::retry_wrapper::{RetryWrapper, RetryableReqwestError}; @@ -296,7 +296,7 @@ pub(crate) struct ChunkRangeWrite { pub(crate) struct FetchTermDownloadOnceAndWriteEverywhereUsed { pub download: FetchTermDownload, // pub write_offset: u64, // start position of the writer to write to - pub output: OutputProvider, + pub output: SeekingOutputProvider, pub writes: Vec, } diff --git a/cas_client/src/interface.rs b/cas_client/src/interface.rs index 770f36a6..8df05489 100644 --- a/cas_client/src/interface.rs +++ b/cas_client/src/interface.rs @@ -1,4 +1,3 @@ -use std::collections::HashMap; use std::sync::Arc; use bytes::Bytes; @@ -9,9 +8,9 @@ use merklehash::MerkleHash; use progress_tracking::item_tracking::SingleItemProgressUpdater; use progress_tracking::upload_tracking::CompletionTracker; -#[cfg(not(target_family = "wasm"))] -use crate::OutputProvider; use crate::error::Result; +#[cfg(not(target_family = "wasm"))] +use crate::{SeekingOutputProvider, SequentialOutput}; /// A Client to the Shard service. The shard service /// provides for @@ -25,24 +24,25 @@ pub trait Client { /// /// The http_client passed in is a non-authenticated client. This is used to directly communicate /// with the backing store (S3) to retrieve xorbs. + /// + /// Content is written in-order to the provided SequentialOutput #[cfg(not(target_family = "wasm"))] - async fn get_file( + async fn get_file_with_sequential_writer( &self, hash: &MerkleHash, byte_range: Option, - output_provider: &OutputProvider, + output_provider: SequentialOutput, progress_updater: Option>, ) -> Result; #[cfg(not(target_family = "wasm"))] - async fn batch_get_file(&self, files: HashMap) -> Result { - let mut n_bytes = 0; - // Provide the basic naive implementation as a default. - for (h, w) in files { - n_bytes += self.get_file(&h, None, w, None).await?; - } - Ok(n_bytes) - } + async fn get_file_with_parallel_writer( + &self, + hash: &MerkleHash, + byte_range: Option, + output_provider: SeekingOutputProvider, + progress_updater: Option>, + ) -> Result; async fn get_file_reconstruction_info( &self, diff --git a/cas_client/src/lib.rs b/cas_client/src/lib.rs index 7625516b..91b37725 100644 --- a/cas_client/src/lib.rs +++ b/cas_client/src/lib.rs @@ -6,7 +6,7 @@ pub use interface::Client; #[cfg(not(target_family = "wasm"))] pub use local_client::LocalClient; #[cfg(not(target_family = "wasm"))] -pub use output_provider::{FileProvider, OutputProvider}; +pub use output_provider::*; pub use remote_client::RemoteClient; pub use crate::error::CasClientError; diff --git a/cas_client/src/local_client.rs b/cas_client/src/local_client.rs index b45a47fb..07831c3c 100644 --- a/cas_client/src/local_client.rs +++ b/cas_client/src/local_client.rs @@ -18,12 +18,12 @@ use merklehash::MerkleHash; use progress_tracking::item_tracking::SingleItemProgressUpdater; use progress_tracking::upload_tracking::CompletionTracker; use tempfile::TempDir; +use tokio::io::AsyncWriteExt; use tokio::runtime::Handle; use tracing::{debug, error, info, warn}; -use crate::Client; use crate::error::{CasClientError, Result}; -use crate::output_provider::OutputProvider; +use crate::{Client, SeekingOutputProvider, SequentialOutput}; pub struct LocalClient { tmp_dir: Option, // To hold directory to use for local testing @@ -232,14 +232,14 @@ impl LocalClient { } } -/// LocalClient is responsible for writing/reading Xorbs on local disk. +/// LocalClient is responsible for writing/reading Xorbs on the local disk. #[async_trait] impl Client for LocalClient { - async fn get_file( + async fn get_file_with_sequential_writer( &self, hash: &MerkleHash, byte_range: Option, - output_provider: &OutputProvider, + mut output_provider: SequentialOutput, _progress_updater: Option>, ) -> Result { let Some((file_info, _)) = self @@ -250,7 +250,6 @@ impl Client for LocalClient { else { return Err(CasClientError::FileNotFound(*hash)); }; - let mut writer = output_provider.get_writer_at(0)?; // This is just used for testing, so inefficient is fine. let mut file_vec = Vec::new(); @@ -269,11 +268,23 @@ impl Client for LocalClient { .unwrap_or(file_vec.len()) .min(file_vec.len()); - writer.write_all(&file_vec[start..end])?; + output_provider.write_all(&file_vec[start..end]).await?; Ok((end - start) as u64) } + async fn get_file_with_parallel_writer( + &self, + hash: &MerkleHash, + byte_range: Option, + output_provider: SeekingOutputProvider, + progress_updater: Option>, + ) -> Result { + let sequential = output_provider.try_into()?; + self.get_file_with_sequential_writer(hash, byte_range, sequential, progress_updater) + .await + } + /// Query the shard server for the file reconstruction info. /// Returns the FileInfo for reconstructing the file and the shard ID that /// defines the file info. diff --git a/cas_client/src/output_provider.rs b/cas_client/src/output_provider.rs index 370cc6c9..54b3c89a 100644 --- a/cas_client/src/output_provider.rs +++ b/cas_client/src/output_provider.rs @@ -1,28 +1,87 @@ -use std::io::{Cursor, Seek, SeekFrom, Write}; -use std::path::PathBuf; -use std::sync::{Arc, Mutex}; +use std::io::{Seek, SeekFrom, Write}; +use std::path::{Path, PathBuf}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::AsyncWrite; + +use crate::CasClientError; use crate::error::Result; -/// Enum of different output formats to write reconstructed files. +/// type that represents all acceptable sequential output mechanisms +/// To convert something that is Write rather than AsyncWrite uses the AsyncWriteFromWrite adapter +pub type SequentialOutput = Box; + +pub fn sequential_output_from_filepath(filename: impl AsRef) -> Result { + let file = std::fs::OpenOptions::new() + .write(true) + .truncate(false) + .create(true) + .open(&filename)?; + Ok(Box::new(AsyncWriteFromWrite(Some(Box::new(file))))) +} + +/// Enum of different output formats to write reconstructed files +/// where the result writer can be set at a specific position and new handles can be created #[derive(Debug, Clone)] -pub enum OutputProvider { +pub enum SeekingOutputProvider { File(FileProvider), #[cfg(test)] - Buffer(BufferProvider), + Buffer(buffer_provider::BufferProvider), } -impl OutputProvider { +impl SeekingOutputProvider { + // shortcut to create a new FileProvider variant from filename + pub fn new_file_provider(filename: PathBuf) -> Self { + Self::File(FileProvider::new(filename)) + } + /// Create a new writer to start writing at the indicated start location. pub(crate) fn get_writer_at(&self, start: u64) -> Result> { match self { - OutputProvider::File(fp) => fp.get_writer_at(start), + SeekingOutputProvider::File(fp) => fp.get_writer_at(start), #[cfg(test)] - OutputProvider::Buffer(bp) => bp.get_writer_at(start), + SeekingOutputProvider::Buffer(bp) => bp.get_writer_at(start), } } } +// Adapter used to create an AsyncWrite from a Writer. +struct AsyncWriteFromWrite(Option>); + +impl AsyncWrite for AsyncWriteFromWrite { + fn poll_write(mut self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + let Some(inner) = self.0.as_mut() else { + return Poll::Ready(Ok(0)); + }; + Poll::Ready(inner.write(buf)) + } + + fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + let Some(inner) = self.0.as_mut() else { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "writer closed, already dropped", + ))); + }; + Poll::Ready(inner.flush()) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + let _ = self.0.take(); + Poll::Ready(Ok(())) + } +} + +impl TryFrom for SequentialOutput { + type Error = CasClientError; + + fn try_from(value: SeekingOutputProvider) -> std::result::Result { + let w = value.get_writer_at(0)?; + Ok(Box::new(AsyncWriteFromWrite(Some(w)))) + } +} + /// Provides new Writers to a file located at a particular location #[derive(Debug, Clone)] pub struct FileProvider { @@ -45,44 +104,69 @@ impl FileProvider { } } -#[derive(Debug, Default, Clone)] -pub struct BufferProvider { - pub buf: ThreadSafeBuffer, -} +#[cfg(test)] +pub(crate) mod buffer_provider { + use std::io::{Cursor, Write}; + use std::sync::{Arc, Mutex}; + + use crate::error::Result; + use crate::output_provider::AsyncWriteFromWrite; + use crate::{SeekingOutputProvider, SequentialOutput}; -impl BufferProvider { - pub fn get_writer_at(&self, start: u64) -> crate::error::Result> { - let mut buffer = self.buf.clone(); - buffer.idx = start; - Ok(Box::new(buffer)) + /// BufferProvider may be Seeking or Sequential + /// only used in testing + #[derive(Debug, Clone)] + pub struct BufferProvider { + pub buf: ThreadSafeBuffer, } -} -#[derive(Debug, Default, Clone)] -/// Thread-safe in-memory buffer that implements [Write](Write) trait at some position -/// within an underlying buffer and allows access to inner buffer. -/// Thread-safe in-memory buffer that implements [Write](Write) trait and allows -/// access to inner buffer -pub struct ThreadSafeBuffer { - idx: u64, - inner: Arc>>>, -} -impl ThreadSafeBuffer { - pub fn value(&self) -> Vec { - self.inner.lock().unwrap().get_ref().clone() + impl BufferProvider { + pub fn get_writer_at(&self, start: u64) -> Result> { + let mut buffer = self.buf.clone(); + buffer.idx = start; + Ok(Box::new(buffer)) + } + } + + #[derive(Debug, Default, Clone)] + /// Thread-safe in-memory buffer that implements [Write](Write) trait at some position + /// within an underlying buffer and allows access to the inner buffer. + /// Thread-safe in-memory buffer that implements [Write](Write) trait and allows + /// access to the inner buffer + pub struct ThreadSafeBuffer { + idx: u64, + inner: Arc>>>, + } + + impl ThreadSafeBuffer { + pub fn value(&self) -> Vec { + self.inner.lock().unwrap().get_ref().clone() + } } -} -impl std::io::Write for ThreadSafeBuffer { - fn write(&mut self, buf: &[u8]) -> std::io::Result { - let mut guard = self.inner.lock().map_err(|e| std::io::Error::other(format!("{e}")))?; - guard.set_position(self.idx); - let num_written = guard.write(buf)?; - self.idx = guard.position(); - Ok(num_written) + impl Write for ThreadSafeBuffer { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let mut guard = self.inner.lock().map_err(|e| std::io::Error::other(format!("{e}")))?; + guard.set_position(self.idx); + let num_written = Write::write(guard.get_mut(), buf)?; + self.idx = guard.position(); + Ok(num_written) + } + + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } } - fn flush(&mut self) -> std::io::Result<()> { - Ok(()) + impl From for SequentialOutput { + fn from(value: ThreadSafeBuffer) -> Self { + Box::new(AsyncWriteFromWrite(Some(Box::new(value)))) + } + } + + impl From for SeekingOutputProvider { + fn from(value: ThreadSafeBuffer) -> Self { + SeekingOutputProvider::Buffer(BufferProvider { buf: value }) + } } } diff --git a/cas_client/src/remote_client.rs b/cas_client/src/remote_client.rs index 2a9a55d0..4e156b2a 100644 --- a/cas_client/src/remote_client.rs +++ b/cas_client/src/remote_client.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::io::Write; use std::mem::take; use std::path::PathBuf; use std::sync::Arc; @@ -21,6 +20,7 @@ use progress_tracking::item_tracking::SingleItemProgressUpdater; use progress_tracking::upload_tracking::CompletionTracker; use reqwest::{Body, Response, StatusCode, Url}; use reqwest_middleware::ClientWithMiddleware; +use tokio::io::AsyncWriteExt; use tokio::sync::{OwnedSemaphorePermit, mpsc}; use tokio::task::{JoinHandle, JoinSet}; use tracing::{debug, info, instrument}; @@ -34,7 +34,7 @@ use crate::download_utils::*; use crate::error::{CasClientError, Result}; use crate::http_client::{Api, ResponseErrorLogger, RetryConfig}; #[cfg(not(target_family = "wasm"))] -use crate::output_provider::OutputProvider; +use crate::output_provider::{SeekingOutputProvider, SequentialOutput}; use crate::retry_wrapper::RetryWrapper; use crate::{Client, http_client}; @@ -49,7 +49,7 @@ utils::configurable_constants! { high_performance: 256, }; - /// Send a report of successful partial upload every 512kb. + /// Send a report of a successful partial upload every 512kb. ref UPLOAD_REPORTING_BLOCK_SIZE : usize = 512 * 1024; /// Env (HF_XET_RECONSTRUCT_WRITE_SEQUENTIALLY) to switch to writing terms sequentially to disk. @@ -126,7 +126,7 @@ pub(crate) async fn map_fetch_info_into_download_tasks( chunk_cache: Option>, client: Arc, range_download_single_flight: Arc>, - output_provider: &OutputProvider, + output_provider: &SeekingOutputProvider, ) -> Result> { // the actual segment length. // the file_range end may actually exceed the file total length for the last segment. @@ -291,14 +291,14 @@ impl RemoteClient { // storage uses HDDs. #[instrument(skip_all, name = "RemoteClient::reconstruct_file_segmented", fields(file.hash = file_hash.hex() ))] - async fn reconstruct_file_to_writer_segmented( + async fn reconstruct_file_to_writer_segmented_sequential_write( &self, file_hash: &MerkleHash, byte_range: Option, - writer: &OutputProvider, + mut writer: SequentialOutput, progress_updater: Option>, ) -> Result { - // Use an unlimited queue size, as queue size is inherently bounded by degree of concurrency. + // Use an unlimited queue size, as queue size is inherently bounded by a degree of concurrency. let (task_tx, mut task_rx) = mpsc::unbounded_channel::>(); let (running_downloads_tx, mut running_downloads_rx) = mpsc::unbounded_channel::>, OwnedSemaphorePermit)>>>(); @@ -409,13 +409,12 @@ impl RemoteClient { Ok(()) }); - let mut writer = writer.get_writer_at(0)?; let mut total_written = 0; while let Some(result) = running_downloads_rx.recv().await { match result.await { Ok(Ok((mut download_result, permit))) => { let data = take(&mut download_result.payload); - writer.write_all(&data)?; + writer.write_all(&data).await?; // drop permit after data written out so they don't accumulate in memory unbounded drop(permit); @@ -432,7 +431,7 @@ impl RemoteClient { Err(e) => Err(anyhow!("{e:?}"))?, } } - writer.flush()?; + writer.flush().await?; queue_dispatcher.await??; @@ -449,7 +448,7 @@ impl RemoteClient { &self, file_hash: &MerkleHash, byte_range: Option, - writer: &OutputProvider, + writer: &SeekingOutputProvider, progress_updater: Option>, ) -> Result { // Use the unlimited queue, as queue size is inherently bounded by degree of concurrency. @@ -593,6 +592,97 @@ impl RemoteClient { #[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)] #[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))] impl Client for RemoteClient { + #[cfg(not(target_family = "wasm"))] + async fn get_file_with_sequential_writer( + &self, + hash: &MerkleHash, + byte_range: Option, + output_provider: SequentialOutput, + progress_updater: Option>, + ) -> Result { + self.reconstruct_file_to_writer_segmented_sequential_write(hash, byte_range, output_provider, progress_updater) + .await + } + + #[cfg(not(target_family = "wasm"))] + async fn get_file_with_parallel_writer( + &self, + hash: &MerkleHash, + byte_range: Option, + output_provider: SeekingOutputProvider, + progress_updater: Option>, + ) -> Result { + self.reconstruct_file_to_writer_segmented_parallel_write(hash, byte_range, &output_provider, progress_updater) + .await + } + + #[instrument(skip_all, name = "RemoteClient::get_file_reconstruction", fields(file.hash = file_hash.hex() + ))] + async fn get_file_reconstruction_info( + &self, + file_hash: &MerkleHash, + ) -> Result)>> { + let url = Url::parse(&format!("{}/reconstructions/{}", self.endpoint, file_hash.hex()))?; + + let api_tag = "cas::get_reconstruction_info"; + let client = self.authenticated_http_client.clone(); + + let response: QueryReconstructionResponse = RetryWrapper::new(api_tag) + .run_and_extract_json(move || client.get(url.clone()).with_extension(Api(api_tag)).send()) + .await?; + + Ok(Some(( + MDBFileInfo { + metadata: FileDataSequenceHeader::new(*file_hash, response.terms.len(), false, false), + segments: response + .terms + .into_iter() + .map(|ce| { + FileDataSequenceEntry::new(ce.hash.into(), ce.unpacked_length, ce.range.start, ce.range.end) + }) + .collect(), + verification: vec![], + metadata_ext: None, + }, + None, + ))) + } + + async fn query_for_global_dedup_shard(&self, prefix: &str, chunk_hash: &MerkleHash) -> Result> { + let Some(response) = self.query_dedup_api(prefix, chunk_hash).await? else { + return Ok(None); + }; + + Ok(Some(response.bytes().await?)) + } + + #[instrument(skip_all, name = "RemoteClient::upload_shard", fields(shard.len = shard_data.len()))] + async fn upload_shard(&self, shard_data: Bytes) -> Result { + if self.dry_run { + return Ok(true); + } + + let api_tag = "cas::upload_shard"; + let client = self.authenticated_http_client.clone(); + + let url = Url::parse(&format!("{}/shards", self.endpoint))?; + + let response: UploadShardResponse = RetryWrapper::new(api_tag) + .run_and_extract_json(move || { + client + .post(url.clone()) + .with_extension(Api(api_tag)) + .body(shard_data.clone()) + .send() + }) + .await?; + + match response.result { + UploadShardResponseType::Exists => Ok(false), + UploadShardResponseType::SyncPerformed => Ok(true), + } + } + #[cfg(not(target_family = "wasm"))] #[instrument(skip_all, name = "RemoteClient::upload_xorb", fields(key = Key{prefix : prefix.to_string(), hash : serialized_cas_object.hash}.to_string(), xorb.len = serialized_cas_object.serialized_data.len(), xorb.num_chunks = serialized_cas_object.num_chunks @@ -702,99 +792,6 @@ impl Client for RemoteClient { fn use_shard_footer(&self) -> bool { false } - - #[cfg(not(target_family = "wasm"))] - async fn get_file( - &self, - hash: &MerkleHash, - byte_range: Option, - output_provider: &OutputProvider, - progress_updater: Option>, - ) -> Result { - // If the user has set the `HF_XET_RECONSTRUCT_WRITE_SEQUENTIALLY=true` env variable, then we - // should write the file to the output sequentially instead of in parallel. - if *RECONSTRUCT_WRITE_SEQUENTIALLY { - info!("reconstruct terms sequentially"); - self.reconstruct_file_to_writer_segmented(hash, byte_range, output_provider, progress_updater) - .await - } else { - info!("reconstruct terms in parallel"); - self.reconstruct_file_to_writer_segmented_parallel_write( - hash, - byte_range, - output_provider, - progress_updater, - ) - .await - } - } - - #[instrument(skip_all, name = "RemoteClient::get_file_reconstruction", fields(file.hash = file_hash.hex() - ))] - async fn get_file_reconstruction_info( - &self, - file_hash: &MerkleHash, - ) -> Result)>> { - let url = Url::parse(&format!("{}/reconstructions/{}", self.endpoint, file_hash.hex()))?; - - let api_tag = "cas::get_reconstruction_info"; - let client = self.authenticated_http_client.clone(); - - let response: QueryReconstructionResponse = RetryWrapper::new(api_tag) - .run_and_extract_json(move || client.get(url.clone()).with_extension(Api(api_tag)).send()) - .await?; - - Ok(Some(( - MDBFileInfo { - metadata: FileDataSequenceHeader::new(*file_hash, response.terms.len(), false, false), - segments: response - .terms - .into_iter() - .map(|ce| { - FileDataSequenceEntry::new(ce.hash.into(), ce.unpacked_length, ce.range.start, ce.range.end) - }) - .collect(), - verification: vec![], - metadata_ext: None, - }, - None, - ))) - } - - #[instrument(skip_all, name = "RemoteClient::upload_shard", fields(shard.len = shard_data.len()))] - async fn upload_shard(&self, shard_data: Bytes) -> Result { - if self.dry_run { - return Ok(true); - } - - let api_tag = "cas::upload_shard"; - let client = self.authenticated_http_client.clone(); - - let url = Url::parse(&format!("{}/shards", self.endpoint))?; - - let response: UploadShardResponse = RetryWrapper::new(api_tag) - .run_and_extract_json(move || { - client - .post(url.clone()) - .with_extension(Api(api_tag)) - .body(shard_data.clone()) - .send() - }) - .await?; - - match response.result { - UploadShardResponseType::Exists => Ok(false), - UploadShardResponseType::SyncPerformed => Ok(true), - } - } - - async fn query_for_global_dedup_shard(&self, prefix: &str, chunk_hash: &MerkleHash) -> Result> { - let Some(response) = self.query_dedup_api(prefix, chunk_hash).await? else { - return Ok(None); - }; - - Ok(Some(response.bytes().await?)) - } } #[cfg(test)] @@ -813,7 +810,7 @@ mod tests { use xet_runtime::XetRuntime; use super::*; - use crate::output_provider::BufferProvider; + use crate::buffer_provider::ThreadSafeBuffer; #[ignore = "requires a running CAS server"] #[traced_test] @@ -1199,33 +1196,36 @@ mod tests { // test reconstruct and sequential write let test = test_case.clone(); let client = RemoteClient::new(endpoint, &None, &None, None, "", false); - let provider = BufferProvider::default(); - let buf = provider.buf.clone(); - let writer = OutputProvider::Buffer(provider); + let buf = ThreadSafeBuffer::default(); + let provider = SequentialOutput::from(buf.clone()); let resp = threadpool.external_run_async_task(async move { client - .reconstruct_file_to_writer_segmented(&test.file_hash, Some(test.file_range), &writer, None) + .reconstruct_file_to_writer_segmented_sequential_write( + &test.file_hash, + Some(test.file_range), + provider, + None, + ) .await })?; assert_eq!(test.expect_error, resp.is_err(), "{:?}", resp.err()); if !test.expect_error { - assert_eq!(test.expected_data.len() as u64, resp.unwrap()); + assert_eq!(test.expected_data.len() as u64, resp?); assert_eq!(test.expected_data, buf.value()); } // test reconstruct and parallel write let test = test_case; let client = RemoteClient::new(endpoint, &None, &None, None, "", false); - let provider = BufferProvider::default(); - let buf = provider.buf.clone(); - let writer = OutputProvider::Buffer(provider); + let buf = ThreadSafeBuffer::default(); + let provider = SeekingOutputProvider::from(buf.clone()); let resp = threadpool.external_run_async_task(async move { client .reconstruct_file_to_writer_segmented_parallel_write( &test.file_hash, Some(test.file_range), - &writer, + &provider, None, ) .await diff --git a/cas_object/src/byte_grouping/bg4.rs b/cas_object/src/byte_grouping/bg4.rs index 2d5dd900..672c226a 100644 --- a/cas_object/src/byte_grouping/bg4.rs +++ b/cas_object/src/byte_grouping/bg4.rs @@ -228,7 +228,7 @@ pub fn bg4_regroup_together_combined_write_8(g: &[u8]) -> Vec { copy_nonoverlapping(&eightbytes as *const u8, d_ptr.add(8 * i), 8); } - if split % 2 != 0 { + if !split.is_multiple_of(2) { let i = split - 1; let fourbytes = [*g0.add(i), *g1.add(i), *g2.add(i), *g3.add(i)]; data[4 * i..4 * i + 4].copy_from_slice(&fourbytes[..]); diff --git a/data/src/bin/example.rs b/data/src/bin/example.rs index f60390dc..64fde1db 100644 --- a/data/src/bin/example.rs +++ b/data/src/bin/example.rs @@ -4,7 +4,7 @@ use std::path::PathBuf; use std::sync::{Arc, OnceLock}; use anyhow::Result; -use cas_client::{FileProvider, OutputProvider}; +use cas_client::SeekingOutputProvider; use clap::{Args, Parser, Subcommand}; use data::configurations::*; use data::{FileDownloader, FileUploadSession, XetFileInfo}; @@ -120,13 +120,13 @@ async fn smudge_file(arg: &SmudgeArg) -> Result<()> { None => Box::new(std::io::stdin()), }; - let writer = OutputProvider::File(FileProvider::new(arg.dest.clone())); - smudge(arg.dest.to_string_lossy().into(), reader, &writer).await?; + let writer = SeekingOutputProvider::new_file_provider(arg.dest.clone()); + smudge(arg.dest.to_string_lossy().into(), reader, writer).await?; Ok(()) } -async fn smudge(name: Arc, mut reader: impl Read, writer: &OutputProvider) -> Result<()> { +async fn smudge(name: Arc, mut reader: impl Read, writer: SeekingOutputProvider) -> Result<()> { let mut input = String::new(); reader.read_to_string(&mut input)?; diff --git a/data/src/bin/xtool.rs b/data/src/bin/xtool.rs index a5e7f3ce..7cd66bed 100644 --- a/data/src/bin/xtool.rs +++ b/data/src/bin/xtool.rs @@ -11,7 +11,7 @@ use clap::{Args, Parser, Subcommand}; use data::data_client::default_config; use data::migration_tool::hub_client_token_refresher::HubClientTokenRefresher; use data::migration_tool::migrate::migrate_files_impl; -use hub_client::{BearerCredentialHelper, HubClient, Operation, RepoInfo}; +use hub_client::{BearerCredentialHelper, HubClient, HubXetTokenTrait, Operation, RepoInfo}; use merklehash::MerkleHash; use utils::auth::TokenRefresher; use walkdir::WalkDir; @@ -198,7 +198,7 @@ async fn query_reconstruction( hub_client: HubClient, ) -> Result> { let operation = Operation::Download; - let jwt_info = hub_client.get_cas_jwt(operation).await?; + let jwt_info = hub_client.get_xet_token(operation).await?; let token_refresher = Arc::new(HubClientTokenRefresher { operation, client: Arc::new(hub_client), diff --git a/data/src/data_client.rs b/data/src/data_client.rs index 1e67da03..366c1e7d 100644 --- a/data/src/data_client.rs +++ b/data/src/data_client.rs @@ -5,8 +5,10 @@ use std::io::Read; use std::path::{Path, PathBuf}; use std::sync::Arc; -use cas_client::remote_client::PREFIX_DEFAULT; -use cas_client::{CHUNK_CACHE_SIZE_BYTES, CacheConfig, FileProvider, OutputProvider}; +use cas_client::remote_client::{PREFIX_DEFAULT, RECONSTRUCT_WRITE_SEQUENTIALLY}; +use cas_client::{ + CHUNK_CACHE_SIZE_BYTES, CacheConfig, SeekingOutputProvider, SequentialOutput, sequential_output_from_filepath, +}; use cas_object::CompressionScheme; use deduplication::DeduplicationMetrics; use dirs::home_dir; @@ -274,14 +276,30 @@ async fn smudge_file( if let Some(parent_dir) = path.parent() { std::fs::create_dir_all(parent_dir)?; } - let output = OutputProvider::File(FileProvider::new(path)); // Wrap the progress updater in the proper tracking struct. let progress_updater = progress_updater.map(ItemProgressUpdater::new); - downloader - .smudge_file_from_hash(&file_info.merkle_hash()?, file_path.into(), &output, None, progress_updater) - .await?; + if *RECONSTRUCT_WRITE_SEQUENTIALLY { + let output: SequentialOutput = sequential_output_from_filepath(file_path)?; + info!("Using sequential writer for smudge"); + downloader + .smudge_file_from_hash_sequential( + &file_info.merkle_hash()?, + file_path.into(), + output, + None, + progress_updater, + ) + .await?; + } else { + let output = SeekingOutputProvider::new_file_provider(path); + info!("Using parallel writer for smudge"); + downloader + .smudge_file_from_hash(&file_info.merkle_hash()?, file_path.into(), output, None, progress_updater) + .await?; + }; + Ok(file_path.to_string()) } diff --git a/data/src/file_downloader.rs b/data/src/file_downloader.rs index 9bc96558..fed190fc 100644 --- a/data/src/file_downloader.rs +++ b/data/src/file_downloader.rs @@ -1,11 +1,11 @@ use std::borrow::Cow; use std::sync::Arc; -use cas_client::{Client, OutputProvider}; +use cas_client::{Client, SeekingOutputProvider, SequentialOutput}; use cas_types::FileRange; use merklehash::MerkleHash; use progress_tracking::item_tracking::ItemProgressUpdater; -use tracing::instrument; +use tracing::{info, instrument}; use ulid::Ulid; use crate::configurations::TranslatorConfig; @@ -16,8 +16,10 @@ use crate::remote_client_interface::create_remote_client; /// Manages the download of files based on a hash or pointer file. /// /// This class handles the clean operations. It's meant to be a single atomic session -/// that succeeds or fails as a unit; i.e. all files get uploaded on finalization, and all shards +/// that succeeds or fails as a unit; i.e., all files get uploaded on finalization, and all shards /// and xorbs needed to reconstruct those files are properly uploaded and registered. +/// Cheaply cloneable +#[derive(Clone)] pub struct FileDownloader { /* ----- Configurations ----- */ config: Arc, @@ -37,19 +39,46 @@ impl FileDownloader { Ok(Self { config, client }) } - #[instrument(skip_all, name = "FileDownloader::smudge_file_from_hash", fields(hash=file_id.hex()))] + #[instrument(skip_all, name = "FileDownloader::smudge_file_from_hash", fields(hash=file_id.hex() + ))] pub async fn smudge_file_from_hash( &self, file_id: &MerkleHash, file_name: Arc, - output: &OutputProvider, + output: SeekingOutputProvider, + range: Option, + progress_updater: Option>, + ) -> Result { + let file_progress_tracker = progress_updater.map(|p| ItemProgressUpdater::item_tracker(&p, file_name, None)); + + let n_bytes = self + .client + .get_file_with_parallel_writer(file_id, range, output, file_progress_tracker) + .await?; + + prometheus_metrics::FILTER_BYTES_SMUDGED.inc_by(n_bytes); + + Ok(n_bytes) + } + + #[instrument(skip_all, name = "FileDownloader::smudge_file_from_hash", fields(hash=file_id.hex() + ))] + pub async fn smudge_file_from_hash_sequential( + &self, + file_id: &MerkleHash, + file_name: Arc, + output: SequentialOutput, range: Option, progress_updater: Option>, ) -> Result { let file_progress_tracker = progress_updater.map(|p| ItemProgressUpdater::item_tracker(&p, file_name, None)); // Currently, this works by always directly querying the remote server. - let n_bytes = self.client.get_file(file_id, range, output, file_progress_tracker).await?; + info!("Using sequential writer for smudge"); + let n_bytes = self + .client + .get_file_with_sequential_writer(file_id, range, output, file_progress_tracker) + .await?; prometheus_metrics::FILTER_BYTES_SMUDGED.inc_by(n_bytes); diff --git a/data/src/file_upload_session.rs b/data/src/file_upload_session.rs index 9b81e7dd..b3e14cd7 100644 --- a/data/src/file_upload_session.rs +++ b/data/src/file_upload_session.rs @@ -627,7 +627,7 @@ mod tests { /// * `output_path`: path to write the hydrated/original file async fn test_smudge_file(cas_path: &Path, pointer_path: &Path, output_path: &Path) { let mut reader = File::open(pointer_path).unwrap(); - let writer = OutputProvider::File(FileProvider::new(output_path.to_path_buf())); + let writer = SeekingOutputProvider::new_file_provider(output_path.to_path_buf()); let mut input = String::new(); reader.read_to_string(&mut input).unwrap(); @@ -642,7 +642,7 @@ mod tests { .smudge_file_from_hash( &xet_file.merkle_hash().expect("File hash is not a valid file hash"), output_path.to_string_lossy().into(), - &writer, + writer, None, None, ) @@ -652,7 +652,7 @@ mod tests { use std::fs::{read, write}; - use cas_client::{FileProvider, OutputProvider}; + use cas_client::SeekingOutputProvider; use tempfile::tempdir; use super::*; diff --git a/data/src/migration_tool/hub_client_token_refresher.rs b/data/src/migration_tool/hub_client_token_refresher.rs index 87843aa4..7c0ad9e4 100644 --- a/data/src/migration_tool/hub_client_token_refresher.rs +++ b/data/src/migration_tool/hub_client_token_refresher.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use hub_client::{HubClient, Operation}; +use hub_client::{HubClient, HubXetTokenTrait, Operation}; use utils::auth::{TokenInfo, TokenRefresher}; use utils::errors::AuthError; @@ -15,7 +15,7 @@ impl TokenRefresher for HubClientTokenRefresher { async fn refresh(&self) -> std::result::Result { let jwt_info = self .client - .get_cas_jwt(self.operation) + .get_xet_token(self.operation) .await .map_err(AuthError::token_refresh_failure)?; diff --git a/data/src/migration_tool/migrate.rs b/data/src/migration_tool/migrate.rs index 13887c97..c8706911 100644 --- a/data/src/migration_tool/migrate.rs +++ b/data/src/migration_tool/migrate.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use anyhow::Result; use cas_object::CompressionScheme; -use hub_client::{BearerCredentialHelper, HubClient, Operation, RepoInfo}; +use hub_client::{BearerCredentialHelper, HubClient, HubXetTokenTrait, Operation, RepoInfo}; use mdb_shard::file_structs::MDBFileInfo; use tracing::{Instrument, Span, info_span, instrument}; use utils::auth::TokenRefresher; @@ -60,7 +60,7 @@ pub async fn migrate_files_impl( dry_run: bool, ) -> Result { let operation = Operation::Upload; - let jwt_info = hub_client.get_cas_jwt(operation).await?; + let jwt_info = hub_client.get_xet_token(operation).await?; let token_refresher = Arc::new(HubClientTokenRefresher { operation, client: Arc::new(hub_client), diff --git a/data/src/test_utils.rs b/data/src/test_utils.rs index 1949d9ff..2d5efe1f 100644 --- a/data/src/test_utils.rs +++ b/data/src/test_utils.rs @@ -3,7 +3,7 @@ use std::io::{Read, Write}; use std::path::{Path, PathBuf}; use std::sync::Arc; -use cas_client::{FileProvider, OutputProvider}; +use cas_client::SeekingOutputProvider; use progress_tracking::TrackingProgressUpdater; use rand::prelude::*; use tempfile::TempDir; @@ -214,7 +214,7 @@ impl LocalHydrateDehydrateTest { let out_filename = self.dest_dir.join(entry.file_name()); // Create an output file for writing - let file_out = OutputProvider::File(FileProvider::new(out_filename.clone())); + let file_out = SeekingOutputProvider::new_file_provider(out_filename.clone()); // Pointer file. let xf: XetFileInfo = serde_json::from_reader(File::open(entry.path()).unwrap()).unwrap(); @@ -223,7 +223,7 @@ impl LocalHydrateDehydrateTest { .smudge_file_from_hash( &xf.merkle_hash().unwrap(), out_filename.to_string_lossy().into(), - &file_out, + file_out, None, None, ) diff --git a/hf_xet/Cargo.lock b/hf_xet/Cargo.lock index 9904213b..7b4d4fd9 100644 --- a/hf_xet/Cargo.lock +++ b/hf_xet/Cargo.lock @@ -155,7 +155,7 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.89" +version = "0.1.90" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" dependencies = [ @@ -3827,6 +3827,7 @@ dependencies = [ "shellexpand", "thiserror 2.0.15", "tokio", + "tokio-util", "tracing", "web-time", ] diff --git a/hf_xet_thin_wasm/Cargo.lock b/hf_xet_thin_wasm/Cargo.lock index c7308943..8fb1e086 100644 --- a/hf_xet_thin_wasm/Cargo.lock +++ b/hf_xet_thin_wasm/Cargo.lock @@ -1395,6 +1395,19 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "tokio-util" +version = "0.7.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14307c986784f72ef81c89db7d9e28d6ac26d16213b109ea501696195e6e3ce5" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "tracing" version = "0.1.41" @@ -1472,6 +1485,7 @@ dependencies = [ "shellexpand", "thiserror", "tokio", + "tokio-util", "tracing", "web-time", ] diff --git a/hf_xet_wasm/Cargo.lock b/hf_xet_wasm/Cargo.lock index 19ea5e0a..7b1ff93a 100644 --- a/hf_xet_wasm/Cargo.lock +++ b/hf_xet_wasm/Cargo.lock @@ -108,7 +108,7 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.89" +version = "0.1.90" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" dependencies = [ @@ -2749,6 +2749,7 @@ dependencies = [ "shellexpand", "thiserror 2.0.16", "tokio", + "tokio-util", "tracing", "web-time", ] diff --git a/hub_client/Cargo.toml b/hub_client/Cargo.toml index e958310c..fafb3814 100644 --- a/hub_client/Cargo.toml +++ b/hub_client/Cargo.toml @@ -5,10 +5,13 @@ edition = "2024" [dependencies] cas_client = { path = "../cas_client" } +cas_types = { path = "../cas_types" } anyhow = { workspace = true } async-trait = { workspace = true } +bytes = { workspace = true } http = { workspace = true } +regex = { workspace = true } reqwest = { workspace = true } reqwest-middleware = { workspace = true } serde = { workspace = true } @@ -17,4 +20,8 @@ urlencoding = { workspace = true } [dev-dependencies] serde_json = { workspace = true } -tokio = { workspace = true } \ No newline at end of file +tokio = { workspace = true } + +[[example]] +path = "examples/list_files.rs" +name = "list-files" \ No newline at end of file diff --git a/hub_client/examples/list_files.rs b/hub_client/examples/list_files.rs new file mode 100644 index 00000000..a6e192f9 --- /dev/null +++ b/hub_client/examples/list_files.rs @@ -0,0 +1,30 @@ +use hub_client::{BearerCredentialHelper, HFRepoType, HubClient, HubRepositoryTrait, RepoInfo}; + +#[tokio::main] +async fn main() { + let token = std::env::var("HF_TOKEN").unwrap(); + + let client = HubClient::new( + "https://huggingface.co", + RepoInfo::new(HFRepoType::Model, "assafvayner/unsafetensors_new".to_string()), + None, + "", + "", + BearerCredentialHelper::new(token.to_string(), "assafvayner"), + ) + .unwrap(); + + let paths = client.list_files("").await.unwrap(); + println!("{:?}", paths); + + println!(); + for entry in &paths { + let Some(dir) = entry.as_directory() else { + continue; + }; + let sub_paths = client.list_files(dir.path.as_str()).await.unwrap(); + println!("sub paths of {}", dir.path); + println!("{:?}", sub_paths); + println!(); + } +} diff --git a/hub_client/src/client.rs b/hub_client/src/client.rs index 9eb7e54b..b9dff26b 100644 --- a/hub_client/src/client.rs +++ b/hub_client/src/client.rs @@ -1,13 +1,14 @@ +pub(crate) mod repo; +pub(crate) mod xet_token; + use std::sync::Arc; use cas_client::exports::ClientWithMiddleware; -use cas_client::{Api, ResponseErrorLogger, RetryConfig, build_http_client}; -use http::header; -use urlencoding::encode; +use cas_client::{RetryConfig, build_http_client}; use crate::auth::CredentialHelper; use crate::errors::*; -use crate::types::{CasJWTInfo, RepoInfo}; +use crate::types::RepoInfo; /// The type of operation to perform, either to upload files or to download files. /// Different operations lead to CAS access token with different authorization levels. @@ -60,53 +61,12 @@ impl HubClient { cred_helper, }) } - - // Get CAS access token from Hub access token. - pub async fn get_cas_jwt(&self, operation: Operation) -> Result { - let endpoint = self.endpoint.as_str(); - let repo_type = self.repo_info.repo_type.as_str(); - let repo_id = self.repo_info.full_name.as_str(); - let token_type = operation.token_type(); - - // The reference may contain "/" but the "xet-[]-token" API only parses "rev" from a single component, - // thus we encode the reference. It defaults to "main" if not specified by caller because the - // API route expects a "rev" component. - let rev = encode(self.reference.as_deref().unwrap_or("main")); - - // Clients can get a xet write token, if - // - the "rev" is a regular branch, with a HF write token; - // - the "rev" is a pr branch, with a HF write or read token; - // - it intends to create a pr and repo is enabled for discussion, with a HF write or read token. - let query = if matches!(operation, Operation::Upload) && self.reference.is_none() { - "?create_pr=1" - } else { - "" - }; - - // note that this API doesn't take a Basic auth - let url = format!("{endpoint}/api/{repo_type}s/{repo_id}/xet-{token_type}-token/{rev}{query}"); - - let req = self - .client - .get(url) - .with_extension(Api("xet-token")) - .header(header::USER_AGENT, &self.user_agent); - let req = self - .cred_helper - .fill_credential(req) - .await - .map_err(HubClientError::CredentialHelper)?; - let response = req.send().await.process_error("xet-write-token")?; - - let info: CasJWTInfo = response.json().await?; - - Ok(info) - } } #[cfg(test)] mod tests { use super::HubClient; + use crate::client::xet_token::HubXetTokenTrait; use crate::errors::Result; use crate::{BearerCredentialHelper, HFRepoType, Operation, RepoInfo}; @@ -126,7 +86,7 @@ mod tests { cred_helper, )?; - let read_info = hub_client.get_cas_jwt(Operation::Upload).await?; + let read_info = hub_client.get_xet_token(Operation::Upload).await?; assert!(read_info.access_token.len() > 0); assert!(read_info.cas_url.len() > 0); @@ -151,7 +111,7 @@ mod tests { cred_helper, )?; - let read_info = hub_client.get_cas_jwt(Operation::Upload).await?; + let read_info = hub_client.get_xet_token(Operation::Upload).await?; assert!(read_info.access_token.len() > 0); assert!(read_info.cas_url.len() > 0); @@ -176,7 +136,7 @@ mod tests { cred_helper, )?; - let read_info = hub_client.get_cas_jwt(Operation::Upload).await?; + let read_info = hub_client.get_xet_token(Operation::Upload).await?; assert!(read_info.access_token.len() > 0); assert!(read_info.cas_url.len() > 0); diff --git a/hub_client/src/client/repo.rs b/hub_client/src/client/repo.rs new file mode 100644 index 00000000..f5b70f3d --- /dev/null +++ b/hub_client/src/client/repo.rs @@ -0,0 +1,178 @@ +use std::sync::LazyLock; + +use bytes::Bytes; +use cas_client::{Api, ResponseErrorLogger}; +use cas_types::{FileRange, HexMerkleHash, HttpRange}; +use http::header; +use regex::Regex; +use reqwest::Response; +use serde::Deserialize; + +use crate::{HubClient, HubClientError}; + +#[async_trait::async_trait] +pub trait HubRepositoryTrait { + async fn list_files(&self, path: &str) -> crate::Result>; + async fn download_resolved_content(&self, path: &str, range: Option) -> crate::Result; +} + +#[derive(Debug, Deserialize)] +#[serde(tag = "type")] // tells Serde to use the "type" field to decide which variant to use +pub enum TreeEntry { + #[serde(rename = "file")] + File(FileEntry), + + #[serde(rename = "directory")] + Directory(DirectoryEntry), +} + +impl TreeEntry { + pub fn is_file(&self) -> bool { + matches!(self, TreeEntry::File(_)) + } + + pub fn is_directory(&self) -> bool { + matches!(self, TreeEntry::Directory(_)) + } + + pub fn as_file(&self) -> Option<&FileEntry> { + if let TreeEntry::File(file) = self { + Some(file) + } else { + None + } + } + + pub fn as_directory(&self) -> Option<&DirectoryEntry> { + if let TreeEntry::Directory(dir) = self { + Some(dir) + } else { + None + } + } + + pub fn path(&self) -> &str { + match self { + TreeEntry::File(file) => file.path.as_str(), + TreeEntry::Directory(dir) => dir.path.as_str(), + } + } +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct FileEntry { + pub oid: String, + pub size: u64, + pub lfs: Option, + pub xet_hash: Option, + pub path: String, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LfsInfo { + pub oid: String, + pub size: u64, + pub pointer_size: u64, +} + +#[derive(Debug, Deserialize)] +pub struct DirectoryEntry { + pub oid: String, + pub size: u64, + pub path: String, +} + +/// Extracts the URL from a Link header of the form: +/// `; rel="next"` +pub fn parse_link_url(response: &Response) -> Option { + let header = response.headers().get("link")?.to_str().ok()?; + // Compile the regex once (you could make it lazy_static if used often) + let re = LazyLock::new(|| Regex::new(r#"<([^>]+)>;\s*rel="next""#).unwrap()); + re.captures(header).and_then(|caps| caps.get(1)).map(|m| m.as_str().to_string()) +} + +#[async_trait::async_trait] +impl HubRepositoryTrait for HubClient { + async fn list_files(&self, path: &str) -> crate::Result> { + let endpoint = self.endpoint.as_str(); + let repo_type = self.repo_info.repo_type.as_str(); + let repo_id = self.repo_info.full_name.as_str(); + let rev = self.reference.as_deref().unwrap_or("main"); + let path = normalize_path(path); + let url = format!("{endpoint}/api/{repo_type}s/{repo_id}/tree/{rev}{path}?limit=1000"); + + let req = self + .client + .get(&url) + .with_extension(Api("tree")) + .header(header::USER_AGENT, &self.user_agent); + + let req = self + .cred_helper + .fill_credential(req) + .await + .map_err(HubClientError::CredentialHelper)?; + let response = req.send().await.process_error("list-files")?; + + let mut link = parse_link_url(&response); + let mut entries: Vec = response.json().await?; + + while let Some(page_url) = link.take() { + let response = self + .client + .get(page_url.as_str()) + .with_extension(Api("tree")) + .send() + .await + .process_error("list-files-pagination")?; + + link = parse_link_url(&response); + let page_entries: Vec = response.json().await?; + entries.extend(page_entries); + } + + Ok(entries) + } + + // TODO: have this interface return a Stream/Reader, want #528 + async fn download_resolved_content(&self, path: &str, range: Option) -> crate::Result { + let endpoint = self.endpoint.as_str(); + let repo_type = self.repo_info.repo_type.as_str_hide_model(); + let repo_type_str = if repo_type.is_empty() { + "" + } else { + &format!("{}s/", repo_type) + }; + let repo_id = self.repo_info.full_name.as_str(); + let rev = self.reference.as_deref().unwrap_or("main"); + + let path = normalize_path(path); + // https://huggingface.co/spaces/google/emoji-gemma/resolve/main/myemoji-gemma-3-270m-it.task?download=true + let url = format!("{endpoint}/{repo_type_str}{repo_id}/resolve/{rev}{path}?download=true"); + let mut req = self + .client + .get(url) + .with_extension(Api("resolve")) + .header(header::USER_AGENT, &self.user_agent); + if let Some(range) = range { + req = req.header(http::header::RANGE, HttpRange::from(range).range_header()); + } + let req = self + .cred_helper + .fill_credential(req) + .await + .map_err(HubClientError::CredentialHelper)?; + let result = req.send().await?.error_for_status()?.bytes().await?; + Ok(result) + } +} + +fn normalize_path(path: &str) -> String { + if path.is_empty() || path.starts_with('/') { + path.to_string() + } else { + format!("/{}", path) + } +} diff --git a/hub_client/src/client/xet_token.rs b/hub_client/src/client/xet_token.rs new file mode 100644 index 00000000..86ec8a49 --- /dev/null +++ b/hub_client/src/client/xet_token.rs @@ -0,0 +1,55 @@ +use cas_client::{Api, ResponseErrorLogger}; +use http::header; +use urlencoding::encode; + +use crate::{CasJWTInfo, HubClient, HubClientError, Operation}; + +#[async_trait::async_trait] +pub trait HubXetTokenTrait { + // Get CAS access token from Hub access token. + async fn get_xet_token(&self, operation: Operation) -> crate::Result; +} + +#[async_trait::async_trait] +impl HubXetTokenTrait for HubClient { + async fn get_xet_token(&self, operation: Operation) -> crate::Result { + let endpoint = self.endpoint.as_str(); + let repo_type = self.repo_info.repo_type.as_str(); + let repo_id = self.repo_info.full_name.as_str(); + let token_type = operation.token_type(); + + // The reference may contain "/" but the "xet-[]-token" API only parses "rev" from a single component, + // thus we encode the reference. It defaults to "main" if not specified by caller because the + // API route expects a "rev" component. + let rev = encode(self.reference.as_deref().unwrap_or("main")); + + // Clients can get a xet write token, if + // - the "rev" is a regular branch, with a HF write token; + // - the "rev" is a pr branch, with a HF write or read token; + // - it intends to create a pr and repo is enabled for discussion, with a HF write or read token. + let query = if matches!(operation, Operation::Upload) && self.reference.is_none() { + "?create_pr=1" + } else { + "" + }; + + // note that this API doesn't take a Basic auth + let url = format!("{endpoint}/api/{repo_type}s/{repo_id}/xet-{token_type}-token/{rev}{query}"); + + let req = self + .client + .get(url) + .with_extension(Api("xet-token")) + .header(header::USER_AGENT, &self.user_agent); + let req = self + .cred_helper + .fill_credential(req) + .await + .map_err(HubClientError::CredentialHelper)?; + let response = req.send().await.process_error("xet-write-token")?; + + let info: CasJWTInfo = response.json().await?; + + Ok(info) + } +} diff --git a/hub_client/src/lib.rs b/hub_client/src/lib.rs index f6b7015f..ec2e262b 100644 --- a/hub_client/src/lib.rs +++ b/hub_client/src/lib.rs @@ -4,6 +4,8 @@ mod errors; mod types; pub use auth::{BearerCredentialHelper, CredentialHelper, NoopCredentialHelper}; +pub use client::repo::*; +pub use client::xet_token::HubXetTokenTrait; pub use client::{HubClient, Operation}; pub use errors::{HubClientError, Result}; pub use types::{CasJWTInfo, HFRepoType, RepoInfo}; diff --git a/hub_client/src/types.rs b/hub_client/src/types.rs index 76c33c22..70526dbd 100644 --- a/hub_client/src/types.rs +++ b/hub_client/src/types.rs @@ -15,7 +15,7 @@ pub struct CasJWTInfo { } // This defines the exact three types of repos served on HF Hub. -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Clone, Copy)] pub enum HFRepoType { Model, Dataset, @@ -44,6 +44,14 @@ impl HFRepoType { HFRepoType::Space => "space", } } + + pub fn as_str_hide_model(&self) -> &str { + match self { + HFRepoType::Model => "", + HFRepoType::Dataset => "dataset", + HFRepoType::Space => "space", + } + } } impl Display for HFRepoType { @@ -61,6 +69,10 @@ pub struct RepoInfo { } impl RepoInfo { + pub fn new(repo_type: HFRepoType, full_name: String) -> Self { + Self { repo_type, full_name } + } + pub fn try_from(repo_type: &str, repo_id: &str) -> Result { Ok(Self { repo_type: repo_type.parse()?, diff --git a/utils/Cargo.toml b/utils/Cargo.toml index d29a5e8d..8fa3dc8f 100644 --- a/utils/Cargo.toml +++ b/utils/Cargo.toml @@ -24,6 +24,9 @@ thiserror = { workspace = true } tokio = { workspace = true, features = ["time", "rt", "macros", "sync"] } tracing = { workspace = true } +[target.'cfg(not(target_family = "wasm"))'.dependencies] +tokio-util = { workspace = true, features = ["io"] } + [target.'cfg(not(target_family = "wasm"))'.dev-dependencies] tempfile = { workspace = true } xet_runtime = { path = "../xet_runtime" } diff --git a/utils/src/lib.rs b/utils/src/lib.rs index d4ea1df9..a7d57aca 100644 --- a/utils/src/lib.rs +++ b/utils/src/lib.rs @@ -25,3 +25,6 @@ pub use file_paths::{CwdGuard, EnvVarGuard, normalized_path_from_user_string}; pub mod byte_size; pub use byte_size::ByteSize; + +#[cfg(not(target_family = "wasm"))] +pub mod pipe; diff --git a/utils/src/output_bytes.rs b/utils/src/output_bytes.rs index 8242574c..c15b01d3 100644 --- a/utils/src/output_bytes.rs +++ b/utils/src/output_bytes.rs @@ -17,7 +17,7 @@ pub fn output_bytes(v: u64) -> String { for (div, s) in map { let curr = v as f64 / div as f64; if v / div > 0 { - return if v % div == 0 { + return if v.is_multiple_of(div) { format!("{} {}", v / div, s) } else { format!("{curr:.2} {s}") diff --git a/utils/src/pipe.rs b/utils/src/pipe.rs new file mode 100644 index 00000000..7a1b0b30 --- /dev/null +++ b/utils/src/pipe.rs @@ -0,0 +1,101 @@ +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use bytes::Bytes; +use futures::Stream; +use tokio::io::AsyncWrite; +use tokio::sync::mpsc; +use tokio::sync::mpsc::error::TrySendError; +use tokio_util::io::StreamReader; + +pub fn pipe(buffer_size: usize) -> (ChannelWriter, ChannelStream) { + let (sender, receiver) = mpsc::channel(buffer_size); + (ChannelWriter::new(sender), ChannelStream::new(receiver)) +} + +/// Adapter that implements AsyncRead from an mpsc Receiver +pub struct ChannelStream(mpsc::Receiver>); + +impl ChannelStream { + fn new(rx: mpsc::Receiver>) -> Self { + Self(rx) + } + + pub fn reader(self) -> ChannelReader { + ChannelReader::new(self) + } +} + +impl Stream for ChannelStream { + type Item = io::Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.0.poll_recv(cx) + } +} + +type ChannelReader = StreamReader; + +/// Adapter that implements AsyncWrite from a mpsc Sender +pub struct ChannelWriter(mpsc::Sender>); + +impl ChannelWriter { + fn new(tx: mpsc::Sender>) -> Self { + Self(tx) + } +} + +impl AsyncWrite for ChannelWriter { + fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + let perm = match self.0.try_reserve() { + Ok(p) => p, + Err(TrySendError::Closed(_)) => { + return Poll::Ready(Err(io::Error::new(io::ErrorKind::BrokenPipe, "receiver closed"))); + }, + Err(TrySendError::Full(_)) => return Poll::Pending, + }; + + let data = Bytes::copy_from_slice(buf); + let len = data.len(); + perm.send(Ok(data)); + + Poll::Ready(Ok(len)) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + // mpsc channels don't buffer in the same way, so flush is a no-op + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + // Dropping the sender will close the channel + Poll::Ready(Ok(())) + } +} + +#[cfg(test)] +mod tests { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + use super::*; + + #[tokio::test] + async fn test_channel_read_write() { + let (mut writer, stream) = pipe(10); + let mut reader = stream.reader(); + + // Write some data + writer.write_all(b"Hello, ").await.unwrap(); + writer.write_all(b"World!").await.unwrap(); + + // Drop writer to signal EOF + drop(writer); + + // Read the data + let mut buf = Vec::new(); + reader.read_to_end(&mut buf).await.unwrap(); + + assert_eq!(buf, b"Hello, World!"); + } +} diff --git a/xet-mount/Cargo.toml b/xet-mount/Cargo.toml new file mode 100644 index 00000000..ac3ac0ef --- /dev/null +++ b/xet-mount/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "xet-mount" +version = "0.1.0" +edition = "2024" + +[[bin]] +name = "xet-mount" +path = "src/main.rs" + +[dependencies] +anyhow = { workspace = true } +async-trait = { workspace = true } +clap = { workspace = true } +tokio = { workspace = true, features = ["signal"] } +uuid = { workspace = true } + +cas_types = { path = "../cas_types" } +cas_client = { path = "../cas_client" } +data = { path = "../data" } +hub_client = { path = "../hub_client" } +merklehash = { path = "../merklehash" } +utils = { path = "../utils" } + +nfsserve = { git = "https://github.com/huggingface/nfsserve.git" } diff --git a/xet-mount/src/fs.rs b/xet-mount/src/fs.rs new file mode 100644 index 00000000..e8aa3a15 --- /dev/null +++ b/xet-mount/src/fs.rs @@ -0,0 +1,419 @@ +use std::collections::HashMap; +use std::path::Path; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; + +use cas_client::SequentialOutput; +use cas_types::FileRange; +use data::FileDownloader; +use hub_client::{HubClient, HubRepositoryTrait, TreeEntry}; +use merklehash::MerkleHash; +use nfsserve::nfs::{fattr3, fileid3, filename3, ftype3, nfspath3, nfsstat3, nfstime3, sattr3, specdata3}; +use nfsserve::vfs::{DirEntry, NFSFileSystem, ReadDirResult, VFSCapabilities}; +use tokio::io::AsyncReadExt; +use tokio::sync::{OnceCell, RwLock}; + +#[derive(Clone, Debug)] +enum Item { + Directory(Arc), + RegularFile(Arc), + XetFile(Arc), +} + +impl Item { + fn fattr3(&self) -> &fattr3 { + match self { + Item::Directory(dir) => &dir.fattr3, + Item::RegularFile(file) => &file.fattr3, + Item::XetFile(file) => &file.fattr3, + } + } + + fn fileid(&self) -> fileid3 { + self.fattr3().fileid + } + + fn path(&self) -> &str { + match self { + Item::Directory(dir) => dir.path.as_str(), + Item::RegularFile(file) => file.path.as_str(), + Item::XetFile(file) => file.path.as_str(), + } + } + + fn filename(&self) -> &[u8] { + Path::new(self.path()).file_name().map(|s| s.as_encoded_bytes()).unwrap_or(b"/") + } +} + +#[derive(Debug)] +struct RegularFile { + fattr3: fattr3, + path: String, +} + +impl From for Item { + fn from(value: RegularFile) -> Self { + Self::RegularFile(Arc::new(value)) + } +} + +#[derive(Debug)] +struct XetFile { + fattr3: fattr3, + path: String, + hash: MerkleHash, +} + +impl From for Item { + fn from(value: XetFile) -> Self { + Self::XetFile(Arc::new(value)) + } +} + +#[derive(Debug)] +struct Directory { + fattr3: fattr3, + children: OnceCell>, + path: String, +} + +impl From for Item { + fn from(value: Directory) -> Self { + Self::Directory(Arc::new(value)) + } +} + +pub struct XetFS { + inner: Arc, + quiet: bool, +} + +struct XetFSInner { + everything: RwLock>, + hub_client: Arc, + next_id: AtomicU64, + xet_downloader: FileDownloader, +} + +const ROOT_DIR_ID: fileid3 = 0; + +impl XetFS { + pub fn new(hub_client: Arc, xet_downloader: FileDownloader, quiet: bool) -> Self { + Self { + inner: Arc::new(XetFSInner::new(hub_client, xet_downloader)), + quiet, + } + } +} + +impl XetFSInner { + pub fn new(hub_client: Arc, xet_downloader: FileDownloader) -> Self { + let mut everything = HashMap::new(); + let root_attr = fattr3 { + ftype: ftype3::NF3DIR, + mode: 0o755, + nlink: 1, + uid: 0, + gid: 0, + size: 0, + used: 0, + rdev: specdata3::default(), + fsid: 0, + fileid: ROOT_DIR_ID, + atime: nfstime3::default(), + mtime: nfstime3::default(), + ctime: nfstime3::default(), + }; + let root = Directory { + fattr3: root_attr, + children: OnceCell::new(), + path: String::new(), + }; + everything.insert(ROOT_DIR_ID, Item::Directory(Arc::new(root))); + Self { + everything: RwLock::new(everything), + hub_client, + xet_downloader, + next_id: AtomicU64::new(ROOT_DIR_ID + 1), + } + } + + fn get_next_id(&self) -> fileid3 { + self.next_id.fetch_add(1, Ordering::AcqRel) + } + + // should be called in try_init for a children field + async fn get_children_for_path(&self, path: &str) -> Result, nfsstat3> { + let entries = self.hub_client.list_files(path).await.map_err(|_| nfsstat3::NFS3ERR_IO)?; + + // Build the children map + let mut children: Vec = Vec::with_capacity(entries.len()); + let mut everything_guard = self.everything.write().await; + for entry in entries { + let fileid = self.get_next_id(); + let item: Item = match entry { + TreeEntry::File(file_entry) => { + let attr = get_fattr3(fileid, file_entry.size, ftype3::NF3REG); + // Decide RegularFile vs XetFile based on xet_hash presence + if let Some(xet_hash) = file_entry.xet_hash { + XetFile { + fattr3: attr, + path: file_entry.path.clone(), + hash: xet_hash.into(), + } + .into() + } else { + RegularFile { + fattr3: attr, + path: file_entry.path.clone(), + } + .into() + } + }, + TreeEntry::Directory(dirent) => { + let attr = get_fattr3(fileid, dirent.size, ftype3::NF3DIR); + Directory { + fattr3: attr, + children: OnceCell::new(), + path: dirent.path.clone(), + } + .into() + }, + }; + children.push(item.clone()); + everything_guard.insert(fileid, item); + } + Ok(children) + } + + async fn get_dir(&self, dirid: fileid3) -> Result, nfsstat3> { + // Fetch the directory item + let maybe_item = { self.everything.read().await.get(&dirid).cloned() }; + match maybe_item { + Some(Item::Directory(dir)) => Ok(dir), + Some(_) => Err(nfsstat3::NFS3ERR_NOTDIR), + None => Err(nfsstat3::NFS3ERR_NOENT), + } + } + + async fn download_regular_file( + &self, + file: Arc, + offset: u64, + count: u32, + quiet: bool, + ) -> Result<(Vec, bool), nfsstat3> { + if !quiet { + eprintln!("Downloading regular file: {file:?}"); + } + let file_len = file.fattr3.size; + let past_the_end = offset + count as u64 > file_len; + + let data = self + .hub_client + .download_resolved_content(&file.path, Some(FileRange::new(offset, offset + count as u64))) + .await + .map_err(|_| nfsstat3::NFS3ERR_IO)? + .to_vec(); + + Ok((data, past_the_end)) + } + + async fn download_xet_file( + &self, + file: Arc, + offset: u64, + count: u32, + ) -> Result<(Vec, bool), nfsstat3> { + let file_len = file.fattr3.size; + let past_the_end = offset + count as u64 > file_len; + + let (w, s) = utils::pipe::pipe(10); + let sequential_output: SequentialOutput = Box::new(w); + + let downloader = self.xet_downloader.clone(); + let hash = file.hash; + let jh = tokio::spawn(async move { + downloader + .smudge_file_from_hash_sequential( + &hash, + file.path.clone().into(), + sequential_output, + Some(FileRange::new(offset, offset + count as u64)), + None, + ) + .await + }); + let mut res = Vec::with_capacity(1024.min(count as usize)); + s.reader().read_to_end(&mut res).await.map_err(|_| nfsstat3::NFS3ERR_IO)?; + // this should be instantaneous + jh.await.map_err(|_| nfsstat3::NFS3ERR_IO)?.map_err(|_| nfsstat3::NFS3ERR_IO)?; + + Ok((res, past_the_end)) + } +} + +fn get_fattr3(fileid: fileid3, filesize: u64, ftype: ftype3) -> fattr3 { + fattr3 { + ftype, + mode: 0o755, + nlink: 1, + uid: 0, + gid: 0, + size: filesize, + used: 0, + rdev: specdata3::default(), + fsid: 0, + fileid, + atime: nfstime3::default(), + mtime: nfstime3::default(), + ctime: nfstime3::default(), + } +} + +#[async_trait::async_trait] +impl NFSFileSystem for XetFS { + fn capabilities(&self) -> VFSCapabilities { + VFSCapabilities::ReadOnly + } + + fn root_dir(&self) -> fileid3 { + ROOT_DIR_ID + } + + async fn lookup(&self, dirid: fileid3, filename: &filename3) -> Result { + let dir = self.inner.get_dir(dirid).await?; + let children = dir + .children + .get_or_try_init(|| self.inner.get_children_for_path(dir.path.as_str())) + .await?; + for child in children { + if child.filename() == filename.0 { + return Ok(child.fileid()); + } + } + Err(nfsstat3::NFS3ERR_NOENT) + } + + async fn getattr(&self, id: fileid3) -> Result { + match self.inner.everything.read().await.get(&id) { + Some(item) => Ok(*item.fattr3()), + None => Err(nfsstat3::NFS3ERR_NOENT), + } + } + + async fn setattr(&self, _id: fileid3, _setattr: sattr3) -> Result { + Err(nfsstat3::NFS3ERR_ROFS) + } + + async fn read(&self, id: fileid3, offset: u64, count: u32) -> Result<(Vec, bool), nfsstat3> { + if !self.quiet { + eprintln!("read: id: {:?}, offset: {:?}, count: {:?}", id, offset, count); + } + let Some(item) = self.inner.everything.read().await.get(&id).cloned() else { + return Err(nfsstat3::NFS3ERR_NOENT); + }; + match item { + Item::Directory(_) => Err(nfsstat3::NFS3ERR_ISDIR), + Item::RegularFile(file) => self.inner.download_regular_file(file, offset, count, self.quiet).await, + Item::XetFile(file) => self.inner.download_xet_file(file, offset, count).await, + } + } + + async fn write(&self, _id: fileid3, _offset: u64, _data: &[u8]) -> Result { + Err(nfsstat3::NFS3ERR_ROFS) + } + + async fn create( + &self, + _dirid: fileid3, + _filename: &filename3, + _attr: sattr3, + ) -> Result<(fileid3, fattr3), nfsstat3> { + Err(nfsstat3::NFS3ERR_ROFS) + } + + async fn create_exclusive(&self, _dirid: fileid3, _filename: &filename3) -> Result { + Err(nfsstat3::NFS3ERR_ROFS) + } + + async fn mkdir(&self, _dirid: fileid3, _dirname: &filename3) -> Result<(fileid3, fattr3), nfsstat3> { + Err(nfsstat3::NFS3ERR_NOTSUPP) + } + + async fn remove(&self, _dirid: fileid3, _filename: &filename3) -> Result<(), nfsstat3> { + Err(nfsstat3::NFS3ERR_ROFS) + } + + async fn rename( + &self, + _from_dirid: fileid3, + _from_filename: &filename3, + _to_dirid: fileid3, + _to_filename: &filename3, + ) -> Result<(), nfsstat3> { + Err(nfsstat3::NFS3ERR_ROFS) + } + + async fn readdir( + &self, + dirid: fileid3, + start_after: fileid3, + max_entries: usize, + ) -> Result { + if !self.quiet { + eprintln!("readdir: dirid: {:?}, start_after: {:?}, max_entries: {:?}", dirid, start_after, max_entries); + } + // Fetch the directory item + let maybe_item = { self.inner.everything.read().await.get(&dirid).cloned() }; + let dir_item = match maybe_item { + Some(Item::Directory(dir)) => dir, + Some(_) => return Err(nfsstat3::NFS3ERR_NOTDIR), + None => return Err(nfsstat3::NFS3ERR_NOENT), + }; + + // If children not cached, load from hub and populate + let children = dir_item + .children + .get_or_try_init(|| self.inner.get_children_for_path(dir_item.path.as_str())) + .await?; + + let mut entries = vec![]; + let mut skipped = 0; + // since the children list is sorted, we should binary search over this list + // to find how many to skip + for child in children { + if child.fileid() <= start_after { + skipped += 1; + continue; + } + entries.push(DirEntry { + fileid: child.fileid(), + name: child.filename().into(), + attr: *child.fattr3(), + }); + debug_assert!(entries.len() <= max_entries); + if entries.len() == max_entries { + break; + } + } + let end = skipped + entries.len() == children.len(); + let result = ReadDirResult { entries, end }; + Ok(result) + } + + async fn symlink( + &self, + _dirid: fileid3, + _linkname: &filename3, + _symlink: &nfspath3, + _attr: &sattr3, + ) -> Result<(fileid3, fattr3), nfsstat3> { + Err(nfsstat3::NFS3ERR_ROFS) + } + + async fn readlink(&self, _id: fileid3) -> Result { + Err(nfsstat3::NFS3ERR_NOTSUPP) + } +} diff --git a/xet-mount/src/main.rs b/xet-mount/src/main.rs new file mode 100644 index 00000000..668f2142 --- /dev/null +++ b/xet-mount/src/main.rs @@ -0,0 +1,187 @@ +mod fs; + +use std::path::PathBuf; +use std::process::Command; +use std::sync::Arc; + +use clap::Parser; +use data::FileDownloader; +use data::data_client::default_config; +use data::migration_tool::hub_client_token_refresher::HubClientTokenRefresher; +use hub_client::{BearerCredentialHelper, HFRepoType, HubClient, HubXetTokenTrait, Operation, RepoInfo}; +use nfsserve::tcp::{NFSTcp, NFSTcpListener}; +use tokio::signal::unix::SignalKind; +use uuid::Uuid; + +use crate::fs::XetFS; + +#[derive(Parser, Debug)] +#[command( + name = "xet-mount", + version, + about = "Mount a Hugging Face repository to a local directory" +)] +struct MountArgs { + #[clap(short, long)] + repo_id: String, + #[clap(long, short = 't', default_value = "model")] + repo_type: HFRepoType, + #[clap(long, visible_alias = "ref", default_value = "main")] + reference: Option, + #[clap(long)] + token: Option, + #[clap(long)] + path: PathBuf, + #[clap(short, long)] + quiet: bool, +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let args = MountArgs::parse(); + eprintln!("{:?}", args); + + let session_id = Uuid::new_v4(); + let user_agent = format!("xet-mount/{}", env!("CARGO_PKG_VERSION")); + + let Some(token) = args.token.or_else(|| std::env::var("HF_TOKEN").ok()) else { + return Err("HF_TOKEN is not set".into()); + }; + + let cred_helper = BearerCredentialHelper::new(token, ""); + + let hub_client = Arc::new(HubClient::new( + "https://huggingface.co", + RepoInfo { + repo_type: args.repo_type, + full_name: args.repo_id, + }, + args.reference, + user_agent.as_str(), + session_id.to_string().as_str(), + cred_helper, + )?); + let jwt_info = hub_client.get_xet_token(Operation::Download).await?; + let token_refresher = HubClientTokenRefresher { + operation: Operation::Download, + client: hub_client.clone(), + }; + let config = default_config( + jwt_info.cas_url, + None, + Some((jwt_info.access_token, jwt_info.exp)), + Some(Arc::new(token_refresher)), + )?; + let xet_downloader = FileDownloader::new(Arc::new(config)).await?; + + let xfs = XetFS::new(hub_client, xet_downloader, args.quiet); + + let listener = NFSTcpListener::bind("127.0.0.1:11111", xfs) + .await + .expect("Failed to bind to port 11111"); + + let ip = listener.get_listen_ip().to_string(); + let hostport = listener.get_listen_port(); + + let task_handle = tokio::spawn(async move { listener.handle_forever().await }); + + let mount_path = utils::normalized_path_from_user_string(args.path.as_os_str().to_str().expect("invalid path")); + let cleanup_dir = !perform_mount(ip, hostport, mount_path.clone()).await?; + + // Set up signal handlers + let mut sigterm = tokio::signal::unix::signal(SignalKind::terminate())?; + let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt())?; + + let res = tokio::select! { + result = task_handle => result, + _ = sigterm.recv() => Ok(Ok(())), + _ = sigint.recv() => Ok(Ok(())), + }; + unmount(mount_path.clone(), cleanup_dir).await?; + res??; + Ok(()) +} + +#[cfg(target_os = "macos")] +const MOUNT_BIN: &str = "/sbin/mount"; +#[cfg(target_os = "macos")] +const UMOUNT_BIN: &str = "/sbin/umount"; + +#[cfg(target_os = "linux")] +const MOUNT_BIN: &str = "/sbin/mount.nfs"; +#[cfg(target_os = "linux")] +const UMOUNT_BIN: &str = "/usr/bin/umount"; + +// on success returns a boolean indicating whether the mount directory existed before being called +async fn perform_mount(ip: String, hostport: u16, mount_path: PathBuf) -> Result { + eprintln!("Performing mount... on {mount_path:?}"); + + let previously_existed = std::fs::exists(&mount_path)?; + if !previously_existed { + std::fs::create_dir_all(&mount_path)?; + } + let mut cmd = Command::new(MOUNT_BIN); + #[cfg(target_os = "macos")] + { + cmd.args(["-t", "nfs"]); + cmd.args([ + "-o", + &format!("rdonly,nolocks,vers=3,tcp,rsize=1048576,actimeo=120,port={hostport},mountport={hostport}"), + ]); + } + #[cfg(target_os = "linux")] + cmd.args([ + "-o", + &format!("ro,vers=3,tcp,mountport={hostport},port={hostport},rsize=1048576,actimeo=120,user,noacl,nolock"), + ]); + + cmd.arg(format!("{}:/", &ip)).arg(mount_path.clone()); + + if !cmd.status().is_ok_and(|e| e.success()) { + let mut cmd = Command::new("sudo"); + cmd.arg(MOUNT_BIN); + #[cfg(target_os = "macos")] + { + cmd.args(["-t", "nfs"]); + cmd.args([ + "-o", + &format!("rdonly,nolocks,vers=3,tcp,rsize=1048576,actimeo=120,port={hostport},mountport={hostport}"), + ]); + } + #[cfg(target_os = "linux")] + cmd.args([ + "-o", + &format!("ro,vers=3,tcp,mountport={hostport},port={hostport},rsize=1048576,actimeo=120,user,noacl,nolock"), + ]); + + cmd.arg(format!("{}:/", &ip)).arg(mount_path); + eprintln!("{cmd:?}"); + cmd.status()?; + } + + eprintln!("Mounted."); + + Ok(previously_existed) +} + +async fn unmount(mount_path: PathBuf, delete_path: bool) -> Result<(), anyhow::Error> { + eprintln!("Unmounting..."); + + let mut cmd = Command::new(UMOUNT_BIN); + cmd.arg("-f"); + cmd.arg(mount_path.clone()); + if !cmd.status().is_ok_and(|e| e.success()) { + let mut cmd = Command::new("sudo"); + cmd.arg(UMOUNT_BIN); + cmd.arg("-f"); + cmd.arg(mount_path.clone()); + cmd.status()?; + } + + eprintln!("Unmounted."); + if delete_path { + std::fs::remove_dir_all(&mount_path)?; + } + + Ok(()) +}