@@ -2,6 +2,11 @@ use reqwest::{Client, StatusCode};
22use serde_json:: { json, Value } ;
33use tensorzero_core:: {
44 clickhouse:: test_helpers:: { select_feedback_clickhouse, select_feedback_tags_clickhouse} ,
5+ config_parser:: {
6+ Config , MetricConfig , MetricConfigLevel , MetricConfigOptimize , MetricConfigType ,
7+ } ,
8+ endpoints:: feedback:: { feedback, Params } ,
9+ gateway_util:: AppStateData ,
510 inference:: types:: { ContentBlockChatOutput , JsonInferenceOutput , Role , Text , TextKind } ,
611} ;
712use tokio:: time:: { sleep, Duration } ;
@@ -170,6 +175,39 @@ async fn e2e_test_comment_feedback_with_payload(inference_payload: serde_json::V
170175 assert_eq ! ( retrieved_value, "bad job!" ) ;
171176}
172177
178+ #[ tokio:: test( flavor = "multi_thread" ) ]
179+ async fn e2e_test_comment_feedback_validation_disabled ( ) {
180+ let mut config = Config :: default ( ) ;
181+ let clickhouse = get_clickhouse ( ) . await ;
182+ config. gateway . unstable_disable_feedback_target_validation = true ;
183+ let state = AppStateData :: new_with_clickhouse_and_http_client (
184+ config. into ( ) ,
185+ clickhouse. clone ( ) ,
186+ reqwest:: Client :: new ( ) ,
187+ ) ;
188+ let inference_id = Uuid :: now_v7 ( ) ;
189+ let params = Params {
190+ inference_id : Some ( inference_id) ,
191+ metric_name : "comment" . to_string ( ) ,
192+ value : json ! ( "foo bar" ) ,
193+ ..Default :: default ( )
194+ } ;
195+ let val = feedback ( state, params) . await . unwrap ( ) ;
196+ tokio:: time:: sleep ( Duration :: from_millis ( 500 ) ) . await ;
197+
198+ // Check that this was correctly written to ClickHouse
199+ let query = format ! (
200+ "SELECT * FROM CommentFeedback WHERE target_id='{inference_id}' FORMAT JsonEachRow"
201+ ) ;
202+ let response = clickhouse
203+ . run_query_synchronous_no_params ( query)
204+ . await
205+ . unwrap ( ) ;
206+ let result: Value = serde_json:: from_str ( & response. response ) . unwrap ( ) ;
207+ let clickhouse_feedback_id = Uuid :: parse_str ( result[ "id" ] . as_str ( ) . unwrap ( ) ) . unwrap ( ) ;
208+ assert_eq ! ( val. feedback_id, clickhouse_feedback_id) ;
209+ }
210+
173211#[ tokio:: test]
174212async fn e2e_test_demonstration_feedback_normal_function ( ) {
175213 e2e_test_demonstration_feedback_with_payload ( serde_json:: json!( {
@@ -1160,6 +1198,47 @@ async fn e2e_test_float_feedback_with_payload(inference_payload: serde_json::Val
11601198 assert_eq ! ( metric_name, "brevity_score" ) ;
11611199}
11621200
1201+ #[ tokio:: test( flavor = "multi_thread" ) ]
1202+ async fn e2e_test_float_feedback_validation_disabled ( ) {
1203+ let mut config = Config :: default ( ) ;
1204+ let metric_config = MetricConfig {
1205+ r#type : MetricConfigType :: Float ,
1206+ optimize : MetricConfigOptimize :: Max ,
1207+ level : MetricConfigLevel :: Inference ,
1208+ } ;
1209+ config
1210+ . metrics
1211+ . insert ( "user_score" . to_string ( ) , metric_config) ;
1212+ let clickhouse = get_clickhouse ( ) . await ;
1213+ config. gateway . unstable_disable_feedback_target_validation = true ;
1214+ let state = AppStateData :: new_with_clickhouse_and_http_client (
1215+ config. into ( ) ,
1216+ clickhouse. clone ( ) ,
1217+ reqwest:: Client :: new ( ) ,
1218+ ) ;
1219+ let inference_id = Uuid :: now_v7 ( ) ;
1220+ let params = Params {
1221+ inference_id : Some ( inference_id) ,
1222+ metric_name : "user_score" . to_string ( ) ,
1223+ value : json ! ( 3.1 ) ,
1224+ ..Default :: default ( )
1225+ } ;
1226+ let val = feedback ( state, params) . await . unwrap ( ) ;
1227+ tokio:: time:: sleep ( Duration :: from_millis ( 500 ) ) . await ;
1228+
1229+ // Check that this was correctly written to ClickHouse
1230+ let query = format ! (
1231+ "SELECT * FROM FloatMetricFeedback WHERE target_id='{inference_id}' FORMAT JsonEachRow"
1232+ ) ;
1233+ let response = clickhouse
1234+ . run_query_synchronous_no_params ( query)
1235+ . await
1236+ . unwrap ( ) ;
1237+ let result: Value = serde_json:: from_str ( & response. response ) . unwrap ( ) ;
1238+ let clickhouse_feedback_id = Uuid :: parse_str ( result[ "id" ] . as_str ( ) . unwrap ( ) ) . unwrap ( ) ;
1239+ assert_eq ! ( val. feedback_id, clickhouse_feedback_id) ;
1240+ }
1241+
11631242#[ tokio:: test]
11641243async fn e2e_test_boolean_feedback_normal_function ( ) {
11651244 e2e_test_boolean_feedback_with_payload ( serde_json:: json!( {
@@ -1353,6 +1432,47 @@ async fn e2e_test_boolean_feedback_with_payload(inference_payload: serde_json::V
13531432 assert_eq ! ( metric_name, "goal_achieved" ) ;
13541433}
13551434
1435+ #[ tokio:: test( flavor = "multi_thread" ) ]
1436+ async fn e2e_test_boolean_feedback_validation_disabled ( ) {
1437+ let mut config = Config :: default ( ) ;
1438+ let metric_config = MetricConfig {
1439+ r#type : MetricConfigType :: Boolean ,
1440+ optimize : MetricConfigOptimize :: Max ,
1441+ level : MetricConfigLevel :: Inference ,
1442+ } ;
1443+ config
1444+ . metrics
1445+ . insert ( "task_success" . to_string ( ) , metric_config) ;
1446+ let clickhouse = get_clickhouse ( ) . await ;
1447+ config. gateway . unstable_disable_feedback_target_validation = true ;
1448+ let state = AppStateData :: new_with_clickhouse_and_http_client (
1449+ config. into ( ) ,
1450+ clickhouse. clone ( ) ,
1451+ reqwest:: Client :: new ( ) ,
1452+ ) ;
1453+ let inference_id = Uuid :: now_v7 ( ) ;
1454+ let params = Params {
1455+ inference_id : Some ( inference_id) ,
1456+ metric_name : "task_success" . to_string ( ) ,
1457+ value : json ! ( true ) ,
1458+ ..Default :: default ( )
1459+ } ;
1460+ let val = feedback ( state, params) . await . unwrap ( ) ;
1461+ tokio:: time:: sleep ( Duration :: from_millis ( 500 ) ) . await ;
1462+
1463+ // Check that this was correctly written to ClickHouse
1464+ let query = format ! (
1465+ "SELECT * FROM BooleanMetricFeedback WHERE target_id='{inference_id}' FORMAT JsonEachRow"
1466+ ) ;
1467+ let response = clickhouse
1468+ . run_query_synchronous_no_params ( query)
1469+ . await
1470+ . unwrap ( ) ;
1471+ let result: Value = serde_json:: from_str ( & response. response ) . unwrap ( ) ;
1472+ let clickhouse_feedback_id = Uuid :: parse_str ( result[ "id" ] . as_str ( ) . unwrap ( ) ) . unwrap ( ) ;
1473+ assert_eq ! ( val. feedback_id, clickhouse_feedback_id) ;
1474+ }
1475+
13561476#[ tokio:: test( flavor = "multi_thread" ) ]
13571477#[ traced_test]
13581478async fn test_fast_inference_then_feedback ( ) {
0 commit comments