1+ import oci
12import asyncio
23import websockets
34import json
4- import oci
55from throttler import throttle
66from pypdf import PdfReader
77from io import BytesIO
88from typing import Any , Dict , List
99import re
1010from types import SimpleNamespace
1111
12- # TODO: Please update config profile name and use the compartmentId that has policies grant permissions for using Generative AI Service
13- compartment_id = "<compartment_ocid>"
14- CONFIG_PROFILE = "DEFAULT"
15- config = oci .config .from_file ('~/.oci/config' , CONFIG_PROFILE )
16-
17- # Service endpoint
18- endpoint = "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com"
19- generative_ai_inference_client = (
20- oci .generative_ai_inference .GenerativeAiInferenceClient (
21- config = config ,
22- service_endpoint = endpoint ,
23- retry_strategy = oci .retry .NoneRetryStrategy (),
24- timeout = (10 , 240 ),
25- )
26- )
27-
28- @throttle (rate_limit = 15 , period = 65.0 )
29- async def generate_ai_response (prompts ):
30- prompt = ""
31- llm_inference_request = (
32- oci .generative_ai_inference .models .CohereLlmInferenceRequest ()
33- )
34- llm_inference_request .prompt = prompts
35- llm_inference_request .max_tokens = 1000
36- llm_inference_request .temperature = 0.75
37- llm_inference_request .top_p = 0.7
38- llm_inference_request .frequency_penalty = 1.0
12+ with open ('config.json' ) as f :
13+ config = json .load (f )
3914
40- generate_text_detail = oci .generative_ai_inference .models .GenerateTextDetails ()
41- generate_text_detail .serving_mode = oci .generative_ai_inference .models .DedicatedServingMode (endpoint_id = "ocid1.generativeaiendpoint.oc1.us-chicago-1.amaaaaaaeras5xiavrsefrftfupp42lnniddgjnxuwbv5jypl64i7ktan65a" )
15+ # Load configuration parameters
16+ compartment_id = config ['compartment_id' ]
17+ CONFIG_PROFILE = config ['config_profile' ]
18+ endpoint = config ['service_endpoint' ]
19+ model_type = config ['model_type' ]
20+ model_id = config [f'{ model_type } _model_id' ]
4221
43- generate_text_detail .compartment_id = compartment_id
44- generate_text_detail .inference_request = llm_inference_request
45-
46- if "<compartment_ocid>" in compartment_id :
47- print ("ERROR:Please update your compartment id in target python file" )
48- quit ()
22+ config = oci .config .from_file ('~/.oci/config' , CONFIG_PROFILE )
4923
50- generate_text_response = generative_ai_inference_client .generate_text (generate_text_detail )
51- # Print result
52- print ("**************************Generate Texts Result**************************" )
53- print (vars (generate_text_response ))
24+ generative_ai_inference_client = oci .generative_ai_inference .GenerativeAiInferenceClient (
25+ config = config ,
26+ service_endpoint = endpoint ,
27+ retry_strategy = oci .retry .NoneRetryStrategy (),
28+ timeout = (10 , 240 )
29+ )
5430
55- return generate_text_response
31+ chat_detail = oci . generative_ai_inference . models . ChatDetails ()
5632
33+ # Define a function to generate an AI response
5734@throttle (rate_limit = 15 , period = 65.0 )
58- async def generate_ai_summary (summary_txt , prompt ):
59- # You can also load the summary text from a file, or as a parameter in main
60- #with open('files/summarize_data.txt', 'r') as file:
61- # text_to_summarize = file.read()
62-
63- summarize_text_detail = oci .generative_ai_inference .models .SummarizeTextDetails ()
64- summarize_text_detail .serving_mode = oci .generative_ai_inference .models .OnDemandServingMode (model_id = "cohere.command" )
65- summarize_text_detail .compartment_id = compartment_id
66- #summarize_text_detail.input = text_to_summarize
67- summarize_text_detail .input = summary_txt
68- summarize_text_detail .additional_command = prompt
69- summarize_text_detail .extractiveness = "AUTO" # HIGH, LOW
70- summarize_text_detail .format = "AUTO" # brackets, paragraph
71- summarize_text_detail .length = "LONG" # high, AUTO
72- summarize_text_detail .temperature = .25 # [0,1]
73-
35+ async def generate_ai_response (prompts ):
36+ # Determine the request type based on the model type
37+ if model_type == 'cohere' :
38+ chat_request = oci .generative_ai_inference .models .CohereChatRequest ()
39+ chat_request .max_tokens = 2000
40+ chat_request .temperature = 0.25
41+ chat_request .frequency_penalty = 0
42+ chat_request .top_p = 0.75
43+ chat_request .top_k = 0
44+ elif model_type == 'llama' :
45+ chat_request = oci .generative_ai_inference .models .GenericChatRequest ()
46+ chat_request .api_format = oci .generative_ai_inference .models .BaseChatRequest .API_FORMAT_GENERIC
47+ chat_request .max_tokens = 2000
48+ chat_request .temperature = 1
49+ chat_request .frequency_penalty = 0
50+ chat_request .presence_penalty = 0
51+ chat_request .top_p = 0.75
52+ chat_request .top_k = - 1
53+ else :
54+ raise ValueError ("Unsupported model type" )
55+
56+ # Process the prompts
57+ if isinstance (prompts , str ):
58+ if model_type == 'cohere' :
59+ chat_request .message = prompts
60+ else :
61+ content = oci .generative_ai_inference .models .TextContent ()
62+ content .text = prompts
63+ message = oci .generative_ai_inference .models .Message ()
64+ message .role = "USER"
65+ message .content = [content ]
66+ chat_request .messages = [message ]
67+ elif isinstance (prompts , list ):
68+ chat_request .messages = prompts
69+ else :
70+ raise ValueError ("Invalid input type for generate_ai_response" )
71+
72+ # Set up the chat detail object
73+ chat_detail .chat_request = chat_request
74+ on_demand_mode = oci .generative_ai_inference .models .OnDemandServingMode (model_id = model_id )
75+ chat_detail .serving_mode = on_demand_mode
76+ chat_detail .compartment_id = compartment_id
77+
78+ # Send the request and get the response
79+ chat_response = generative_ai_inference_client .chat (chat_detail )
80+
81+ # Validate the compartment ID
7482 if "<compartment_ocid>" in compartment_id :
75- print ("ERROR:Please update your compartment id in target python file" )
83+ print ("ERROR: Please update your compartment id in target python file" )
7684 quit ()
7785
78- summarize_text_response = generative_ai_inference_client .summarize_text (summarize_text_detail )
79-
80- # Print result
81- #print("**************************Summarize Texts Result**************************")
82- #print(summarize_text_response.data)
86+ # Print the chat result
87+ print ("**************************Chat Result**************************" )
88+ print (vars (chat_response ))
8389
84- return summarize_text_response . data
90+ return chat_response
8591
8692async def parse_pdf (file : BytesIO ) -> List [str ]:
8793 pdf = PdfReader (file )
8894 output = []
8995 for page in pdf .pages :
9096 text = page .extract_text ()
91- # Merge hyphenated words
9297 text = re .sub (r"(\w+)-\n(\w+)" , r"\1\2" , text )
93- # Fix newlines in the middle of sentences
9498 text = re .sub (r"(?<!\n\s)\n(?!\s\n)" , " " , text .strip ())
95- # Remove multiple newlines
9699 text = re .sub (r"\n\s*\n" , "\n \n " , text )
97100 output .append (text )
98101 return output
@@ -102,34 +105,73 @@ async def handle_websocket(websocket, path):
102105 while True :
103106 data = await websocket .recv ()
104107 if isinstance (data , str ):
105- # if we are dealing with text, make it JSON
106- objData = json .loads (data ,object_hook = lambda d : SimpleNamespace (** d ))
108+ objData = json .loads (data , object_hook = lambda d : SimpleNamespace (** d ))
107109 if objData .msgType == "question" :
108110 prompt = objData .data
109- if objData .msgType == "question" :
110- response = await generate_ai_response (prompt )
111- answer = response .data .inference_response .generated_texts [0 ].text
112- buidJSON = {"msgType" :"answer" ,"data" :answer }
113- await websocket .send (json .dumps (buidJSON ))
114- # if it's not text, we have a binary and we will treat it as a PDF
115- if not isinstance (data ,str ):
116- # split the ArrayBuffer into metadata and the actual PDF file
111+ response = await generate_ai_response (prompt )
112+
113+ if model_type == 'llama' :
114+ answer = response .data .chat_response .choices [0 ].message .content [0 ].text
115+ elif model_type == 'cohere' :
116+ answer = response .data .chat_response .text
117+ else :
118+ answer = ""
119+
120+ buidJSON = {"msgType" : "answer" , "data" : answer }
121+ await websocket .send (json .dumps (buidJSON ))
122+ elif objData .msgType == "summary" :
123+ pdfFileObj = BytesIO (objData .data )
124+ output = await parse_pdf (pdfFileObj )
125+ chunk_size = 512
126+ chunks = [' ' .join (output [i :i + chunk_size ]) for i in range (0 , len (output ), chunk_size )]
127+
128+ print (f"Processing { len (chunks )} chunks..." )
129+
130+ summaries = []
131+ for index , chunk in enumerate (chunks ):
132+ print (f"Processing chunk { index + 1 } /{ len (chunks )} ..." )
133+ response = await generate_ai_response (f"Summarize: { chunk } " )
134+ if model_type == 'llama' :
135+ summary = response .data .chat_response .choices [0 ].message .content [0 ].text
136+ elif model_type == 'cohere' :
137+ summary = response .data .chat_response .text
138+ else :
139+ summary = ""
140+ summaries .append (summary )
141+
142+ final_summary = ' ' .join (summaries )
143+ buidJSON = {"msgType" : "summary" , "data" : final_summary }
144+ await websocket .send (json .dumps (buidJSON ))
145+ else :
117146 objData = data .split (b'\r \n \r \n ' )
118- # decode the metadata and parse the JSON data. Creating Dict properties from the JSON
119- metadata = json .loads (objData [0 ].decode ('utf-8' ),object_hook = lambda d : SimpleNamespace (** d ))
147+ metadata = json .loads (objData [0 ].decode ('utf-8' ), object_hook = lambda d : SimpleNamespace (** d ))
120148 pdfFileObj = BytesIO (objData [1 ])
121149 output = await parse_pdf (pdfFileObj )
122- response = await generate_ai_summary ('' .join (output ),metadata .msgPrompt )
123- summary = response .summary
124- buidJSON = {"msgType" :"summary" ,"data" : summary }
150+ chunk_size = 512
151+ chunks = [' ' .join (output [i :i + chunk_size ]) for i in range (0 , len (output ), chunk_size )]
152+
153+ print (f"Processing { len (chunks )} chunks..." )
154+
155+ summaries = []
156+ for index , chunk in enumerate (chunks ):
157+ print (f"Processing chunk { index + 1 } /{ len (chunks )} ..." )
158+ response = await generate_ai_response (f"Summarize: { chunk } " )
159+ if model_type == 'llama' :
160+ summary = response .data .chat_response .choices [0 ].message .content [0 ].text
161+ elif model_type == 'cohere' :
162+ summary = response .data .chat_response .text
163+ else :
164+ summary = ""
165+ summaries .append (summary )
166+
167+ final_summary = ' ' .join (summaries )
168+ buidJSON = {"msgType" : "summary" , "data" : final_summary }
125169 await websocket .send (json .dumps (buidJSON ))
126170 except websockets .exceptions .ConnectionClosedOK as e :
127171 print (f"Connection closed: { e } " )
128-
129-
172+
130173async def start_server ():
131- await websockets .serve (handle_websocket , "localhost" , 1986 , max_size = 200000000 )
132-
174+ async with websockets .serve (handle_websocket , "localhost" , 1986 , max_size = 200000000 ):
175+ await asyncio . Future () # run forever
133176
134- asyncio .get_event_loop ().run_until_complete (start_server ())
135- asyncio .get_event_loop ().run_forever ()
177+ asyncio .run (start_server ())
0 commit comments