Skip to content

Commit f8b2bc3

Browse files
authored
Port countInferencesForFunction and countInferencesForVariant (tensorzero#5069)
* Port countInferencesForFunction and countInferencesForVariant * Make frontend call the count inferences API
1 parent 74a28dd commit f8b2bc3

File tree

16 files changed

+485
-100
lines changed

16 files changed

+485
-100
lines changed

gateway/src/routes/internal.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ pub fn build_internal_non_otel_enabled_routes() -> Router<AppStateData> {
2121
"/internal/functions/{function_name}/variant_sampling_probabilities",
2222
get(endpoints::variant_probabilities::get_variant_sampling_probabilities_by_function_handler),
2323
)
24+
.route(
25+
"/internal/functions/{function_name}/inference-stats",
26+
get(endpoints::internal::inference_stats::get_inference_stats_handler),
27+
)
2428
.route(
2529
"/internal/ui-config",
2630
get(endpoints::ui::get_config::ui_config_handler),
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
2+
3+
/**
4+
* Response containing inference statistics
5+
*/
6+
export type InferenceStatsResponse = {
7+
/**
8+
* The count of inferences for the function (and optionally variant)
9+
*/
10+
inference_count: bigint;
11+
};

internal/tensorzero-node/lib/bindings/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ export * from "./InferenceOutputSource";
148148
export * from "./InferenceParams";
149149
export * from "./InferenceResponse";
150150
export * from "./InferenceResponseToolCall";
151+
export * from "./InferenceStatsResponse";
151152
export * from "./Input";
152153
export * from "./InputMessage";
153154
export * from "./InputMessageContent";
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
//! ClickHouse queries for inference statistics.
2+
3+
use std::collections::HashMap;
4+
5+
use super::ClickHouseConnectionInfo;
6+
use super::select_queries::parse_count;
7+
use crate::error::Error;
8+
use crate::function::FunctionConfigType;
9+
10+
/// Parameters for counting inferences for a function.
11+
#[derive(Debug)]
12+
pub struct CountInferencesParams<'a> {
13+
pub function_name: &'a str,
14+
pub function_type: FunctionConfigType,
15+
pub variant_name: Option<&'a str>,
16+
}
17+
18+
impl ClickHouseConnectionInfo {
19+
/// Counts the number of inferences for a function, optionally filtered by variant.
20+
pub async fn count_inferences_for_function(
21+
&self,
22+
params: CountInferencesParams<'_>,
23+
) -> Result<u64, Error> {
24+
let (query, params) = build_count_inferences_query(&params);
25+
let response = self.run_query_synchronous(query, &params).await?;
26+
parse_count(&response.response)
27+
}
28+
}
29+
30+
/// Builds the SQL query for counting inferences.
31+
fn build_count_inferences_query<'a>(
32+
params: &'a CountInferencesParams<'a>,
33+
) -> (String, HashMap<&'a str, &'a str>) {
34+
let mut query_params = HashMap::new();
35+
query_params.insert("function_name", params.function_name);
36+
37+
let table_name = match params.function_type {
38+
FunctionConfigType::Chat => "ChatInference",
39+
FunctionConfigType::Json => "JsonInference",
40+
};
41+
42+
let query = match params.variant_name {
43+
Some(variant_name) => {
44+
query_params.insert("variant_name", variant_name);
45+
format!(
46+
"SELECT COUNT() AS count
47+
FROM {table_name}
48+
WHERE function_name = {{function_name:String}}
49+
AND variant_name = {{variant_name:String}}
50+
FORMAT JSONEachRow"
51+
)
52+
}
53+
None => {
54+
format!(
55+
"SELECT COUNT() AS count
56+
FROM {table_name}
57+
WHERE function_name = {{function_name:String}}
58+
FORMAT JSONEachRow"
59+
)
60+
}
61+
};
62+
63+
(query, query_params)
64+
}
65+
66+
#[cfg(test)]
67+
mod tests {
68+
use super::*;
69+
use crate::db::clickhouse::query_builder::test_util::{
70+
assert_query_contains, assert_query_does_not_contain,
71+
};
72+
73+
#[test]
74+
fn test_build_count_inferences_query_chat_no_variant() {
75+
let params = CountInferencesParams {
76+
function_name: "write_haiku",
77+
function_type: FunctionConfigType::Chat,
78+
variant_name: None,
79+
};
80+
let (query, query_params) = build_count_inferences_query(&params);
81+
assert_query_contains(&query, "FROM ChatInference");
82+
assert_query_contains(&query, "function_name = {function_name:String}");
83+
assert_query_does_not_contain(&query, "variant_name");
84+
assert_eq!(query_params.len(), 1);
85+
assert_eq!(query_params.get("function_name"), Some(&"write_haiku"));
86+
}
87+
88+
#[test]
89+
fn test_build_count_inferences_query_json_no_variant() {
90+
let params = CountInferencesParams {
91+
function_name: "extract_entities",
92+
function_type: FunctionConfigType::Json,
93+
variant_name: None,
94+
};
95+
let (query, query_params) = build_count_inferences_query(&params);
96+
assert_query_contains(&query, "FROM JsonInference");
97+
assert_query_contains(&query, "function_name = {function_name:String}");
98+
assert_query_does_not_contain(&query, "variant_name");
99+
assert_eq!(query_params.len(), 1);
100+
assert_eq!(query_params.get("function_name"), Some(&"extract_entities"));
101+
}
102+
103+
#[test]
104+
fn test_build_count_inferences_query_chat_with_variant() {
105+
let params = CountInferencesParams {
106+
function_name: "write_haiku",
107+
function_type: FunctionConfigType::Chat,
108+
variant_name: Some("initial_prompt_gpt4o_mini"),
109+
};
110+
let (query, query_params) = build_count_inferences_query(&params);
111+
assert_query_contains(&query, "FROM ChatInference");
112+
assert_query_contains(&query, "function_name = {function_name:String}");
113+
assert_query_contains(&query, "variant_name = {variant_name:String}");
114+
assert_eq!(query_params.len(), 2);
115+
assert_eq!(query_params.get("function_name"), Some(&"write_haiku"));
116+
assert_eq!(
117+
query_params.get("variant_name"),
118+
Some(&"initial_prompt_gpt4o_mini")
119+
);
120+
}
121+
122+
#[test]
123+
fn test_build_count_inferences_query_json_with_variant() {
124+
let params = CountInferencesParams {
125+
function_name: "extract_entities",
126+
function_type: FunctionConfigType::Json,
127+
variant_name: Some("gpt4o_initial_prompt"),
128+
};
129+
let (query, query_params) = build_count_inferences_query(&params);
130+
assert_query_contains(&query, "FROM JsonInference");
131+
assert_query_contains(&query, "function_name = {function_name:String}");
132+
assert_query_contains(&query, "variant_name = {variant_name:String}");
133+
assert_eq!(query_params.len(), 2);
134+
assert_eq!(query_params.get("function_name"), Some(&"extract_entities"));
135+
assert_eq!(
136+
query_params.get("variant_name"),
137+
Some(&"gpt4o_initial_prompt")
138+
);
139+
}
140+
}

