|
9 | 9 | //! - Core AI capabilities traits ([`CompletionFeatures`], [`EmbeddingFeatures`]). |
10 | 10 |
|
11 | 11 | use candid::Principal; |
| 12 | +use serde::ser::{SerializeMap, SerializeSeq, Serializer}; |
12 | 13 | use serde::{Deserialize, Serialize}; |
13 | 14 | use serde_json::{Map, json}; |
14 | 15 | use std::collections::BTreeMap; |
@@ -375,13 +376,113 @@ pub struct FunctionDefinition { |
375 | 376 | pub description: String, |
376 | 377 |
|
377 | 378 | /// JSON schema defining the function's parameters. |
| 379 | + #[serde(serialize_with = "serialize_openapi_schema_ordered")] |
378 | 380 | pub parameters: Json, |
379 | 381 |
|
380 | 382 | /// Whether to enable strict schema adherence when generating the function call. If set to true, the model will follow the exact schema defined in the parameters field. Only a subset of JSON Schema is supported when strict is true. |
381 | 383 | #[serde(skip_serializing_if = "Option::is_none")] |
382 | 384 | pub strict: Option<bool>, |
383 | 385 | } |
384 | 386 |
|
| 387 | +pub fn serialize_optional_openapi_schema_ordered<S>( |
| 388 | + value: &Option<Json>, |
| 389 | + serializer: S, |
| 390 | +) -> Result<S::Ok, S::Error> |
| 391 | +where |
| 392 | + S: Serializer, |
| 393 | +{ |
| 394 | + match value { |
| 395 | + None => serializer.serialize_none(), |
| 396 | + Some(v) => serialize_openapi_schema_ordered(v, serializer), |
| 397 | + } |
| 398 | +} |
| 399 | + |
| 400 | +pub fn serialize_openapi_schema_ordered<S>(value: &Json, serializer: S) -> Result<S::Ok, S::Error> |
| 401 | +where |
| 402 | + S: Serializer, |
| 403 | +{ |
| 404 | + struct Ordered<'a>(&'a Json); |
| 405 | + |
| 406 | + impl<'a> serde::Serialize for Ordered<'a> { |
| 407 | + fn serialize<S2>(&self, serializer: S2) -> Result<S2::Ok, S2::Error> |
| 408 | + where |
| 409 | + S2: Serializer, |
| 410 | + { |
| 411 | + serialize_openapi_schema_ordered(self.0, serializer) |
| 412 | + } |
| 413 | + } |
| 414 | + |
| 415 | + match value { |
| 416 | + Json::Null => serializer.serialize_none(), |
| 417 | + Json::Bool(b) => serializer.serialize_bool(*b), |
| 418 | + Json::Number(n) => n.serialize(serializer), |
| 419 | + Json::String(s) => serializer.serialize_str(s), |
| 420 | + Json::Array(items) => { |
| 421 | + let mut seq = serializer.serialize_seq(Some(items.len()))?; |
| 422 | + for item in items { |
| 423 | + seq.serialize_element(&Ordered(item))?; |
| 424 | + } |
| 425 | + seq.end() |
| 426 | + } |
| 427 | + Json::Object(map) => { |
| 428 | + // Gemini preview models can be sensitive to schema key order. |
| 429 | + // Emit a deterministic, schema-friendly ordering recursively: |
| 430 | + // 1) common schema keys in a fixed order |
| 431 | + // 2) remaining keys in lexical order |
| 432 | + |
| 433 | + // https://ai.google.dev/api/caching#FunctionDeclaration |
| 434 | + const FIXED_ORDER: [&str; 23] = [ |
| 435 | + "name", |
| 436 | + "type", |
| 437 | + "format", |
| 438 | + "title", |
| 439 | + "description", |
| 440 | + "nullable", |
| 441 | + "enum", |
| 442 | + "maxItems", |
| 443 | + "minItems", |
| 444 | + "properties", |
| 445 | + "required", |
| 446 | + "minProperties", |
| 447 | + "maxProperties", |
| 448 | + "minLength", |
| 449 | + "maxLength", |
| 450 | + "pattern", |
| 451 | + "example", |
| 452 | + "anyOf", |
| 453 | + "propertyOrdering", |
| 454 | + "default", |
| 455 | + "items", |
| 456 | + "minimum", |
| 457 | + "maximum", |
| 458 | + ]; |
| 459 | + |
| 460 | + let mut out = serializer.serialize_map(Some(map.len()))?; |
| 461 | + |
| 462 | + for key in FIXED_ORDER { |
| 463 | + if let Some(v) = map.get(key) { |
| 464 | + out.serialize_entry(key, &Ordered(v))?; |
| 465 | + } |
| 466 | + } |
| 467 | + |
| 468 | + let mut rest_keys: Vec<&str> = map |
| 469 | + .keys() |
| 470 | + .map(|k| k.as_str()) |
| 471 | + .filter(|k| !FIXED_ORDER.contains(k)) |
| 472 | + .collect(); |
| 473 | + rest_keys.sort_unstable(); |
| 474 | + |
| 475 | + for key in rest_keys { |
| 476 | + if let Some(v) = map.get(key) { |
| 477 | + out.serialize_entry(key, &Ordered(v))?; |
| 478 | + } |
| 479 | + } |
| 480 | + |
| 481 | + out.end() |
| 482 | + } |
| 483 | + } |
| 484 | +} |
| 485 | + |
385 | 486 | impl FunctionDefinition { |
386 | 487 | /// Modifies the function name with a prefix. |
387 | 488 | pub fn name_with_prefix(mut self, prefix: &str) -> Self { |
@@ -582,15 +683,36 @@ mod tests { |
582 | 683 | .into(); |
583 | 684 | // println!("{}", documents); |
584 | 685 |
|
585 | | - assert_eq!( |
586 | | - documents.to_string(), |
587 | | - "<documents>\n{\"content\":\"Test document 1.\",\"metadata\":{\"_id\":1}}\n{\"content\":\"Test document 2.\",\"metadata\":{\"_id\":2,\"a\":\"b\",\"key\":\"value\"}}\n</documents>" |
588 | | - ); |
| 686 | + let s = documents.to_string(); |
| 687 | + let lines: Vec<&str> = s.lines().collect(); |
| 688 | + assert_eq!(lines[0], "<documents>"); |
| 689 | + assert_eq!(lines[3], "</documents>"); |
| 690 | + |
| 691 | + let doc1: Json = serde_json::from_str(lines[1]).unwrap(); |
| 692 | + assert_eq!(doc1.get("content").unwrap(), "Test document 1."); |
| 693 | + assert_eq!(doc1.get("metadata").unwrap().get("_id").unwrap(), 1); |
| 694 | + |
| 695 | + let doc2: Json = serde_json::from_str(lines[2]).unwrap(); |
| 696 | + assert_eq!(doc2.get("content").unwrap(), "Test document 2."); |
| 697 | + assert_eq!(doc2.get("metadata").unwrap().get("_id").unwrap(), 2); |
| 698 | + assert_eq!(doc2.get("metadata").unwrap().get("key").unwrap(), "value"); |
| 699 | + assert_eq!(doc2.get("metadata").unwrap().get("a").unwrap(), "b"); |
| 700 | + |
589 | 701 | let documents = documents.with_tag("my_docs".to_string()); |
590 | | - assert_eq!( |
591 | | - documents.to_string(), |
592 | | - "<my_docs>\n{\"content\":\"Test document 1.\",\"metadata\":{\"_id\":1}}\n{\"content\":\"Test document 2.\",\"metadata\":{\"_id\":2,\"a\":\"b\",\"key\":\"value\"}}\n</my_docs>" |
593 | | - ); |
| 702 | + let s = documents.to_string(); |
| 703 | + let lines: Vec<&str> = s.lines().collect(); |
| 704 | + assert_eq!(lines[0], "<my_docs>"); |
| 705 | + assert_eq!(lines[3], "</my_docs>"); |
| 706 | + |
| 707 | + let doc1: Json = serde_json::from_str(lines[1]).unwrap(); |
| 708 | + assert_eq!(doc1.get("content").unwrap(), "Test document 1."); |
| 709 | + assert_eq!(doc1.get("metadata").unwrap().get("_id").unwrap(), 1); |
| 710 | + |
| 711 | + let doc2: Json = serde_json::from_str(lines[2]).unwrap(); |
| 712 | + assert_eq!(doc2.get("content").unwrap(), "Test document 2."); |
| 713 | + assert_eq!(doc2.get("metadata").unwrap().get("_id").unwrap(), 2); |
| 714 | + assert_eq!(doc2.get("metadata").unwrap().get("key").unwrap(), "value"); |
| 715 | + assert_eq!(doc2.get("metadata").unwrap().get("a").unwrap(), "b"); |
594 | 716 | } |
595 | 717 |
|
596 | 718 | #[test] |
@@ -815,4 +937,70 @@ mod tests { |
815 | 937 | &serde_json::json!({"x": true}) |
816 | 938 | ); |
817 | 939 | } |
| 940 | + |
| 941 | + #[test] |
| 942 | + fn test_function_definition_parameters_openapi_order() { |
| 943 | + let def = FunctionDefinition { |
| 944 | + name: "trigger_paywall".into(), |
| 945 | + description: "Trigger payment".into(), |
| 946 | + parameters: serde_json::json!({ |
| 947 | + // Intentionally not in the preferred order. |
| 948 | + "properties": { |
| 949 | + "hook_text": { |
| 950 | + "description": "hook", |
| 951 | + "type": "string" |
| 952 | + }, |
| 953 | + "reason": { |
| 954 | + "description": "reason", |
| 955 | + "type": "string" |
| 956 | + } |
| 957 | + }, |
| 958 | + "description": "top", |
| 959 | + "required": ["reason", "hook_text"], |
| 960 | + "type": "object" |
| 961 | + }), |
| 962 | + strict: None, |
| 963 | + }; |
| 964 | + |
| 965 | + let s = serde_json::to_string(&def).unwrap(); |
| 966 | + let start = |
| 967 | + s.find("\"parameters\":{").expect("parameters should exist") + "\"parameters\":{".len(); |
| 968 | + let sub = &s[start..]; |
| 969 | + let i_type = sub.find("\"type\"").unwrap(); |
| 970 | + let i_props = sub.find("\"properties\"").unwrap(); |
| 971 | + let i_req = sub.find("\"required\"").unwrap(); |
| 972 | + let i_desc = sub.find("\"description\"").unwrap(); |
| 973 | + assert!(i_type < i_props); |
| 974 | + assert!(i_props > i_desc); |
| 975 | + assert!(i_props < i_req); |
| 976 | + } |
| 977 | + |
| 978 | + #[test] |
| 979 | + fn test_function_definition_nested_schema_order_is_deterministic() { |
| 980 | + let def = FunctionDefinition { |
| 981 | + name: "trigger_paywall".into(), |
| 982 | + description: "Trigger payment".into(), |
| 983 | + parameters: serde_json::json!({ |
| 984 | + "type": "object", |
| 985 | + "properties": { |
| 986 | + "hook_text": { |
| 987 | + // Intentionally reverse order. |
| 988 | + "description": "hook", |
| 989 | + "type": "string" |
| 990 | + } |
| 991 | + }, |
| 992 | + "required": ["hook_text"], |
| 993 | + }), |
| 994 | + strict: None, |
| 995 | + }; |
| 996 | + |
| 997 | + let s = serde_json::to_string(&def).unwrap(); |
| 998 | + // Ensure nested property schema emits type before description. |
| 999 | + let needle = "\"hook_text\":{"; |
| 1000 | + let start = s.find(needle).unwrap() + needle.len(); |
| 1001 | + let sub = &s[start..]; |
| 1002 | + let i_type = sub.find("\"type\"").unwrap(); |
| 1003 | + let i_desc = sub.find("\"description\"").unwrap(); |
| 1004 | + assert!(i_type < i_desc); |
| 1005 | + } |
818 | 1006 | } |
0 commit comments