Skip to content

Commit e8d435d

Browse files
authored
pass bucket to every storage call (#786)
* pass bucket to every storage call * pass bucket in queue
1 parent d59be2e commit e8d435d

File tree

8 files changed

+55
-68
lines changed

8 files changed

+55
-68
lines changed

app-server/src/api/v1/datasets.rs

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::{env, sync::Arc};
1+
use std::sync::Arc;
22

33
use actix_web::{HttpResponse, get, post, web};
44
use futures_util::StreamExt;
@@ -32,7 +32,6 @@ async fn get_datapoints(
3232
let clickhouse = clickhouse.into_inner().as_ref().clone();
3333
let query = params.into_inner();
3434

35-
// Still get dataset metadata from PostgreSQL
3635
let dataset_id =
3736
db::datasets::get_dataset_id_by_name(&db.pool, &query.name, project_id).await?;
3837

@@ -52,10 +51,8 @@ async fn get_datapoints(
5251
)
5352
.await?;
5453

55-
// Get total count from ClickHouse
5654
let total_count = ch_datapoints::count_datapoints(clickhouse, project_id, dataset_id).await?;
5755

58-
// Convert CHDatapoints to Datapoints
5956
let datapoints: Vec<Datapoint> = ch_datapoints
6057
.into_iter()
6158
.map(|ch_dp| ch_dp.into())
@@ -108,7 +105,6 @@ async fn create_datapoints(
108105
})));
109106
}
110107

111-
// Get dataset metadata from PostgreSQL
112108
let dataset_id =
113109
db::datasets::get_dataset_id_by_name(&db.pool, &request.dataset_name, project_id).await?;
114110

@@ -131,13 +127,11 @@ async fn create_datapoints(
131127
})
132128
.collect();
133129

134-
// Convert to ClickHouse datapoints
135130
let ch_datapoints: Vec<ch_datapoints::CHDatapoint> = datapoints
136131
.iter()
137132
.map(|dp| ch_datapoints::CHDatapoint::from_datapoint(dp, project_id))
138133
.collect();
139134

140-
// Insert into ClickHouse
141135
ch_datapoints::insert_datapoints(clickhouse, ch_datapoints).await?;
142136

