Skip to content

Commit 45663b0

Browse files
committed
updating python
1 parent a457737 commit 45663b0

File tree

3 files changed

+133
-90
lines changed

3 files changed

+133
-90
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,5 @@ build/
7575
.idea
7676
bin/
7777
dist/
78-
application-local.yaml
78+
application-local.yaml
79+
service/python/config.json

service/python/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ cached-property==1.5.2
44
certifi==2024.7.4
55
cffi==1.16.0
66
circuitbreaker==1.4.0
7-
cryptography==43.0.1
7+
cryptography==42.0.6
88
oci==2.126.2
99
pycparser==2.21
1010
pyOpenSSL==24.1.0

service/python/server.py

Lines changed: 130 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,98 +1,101 @@
1+
import oci
12
import asyncio
23
import websockets
34
import json
4-
import oci
55
from throttler import throttle
66
from pypdf import PdfReader
77
from io import BytesIO
88
from typing import Any, Dict, List
99
import re
1010
from types import SimpleNamespace
1111

12-
# TODO: Please update config profile name and use the compartmentId that has policies grant permissions for using Generative AI Service
13-
compartment_id = "<compartment_ocid>"
14-
CONFIG_PROFILE = "DEFAULT"
15-
config = oci.config.from_file('~/.oci/config', CONFIG_PROFILE)
16-
17-
# Service endpoint
18-
endpoint = "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com"
19-
generative_ai_inference_client = (
20-
oci.generative_ai_inference.GenerativeAiInferenceClient(
21-
config=config,
22-
service_endpoint=endpoint,
23-
retry_strategy=oci.retry.NoneRetryStrategy(),
24-
timeout=(10, 240),
25-
)
26-
)
27-
28-
@throttle(rate_limit=15, period=65.0)
29-
async def generate_ai_response(prompts):
30-
prompt = ""
31-
llm_inference_request = (
32-
oci.generative_ai_inference.models.CohereLlmInferenceRequest()
33-
)
34-
llm_inference_request.prompt = prompts
35-
llm_inference_request.max_tokens = 1000
36-
llm_inference_request.temperature = 0.75
37-
llm_inference_request.top_p = 0.7
38-
llm_inference_request.frequency_penalty = 1.0
12+
with open('config.json') as f:
13+
config = json.load(f)
3914

40-
generate_text_detail = oci.generative_ai_inference.models.GenerateTextDetails()
41-
generate_text_detail.serving_mode = oci.generative_ai_inference.models.DedicatedServingMode(endpoint_id="ocid1.generativeaiendpoint.oc1.us-chicago-1.amaaaaaaeras5xiavrsefrftfupp42lnniddgjnxuwbv5jypl64i7ktan65a")
15+
# Load configuration parameters
16+
compartment_id = config['compartment_id']
17+
CONFIG_PROFILE = config['config_profile']
18+
endpoint = config['service_endpoint']
19+
model_type = config['model_type']
20+
model_id = config[f'{model_type}_model_id']
4221

43-
generate_text_detail.compartment_id = compartment_id
44-
generate_text_detail.inference_request = llm_inference_request
45-
46-
if "<compartment_ocid>" in compartment_id:
47-
print("ERROR:Please update your compartment id in target python file")
48-
quit()
22+
config = oci.config.from_file('~/.oci/config', CONFIG_PROFILE)
4923

50-
generate_text_response = generative_ai_inference_client.generate_text(generate_text_detail)
51-
# Print result
52-
print("**************************Generate Texts Result**************************")
53-
print(vars(generate_text_response))
24+
generative_ai_inference_client = oci.generative_ai_inference.GenerativeAiInferenceClient(
25+
config=config,
26+
service_endpoint=endpoint,
27+
retry_strategy=oci.retry.NoneRetryStrategy(),
28+
timeout=(10, 240)
29+
)
5430

55-
return generate_text_response
31+
chat_detail = oci.generative_ai_inference.models.ChatDetails()
5632

