Skip to content

Commit e2de963

Browse files
authored
add config.gateway.unstable_disable_feedback_target_validation (tensorzero#2944)
* added a config.gateway.unstable_disable_feedback_target_validation flag, implementation, and tests * slowed down polling for feedback get_function name * changed margin for jitter, added comment * fixed bindings * fixed type in Feedback table story
1 parent d49eaad commit e2de963

File tree

7 files changed

+162
-13
lines changed

7 files changed

+162
-13
lines changed

clients/rust/src/lib.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,8 +410,7 @@ impl Client {
410410
.await
411411
.map_err(err_to_http)
412412
})
413-
.await?
414-
.0)
413+
.await?)
415414
}
416415
}
417416
}

internal/tensorzero-node/lib/bindings/GatewayConfig.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ export type GatewayConfig = {
1111
export: ExportConfig;
1212
base_path: string | null;
1313
unstable_error_json: boolean;
14+
unstable_disable_feedback_target_validation: boolean;
1415
};

tensorzero-core/src/config_parser/gateway.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ pub struct UninitializedGatewayConfig {
2525
// If set, all of the HTTP endpoints will have this path prepended.
2626
// E.g. a base path of `/custom/prefix` will cause the inference endpoint to become `/custom/prefix/inference`.
2727
pub base_path: Option<String>,
28+
// If set to `true`, disables validation on feedback queries (read from ClickHouse to check that the target is valid)
29+
#[serde(default)]
30+
pub unstable_disable_feedback_target_validation: bool,
2831
/// If enabled, adds an 'error_json' field alongside the human-readable 'error' field
2932
/// in HTTP error responses. This contains a JSON-serialized version of the error.
3033
/// While 'error_json' will always be valid JSON when present, the exact contents is unstable,
@@ -63,6 +66,8 @@ impl UninitializedGatewayConfig {
6366
export: self.export,
6467
base_path: self.base_path,
6568
unstable_error_json: self.unstable_error_json,
69+
unstable_disable_feedback_target_validation: self
70+
.unstable_disable_feedback_target_validation,
6671
})
6772
}
6873
}
@@ -80,6 +85,7 @@ pub struct GatewayConfig {
8085
// E.g. a base path of `/custom/prefix` will cause the inference endpoint to become `/custom/prefix/inference`.
8186
pub base_path: Option<String>,
8287
pub unstable_error_json: bool,
88+
pub unstable_disable_feedback_target_validation: bool,
8389
}
8490

8591
fn serialize_optional_socket_addr<S>(

tensorzero-core/src/endpoints/feedback.rs

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,15 @@ use super::validate_tags;
3232
///
3333
/// This is the amount of time we want to wait after the target was supposed to have been written
3434
/// before we decide that the target was actually not written because we can't find it in the database.
35-
const FEEDBACK_COOLDOWN_PERIOD: Duration = Duration::from_secs(5);
35+
/// This should really be read at 5000ms but since there might be some jitter we want to make sure there's
36+
/// a read at ~5s
37+
const FEEDBACK_COOLDOWN_PERIOD: Duration = Duration::from_millis(6000);
3638
/// Since we can't be sure that an inference actually completed when the ID says it was
3739
/// (the ID is generated at the start of the inference), we wait a minimum amount of time
3840
/// before we decide that the target was actually not written because we can't find it in the database.
39-
const FEEDBACK_MINIMUM_WAIT_TIME: Duration = Duration::from_millis(1200);
41+
const FEEDBACK_MINIMUM_WAIT_TIME: Duration = Duration::from_millis(1000);
42+
/// We also poll in the intermediate time so that we can return as soon as we find a target entry.
43+
const FEEDBACK_TARGET_POLL_INTERVAL: Duration = Duration::from_millis(2000);
4044

4145
/// The expected payload is a JSON object with the following fields:
4246
#[derive(Debug, Default, Serialize, Deserialize)]
@@ -94,7 +98,7 @@ pub async fn feedback_handler(
9498
State(app_state): AppState,
9599
StructuredJson(params): StructuredJson<Params>,
96100
) -> Result<Json<FeedbackResponse>, Error> {
97-
feedback(app_state, params).await
101+
Ok(Json(feedback(app_state, params).await?))
98102
}
99103

100104
// Helper function to avoid requiring axum types in the client
@@ -105,7 +109,7 @@ pub async fn feedback(
105109
..
106110
}: AppStateData,
107111
params: Params,
108-
) -> Result<Json<FeedbackResponse>, Error> {
112+
) -> Result<FeedbackResponse, Error> {
109113
validate_tags(&params.tags, params.internal)?;
110114
validate_feedback_specific_tags(&params.tags)?;
111115
// Get the metric config or return an error if it doesn't exist
@@ -139,6 +143,7 @@ pub async fn feedback(
139143
feedback_metadata.level,
140144
feedback_id,
141145
dryrun,
146+
config.gateway.unstable_disable_feedback_target_validation,
142147
)
143148
.await?;
144149
}
@@ -161,6 +166,7 @@ pub async fn feedback(
161166
feedback_metadata.target_id,
162167
feedback_id,
163168
dryrun,
169+
config.gateway.unstable_disable_feedback_target_validation,
164170
)
165171
.await?;
166172
}
@@ -172,12 +178,13 @@ pub async fn feedback(
172178
feedback_metadata.target_id,
173179
feedback_id,
174180
dryrun,
181+
config.gateway.unstable_disable_feedback_target_validation,
175182
)
176183
.await?;
177184
}
178185
}
179186

