Skip to content

Commit 74eaccc

Browse files
Test7
1 parent e39044e commit 74eaccc

File tree

15 files changed

+796
-478
lines changed

15 files changed

+796
-478
lines changed

.pylintrc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ max-line-length=120
2323
max-args=10
2424
max-locals=30
2525
max-branches=20
26-
max-lines=1500
2726
max-statements=100
2827

2928
[LOGGING]

backend/auth/auth_utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,17 @@ def get_authenticated_user_details(request_headers):
55
if "X-Ms-Client-Principal-Id" not in request_headers.keys():
66
# if it's not, assume we're in development mode and return a default user
77
from . import sample_user
8+
89
raw_user_object = sample_user.sample_user
910
else:
1011
# if it is, get the user details from the EasyAuth headers
1112
raw_user_object = {k: v for k, v in request_headers.items()}
1213

13-
user_object['user_principal_id'] = raw_user_object.get('X-Ms-Client-Principal-Id')
14-
user_object['user_name'] = raw_user_object.get('X-Ms-Client-Principal-Name')
15-
user_object['auth_provider'] = raw_user_object.get('X-Ms-Client-Principal-Idp')
16-
user_object['auth_token'] = raw_user_object.get('X-Ms-Token-Aad-Id-Token')
17-
user_object['client_principal_b64'] = raw_user_object.get('X-Ms-Client-Principal')
18-
user_object['aad_id_token'] = raw_user_object.get('X-Ms-Token-Aad-Id-Token')
14+
user_object["user_principal_id"] = raw_user_object.get("X-Ms-Client-Principal-Id")
15+
user_object["user_name"] = raw_user_object.get("X-Ms-Client-Principal-Name")
16+
user_object["auth_provider"] = raw_user_object.get("X-Ms-Client-Principal-Idp")
17+
user_object["auth_token"] = raw_user_object.get("X-Ms-Token-Aad-Id-Token")
18+
user_object["client_principal_b64"] = raw_user_object.get("X-Ms-Client-Principal")
19+
user_object["aad_id_token"] = raw_user_object.get("X-Ms-Token-Aad-Id-Token")
1920

2021
return user_object

backend/auth/sample_user.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
"Max-Forwards": "10",
1212
"Origin": "https://your_app_service.azurewebsites.net",
1313
"Referer": "https://your_app_service.azurewebsites.net/",
14-
"Sec-Ch-Ua": "\"Microsoft Edge\";v=\"113\", \"Chromium\";v=\"113\", \"Not-A.Brand\";v=\"24\"",
14+
"Sec-Ch-Ua": '"Microsoft Edge";v="113", "Chromium";v="113", "Not-A.Brand";v="24"',
1515
"Sec-Ch-Ua-Mobile": "?0",
16-
"Sec-Ch-Ua-Platform": "\"Windows\"",
16+
"Sec-Ch-Ua-Platform": '"Windows"',
1717
"Sec-Fetch-Dest": "empty",
1818
"Sec-Fetch-Mode": "cors",
1919
"Sec-Fetch-Site": "same-origin",
@@ -35,5 +35,5 @@
3535
"X-Ms-Token-Aad-Id-Token": "your_aad_id_token",
3636
"X-Original-Url": "/chatgpt",
3737
"X-Site-Deployment-Id": "your_app_service",
38-
"X-Waws-Unencoded-Url": "/chatgpt"
38+
"X-Waws-Unencoded-Url": "/chatgpt",
3939
}

backend/history/cosmosdbservice.py

Lines changed: 72 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,39 +4,58 @@
44
from azure.cosmos import exceptions
55

66

7-
class CosmosConversationClient():
8-
9-
def __init__(self, cosmosdb_endpoint: str, credential: any, database_name: str, container_name: str, enable_message_feedback: bool = False):
7+
class CosmosConversationClient:
8+
def __init__(
9+
self,
10+
cosmosdb_endpoint: str,
11+
credential: any,
12+
database_name: str,
13+
container_name: str,
14+
enable_message_feedback: bool = False,
15+
):
1016
self.cosmosdb_endpoint = cosmosdb_endpoint
1117
self.credential = credential
1218
self.database_name = database_name
1319
self.container_name = container_name
1420
self.enable_message_feedback = enable_message_feedback
1521
try:
16-
self.cosmosdb_client = CosmosClient(self.cosmosdb_endpoint, credential=credential)
22+
self.cosmosdb_client = CosmosClient(
23+
self.cosmosdb_endpoint, credential=credential
24+
)
1725
except exceptions.CosmosHttpResponseError as e:
1826
if e.status_code == 401:
1927
raise ValueError("Invalid credentials") from e
2028
else:
2129
raise ValueError("Invalid CosmosDB endpoint") from e
2230

