@@ -24,9 +24,12 @@ package externalfunctions
24
24
25
25
import (
26
26
"bytes"
27
+ "encoding/base64"
27
28
"encoding/json"
28
29
"fmt"
30
+ "io"
29
31
"net/http"
32
+ "net/url"
30
33
"regexp"
31
34
"sort"
32
35
"strconv"
@@ -897,6 +900,99 @@ func AecGetContextFromRetrieverModule(
897
900
return context
898
901
}
899
902
903
+ // GetContextFromDataPlugin retrieves context from a data plugin
904
+ //
905
+ // Tags:
906
+ // - @displayName: Get Context from Data Plugin
907
+ //
908
+ // Parameters:
909
+ // - userQuery: the user query
910
+ // - apiUrl: the API URL of the data plugin
911
+ // - username: the username for authentication at the data plugin
912
+ // - password: the password for authentication at the data plugin
913
+ // - topK: the number of results to be returned
914
+ //
915
+ // Returns:
916
+ // - context: the context retrieved from the data plugin
917
+ func GetContextFromDataPlugin (userQuery string , apiUrl string , username string , password string , topK int ) (context []sharedtypes.AnsysGPTRetrieverModuleChunk ) {
918
+ // Encode the query and max_doc in base64
919
+ encodedQuery := base64 .StdEncoding .EncodeToString ([]byte (userQuery ))
920
+ encodedMaxDoc := base64 .StdEncoding .EncodeToString ([]byte (strconv .Itoa (topK )))
921
+
922
+ // Prepare form data
923
+ formData := url.Values {}
924
+ formData .Set ("q" , encodedQuery )
925
+ formData .Set ("max_doc" , encodedMaxDoc )
926
+
927
+ // Create the request
928
+ req , err := http .NewRequest ("POST" , apiUrl , bytes .NewBufferString (formData .Encode ()))
929
+ if err != nil {
930
+ panic (fmt .Errorf ("error creating request: %v" , err ))
931
+ }
932
+
933
+ // Set content type for form data
934
+ req .Header .Set ("Content-Type" , "application/x-www-form-urlencoded" )
935
+
936
+ // Set basic authentication header
937
+ auth := username + ":" + password
938
+ encodedAuth := base64 .StdEncoding .EncodeToString ([]byte (auth ))
939
+ req .Header .Set ("Authorization" , "Basic " + encodedAuth )
940
+
941
+ // Make the request
942
+ client := & http.Client {}
943
+ resp , err := client .Do (req )
944
+ if err != nil {
945
+ panic (fmt .Errorf ("error making request: %v" , err ))
946
+ }
947
+ defer resp .Body .Close ()
948
+
949
+ // Read the response body
950
+ body , err := io .ReadAll (resp .Body )
951
+ if err != nil {
952
+ panic (fmt .Errorf ("error reading response body: %v" , err ))
953
+ }
954
+
955
+ // Check status code
956
+ if resp .StatusCode != 200 {
957
+ panic (fmt .Errorf ("error response from data plugin: %v, body: %s" , resp .Status , string (body )))
958
+ }
959
+
960
+ // The response is base64 encoded
961
+ base64EncodedResponse := string (body )
962
+
963
+ // Decode the base64 response
964
+ decodedResponse , err := base64 .StdEncoding .DecodeString (base64EncodedResponse )
965
+ if err != nil {
966
+ panic (fmt .Errorf ("error decoding base64 response: %v" , err ))
967
+ }
968
+
969
+ // Unmarshal the response
970
+ response := map [string ]sharedtypes.AnsysGPTRetrieverModuleChunk {}
971
+ err = json .Unmarshal (decodedResponse , & response )
972
+ if err != nil {
973
+ panic (fmt .Errorf ("error unmarshalling response: %v" , err ))
974
+ }
975
+ logging .Log .Debugf (& logging.ContextMap {}, "Received response from retriever module: %v" , response )
976
+
977
+ // Extract the context from the response
978
+ context = make ([]sharedtypes.AnsysGPTRetrieverModuleChunk , len (response ))
979
+ for chunkNum , chunk := range response {
980
+ // Extract int from chunkNum
981
+ _ , chunkNumstring , found := strings .Cut (chunkNum , "chunk " )
982
+ if ! found {
983
+ panic (fmt .Errorf ("error extracting chunk number from '%v'" , chunkNum ))
984
+ }
985
+ chunkNumInt , err := strconv .Atoi (chunkNumstring )
986
+ if err != nil {
987
+ panic (fmt .Errorf ("error converting chunk number to int: %v" , err ))
988
+ }
989
+ // Store the chunk in the context slice
990
+ context [chunkNumInt - 1 ] = chunk
991
+ }
992
+
993
+ return context
994
+ }
995
+
900
996
// AecPerformLLMFinalRequest performs a final request to LLM
901
997
//
902
998
// Tags:
@@ -927,7 +1023,8 @@ func AecPerformLLMFinalRequest(systemTemplate string,
927
1023
tokenCountModelName string ,
928
1024
isStream bool ,
929
1025
userEmail string ,
930
- jwtToken string ) (message string , stream * chan string ) {
1026
+ jwtToken string ,
1027
+ dontSendTokenCount bool ) (message string , stream * chan string ) {
931
1028
932
1029
logging .Log .Debugf (& logging.ContextMap {}, "Performing LLM final request" )
933
1030
@@ -1018,8 +1115,14 @@ func AecPerformLLMFinalRequest(systemTemplate string,
1018
1115
}
1019
1116
totalInputTokenCount := previousInputTokenCount + inputTokenCount
1020
1117
1118
+ // check if token count should be sent
1119
+ sendTokenCount := false
1120
+ if ! dontSendTokenCount {
1121
+ sendTokenCount = true
1122
+ }
1123
+
1021
1124
// Start a goroutine to transfer the data from the response channel to the stream channel.
1022
- go transferDatafromResponseToStreamChannel (& responseChannel , & streamChannel , false , true , tokenCountEndpoint , totalInputTokenCount , previousOutputTokenCount , tokenCountModelName , jwtToken , userEmail , true , contextString )
1125
+ go transferDatafromResponseToStreamChannel (& responseChannel , & streamChannel , false , sendTokenCount , tokenCountEndpoint , totalInputTokenCount , previousOutputTokenCount , tokenCountModelName , jwtToken , userEmail , true , contextString )
1023
1126
1024
1127
return "" , & streamChannel
1025
1128
}
0 commit comments