diff --git a/Cargo.toml b/Cargo.toml index 938336a91..a01ee9220 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,3 +79,8 @@ tree-sitter-yaml = "0.7.0" globset = "0.4.16" unicase = "2.8.1" +google-drive3 = "6.0.0" +hyper-util = "0.1.10" +hyper-rustls = { version = "0.27.5", features = ["ring"] } +yup-oauth2 = "12.1.0" +rustls = { version = "0.23.25", features = ["ring"] } diff --git a/python/cocoindex/sources.py b/python/cocoindex/sources.py index a443f3202..decaeca1b 100644 --- a/python/cocoindex/sources.py +++ b/python/cocoindex/sources.py @@ -16,3 +16,13 @@ class LocalFile(op.SourceSpec): # If provided, files matching these patterns will be excluded. # See https://docs.rs/globset/latest/globset/index.html#syntax for the syntax of the patterns. excluded_patterns: list[str] | None = None + + +class GoogleDrive(op.SourceSpec): + """Import data from Google Drive.""" + + _op_category = op.OpCategory.SOURCE + + service_account_credential_path: str + root_folder_id: str + binary: bool = False diff --git a/src/ops/registration.rs b/src/ops/registration.rs index 1a258fc94..5d7678b97 100644 --- a/src/ops/registration.rs +++ b/src/ops/registration.rs @@ -7,8 +7,11 @@ use std::sync::{Arc, LazyLock, RwLock, RwLockReadGuard}; fn register_executor_factories(registry: &mut ExecutorFactoryRegistry) -> Result<()> { sources::local_file::Factory.register(registry)?; + sources::google_drive::Factory.register(registry)?; + functions::split_recursively::Factory.register(registry)?; functions::extract_by_llm::Factory.register(registry)?; + Arc::new(storages::postgres::Factory::default()).register(registry)?; Ok(()) diff --git a/src/ops/sources/google_drive.rs b/src/ops/sources/google_drive.rs new file mode 100644 index 000000000..7417679e1 --- /dev/null +++ b/src/ops/sources/google_drive.rs @@ -0,0 +1,139 @@ +use std::sync::Arc; + +use google_drive3::{ + api::Scope, + yup_oauth2::{read_service_account_key, ServiceAccountAuthenticator}, + DriveHub, +}; +use hyper_rustls::HttpsConnector; +use hyper_util::client::legacy::connect::HttpConnector; + +use crate::ops::sdk::*; + +#[derive(Debug, Deserialize)] +pub struct Spec { + service_account_credential_path: String, + binary: bool, + root_folder_id: String, +} + +struct Executor { + drive_hub: DriveHub>, + binary: bool, + root_folder_id: String, +} + +impl Executor { + async fn new(spec: Spec) -> Result { + // let user_secret = read_authorized_user_secret(spec.service_account_credential_path).await?; + let service_account_key = + read_service_account_key(spec.service_account_credential_path).await?; + let auth = ServiceAccountAuthenticator::builder(service_account_key) + .build() + .await?; + let client = + hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()) + .build( + hyper_rustls::HttpsConnectorBuilder::new() + .with_provider_and_native_roots(rustls::crypto::ring::default_provider())? + .https_only() + .enable_http1() + .build(), + ); + let drive_hub = DriveHub::new(client, auth); + Ok(Self { + drive_hub, + binary: spec.binary, + root_folder_id: spec.root_folder_id, + }) + } +} + +fn escape_string(s: &str) -> String { + let mut escaped = String::with_capacity(s.len()); + for c in s.chars() { + match c { + '\'' | '\\' => escaped.push('\\'), + _ => {} + } + escaped.push(c); + } + escaped +} + +#[async_trait] +impl SourceExecutor for Executor { + async fn list_keys(&self) -> Result> { + let query = format!("'{}' in parents", escape_string(&self.root_folder_id)); + let mut next_page_token: Option = None; + let mut result = Vec::new(); + loop { + let mut list_call = self + .drive_hub + .files() + .list() + .q(&query) + .add_scope(Scope::Readonly); + if let Some(next_page_token) = &next_page_token { + list_call = list_call.page_token(next_page_token); + } + let (resp, files) = list_call.doit().await?; + if let Some(files) = files.files { + for file in files { + if let Some(name) = file.name { + result.push(KeyValue::Str(Arc::from(name))); + } + } + } + next_page_token = files.next_page_token; + if next_page_token.is_none() { + break; + } + } + Ok(result) + } + + async fn get_value(&self, key: &KeyValue) -> Result> { + unimplemented!() + } +} + +pub struct Factory; + +#[async_trait] +impl SourceFactoryBase for Factory { + type Spec = Spec; + + fn name(&self) -> &str { + "GoogleDrive" + } + + fn get_output_schema( + &self, + spec: &Spec, + _context: &FlowInstanceContext, + ) -> Result { + Ok(make_output_type(CollectionSchema::new( + CollectionKind::Table, + vec![ + FieldSchema::new("filename", make_output_type(BasicValueType::Str)), + FieldSchema::new( + "content", + make_output_type(if spec.binary { + BasicValueType::Bytes + } else { + BasicValueType::Str + }), + ), + ], + ))) + } + + async fn build_executor( + self: Arc, + spec: Spec, + _context: Arc, + ) -> Result> { + Ok(Box::new(Executor::new(spec).await?)) + } +} diff --git a/src/ops/sources/mod.rs b/src/ops/sources/mod.rs index 6a32d435b..d5b45e352 100644 --- a/src/ops/sources/mod.rs +++ b/src/ops/sources/mod.rs @@ -1 +1,2 @@ +pub mod google_drive; pub mod local_file;