2331
try:
24-
self.database_client = self.cosmosdb_client.get_database_client(database_name)
32+
self.database_client = self.cosmosdb_client.get_database_client(
33+
database_name
34+
)
2535
except exceptions.CosmosResourceNotFoundError:
2636
raise ValueError("Invalid CosmosDB database name")
2737

2838
try:
29-
self.container_client = self.database_client.get_container_client(container_name)
39+
self.container_client = self.database_client.get_container_client(
40+
container_name
41+
)
3042
except exceptions.CosmosResourceNotFoundError:
3143
raise ValueError("Invalid CosmosDB container name")
3244

3345
async def ensure(self):
34-
if not self.cosmosdb_client or not self.database_client or not self.container_client:
46+
if (
47+
not self.cosmosdb_client
48+
or not self.database_client
49+
or not self.container_client
50+
):
3551
return False, "CosmosDB client not initialized correctly"
3652
try:
3753
database_info = await self.database_client.read()
3854
except:
39-
return False, f"CosmosDB database {self.database_name} on account {self.cosmosdb_endpoint} not found"
55+
return (
56+
False,
57+
f"CosmosDB database {self.database_name} on account {self.cosmosdb_endpoint} not found",
58+
)
4059

4160
try:
4261
container_info = await self.container_client.read()
@@ -45,14 +64,14 @@ async def ensure(self):
4564

4665
return True, "CosmosDB client initialized successfully"
4766

48-
async def create_conversation(self, user_id, title=''):
67+
async def create_conversation(self, user_id, title=""):
4968
conversation = {
50-
'id': str(uuid.uuid4()),
51-
'type': 'conversation',
52-
'createdAt': datetime.utcnow().isoformat(),
53-
'updatedAt': datetime.utcnow().isoformat(),
54-
'userId': user_id,
55-
'title': title
69+
"id": str(uuid.uuid4()),
70+
"type": "conversation",
71+
"createdAt": datetime.utcnow().isoformat(),
72+
"updatedAt": datetime.utcnow().isoformat(),
73+
"userId": user_id,
74+
"title": title,
5675
}
5776
# TODO: add some error handling based on the output of the upsert_item call
5877
resp = await self.container_client.upsert_item(conversation)
@@ -69,9 +88,13 @@ async def upsert_conversation(self, conversation):
6988
return False
7089

7190
async def delete_conversation(self, user_id, conversation_id):
72-
conversation = await self.container_client.read_item(item=conversation_id, partition_key=user_id)
91+
conversation = await self.container_client.read_item(
92+
item=conversation_id, partition_key=user_id
93+
)
7394
if conversation:
74-
resp = await self.container_client.delete_item(item=conversation_id, partition_key=user_id)
95+
resp = await self.container_client.delete_item(
96+
item=conversation_id, partition_key=user_id
97+
)
7598
return resp
7699
else:
77100
return True
@@ -82,41 +105,36 @@ async def delete_messages(self, conversation_id, user_id):
82105
response_list = []
83106
if messages:
84107
for message in messages:
85-
resp = await self.container_client.delete_item(item=message['id'], partition_key=user_id)
108+
resp = await self.container_client.delete_item(
109+
item=message["id"], partition_key=user_id
110+
)
86111
response_list.append(resp)
87112
return response_list
88113

89-
async def get_conversations(self, user_id, limit, sort_order='DESC', offset=0):
90-
parameters = [
91-
{
92-
'name': '@userId',
93-
'value': user_id
94-
}
95-
]
114+
async def get_conversations(self, user_id, limit, sort_order="DESC", offset=0):
115+
parameters = [{"name": "@userId", "value": user_id}]
96116
query = f"SELECT * FROM c where c.userId = @userId and c.type='conversation' order by c.updatedAt {sort_order}"
97117
if limit is not None:
98118
query += f" offset {offset} limit {limit}"
99119

100120
conversations = []
101-
async for item in self.container_client.query_items(query=query, parameters=parameters):
121+
async for item in self.container_client.query_items(
122+
query=query, parameters=parameters
123+
):
102124
conversations.append(item)
103125

104126
return conversations
105127

