Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/collection_manager/sides/operation/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,13 @@ pub enum CollectionWriteOperation {
UpdateMcpDescription {
mcp_description: Option<String>,
},
UpdateReadApiKey {
#[serde(
deserialize_with = "deserialize_api_key",
serialize_with = "serialize_api_key"
)]
read_api_key: ApiKey,
},
PinRule(PinRuleOperation<DocumentId>),
Shelf(ShelfOperation<DocumentId>),
DocumentStorage(DocumentStorageWriteOperation),
Expand Down Expand Up @@ -395,6 +402,9 @@ impl WriteOperation {
_,
CollectionWriteOperation::UpdateMcpDescription { .. },
) => "update_mcp_description",
WriteOperation::Collection(_, CollectionWriteOperation::UpdateReadApiKey { .. }) => {
"update_read_api_key"
}
WriteOperation::Collection(
_,
CollectionWriteOperation::IndexWriteOperation(_, IndexWriteOperation::Index { .. }),
Expand Down
28 changes: 20 additions & 8 deletions src/collection_manager/sides/read/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@
default_locale: Locale,
deleted: bool,

read_api_key: ApiKey,
read_api_key: OramaAsyncLock<ApiKey>,
write_api_key: Option<ApiKey>,
context: ReadSideContext,
offload_config: OffloadFieldConfig,
Expand Down Expand Up @@ -163,7 +163,7 @@
}

