Skip to content

Commit 6411577

Browse files
authored
feat: implement Google Drive scope validation for source uploads (#77)
* feat: implement Google Drive scope validation for source uploads - Added functionality to validate Google Drive access token scopes when adding sources, ensuring the token includes the required `drive.file` or broader `drive` scope. - Introduced `ensure_drive_scope` method in the `NblmClient` to check for necessary permissions before processing Drive-related uploads. - Updated CLI and SDK documentation to reflect the new validation requirements for Drive sources. - Enhanced tests to cover scenarios for both valid and invalid Drive access tokens, ensuring robust error handling and user feedback. * chore: bump version to 0.2.0 for nblm-cli, nblm-core, nblm-python, and Python package - Updated version numbers in Cargo.toml files for nblm-cli, nblm-core, and nblm-python to 0.2.0. - Adjusted the Python package version in pyproject.toml and the corresponding lock file. - Enhanced bump-version script to automate version updates and ensure consistency across packages.
1 parent adf58ba commit 6411577

File tree

16 files changed

+449
-23
lines changed

16 files changed

+449
-23
lines changed

crates/nblm-cli/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "nblm-cli"
3-
version = "0.1.5"
3+
version = "0.2.0"
44
edition = "2021"
55
license = "MIT"
66
description = "Command-line interface for NotebookLM Enterprise API"
@@ -24,7 +24,7 @@ tokio = { version = "1.48.0", features = ["macros", "rt-multi-thread"] }
2424
async-trait = "0.1.83"
2525
tracing = "0.1.41"
2626
tracing-subscriber = { version = "0.3.20", features = ["env-filter", "fmt", "json"] }
27-
nblm-core = { version = "0.1.5", path = "../nblm-core" }
27+
nblm-core = { version = "0.2.0", path = "../nblm-core" }
2828
humantime = "2.3.0"
2929
url = "2.5.7"
3030
mime_guess = "2.0.5"

crates/nblm-cli/src/app.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,14 @@ impl NblmApp {
6262
}
6363

6464
pub async fn run(self) -> Result<()> {
65-
let json_mode = self.cli.global.json;
66-
match self.cli.command {
67-
Command::Notebooks(cmd) => notebooks::run(cmd, &self.client, json_mode).await,
68-
Command::Sources(cmd) => sources::run(cmd, &self.client, json_mode).await,
69-
Command::Audio(cmd) => audio::run(cmd, &self.client, json_mode).await,
70-
Command::Share(cmd) => share::run(cmd, &self.client, json_mode).await,
65+
let NblmApp { cli, client } = self;
66+
67+
let json_mode = cli.global.json;
68+
match cli.command {
69+
Command::Notebooks(cmd) => notebooks::run(cmd, &client, json_mode).await,
70+
Command::Sources(cmd) => sources::run(cmd, &client, json_mode).await,
71+
Command::Audio(cmd) => audio::run(cmd, &client, json_mode).await,
72+
Command::Share(cmd) => share::run(cmd, &client, json_mode).await,
7173
Command::Doctor(cmd) => doctor::run(cmd).await,
7274
}
7375
}

crates/nblm-core/Cargo.toml

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "nblm-core"
3-
version = "0.1.5"
3+
version = "0.2.0"
44
edition = "2021"
55
license = "MIT"
66
description = "Core library for NotebookLM Enterprise API client"
@@ -17,11 +17,19 @@ doctest = false
1717
[dependencies]
1818
anyhow = "1.0.100"
1919
async-trait = "0.1.89"
20-
reqwest = { version = "0.12.24", default-features = false, features = ["json", "rustls-tls"] }
20+
reqwest = { version = "0.12.24", default-features = false, features = [
21+
"json",
22+
"rustls-tls",
23+
] }
2124
serde = { version = "1.0.228", features = ["derive"] }
2225
serde_json = "1.0.145"
2326
thiserror = "2.0.17"
24-
tokio = { version = "1.48.0", features = ["macros", "rt-multi-thread", "process", "time"] }
27+
tokio = { version = "1.48.0", features = [
28+
"macros",
29+
"rt-multi-thread",
30+
"process",
31+
"time",
32+
] }
2533
url = "2.5.7"
2634
backon = "1.6.0"
2735
tracing = "0.1.41"
@@ -30,3 +38,7 @@ bytes = "1.7.1"
3038

3139
[features]
3240
default = []
41+
42+
[dev-dependencies]
43+
wiremock = "0.6.5"
44+
serial_test = "3.2.0"

crates/nblm-core/src/auth.rs

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use std::env;
22

33
use async_trait::async_trait;
4+
use reqwest::Client;
5+
use serde::Deserialize;
46
use tokio::process::Command;
57

68
use crate::error::{Error, Result};
@@ -40,6 +42,81 @@ pub trait TokenProvider: Send + Sync {
4042
}
4143
}
4244

45+
const TOKENINFO_ENDPOINT: &str = "https://www.googleapis.com/oauth2/v3/tokeninfo";
46+
const DRIVE_SCOPE: &str = "https://www.googleapis.com/auth/drive";
47+
const DRIVE_FILE_SCOPE: &str = "https://www.googleapis.com/auth/drive.file";
48+
49+
#[derive(Debug, Deserialize)]
50+
struct TokenInfoResponse {
51+
scope: Option<String>,
52+
}
53+
54+
pub async fn ensure_drive_scope(provider: &dyn TokenProvider) -> Result<()> {
55+
let client = Client::new();
56+
let endpoint =
57+
std::env::var("NBLM_TOKENINFO_ENDPOINT").unwrap_or_else(|_| TOKENINFO_ENDPOINT.to_string());
58+
ensure_drive_scope_internal(provider, &client, &endpoint).await
59+
}
60+
61+
async fn ensure_drive_scope_internal(
62+
provider: &dyn TokenProvider,
63+
client: &Client,
64+
endpoint: &str,
65+
) -> Result<()> {
66+
let access_token = provider.access_token().await?;
67+
68+
let response = client
69+
.get(endpoint)
70+
.query(&[("access_token", access_token.as_str())])
71+
.send()
72+
.await
73+
.map_err(|err| {
74+
Error::TokenProvider(format!("failed to validate Google Drive token: {err}"))
75+
})?;
76+
77+
if !response.status().is_success() {
78+
let status = response.status();
79+
let body = response
80+
.text()
81+
.await
82+
.unwrap_or_else(|_| String::from("<failed to read body>"));
83+
return Err(Error::TokenProvider(format!(
84+
"failed to validate Google Drive token (status {}): {}",
85+
status.as_u16(),
86+
body.trim()
87+
)));
88+
}
89+
90+
let info: TokenInfoResponse = response
91+
.json()
92+
.await
93+
.map_err(|err| Error::TokenProvider(format!("invalid tokeninfo response: {err}")))?;
94+
95+
let scopes = info.scope.unwrap_or_default();
96+
if scope_grants_drive_access(&scopes) {
97+
Ok(())
98+
} else {
99+
Err(Error::TokenProvider(
100+
"Google Drive access token is missing the required drive.file scope. Run `gcloud auth login --enable-gdrive-access` and retry.".to_string(),
101+
))
102+
}
103+
}
104+
105+
fn scope_grants_drive_access(scopes: &str) -> bool {
106+
scopes
107+
.split_whitespace()
108+
.any(|scope| scope == DRIVE_FILE_SCOPE || scope == DRIVE_SCOPE)
109+
}
110+
111+
#[cfg(test)]
112+
pub(crate) async fn ensure_drive_scope_with_endpoint(
113+
provider: &dyn TokenProvider,
114+
client: &Client,
115+
endpoint: &str,
116+
) -> Result<()> {
117+
ensure_drive_scope_internal(provider, client, endpoint).await
118+
}
119+
43120
#[derive(Debug, Default, Clone)]
44121
pub struct GcloudTokenProvider {
45122
binary: String,
@@ -137,6 +214,8 @@ impl TokenProvider for StaticTokenProvider {
137214
#[cfg(test)]
138215
mod tests {
139216
use super::*;
217+
use wiremock::matchers::{method, path, query_param};
218+
use wiremock::{Mock, MockServer, ResponseTemplate};
140219

141220
#[tokio::test]
142221
async fn static_token_provider_returns_token() {
@@ -199,4 +278,93 @@ mod tests {
199278
let provider = StaticTokenProvider::new("token");
200279
assert_eq!(provider.kind(), ProviderKind::StaticToken);
201280
}
281+
282+
fn expect_scope_result(scopes: &str, expected: bool) {
283+
assert_eq!(scope_grants_drive_access(scopes), expected);
284+
}
285+
286+
#[test]
287+
fn scope_grants_drive_access_detects_required_scopes() {
288+
expect_scope_result(DRIVE_FILE_SCOPE, true);
289+
expect_scope_result(DRIVE_SCOPE, true);
290+
expect_scope_result(
291+
"https://www.googleapis.com/auth/spreadsheets.readonly",
292+
false,
293+
);
294+
expect_scope_result(
295+
&format!("{DRIVE_FILE_SCOPE} https://www.googleapis.com/auth/calendar"),
296+
true,
297+
);
298+
}
299+
300+
#[tokio::test]
301+
async fn ensure_drive_scope_accepts_valid_scope() {
302+
let server = MockServer::start().await;
303+
Mock::given(method("GET"))
304+
.and(path("/oauth2/v3/tokeninfo"))
305+
.and(query_param("access_token", "valid-token"))
306+
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
307+
"scope": DRIVE_FILE_SCOPE
308+
})))
309+
.mount(&server)
310+
.await;
311+
312+
let provider = StaticTokenProvider::new("valid-token");
313+
let client = reqwest::Client::new();
314+
let endpoint = format!("{}/oauth2/v3/tokeninfo", server.uri());
315+
let result = ensure_drive_scope_with_endpoint(&provider, &client, &endpoint).await;
316+
assert!(result.is_ok());
317+
}
318+
319+
#[tokio::test]
320+
async fn ensure_drive_scope_rejects_missing_scope() {
321+
let server = MockServer::start().await;
322+
Mock::given(method("GET"))
323+
.and(path("/oauth2/v3/tokeninfo"))
324+
.and(query_param("access_token", "no-scope"))
325+
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
326+
"scope": "https://www.googleapis.com/auth/spreadsheets.readonly"
327+
})))
328+
.mount(&server)
329+
.await;
330+
331+
let provider = StaticTokenProvider::new("no-scope");
332+
let client = reqwest::Client::new();
333+
let endpoint = format!("{}/oauth2/v3/tokeninfo", server.uri());
334+
let err = ensure_drive_scope_with_endpoint(&provider, &client, &endpoint)
335+
.await
336+
.unwrap_err();
337+
338+
match err {
339+
Error::TokenProvider(message) => {
340+
assert!(message.contains("drive.file scope"));
341+
}
342+
_ => panic!("expected TokenProvider error"),
343+
}
344+
}
345+
346+
#[tokio::test]
347+
async fn ensure_drive_scope_converts_http_failures() {
348+
let server = MockServer::start().await;
349+
Mock::given(method("GET"))
350+
.and(path("/oauth2/v3/tokeninfo"))
351+
.and(query_param("access_token", "bad-token"))
352+
.respond_with(ResponseTemplate::new(400).set_body_string("invalid_token"))
353+
.mount(&server)
354+
.await;
355+
356+
let provider = StaticTokenProvider::new("bad-token");
357+
let client = reqwest::Client::new();
358+
let endpoint = format!("{}/oauth2/v3/tokeninfo", server.uri());
359+
let err = ensure_drive_scope_with_endpoint(&provider, &client, &endpoint)
360+
.await
361+
.unwrap_err();
362+
363+
match err {
364+
Error::TokenProvider(message) => {
365+
assert!(message.contains("status 400"));
366+
}
367+
_ => panic!("expected TokenProvider error"),
368+
}
369+
}
202370
}

0 commit comments

Comments
 (0)