tensorzero-core/src/db/clickhouse/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ pub mod clickhouse_client; // Public because tests will use clickhouse_client::F
3030
pub mod dataset_queries;
3131
pub mod feedback;
3232
pub mod inference_queries;
33+
pub mod inference_stats;
3334
pub mod migration_manager;
3435
pub mod query_builder;
3536
mod select_queries;
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
//! Inference statistics endpoint for getting inference counts.
2+
3+
use axum::extract::{Path, Query, State};
4+
use axum::{Json, debug_handler};
5+
use serde::{Deserialize, Serialize};
6+
use tracing::instrument;
7+
8+
use crate::db::clickhouse::inference_stats::CountInferencesParams;
9+
use crate::error::{Error, ErrorDetails};
10+
use crate::utils::gateway::{AppState, AppStateData};
11+
12+
/// Query parameters for the inference stats endpoint
13+
#[derive(Debug, Deserialize)]
14+
pub struct InferenceStatsQueryParams {
15+
/// Optional variant name to filter by
16+
pub variant_name: Option<String>,
17+
}
18+
19+
/// Response containing inference statistics
20+
#[derive(Debug, Serialize, Deserialize, ts_rs::TS)]
21+
#[ts(export)]
22+
pub struct InferenceStatsResponse {
23+
/// The count of inferences for the function (and optionally variant)
24+
pub inference_count: u64,
25+
}
26+
27+
/// HTTP handler for the inference stats endpoint
28+
#[debug_handler(state = AppStateData)]
29+
#[instrument(
30+
name = "get_inference_stats_handler",
31+
skip_all,
32+
fields(
33+
function_name = %function_name,
34+
)
35+
)]
36+
pub async fn get_inference_stats_handler(
37+
State(app_state): AppState,
38+
Path(function_name): Path<String>,
39+
Query(params): Query<InferenceStatsQueryParams>,
40+
) -> Result<Json<InferenceStatsResponse>, Error> {
41+
Ok(Json(
42+
get_inference_stats(app_state, &function_name, params).await?,
43+
))
44+
}
45+
46+
/// Core business logic for getting inference statistics
47+
async fn get_inference_stats(
48+
AppStateData {
49+
config,
50+
clickhouse_connection_info,
51+
..
52+
}: AppStateData,
53+
function_name: &str,
54+
params: InferenceStatsQueryParams,
55+
) -> Result<InferenceStatsResponse, Error> {
56+
// Get the function config to determine the function type
57+
let function = config.get_function(function_name)?;
58+
59+
// If variant_name is provided, validate that it exists
60+
if let Some(ref variant_name) = params.variant_name
61+
&& !function.variants().contains_key(variant_name)
62+
{
63+
return Err(ErrorDetails::UnknownVariant {
64+
name: variant_name.clone(),
65+
}
66+
.into());
67+
}
68+
69+
let count_params = CountInferencesParams {
70+
function_name,
71+
function_type: function.config_type(),
72+
variant_name: params.variant_name.as_deref(),
73+
};
74+
75+
let inference_count = clickhouse_connection_info
76+
.count_inferences_for_function(count_params)
77+
.await?;
78+
79+
Ok(InferenceStatsResponse { inference_count })
80+
}
81+
82+
#[cfg(test)]
83+
mod tests {
84+
use super::*;
85+
use crate::config::{Config, ConfigFileGlob};
86+
use crate::testing::get_unit_test_gateway_handle;
87+
use std::io::Write;
88+
use std::sync::Arc;
89+
use tempfile::NamedTempFile;
90+
91+
#[tokio::test]
92+
async fn test_get_inference_stats_function_not_found() {
93+
let config = Arc::new(Config::default());
94+
let gateway_handle = get_unit_test_gateway_handle(config);
95+
96+
let params = InferenceStatsQueryParams { variant_name: None };
97+
98+
let result = get_inference_stats(
99+
gateway_handle.app_state.clone(),
100+
"nonexistent_function",
101+
params,
102+
)
103+
.await;
104+
105+
assert!(result.is_err());
106+
let err = result.unwrap_err();
107+
assert!(err.to_string().contains("nonexistent_function"));
108+
}
109+
110+
#[tokio::test]
111+
async fn test_get_inference_stats_variant_not_found() {
112+
// Create a config with a function but without the requested variant
113+
let config_str = r#"
114+
[functions.test_function]
115+
type = "chat"
116+
117+
[functions.test_function.variants.variant_a]
118+
type = "chat_completion"
119+
model = "openai::gpt-4"
120+
"#;
121+
122+
let mut temp_file = NamedTempFile::new().unwrap();
123+
temp_file.write_all(config_str.as_bytes()).unwrap();
124+
125+
let config = Config::load_from_path_optional_verify_credentials(
126+
&ConfigFileGlob::new_from_path(temp_file.path()).unwrap(),
127+
false,
128+
)
129+
.await
130+
.unwrap()
131+
.into_config_without_writing_for_tests();
132+
133+
let gateway_handle = get_unit_test_gateway_handle(Arc::new(config));
134+
135+
let params = InferenceStatsQueryParams {
136+
variant_name: Some("nonexistent_variant".to_string()),
137+
};
138+
139+
let result =
140+
get_inference_stats(gateway_handle.app_state.clone(), "test_function", params).await;
141+
142+
assert!(result.is_err());
143+
let err = result.unwrap_err();
144+
assert!(err.to_string().contains("nonexistent_variant"));
145+
}
146+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pub mod inference_stats;

tensorzero-core/src/endpoints/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ pub mod embeddings;
88
pub mod fallback;
99
pub mod feedback;
1010
pub mod inference;
11+
pub mod internal;
1112
pub mod object_storage;
1213
pub mod openai_compatible;
1314
pub mod shared_types;

0 commit comments

Comments
 (0)