143137
Ok(HttpResponse::Created().json(serde_json::json!({
@@ -160,7 +154,6 @@ async fn get_parquet(
160154
let project_id = project_api_key.project_id;
161155
let db = db.into_inner();
162156

163-
// Get parquet paths from database
164157
let parquet_path =
165158
db::datasets::get_parquet_path(&db.pool, project_id, dataset_id, &name).await?;
166159

@@ -170,15 +163,14 @@ async fn get_parquet(
170163
})));
171164
};
172165

173-
// Get object metadata to determine file size
174-
let content_length = storage
175-
.get_size(&parquet_path, &env::var("S3_EXPORTS_BUCKET").ok())
176-
.await?;
166+
let Ok(bucket) = std::env::var("S3_EXPORTS_BUCKET") else {
167+
return Ok(HttpResponse::InternalServerError().json(serde_json::json!({
168+
"error": "exports storage is not configured"
169+
})));
170+
};
171+
let content_length = storage.get_size(&bucket, &parquet_path).await?;
177172

178-
// Stream the file from S3
179-
let get_response = storage
180-
.get_stream(&parquet_path, &env::var("S3_EXPORTS_BUCKET").ok())
181-
.await?;
173+
let get_response = storage.get_stream(&bucket, &parquet_path).await?;
182174

183175
let filename = parquet_path.split('/').last().unwrap_or(&name);
184176

app-server/src/api/v1/payloads.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@ pub async fn get_payload(
3131
let key = format!("project/{}/{}", project_id, payload_id);
3232

3333
// Get the payload stream from storage
34-
let mut stream = match storage.as_ref().get_stream(&key, &None).await {
34+
let Ok(bucket) = std::env::var("S3_TRACE_PAYLOADS_BUCKET") else {
35+
return Ok(HttpResponse::InternalServerError().json(serde_json::json!({
36+
"error": "payloads storage is not configured"
37+
})));
38+
};
39+
let mut stream = match storage.as_ref().get_stream(&bucket, &key).await {
3540
Ok(stream) => stream,
3641
Err(e) => {
3742
log::error!("Failed to retrieve payload from storage: {:?}", e);

app-server/src/language_model/chat_message.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,7 @@ impl ChatMessageContentPart {
436436
&self,
437437
project_id: &Uuid,
438438
storage: Arc<Storage>,
439+
bucket: &str,
439440
) -> Result<ChatMessageContentPart> {
440441
match self {
441442
ChatMessageContentPart::Image(image) => {
@@ -460,7 +461,7 @@ impl ChatMessageContentPart {
460461
// Leave intact in case of error
461462
return Ok(self.clone());
462463
}
463-
let url = storage.store(data, &key).await?;
464+
let url = storage.store(&bucket, &key, data).await?;
464465
Ok(ChatMessageContentPart::ImageUrl(ChatMessageImageUrl {
465466
url,
466467
detail: Some(format!("media_type:{};base64", media_type)),
@@ -484,7 +485,7 @@ impl ChatMessageContentPart {
484485
// Leave intact in case of error
485486
return Ok(self.clone());
486487
}
487-
let url = storage.store(data, &key).await?;
488+
let url = storage.store(&bucket, &key, data).await?;
488489
Ok(ChatMessageContentPart::DocumentUrl(
489490
ChatMessageDocumentUrl {
490491
media_type: document.source.media_type.clone(),
@@ -505,7 +506,7 @@ impl ChatMessageContentPart {
505506
// Leave intact in case of error
506507
return Ok(self.clone());
507508
}
508-
let url = storage.store(data, &key).await?;
509+
let url = storage.store(&bucket, &key, data).await?;
509510
Ok(ChatMessageContentPart::ImageUrl(ChatMessageImageUrl {
510511
url,
511512
detail: image_url.detail.clone(),
@@ -526,7 +527,7 @@ impl ChatMessageContentPart {
526527
// Leave intact in case of error
527528
return Ok(self.clone());
528529
}
529-
let url = storage.store(image.image.clone(), &key).await?;
530+
let url = storage.store(&bucket, &key, image.image.clone()).await?;
530531
Ok(ChatMessageContentPart::ImageUrl(ChatMessageImageUrl {
531532
url,
532533
detail: image

app-server/src/main.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -428,8 +428,6 @@ fn main() -> anyhow::Result<()> {
428428
let s3_client = aws_sdk_s3::Client::new(&aws_sdk_config);
429429
let s3_storage = storage::s3::S3Storage::new(
430430
s3_client,
431-
env::var("S3_TRACE_PAYLOADS_BUCKET")
432-
.expect("S3_TRACE_PAYLOADS_BUCKET must be set"),
433431
mq_for_http.clone(),
434432
);
435433
Arc::new(s3_storage.into())

app-server/src/storage/mock.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,21 @@ pub struct MockStorage;
99
impl super::StorageTrait for MockStorage {
1010
type StorageBytesStream =
1111
Pin<Box<dyn futures_util::stream::Stream<Item = bytes::Bytes> + Send + 'static>>;
12-
async fn store(&self, _data: Vec<u8>, _key: &str) -> Result<String> {
12+
async fn store(&self, _bucket: &str, _key: &str, _data: Vec<u8>) -> Result<String> {
1313
Ok("mock".to_string())
1414
}
1515

16-
async fn store_direct(&self, _data: Vec<u8>, _key: &str) -> Result<String> {
16+
async fn store_direct(&self, _bucket: &str, _key: &str, _data: Vec<u8>) -> Result<String> {
1717
Ok("mock".to_string())
1818
}
1919

20-
async fn get_stream(
21-
&self,
22-
_key: &str,
23-
_bucket: &Option<String>,
24-
) -> Result<Self::StorageBytesStream> {
20+
async fn get_stream(&self, _bucket: &str, _key: &str) -> Result<Self::StorageBytesStream> {
2521
Ok(Box::pin(futures_util::stream::once(async move {
2622
bytes::Bytes::new()
2723
})))
2824
}
2925

30-
async fn get_size(&self, _key: &str, _bucket: &Option<String>) -> Result<u64> {
26+
async fn get_size(&self, _bucket: &str, _key: &str) -> Result<u64> {
3127
Ok(0)
3228
}
3329
}

app-server/src/storage/mod.rs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ pub const PAYLOADS_ROUTING_KEY: &str = "payloads_routing_key";
2222
pub struct QueuePayloadMessage {
2323
pub key: String,
2424
pub data: Vec<u8>,
25+
pub bucket: String,
2526
}
2627

2728
use mock::MockStorage;
@@ -37,14 +38,10 @@ pub enum Storage {
3738
#[enum_delegate::register]
3839
pub trait StorageTrait {
3940
type StorageBytesStream: futures_util::stream::Stream<Item = bytes::Bytes>;
40-
async fn store(&self, data: Vec<u8>, key: &str) -> Result<String>;
41-
async fn store_direct(&self, data: Vec<u8>, key: &str) -> Result<String>;
42-
async fn get_stream(
43-
&self,
44-
key: &str,
45-
bucket: &Option<String>,
46-
) -> Result<Self::StorageBytesStream>;
47-
async fn get_size(&self, key: &str, bucket: &Option<String>) -> Result<u64>;
41+
async fn store(&self, bucket: &str, key: &str, data: Vec<u8>) -> Result<String>;
42+
async fn store_direct(&self, bucket: &str, key: &str, data: Vec<u8>) -> Result<String>;
43+
async fn get_stream(&self, bucket: &str, key: &str) -> Result<Self::StorageBytesStream>;
44+
async fn get_size(&self, bucket: &str, key: &str) -> Result<u64>;
4845
}
4946

5047
pub fn create_key(project_id: &Uuid, file_extension: &Option<String>) -> String {
@@ -118,7 +115,7 @@ async fn inner_process_payloads(storage: Arc<Storage>, queue: Arc<MessageQueue>)
118115
};
119116

120117
let store_payload = || async {
121-
storage.store_direct(message.data.clone(), &message.key).await.map_err(|e| {
118+
storage.store_direct(&message.bucket, &message.key, message.data.clone()).await.map_err(|e| {
122119
log::error!("Failed attempt to store payload. Will retry according to backoff policy. Error: {:?}", e);
123120
backoff::Error::transient(e)
124121
})

app-server/src/storage/s3.rs

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,12 @@ use crate::{
1111
#[derive(Clone)]
1212
pub struct S3Storage {
1313
client: Client,
14-
bucket: String,
1514
queue: Arc<MessageQueue>,
1615
}
1716

1817
impl S3Storage {
19-
pub fn new(client: Client, bucket: String, queue: Arc<MessageQueue>) -> Self {
20-
Self {
21-
client,
22-
bucket,
23-
queue,
24-
}
18+
pub fn new(client: Client, queue: Arc<MessageQueue>) -> Self {
19+
Self { client, queue }
2520
}
2621

2722
fn get_url(&self, key: &str) -> String {
@@ -38,11 +33,12 @@ impl S3Storage {
3833
impl super::StorageTrait for S3Storage {
3934
type StorageBytesStream =
4035
Pin<Box<dyn futures_util::stream::Stream<Item = bytes::Bytes> + Send + 'static>>;
41-
async fn store(&self, data: Vec<u8>, key: &str) -> Result<String> {
36+
async fn store(&self, bucket: &str, key: &str, data: Vec<u8>) -> Result<String> {
4237
// Push to queue instead of storing directly
4338
let message = QueuePayloadMessage {
4439
key: key.to_string(),
4540
data,
41+
bucket: bucket.to_string(),
4642
};
4743

4844
self.queue
@@ -57,11 +53,11 @@ impl super::StorageTrait for S3Storage {
5753
Ok(self.get_url(key))
5854
}
5955

60-
async fn store_direct(&self, data: Vec<u8>, key: &str) -> Result<String> {
56+
async fn store_direct(&self, bucket: &str, key: &str, data: Vec<u8>) -> Result<String> {
6157
// Direct storage method used by the payload worker
6258
self.client
6359
.put_object()
64-
.bucket(&self.bucket)
60+
.bucket(bucket)
6561
.key(key)
6662
.body(data.into())
6763
.send()
@@ -70,15 +66,11 @@ impl super::StorageTrait for S3Storage {
7066
Ok(self.get_url(key))
7167
}
7268

73-
async fn get_stream(
74-
&self,
75-
key: &str,
76-
bucket: &Option<String>,
77-
) -> Result<Self::StorageBytesStream> {
69+
async fn get_stream(&self, bucket: &str, key: &str) -> Result<Self::StorageBytesStream> {
7870
let response = self
7971
.client
8072
.get_object()
81-
.bucket(bucket.as_ref().unwrap_or(&self.bucket))
73+
.bucket(bucket)
8274
.key(key)
8375
.send()
8476
.await?;
@@ -92,11 +84,11 @@ impl super::StorageTrait for S3Storage {
9284
)))
9385
}
9486

95-
async fn get_size(&self, key: &str, bucket: &Option<String>) -> Result<u64> {
87+
async fn get_size(&self, bucket: &str, key: &str) -> Result<u64> {
9688
let response = self
9789
.client
9890
.head_object()
99-
.bucket(bucket.as_ref().unwrap_or(&self.bucket))
91+
.bucket(bucket)
10092
.key(key)
10193
.send()
10294
.await?;

app-server/src/traces/spans.rs

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,10 @@ impl Span {
758758
.ok()
759759
.and_then(|s: String| s.parse::<usize>().ok())
760760
.unwrap_or(DEFAULT_PAYLOAD_SIZE_THRESHOLD);
761+
let Ok(bucket) = std::env::var("S3_TRACE_PAYLOADS_BUCKET") else {
762+
log::error!("S3_TRACE_PAYLOADS_BUCKET is not set");
763+
return Err(anyhow::anyhow!("S3_TRACE_PAYLOADS_BUCKET is not set"));
764+
};
761765
if let Some(input) = self.input.clone() {
762766
let span_input = serde_json::from_value::<Vec<ChatMessage>>(input);
763767
if let Ok(span_input) = span_input {
@@ -766,14 +770,16 @@ impl Span {
766770
if let ChatMessageContent::ContentPartList(parts) = message.content {
767771
let mut new_parts = Vec::new();
768772
for part in parts {
769-
let stored_part =
770-
match part.store_media(project_id, storage.clone()).await {
771-
Ok(stored_part) => stored_part,
772-
Err(e) => {
773-
log::error!("Error storing media: {e}");
774-
part
775-
}
776-
};
773+
let stored_part = match part
774+
.store_media(project_id, storage.clone(), &bucket)
775+
.await
776+
{
777+
Ok(stored_part) => stored_part,
778+
Err(e) => {
779+
log::error!("Error storing media: {e}");
780+
part
781+
}
782+
};
777783
new_parts.push(stored_part);
778784
}
779785
message.content = ChatMessageContent::ContentPartList(new_parts);
@@ -801,7 +807,7 @@ impl Span {
801807
data.len()
802808
);
803809
} else {
804-
let url = storage.store(data, &key).await?;
810+
let url = storage.store(&bucket, &key, data).await?;
805811
self.input_url = Some(url);
806812
self.input = Some(serde_json::Value::String(preview));
807813
}
@@ -828,7 +834,7 @@ impl Span {
828834
data.len()
829835
);
830836
} else {
831-
let url = storage.store(data, &key).await?;
837+
let url = storage.store(&bucket, &key, data).await?;
832838
self.output_url = Some(url);
833839
self.output = Some(serde_json::Value::String(
834840
output_str.chars().take(100).collect(),

0 commit comments

Comments
 (0)