impl CollectionReader {
pub async fn empty(

Check warning on line 166 in src/collection_manager/sides/read/collection.rs

View workflow job for this annotation

GitHub Actions / build

this function has too many arguments (10/7)
data_dir: PathBuf,
collection_id: CollectionId,
description: Option<String>,
Expand Down Expand Up @@ -196,7 +196,7 @@
default_locale,
deleted: false,

read_api_key,
read_api_key: OramaAsyncLock::new("collection_read_api_key", read_api_key),
write_api_key,

context,
Expand Down Expand Up @@ -361,7 +361,7 @@
default_locale: dump.default_locale,
deleted: false,

read_api_key: dump.read_api_key,
read_api_key: OramaAsyncLock::new("collection_read_api_key", dump.read_api_key),
write_api_key: dump.write_api_key,

context,
Expand Down Expand Up @@ -546,7 +546,7 @@
description: self.description.clone(),
mcp_description: self.mcp_description.read("commit").await.clone(),
default_locale: self.default_locale,
read_api_key: self.read_api_key,
read_api_key: **self.read_api_key.read("commit").await,
write_api_key: self.write_api_key,
index_ids,
temp_index_ids,
Expand Down Expand Up @@ -672,16 +672,16 @@
.load(std::sync::atomic::Ordering::Relaxed)
}

#[inline]
#[allow(clippy::result_large_err)]
pub fn check_read_api_key(
pub async fn check_read_api_key(
&self,
api_key: &ReadApiKey,
master_api_key: Option<ApiKey>,
) -> Result<(), ReadError> {
let read_api_key = self.read_api_key.read("check_read_api_key").await;
match api_key {
ReadApiKey::ApiKey(api_key) => {
if *api_key == self.read_api_key {
if *api_key == **read_api_key {
return Ok(());
}
if let Some(write_api_key) = self.write_api_key {
Expand All @@ -697,7 +697,7 @@
}
ReadApiKey::Claims(claims) => {
// For JWT claims, verify the orak matches this collection's read API key
if claims.orak == self.read_api_key {
if claims.orak == **read_api_key {
return Ok(());
}
}
Expand Down Expand Up @@ -760,6 +760,15 @@
Ok(())
}

/// Updates the read API key for this collection.
pub async fn update_read_api_key(&self, new_key: ApiKey) -> Result<()> {
let mut read_api_key_lock = self.read_api_key.write("update_read_api_key").await;
**read_api_key_lock = new_key;
drop(read_api_key_lock);

Ok(())
}

pub async fn nlp_search(
&self,
read_side: State<Arc<ReadSide>>,
Expand Down Expand Up @@ -1107,6 +1116,9 @@
CollectionWriteOperation::UpdateMcpDescription { mcp_description } => {
self.update_mcp_description(mcp_description).await?;
}
CollectionWriteOperation::UpdateReadApiKey { read_api_key } => {
self.update_read_api_key(read_api_key).await?;
}
CollectionWriteOperation::PinRule(op) => {
println!("Applying pin rule operation: {op:?}");
let mut pin_rules_lock = self.pin_rules_reader.write("update_pin_rule").await;
Expand Down
45 changes: 34 additions & 11 deletions src/collection_manager/sides/read/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,9 @@ impl ReadSide {
.get_collection(collection_id)
.await
.ok_or_else(|| ReadError::NotFound(collection_id))?;
collection.check_read_api_key(read_api_key, self.master_api_key)?;
collection
.check_read_api_key(read_api_key, self.master_api_key)
.await?;

collection.stats(req).await
}
Expand All @@ -428,7 +430,9 @@ impl ReadSide {
.get_collection(collection_id)
.await
.ok_or_else(|| ReadError::NotFound(collection_id))?;
collection.check_read_api_key(read_api_key, self.master_api_key)?;
collection
.check_read_api_key(read_api_key, self.master_api_key)
.await?;

collection.batch_get_documents(doc_id_strs).await
}
Expand All @@ -444,7 +448,9 @@ impl ReadSide {
.get_collection(collection_id)
.await
.ok_or_else(|| ReadError::NotFound(collection_id))?;
collection.check_read_api_key(read_api_key, self.master_api_key)?;
collection
.check_read_api_key(read_api_key, self.master_api_key)
.await?;

let fields = collection.get_filterable_fields(with_keys).await?;

Expand Down Expand Up @@ -571,7 +577,9 @@ impl ReadSide {
.get_collection(collection_id)
.await
.ok_or_else(|| ReadError::NotFound(collection_id))?;
collection.check_read_api_key(read_api_key, self.master_api_key)?;
collection
.check_read_api_key(read_api_key, self.master_api_key)
.await?;

// Extract extra claims from JWT token if present, otherwise use None for plain API key
let claims: Option<HashMap<String, Value>> = match read_api_key {
Expand Down Expand Up @@ -628,7 +636,9 @@ impl ReadSide {
.get_collection(collection_id)
.await
.ok_or_else(|| ReadError::NotFound(collection_id))?;
collection.check_read_api_key(&read_api_key, self.master_api_key)?;
collection
.check_read_api_key(&read_api_key, self.master_api_key)
.await?;

let collection_stats = self
.collection_stats(
Expand Down Expand Up @@ -665,7 +675,9 @@ impl ReadSide {
.get_collection(collection_id)
.await
.ok_or_else(|| ReadError::NotFound(collection_id))?;
collection.check_read_api_key(read_api_key, self.master_api_key)?;
collection
.check_read_api_key(read_api_key, self.master_api_key)
.await?;

let collection_stats = self
.collection_stats(
Expand Down Expand Up @@ -697,7 +709,9 @@ impl ReadSide {
None => return Err(ReadError::NotFound(collection_id)),
Some(collection) => collection,
};
collection.check_read_api_key(read_api_key, self.master_api_key)?;
collection
.check_read_api_key(read_api_key, self.master_api_key)
.await?;

Ok(collection)
}
Expand Down Expand Up @@ -784,7 +798,9 @@ impl ReadSide {
.await
.ok_or_else(|| ReadError::NotFound(collection_id))?;

collection.check_read_api_key(read_api_key, self.master_api_key)
collection
.check_read_api_key(read_api_key, self.master_api_key)
.await
}

pub fn is_gpu_overloaded(&self) -> bool {
Expand Down Expand Up @@ -851,7 +867,9 @@ impl ReadSide {
.get_collection(collection_id)
.await
.ok_or_else(|| ReadError::NotFound(collection_id))?;
collection.check_read_api_key(read_api_key, self.master_api_key)?;
collection
.check_read_api_key(read_api_key, self.master_api_key)
.await?;

let known_prompt: KnownPrompts = system_prompt_id
.as_str()
Expand All @@ -874,7 +892,9 @@ impl ReadSide {
.get_collection(collection_id)
.await
.ok_or_else(|| ReadError::NotFound(collection_id))?;
collection.check_read_api_key(read_api_key, self.master_api_key)?;
collection
.check_read_api_key(read_api_key, self.master_api_key)
.await?;

match self
.training_sets
Expand All @@ -899,7 +919,10 @@ impl ReadSide {
None => return Some(Err(ReadError::NotFound(collection_id))),
};

if let Err(e) = collection.check_read_api_key(read_api_key, self.master_api_key) {
if let Err(e) = collection
.check_read_api_key(read_api_key, self.master_api_key)
.await
{
return Some(Err(e));
}

Expand Down
8 changes: 8 additions & 0 deletions src/collection_manager/sides/write/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,14 @@ impl CollectionWriter {
Ok(())
}

/// Generates a new random read API key, updates it in-memory, and returns it.
pub fn regenerate_read_api_key(&mut self) -> ApiKey {
let new_key = ApiKey::try_new(cuid2::create_id())
.expect("cuid2 IDs are always valid API keys (under 64 chars)");
self.read_api_key = new_key;
new_key
}

pub async fn as_dto(&self) -> DescribeCollectionResponse {
let mut indexes_desc = vec![];
let mut document_count = 0_usize;
Expand Down
28 changes: 27 additions & 1 deletion src/collection_manager/sides/write/collections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::lock::{OramaAsyncLock, OramaAsyncLockReadGuard};
use crate::metrics::commit::COMMIT_CALCULATION_TIME;
use crate::metrics::CollectionCommitLabels;
use crate::python::embeddings::Model;
use crate::types::CollectionId;
use crate::types::{ApiKey, CollectionId};
use crate::types::{CreateCollection, DescribeCollectionResponse, LanguageDTO};
use oramacore_lib::fs::{create_if_not_exists, BufferedFile};
use oramacore_lib::nlp::locales::Locale;
Expand Down Expand Up @@ -183,6 +183,32 @@ impl CollectionsWriter {
Ok(())
}

/// Regenerates the read API key for a collection and sends the update to the read side.
pub async fn regenerate_read_api_key(
&self,
collection_id: CollectionId,
sender: OperationSender,
) -> Result<ApiKey, WriteError> {
let mut collections = self.collections.write("regenerate_read_api_key").await;
let collection = collections
.get_mut(&collection_id)
.ok_or(WriteError::CollectionNotFound(collection_id))?;

let new_key = collection.regenerate_read_api_key();

sender
.send(WriteOperation::Collection(
collection_id,
CollectionWriteOperation::UpdateReadApiKey {
read_api_key: new_key,
},
))
.await
.context("Cannot send update read API key operation")?;

Ok(new_key)
}

pub async fn list(&self) -> Vec<DescribeCollectionResponse> {
let collections = self.collections.read("list").await;

Expand Down
20 changes: 20 additions & 0 deletions src/collection_manager/sides/write/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,26 @@ impl WriteSide {
res
}

/// Regenerates the read API key for a collection. Requires write access.
pub async fn regenerate_read_api_key(
&self,
write_api_key: WriteApiKey,
collection_id: CollectionId,
) -> Result<ApiKey, WriteError> {
// Verify the collection exists and we have write access
let _collection = self.get_collection(collection_id, write_api_key).await?;
drop(_collection);

self.write_operation_counter.fetch_add(1, Ordering::Relaxed);
let res = self
.collections
.regenerate_read_api_key(collection_id, self.op_sender.clone())
.await;
self.write_operation_counter.fetch_sub(1, Ordering::Relaxed);

res
}

pub async fn create_index(
&self,
write_api_key: WriteApiKey,
Expand Down
1 change: 1 addition & 0 deletions src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ mod omc_test;
mod openai_chat;
mod pin_rules;
mod quick_fulltext_benchmark;
mod regenerate_read_api_key;
mod replace_doc_on_insert;
mod replace_index;
mod shelves;
Expand Down
Loading
Loading