106128
async def get_conversation(self, user_id, conversation_id):
107129
parameters = [
108-
{
109-
'name': '@conversationId',
110-
'value': conversation_id
111-
},
112-
{
113-
'name': '@userId',
114-
'value': user_id
115-
}
130+
{"name": "@conversationId", "value": conversation_id},
131+
{"name": "@userId", "value": user_id},
116132
]
117133
query = f"SELECT * FROM c where c.id = @conversationId and c.type='conversation' and c.userId = @userId"
118134
conversations = []
119-
async for item in self.container_client.query_items(query=query, parameters=parameters):
135+
async for item in self.container_client.query_items(
136+
query=query, parameters=parameters
137+
):
120138
conversations.append(item)
121139

122140
# if no conversations are found, return None
@@ -127,54 +145,52 @@ async def get_conversation(self, user_id, conversation_id):
127145

128146
async def create_message(self, uuid, conversation_id, user_id, input_message: dict):
129147
message = {
130-
'id': uuid,
131-
'type': 'message',
132-
'userId': user_id,
133-
'createdAt': datetime.utcnow().isoformat(),
134-
'updatedAt': datetime.utcnow().isoformat(),
135-
'conversationId': conversation_id,
136-
'role': input_message['role'],
137-
'content': input_message['content']
148+
"id": uuid,
149+
"type": "message",
150+
"userId": user_id,
151+
"createdAt": datetime.utcnow().isoformat(),
152+
"updatedAt": datetime.utcnow().isoformat(),
153+
"conversationId": conversation_id,
154+
"role": input_message["role"],
155+
"content": input_message["content"],
138156
}
139157

140158
if self.enable_message_feedback:
141-
message['feedback'] = ''
159+
message["feedback"] = ""
142160

143161
resp = await self.container_client.upsert_item(message)
144162
if resp:
145163
# update the parent conversations's updatedAt field with the current message's createdAt datetime value
146164
conversation = await self.get_conversation(user_id, conversation_id)
147165
if not conversation:
148166
return "Conversation not found"
149-
conversation['updatedAt'] = message['createdAt']
167+
conversation["updatedAt"] = message["createdAt"]
150168
await self.upsert_conversation(conversation)
151169
return resp
152170
else:
153171
return False
154172

155173
async def update_message_feedback(self, user_id, message_id, feedback):
156-
message = await self.container_client.read_item(item=message_id, partition_key=user_id)
174+
message = await self.container_client.read_item(
175+
item=message_id, partition_key=user_id
176+
)
157177
if message:
158-
message['feedback'] = feedback
178+
message["feedback"] = feedback
159179
resp = await self.container_client.upsert_item(message)
160180
return resp
161181
else:
162182
return False
163183

164184
async def get_messages(self, user_id, conversation_id):
165185
parameters = [
166-
{
167-
'name': '@conversationId',
168-
'value': conversation_id
169-
},
170-
{
171-
'name': '@userId',
172-
'value': user_id
173-
}
186+
{"name": "@conversationId", "value": conversation_id},
187+
{"name": "@userId", "value": user_id},
174188
]
175189
query = f"SELECT * FROM c WHERE c.conversationId = @conversationId AND c.type='message' AND c.userId = @userId ORDER BY c.timestamp ASC"
176190
messages = []
177-
async for item in self.container_client.query_items(query=query, parameters=parameters):
191+
async for item in self.container_client.query_items(
192+
query=query, parameters=parameters
193+
):
178194
messages.append(item)
179195

180196
return messages

backend/security/ms_defender_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22

33

44
def get_msdefender_user_json(authenticated_user_details, request_headers):
5-
auth_provider = authenticated_user_details.get('auth_provider')
6-
source_ip = request_headers.get('X-Forwarded-For', request_headers.get('Remote-Addr', ''))
5+
auth_provider = authenticated_user_details.get("auth_provider")
6+
source_ip = request_headers.get(
7+
"X-Forwarded-For", request_headers.get("Remote-Addr", "")
8+
)
79
user_args = {
8-
"EndUserId": authenticated_user_details.get('user_principal_id'),
10+
"EndUserId": authenticated_user_details.get("user_principal_id"),
911
"EndUserIdType": "EntraId" if auth_provider == "aad" else auth_provider,
10-
"SourceIp": source_ip.split(':')[0], # remove port
12+
"SourceIp": source_ip.split(":")[0], # remove port
1113
}
1214
return json.dumps(user_args)

0 commit comments

Comments
 (0)