180-
Ok(Json(FeedbackResponse { feedback_id }))
187+
Ok(FeedbackResponse { feedback_id })
181188
}
182189

183190
#[derive(Debug)]
@@ -248,10 +255,13 @@ async fn write_comment(
248255
level: &MetricConfigLevel,
249256
feedback_id: Uuid,
250257
dryrun: bool,
258+
disable_validation: bool,
251259
) -> Result<(), Error> {
252260
let Params { value, tags, .. } = params;
253261
// Verify that the function name exists.
254-
let _ = throttled_get_function_name(&connection_info, level, &target_id).await?;
262+
if !disable_validation {
263+
let _ = throttled_get_function_name(&connection_info, level, &target_id).await?;
264+
}
255265
let value = value.as_str().ok_or_else(|| ErrorDetails::InvalidRequest {
256266
message: "Feedback value for a comment must be a string".to_string(),
257267
})?;
@@ -318,6 +328,7 @@ async fn write_float(
318328
target_id: Uuid,
319329
feedback_id: Uuid,
320330
dryrun: bool,
331+
disable_validation: bool,
321332
) -> Result<(), Error> {
322333
let Params {
323334
metric_name,
@@ -327,8 +338,11 @@ async fn write_float(
327338
} = params;
328339
let metric_config: &crate::config_parser::MetricConfig =
329340
config.get_metric_or_err(metric_name)?;
330-
// Verify that the function name exists.
331-
let _ = throttled_get_function_name(&connection_info, &metric_config.level, &target_id).await?;
341+
if !disable_validation {
342+
// Verify that the function name exists.
343+
let _ =
344+
throttled_get_function_name(&connection_info, &metric_config.level, &target_id).await?;
345+
}
332346

333347
let value = value.as_f64().ok_or_else(|| {
334348
Error::new(ErrorDetails::InvalidRequest {
@@ -353,6 +367,7 @@ async fn write_boolean(
353367
target_id: Uuid,
354368
feedback_id: Uuid,
355369
dryrun: bool,
370+
disable_validation: bool,
356371
) -> Result<(), Error> {
357372
let Params {
358373
metric_name,
@@ -361,8 +376,11 @@ async fn write_boolean(
361376
..
362377
} = params;
363378
let metric_config = config.get_metric_or_err(metric_name)?;
364-
// Verify that the function name exists.
365-
let _ = throttled_get_function_name(&connection_info, &metric_config.level, &target_id).await?;
379+
if !disable_validation {
380+
// Verify that the function name exists.
381+
let _ =
382+
throttled_get_function_name(&connection_info, &metric_config.level, &target_id).await?;
383+
}
366384
let value = value.as_bool().ok_or_else(|| {
367385
Error::new(ErrorDetails::InvalidRequest {
368386
message: format!("Feedback value for metric `{metric_name}` must be a boolean"),
@@ -420,7 +438,7 @@ async fn throttled_get_function_name(
420438
}
421439
}
422440
}
423-
tokio::time::sleep(Duration::from_millis(500)).await;
441+
tokio::time::sleep(FEEDBACK_TARGET_POLL_INTERVAL).await;
424442
}
425443
}
426444

tensorzero-core/src/gateway_util.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ mod tests {
278278
export: Default::default(),
279279
base_path: None,
280280
unstable_error_json: false,
281+
unstable_disable_feedback_target_validation: false,
281282
};
282283

283284
let config = Box::leak(Box::new(Config {
@@ -332,6 +333,7 @@ mod tests {
332333
export: Default::default(),
333334
base_path: None,
334335
unstable_error_json: false,
336+
unstable_disable_feedback_target_validation: false,
335337
};
336338

337339
let config = Box::leak(Box::new(Config {
@@ -356,6 +358,7 @@ mod tests {
356358
export: Default::default(),
357359
base_path: None,
358360
unstable_error_json: false,
361+
unstable_disable_feedback_target_validation: false,
359362
};
360363
let config = Box::leak(Box::new(Config {
361364
gateway: gateway_config,
@@ -382,6 +385,7 @@ mod tests {
382385
export: Default::default(),
383386
base_path: None,
384387
unstable_error_json: false,
388+
unstable_disable_feedback_target_validation: false,
385389
};
386390
let config = Config {
387391
gateway: gateway_config,

tensorzero-core/tests/e2e/feedback.rs

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@ use reqwest::{Client, StatusCode};
22
use serde_json::{json, Value};
33
use tensorzero_core::{
44
clickhouse::test_helpers::{select_feedback_clickhouse, select_feedback_tags_clickhouse},
5+
config_parser::{
6+
Config, MetricConfig, MetricConfigLevel, MetricConfigOptimize, MetricConfigType,
7+
},
8+
endpoints::feedback::{feedback, Params},
9+
gateway_util::AppStateData,
510
inference::types::{ContentBlockChatOutput, JsonInferenceOutput, Role, Text, TextKind},
611
};
712
use tokio::time::{sleep, Duration};
@@ -170,6 +175,39 @@ async fn e2e_test_comment_feedback_with_payload(inference_payload: serde_json::V
170175
assert_eq!(retrieved_value, "bad job!");
171176
}
172177

178+
#[tokio::test(flavor = "multi_thread")]
179+
async fn e2e_test_comment_feedback_validation_disabled() {
180+
let mut config = Config::default();
181+
let clickhouse = get_clickhouse().await;
182+
config.gateway.unstable_disable_feedback_target_validation = true;
183+
let state = AppStateData::new_with_clickhouse_and_http_client(
184+
config.into(),
185+
clickhouse.clone(),
186+
reqwest::Client::new(),
187+
);
188+
let inference_id = Uuid::now_v7();
189+
let params = Params {
190+
inference_id: Some(inference_id),
191+
metric_name: "comment".to_string(),
192+
value: json!("foo bar"),
193+
..Default::default()
194+
};
195+
let val = feedback(state, params).await.unwrap();
196+
tokio::time::sleep(Duration::from_millis(500)).await;
197+
198+
// Check that this was correctly written to ClickHouse
199+
let query = format!(
200+
"SELECT * FROM CommentFeedback WHERE target_id='{inference_id}' FORMAT JsonEachRow"
201+
);
202+
let response = clickhouse
203+
.run_query_synchronous_no_params(query)
204+
.await
205+
.unwrap();
206+
let result: Value = serde_json::from_str(&response.response).unwrap();
207+
let clickhouse_feedback_id = Uuid::parse_str(result["id"].as_str().unwrap()).unwrap();
208+
assert_eq!(val.feedback_id, clickhouse_feedback_id);
209+
}
210+
173211
#[tokio::test]
174212
async fn e2e_test_demonstration_feedback_normal_function() {
175213
e2e_test_demonstration_feedback_with_payload(serde_json::json!({
@@ -1160,6 +1198,47 @@ async fn e2e_test_float_feedback_with_payload(inference_payload: serde_json::Val
11601198
assert_eq!(metric_name, "brevity_score");
11611199
}
11621200

1201+
#[tokio::test(flavor = "multi_thread")]
1202+
async fn e2e_test_float_feedback_validation_disabled() {
1203+
let mut config = Config::default();
1204+
let metric_config = MetricConfig {
1205+
r#type: MetricConfigType::Float,
1206+
optimize: MetricConfigOptimize::Max,
1207+
level: MetricConfigLevel::Inference,
1208+
};
1209+
config
1210+
.metrics
1211+
.insert("user_score".to_string(), metric_config);
1212+
let clickhouse = get_clickhouse().await;
1213+
config.gateway.unstable_disable_feedback_target_validation = true;
1214+
let state = AppStateData::new_with_clickhouse_and_http_client(
1215+
config.into(),
1216+
clickhouse.clone(),
1217+
reqwest::Client::new(),
1218+
);
1219+
let inference_id = Uuid::now_v7();
1220+
let params = Params {
1221+
inference_id: Some(inference_id),
1222+
metric_name: "user_score".to_string(),
1223+
value: json!(3.1),
1224+
..Default::default()
1225+
};
1226+
let val = feedback(state, params).await.unwrap();
1227+
tokio::time::sleep(Duration::from_millis(500)).await;
1228+
1229+
// Check that this was correctly written to ClickHouse
1230+
let query = format!(
1231+
"SELECT * FROM FloatMetricFeedback WHERE target_id='{inference_id}' FORMAT JsonEachRow"
1232+
);
1233+
let response = clickhouse
1234+
.run_query_synchronous_no_params(query)
1235+
.await
1236+
.unwrap();
1237+
let result: Value = serde_json::from_str(&response.response).unwrap();
1238+
let clickhouse_feedback_id = Uuid::parse_str(result["id"].as_str().unwrap()).unwrap();
1239+
assert_eq!(val.feedback_id, clickhouse_feedback_id);
1240+
}
1241+
11631242
#[tokio::test]
11641243
async fn e2e_test_boolean_feedback_normal_function() {
11651244
e2e_test_boolean_feedback_with_payload(serde_json::json!({
@@ -1353,6 +1432,47 @@ async fn e2e_test_boolean_feedback_with_payload(inference_payload: serde_json::V
13531432
assert_eq!(metric_name, "goal_achieved");
13541433
}
13551434

1435+
#[tokio::test(flavor = "multi_thread")]
1436+
async fn e2e_test_boolean_feedback_validation_disabled() {
1437+
let mut config = Config::default();
1438+
let metric_config = MetricConfig {
1439+
r#type: MetricConfigType::Boolean,
1440+
optimize: MetricConfigOptimize::Max,
1441+
level: MetricConfigLevel::Inference,
1442+
};
1443+
config
1444+
.metrics
1445+
.insert("task_success".to_string(), metric_config);
1446+
let clickhouse = get_clickhouse().await;
1447+
config.gateway.unstable_disable_feedback_target_validation = true;
1448+
let state = AppStateData::new_with_clickhouse_and_http_client(
1449+
config.into(),
1450+
clickhouse.clone(),
1451+
reqwest::Client::new(),
1452+
);
1453+
let inference_id = Uuid::now_v7();
1454+
let params = Params {
1455+
inference_id: Some(inference_id),
1456+
metric_name: "task_success".to_string(),
1457+
value: json!(true),
1458+
..Default::default()
1459+
};
1460+
let val = feedback(state, params).await.unwrap();
1461+
tokio::time::sleep(Duration::from_millis(500)).await;
1462+
1463+
// Check that this was correctly written to ClickHouse
1464+
let query = format!(
1465+
"SELECT * FROM BooleanMetricFeedback WHERE target_id='{inference_id}' FORMAT JsonEachRow"
1466+
);
1467+
let response = clickhouse
1468+
.run_query_synchronous_no_params(query)
1469+
.await
1470+
.unwrap();
1471+
let result: Value = serde_json::from_str(&response.response).unwrap();
1472+
let clickhouse_feedback_id = Uuid::parse_str(result["id"].as_str().unwrap()).unwrap();
1473+
assert_eq!(val.feedback_id, clickhouse_feedback_id);
1474+
}
1475+
13561476
#[tokio::test(flavor = "multi_thread")]
13571477
#[traced_test]
13581478
async fn test_fast_inference_then_feedback() {

ui/app/components/feedback/FeedbackTable.stories.tsx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ const config: Config = {
3131
bind_address: "localhost:8080",
3232
base_path: "/",
3333
unstable_error_json: false,
34+
unstable_disable_feedback_target_validation: false,
3435
},
3536
object_store_info: { kind: { type: "disabled" } },
3637
provider_types: {

0 commit comments

Comments
 (0)