33+
# Define a function to generate an AI response
5734
@throttle(rate_limit=15, period=65.0)
58-
async def generate_ai_summary(summary_txt, prompt):
59-
# You can also load the summary text from a file, or as a parameter in main
60-
#with open('files/summarize_data.txt', 'r') as file:
61-
# text_to_summarize = file.read()
62-
63-
summarize_text_detail = oci.generative_ai_inference.models.SummarizeTextDetails()
64-
summarize_text_detail.serving_mode = oci.generative_ai_inference.models.OnDemandServingMode(model_id="cohere.command")
65-
summarize_text_detail.compartment_id = compartment_id
66-
#summarize_text_detail.input = text_to_summarize
67-
summarize_text_detail.input = summary_txt
68-
summarize_text_detail.additional_command = prompt
69-
summarize_text_detail.extractiveness = "AUTO" # HIGH, LOW
70-
summarize_text_detail.format = "AUTO" # brackets, paragraph
71-
summarize_text_detail.length = "LONG" # high, AUTO
72-
summarize_text_detail.temperature = .25 # [0,1]
73-
35+
async def generate_ai_response(prompts):
36+
# Determine the request type based on the model type
37+
if model_type == 'cohere':
38+
chat_request = oci.generative_ai_inference.models.CohereChatRequest()
39+
chat_request.max_tokens = 2000
40+
chat_request.temperature = 0.25
41+
chat_request.frequency_penalty = 0
42+
chat_request.top_p = 0.75
43+
chat_request.top_k = 0
44+
elif model_type == 'llama':
45+
chat_request = oci.generative_ai_inference.models.GenericChatRequest()
46+
chat_request.api_format = oci.generative_ai_inference.models.BaseChatRequest.API_FORMAT_GENERIC
47+
chat_request.max_tokens = 2000
48+
chat_request.temperature = 1
49+
chat_request.frequency_penalty = 0
50+
chat_request.presence_penalty = 0
51+
chat_request.top_p = 0.75
52+
chat_request.top_k = -1
53+
else:
54+
raise ValueError("Unsupported model type")
55+
56+
# Process the prompts
57+
if isinstance(prompts, str):
58+
if model_type == 'cohere':
59+
chat_request.message = prompts
60+
else:
61+
content = oci.generative_ai_inference.models.TextContent()
62+
content.text = prompts
63+
message = oci.generative_ai_inference.models.Message()
64+
message.role = "USER"
65+
message.content = [content]
66+
chat_request.messages = [message]
67+
elif isinstance(prompts, list):
68+
chat_request.messages = prompts
69+
else:
70+
raise ValueError("Invalid input type for generate_ai_response")
71+
72+
# Set up the chat detail object
73+
chat_detail.chat_request = chat_request
74+
on_demand_mode = oci.generative_ai_inference.models.OnDemandServingMode(model_id=model_id)
75+
chat_detail.serving_mode = on_demand_mode
76+
chat_detail.compartment_id = compartment_id
77+
78+
# Send the request and get the response
79+
chat_response = generative_ai_inference_client.chat(chat_detail)
80+
81+
# Validate the compartment ID
7482
if "<compartment_ocid>" in compartment_id:
75-
print("ERROR:Please update your compartment id in target python file")
83+
print("ERROR: Please update your compartment id in target python file")
7684
quit()
7785

78-
summarize_text_response = generative_ai_inference_client.summarize_text(summarize_text_detail)
79-
80-
# Print result
81-
#print("**************************Summarize Texts Result**************************")
82-
#print(summarize_text_response.data)
86+
# Print the chat result
87+
print("**************************Chat Result**************************")
88+
print(vars(chat_response))
8389

84-
return summarize_text_response.data
90+
return chat_response
8591

