@@ -11,19 +11,43 @@ const mockCreate = vitest.fn()
1111
1212vitest . mock ( "openai" , ( ) => {
1313 const mockConstructor = vitest . fn ( )
14- return {
15- __esModule : true ,
16- default : mockConstructor . mockImplementation ( ( ) => ( {
17- chat : {
18- completions : {
19- create : mockCreate . mockImplementation ( async ( options ) => {
20- if ( ! options . stream ) {
21- return {
22- id : "test-completion" ,
14+ const mockImplementation = ( ) => ( {
15+ chat : {
16+ completions : {
17+ create : mockCreate . mockImplementation ( async ( options ) => {
18+ if ( ! options . stream ) {
19+ return {
20+ id : "test-completion" ,
21+ choices : [
22+ {
23+ message : { role : "assistant" , content : "Test response" , refusal : null } ,
24+ finish_reason : "stop" ,
25+ index : 0 ,
26+ } ,
27+ ] ,
28+ usage : {
29+ prompt_tokens : 10 ,
30+ completion_tokens : 5 ,
31+ total_tokens : 15 ,
32+ } ,
33+ }
34+ }
35+
36+ return {
37+ [ Symbol . asyncIterator ] : async function * ( ) {
38+ yield {
39+ choices : [
40+ {
41+ delta : { content : "Test response" } ,
42+ index : 0 ,
43+ } ,
44+ ] ,
45+ usage : null ,
46+ }
47+ yield {
2348 choices : [
2449 {
25- message : { role : "assistant" , content : "Test response" , refusal : null } ,
26- finish_reason : "stop" ,
50+ delta : { } ,
2751 index : 0 ,
2852 } ,
2953 ] ,
@@ -33,38 +57,16 @@ vitest.mock("openai", () => {
3357 total_tokens : 15 ,
3458 } ,
3559 }
36- }
37-
38- return {
39- [ Symbol . asyncIterator ] : async function * ( ) {
40- yield {
41- choices : [
42- {
43- delta : { content : "Test response" } ,
44- index : 0 ,
45- } ,
46- ] ,
47- usage : null ,
48- }
49- yield {
50- choices : [
51- {
52- delta : { } ,
53- index : 0 ,
54- } ,
55- ] ,
56- usage : {
57- prompt_tokens : 10 ,
58- completion_tokens : 5 ,
59- total_tokens : 15 ,
60- } ,
61- }
62- } ,
63- }
64- } ) ,
65- } ,
60+ } ,
61+ }
62+ } ) ,
6663 } ,
67- } ) ) ,
64+ } ,
65+ } )
66+ return {
67+ __esModule : true ,
68+ default : mockConstructor . mockImplementation ( mockImplementation ) ,
69+ AzureOpenAI : mockConstructor . mockImplementation ( mockImplementation ) ,
6870 }
6971} )
7072
@@ -775,4 +777,223 @@ describe("OpenAiHandler", () => {
775777 )
776778 } )
777779 } )
780+
781+ describe ( "Azure AI Search" , ( ) => {
782+ const azureSearchOptions = {
783+ ...mockOptions ,
784+ openAiUseAzure : true ,
785+ azureAiSearchEnabled : true ,
786+ azureAiSearchEndpoint : "https://test-search.search.windows.net/" ,
787+ azureAiSearchIndexName : "test-index" ,
788+ azureAiSearchApiKey : "test-search-api-key" ,
789+ azureAiSearchSemanticConfiguration : "azureml-default" ,
790+ azureAiSearchQueryType : "vector_simple_hybrid" ,
791+ azureAiSearchEmbeddingEndpoint :
792+ "https://test-embedding.openai.azure.com/openai/deployments/text-embedding-ada-002/embeddings?api-version=2023-07-01-preview" ,
793+ azureAiSearchEmbeddingApiKey : "test-embedding-api-key" ,
794+ azureAiSearchTopNDocuments : 5 ,
795+ azureAiSearchStrictness : 3 ,
796+ }
797+
798+ it ( "should include data_sources when Azure AI Search is enabled" , async ( ) => {
799+ const azureSearchHandler = new OpenAiHandler ( azureSearchOptions )
800+ const systemPrompt = "You are a helpful assistant."
801+ const messages : Anthropic . Messages . MessageParam [ ] = [
802+ {
803+ role : "user" ,
804+ content : "Hello!" ,
805+ } ,
806+ ]
807+
808+ const stream = azureSearchHandler . createMessage ( systemPrompt , messages )
809+ // Consume the stream to trigger the API call
810+ for await ( const _chunk of stream ) {
811+ }
812+
813+ expect ( mockCreate ) . toHaveBeenCalled ( )
814+ const callArgs = mockCreate . mock . calls [ 0 ] [ 0 ]
815+ expect ( callArgs ) . toHaveProperty ( "data_sources" )
816+ expect ( callArgs . data_sources ) . toHaveLength ( 1 )
817+
818+ const dataSource = callArgs . data_sources [ 0 ]
819+ expect ( dataSource . type ) . toBe ( "azure_search" )
820+ expect ( dataSource . parameters ) . toMatchObject ( {
821+ endpoint : azureSearchOptions . azureAiSearchEndpoint ,
822+ index_name : azureSearchOptions . azureAiSearchIndexName ,
823+ semantic_configuration : azureSearchOptions . azureAiSearchSemanticConfiguration ,
824+ query_type : azureSearchOptions . azureAiSearchQueryType ,
825+ in_scope : true ,
826+ role_information : "You are an AI assistant that helps people find information." ,
827+ strictness : azureSearchOptions . azureAiSearchStrictness ,
828+ top_n_documents : azureSearchOptions . azureAiSearchTopNDocuments ,
829+ authentication : {
830+ type : "api_key" ,
831+ key : azureSearchOptions . azureAiSearchApiKey ,
832+ } ,
833+ embedding_dependency : {
834+ type : "endpoint" ,
835+ endpoint : azureSearchOptions . azureAiSearchEmbeddingEndpoint ,
836+ authentication : {
837+ type : "api_key" ,
838+ key : azureSearchOptions . azureAiSearchEmbeddingApiKey ,
839+ } ,
840+ } ,
841+ fields_mapping : {
842+ content_fields : [ "content" ] ,
843+ filepath_field : "filepath" ,
844+ title_field : "title" ,
845+ url_field : "url" ,
846+ content_fields_separator : "\n" ,
847+ vector_fields : [ "contentVector" ] ,
848+ } ,
849+ } )
850+ } )
851+
852+ it ( "should not include data_sources when Azure AI Search is disabled" , async ( ) => {
853+ const noSearchHandler = new OpenAiHandler ( {
854+ ...azureSearchOptions ,
855+ azureAiSearchEnabled : false ,
856+ } )
857+ const systemPrompt = "You are a helpful assistant."
858+ const messages : Anthropic . Messages . MessageParam [ ] = [
859+ {
860+ role : "user" ,
861+ content : "Hello!" ,
862+ } ,
863+ ]
864+
865+ const stream = noSearchHandler . createMessage ( systemPrompt , messages )
866+ // Consume the stream to trigger the API call
867+ for await ( const _chunk of stream ) {
868+ }
869+
870+ expect ( mockCreate ) . toHaveBeenCalled ( )
871+ const callArgs = mockCreate . mock . calls [ 0 ] [ 0 ]
872+ expect ( callArgs ) . not . toHaveProperty ( "data_sources" )
873+ } )
874+
875+ it ( "should not include data_sources when not using Azure OpenAI" , async ( ) => {
876+ const nonAzureHandler = new OpenAiHandler ( {
877+ ...azureSearchOptions ,
878+ openAiUseAzure : false ,
879+ } )
880+ const systemPrompt = "You are a helpful assistant."
881+ const messages : Anthropic . Messages . MessageParam [ ] = [
882+ {
883+ role : "user" ,
884+ content : "Hello!" ,
885+ } ,
886+ ]
887+
888+ const stream = nonAzureHandler . createMessage ( systemPrompt , messages )
889+ // Consume the stream to trigger the API call
890+ for await ( const _chunk of stream ) {
891+ }
892+
893+ expect ( mockCreate ) . toHaveBeenCalled ( )
894+ const callArgs = mockCreate . mock . calls [ 0 ] [ 0 ]
895+ expect ( callArgs ) . not . toHaveProperty ( "data_sources" )
896+ } )
897+
898+ it ( "should handle Azure AI Search without embedding configuration" , async ( ) => {
899+ const searchWithoutEmbeddingHandler = new OpenAiHandler ( {
900+ ...azureSearchOptions ,
901+ azureAiSearchEmbeddingEndpoint : undefined ,
902+ azureAiSearchEmbeddingApiKey : undefined ,
903+ } )
904+ const systemPrompt = "You are a helpful assistant."
905+ const messages : Anthropic . Messages . MessageParam [ ] = [
906+ {
907+ role : "user" ,
908+ content : "Hello!" ,
909+ } ,
910+ ]
911+
912+ const stream = searchWithoutEmbeddingHandler . createMessage ( systemPrompt , messages )
913+ // Consume the stream to trigger the API call
914+ for await ( const _chunk of stream ) {
915+ }
916+
917+ expect ( mockCreate ) . toHaveBeenCalled ( )
918+ const callArgs = mockCreate . mock . calls [ 0 ] [ 0 ]
919+ expect ( callArgs ) . toHaveProperty ( "data_sources" )
920+
921+ const dataSource = callArgs . data_sources [ 0 ]
922+ expect ( dataSource . parameters ) . not . toHaveProperty ( "embedding_dependency" )
923+ } )
924+
925+ it ( "should not include fields_mapping for non-vector query types" , async ( ) => {
926+ const simpleSearchHandler = new OpenAiHandler ( {
927+ ...azureSearchOptions ,
928+ azureAiSearchQueryType : "simple" ,
929+ } )
930+ const systemPrompt = "You are a helpful assistant."
931+ const messages : Anthropic . Messages . MessageParam [ ] = [
932+ {
933+ role : "user" ,
934+ content : "Hello!" ,
935+ } ,
936+ ]
937+
938+ const stream = simpleSearchHandler . createMessage ( systemPrompt , messages )
939+ // Consume the stream to trigger the API call
940+ for await ( const _chunk of stream ) {
941+ }
942+
943+ expect ( mockCreate ) . toHaveBeenCalled ( )
944+ const callArgs = mockCreate . mock . calls [ 0 ] [ 0 ]
945+ expect ( callArgs ) . toHaveProperty ( "data_sources" )
946+
947+ const dataSource = callArgs . data_sources [ 0 ]
948+ expect ( dataSource . parameters ) . not . toHaveProperty ( "fields_mapping" )
949+ } )
950+
951+ it ( "should include data_sources in non-streaming mode" , async ( ) => {
952+ const nonStreamingHandler = new OpenAiHandler ( {
953+ ...azureSearchOptions ,
954+ openAiStreamingEnabled : false ,
955+ } )
956+ const systemPrompt = "You are a helpful assistant."
957+ const messages : Anthropic . Messages . MessageParam [ ] = [
958+ {
959+ role : "user" ,
960+ content : "Hello!" ,
961+ } ,
962+ ]
963+
964+ const stream = nonStreamingHandler . createMessage ( systemPrompt , messages )
965+ // Consume the stream to trigger the API call
966+ for await ( const _chunk of stream ) {
967+ }
968+
969+ expect ( mockCreate ) . toHaveBeenCalled ( )
970+ const callArgs = mockCreate . mock . calls [ 0 ] [ 0 ]
971+ expect ( callArgs ) . toHaveProperty ( "data_sources" )
972+ expect ( callArgs . data_sources ) . toHaveLength ( 1 )
973+ expect ( callArgs . data_sources [ 0 ] . type ) . toBe ( "azure_search" )
974+ } )
975+
976+ it ( "should not include data_sources when endpoint or index name is missing" , async ( ) => {
977+ const incompleteHandler = new OpenAiHandler ( {
978+ ...azureSearchOptions ,
979+ azureAiSearchEndpoint : undefined ,
980+ } )
981+ const systemPrompt = "You are a helpful assistant."
982+ const messages : Anthropic . Messages . MessageParam [ ] = [
983+ {
984+ role : "user" ,
985+ content : "Hello!" ,
986+ } ,
987+ ]
988+
989+ const stream = incompleteHandler . createMessage ( systemPrompt , messages )
990+ // Consume the stream to trigger the API call
991+ for await ( const _chunk of stream ) {
992+ }
993+
994+ expect ( mockCreate ) . toHaveBeenCalled ( )
995+ const callArgs = mockCreate . mock . calls [ 0 ] [ 0 ]
996+ expect ( callArgs ) . not . toHaveProperty ( "data_sources" )
997+ } )
998+ } )
778999} )
0 commit comments