8692
async def parse_pdf(file: BytesIO) -> List[str]:
8793
pdf = PdfReader(file)
8894
output = []
8995
for page in pdf.pages:
9096
text = page.extract_text()
91-
# Merge hyphenated words
9297
text = re.sub(r"(\w+)-\n(\w+)", r"\1\2", text)
93-
# Fix newlines in the middle of sentences
9498
text = re.sub(r"(?<!\n\s)\n(?!\s\n)", " ", text.strip())
95-
# Remove multiple newlines
9699
text = re.sub(r"\n\s*\n", "\n\n", text)
97100
output.append(text)
98101
return output
@@ -102,34 +105,73 @@ async def handle_websocket(websocket, path):
102105
while True:
103106
data = await websocket.recv()
104107
if isinstance(data, str):
105-
# if we are dealing with text, make it JSON
106-
objData = json.loads(data,object_hook=lambda d: SimpleNamespace(**d))
108+
objData = json.loads(data, object_hook=lambda d: SimpleNamespace(**d))
107109
if objData.msgType == "question":
108110
prompt = objData.data
109-
if objData.msgType == "question":
110-
response = await generate_ai_response(prompt)
111-
answer = response.data.inference_response.generated_texts[0].text
112-
buidJSON = {"msgType":"answer","data":answer}
113-
await websocket.send(json.dumps(buidJSON))
114-
# if it's not text, we have a binary and we will treat it as a PDF
115-
if not isinstance(data,str):
116-
# split the ArrayBuffer into metadata and the actual PDF file
111+
response = await generate_ai_response(prompt)
112+
113+
if model_type == 'llama':
114+
answer = response.data.chat_response.choices[0].message.content[0].text
115+
elif model_type == 'cohere':
116+
answer = response.data.chat_response.text
117+
else:
118+
answer = ""
119+
120+
buidJSON = {"msgType": "answer", "data": answer}
121+
await websocket.send(json.dumps(buidJSON))
122+
elif objData.msgType == "summary":
123+
pdfFileObj = BytesIO(objData.data)
124+
output = await parse_pdf(pdfFileObj)
125+
chunk_size = 512
126+
chunks = [' '.join(output[i:i + chunk_size]) for i in range(0, len(output), chunk_size)]
127+
128+
print(f"Processing {len(chunks)} chunks...")
129+
130+
summaries = []
131+
for index, chunk in enumerate(chunks):
132+
print(f"Processing chunk {index+1}/{len(chunks)}...")
133+
response = await generate_ai_response(f"Summarize: {chunk}")
134+
if model_type == 'llama':
135+
summary = response.data.chat_response.choices[0].message.content[0].text
136+
elif model_type == 'cohere':
137+
summary = response.data.chat_response.text
138+
else:
139+
summary = ""
140+
summaries.append(summary)
141+
142+
final_summary = ' '.join(summaries)
143+
buidJSON = {"msgType": "summary", "data": final_summary}
144+
await websocket.send(json.dumps(buidJSON))
145+
else:
117146
objData = data.split(b'\r\n\r\n')
118-
# decode the metadata and parse the JSON data. Creating Dict properties from the JSON
119-
metadata = json.loads(objData[0].decode('utf-8'),object_hook=lambda d: SimpleNamespace(**d))
147+
metadata = json.loads(objData[0].decode('utf-8'), object_hook=lambda d: SimpleNamespace(**d))
120148
pdfFileObj = BytesIO(objData[1])
121149
output = await parse_pdf(pdfFileObj)
122-
response = await generate_ai_summary(''.join(output),metadata.msgPrompt)
123-
summary = response.summary
124-
buidJSON = {"msgType":"summary","data": summary}
150+
chunk_size = 512
151+
chunks = [' '.join(output[i:i + chunk_size]) for i in range(0, len(output), chunk_size)]
152+
153+
print(f"Processing {len(chunks)} chunks...")
154+
155+
summaries = []
156+
for index, chunk in enumerate(chunks):
157+
print(f"Processing chunk {index+1}/{len(chunks)}...")
158+
response = await generate_ai_response(f"Summarize: {chunk}")
159+
if model_type == 'llama':
160+
summary = response.data.chat_response.choices[0].message.content[0].text
161+
elif model_type == 'cohere':
162+
summary = response.data.chat_response.text
163+
else:
164+
summary = ""
165+
summaries.append(summary)
166+
167+
final_summary = ' '.join(summaries)
168+
buidJSON = {"msgType": "summary", "data": final_summary}
125169
await websocket.send(json.dumps(buidJSON))
126170
except websockets.exceptions.ConnectionClosedOK as e:
127171
print(f"Connection closed: {e}")
128-
129-
172+
130173
async def start_server():
131-
await websockets.serve(handle_websocket, "localhost", 1986, max_size=200000000)
132-
174+
async with websockets.serve(handle_websocket, "localhost", 1986, max_size=200000000):
175+
await asyncio.Future() # run forever
133176

134-
asyncio.get_event_loop().run_until_complete(start_server())
135-
asyncio.get_event_loop().run_forever()
177+
asyncio.run(start_server())

0 commit comments

Comments
 (0)