Skip to content

Commit d429620

Browse files
author
Zhi Zhou
committed
Add content_understanding_face_client
1 parent 03e770d commit d429620

File tree

1 file changed

+316
-0
lines changed

1 file changed

+316
-0
lines changed
Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
import base64
2+
import requests
3+
from requests.models import Response
4+
import logging
5+
6+
7+
class AzureContentUnderstandingFaceClient:
8+
def __init__(
9+
self,
10+
endpoint: str,
11+
api_version: str,
12+
subscription_key: str = None,
13+
token_provider: callable = None,
14+
x_ms_useragent: str = "cu-face-sample-code",
15+
):
16+
if not subscription_key and not token_provider:
17+
raise ValueError(
18+
"Either subscription key or token provider must be provided."
19+
)
20+
if not api_version:
21+
raise ValueError("API version must be provided.")
22+
if not endpoint:
23+
raise ValueError("Endpoint must be provided.")
24+
25+
self._endpoint = endpoint.rstrip("/")
26+
self._api_version = api_version
27+
self._logger = logging.getLogger(__name__)
28+
self._headers = self._get_headers(
29+
subscription_key, token_provider(), x_ms_useragent
30+
)
31+
32+
def _get_face_url(self, endpoint, api_version, action):
33+
return (
34+
f"{endpoint}/contentunderstanding/faces:{action}?api-version={api_version}"
35+
)
36+
37+
def _get_person_directory_url(self, endpoint, api_version, path=None):
38+
url = f"{endpoint}/contentunderstanding/personDirectories"
39+
if path:
40+
url += f"/{path}"
41+
return f"{url}?api-version={api_version}"
42+
43+
def _get_headers(self, subscription_key, api_token, x_ms_useragent):
44+
"""Returns the headers for the HTTP requests.
45+
Args:
46+
subscription_key (str): The subscription key for the service.
47+
api_token (str): The API token for the service.
48+
Returns:
49+
dict: A dictionary containing the headers for the HTTP requests.
50+
"""
51+
headers = (
52+
{"Ocp-Apim-Subscription-Key": subscription_key}
53+
if subscription_key
54+
else {"Authorization": f"Bearer {api_token}"}
55+
)
56+
headers["x-ms-useragent"] = x_ms_useragent
57+
return headers
58+
59+
def _handle_response(self, response: Response, action: str):
60+
if response.status_code == 204:
61+
self._logger.info(f"{action} completed successfully with status 204.")
62+
return None
63+
if response.status_code != 200:
64+
self._logger.error(
65+
f"Error in {action}: {response.status_code} - {response.text}"
66+
)
67+
raise Exception(
68+
f"Error in {action}: {response.status_code} - {response.text}"
69+
)
70+
return response.json()
71+
72+
def detect_faces(self, url: str = None, data: str = None):
73+
request_body = {"url": url, "data": data}
74+
response = requests.post(
75+
self._get_face_url(self._endpoint, self._api_version, "detect"),
76+
headers=self._headers,
77+
json=request_body,
78+
)
79+
return self._handle_response(response, "detect_faces")
80+
81+
def compare_faces(self, data1: str, data2: str):
82+
request_body = {
83+
"faceSource1": {"data": data1},
84+
"faceSource2": {"data": data2},
85+
}
86+
response = requests.post(
87+
self._get_face_url(self._endpoint, self._api_version, "compare"),
88+
headers=self._headers,
89+
json=request_body,
90+
)
91+
return self._handle_response(response, "compare_faces")
92+
93+
def get_person_directories(self):
94+
response = requests.get(
95+
self._get_person_directory_url(self._endpoint, self._api_version),
96+
headers=self._headers,
97+
)
98+
return self._handle_response(response, "get_person_directories")
99+
100+
def get_person_directory(self, person_directory_id: str):
101+
response = requests.get(
102+
self._get_person_directory_url(
103+
self._endpoint, self._api_version, person_directory_id
104+
),
105+
headers=self._headers,
106+
)
107+
return self._handle_response(response, "get_person_directory")
108+
109+
def create_person_directory(
110+
self, person_directory_id: str, description: str = None, tags: dict = None
111+
):
112+
request_body = {"description": description, "tags": tags}
113+
response = requests.put(
114+
self._get_person_directory_url(
115+
self._endpoint, self._api_version, person_directory_id
116+
),
117+
headers=self._headers,
118+
json=request_body,
119+
)
120+
return self._handle_response(response, "create_person_directory")
121+
122+
def update_person_directory(
123+
self, person_directory_id: str, description: str = None, tags: dict = None
124+
):
125+
request_body = {"description": description, "tags": tags}
126+
response = requests.patch(
127+
self._get_person_directory_url(
128+
self._endpoint, self._api_version, person_directory_id
129+
),
130+
headers=self._headers,
131+
json=request_body,
132+
)
133+
return self._handle_response(response, "update_person_directory")
134+
135+
def delete_person_directory(self, person_directory_id: str):
136+
response = requests.delete(
137+
self._get_person_directory_url(
138+
self._endpoint, self._api_version, person_directory_id
139+
),
140+
headers=self._headers,
141+
)
142+
return self._handle_response(response, "delete_person_directory")
143+
144+
def get_persons(self, person_directory_id: str):
145+
response = requests.get(
146+
self._get_person_directory_url(
147+
self._endpoint, self._api_version, f"{person_directory_id}/persons"
148+
),
149+
headers=self._headers,
150+
)
151+
return self._handle_response(response, "get_persons")
152+
153+
def get_person(self, person_directory_id: str, person_id: str):
154+
response = requests.get(
155+
self._get_person_directory_url(
156+
self._endpoint,
157+
self._api_version,
158+
f"{person_directory_id}/persons/{person_id}",
159+
),
160+
headers=self._headers,
161+
)
162+
return self._handle_response(response, "get_person")
163+
164+
def add_person(
165+
self, person_directory_id: str, tags: dict = None, face_ids: list = None
166+
):
167+
request_body = {"tags": tags, "faceIds": face_ids}
168+
response = requests.post(
169+
self._get_person_directory_url(
170+
self._endpoint,
171+
self._api_version,
172+
f"{person_directory_id}/persons",
173+
),
174+
headers=self._headers,
175+
json=request_body,
176+
)
177+
return self._handle_response(response, "add_person")
178+
179+
def update_person(
180+
self,
181+
person_directory_id: str,
182+
person_id: str,
183+
tags: dict = None,
184+
face_ids: list = None,
185+
):
186+
request_body = {"tags": tags, "faceIds": face_ids}
187+
response = requests.patch(
188+
self._get_person_directory_url(
189+
self._endpoint,
190+
self._api_version,
191+
f"{person_directory_id}/persons/{person_id}",
192+
),
193+
headers=self._headers,
194+
json=request_body,
195+
)
196+
return self._handle_response(response, "update_person")
197+
198+
def delete_person(self, person_directory_id: str, person_id: str):
199+
response = requests.delete(
200+
self._get_person_directory_url(
201+
self._endpoint,
202+
self._api_version,
203+
f"{person_directory_id}/persons/{person_id}",
204+
),
205+
headers=self._headers,
206+
)
207+
return self._handle_response(response, "delete_person")
208+
209+
def get_faces(self, person_directory_id: str):
210+
response = requests.get(
211+
self._get_person_directory_url(
212+
self._endpoint,
213+
self._api_version,
214+
f"{person_directory_id}/faces",
215+
),
216+
headers=self._headers,
217+
)
218+
return self._handle_response(response, "get_faces")
219+
220+
def get_face(self, person_directory_id: str, face_id: str):
221+
response = requests.get(
222+
self._get_person_directory_url(
223+
self._endpoint,
224+
self._api_version,
225+
f"{person_directory_id}/faces/{face_id}",
226+
),
227+
headers=self._headers,
228+
)
229+
return self._handle_response(response, "get_face")
230+
231+
def add_face(self, person_directory_id: str, data: str):
232+
request_body = {"faceSource": {"data": data}}
233+
response = requests.post(
234+
self._get_person_directory_url(
235+
self._endpoint,
236+
self._api_version,
237+
f"{person_directory_id}/faces",
238+
),
239+
headers=self._headers,
240+
json=request_body,
241+
)
242+
return self._handle_response(response, "add_face")
243+
244+
def update_face(self, person_directory_id: str, face_id: str, person_id: str):
245+
request_body = {"personId": person_id}
246+
response = requests.patch(
247+
self._get_person_directory_url(
248+
self._endpoint,
249+
self._api_version,
250+
f"{person_directory_id}/faces/{face_id}",
251+
),
252+
headers=self._headers,
253+
json=request_body,
254+
)
255+
return self._handle_response(response, "update_face")
256+
257+
def delete_face(self, person_directory_id: str, face_id: str):
258+
response = requests.delete(
259+
self._get_person_directory_url(
260+
self._endpoint,
261+
self._api_version,
262+
f"{person_directory_id}/faces/{face_id}",
263+
),
264+
headers=self._headers,
265+
)
266+
return self._handle_response(response, "delete_face")
267+
268+
def identify_person(self, person_directory_id: str, data: str):
269+
request_body = {
270+
"faceSource": {"data": data},
271+
}
272+
response = requests.post(
273+
self._get_person_directory_url(
274+
self._endpoint,
275+
self._api_version,
276+
f"{person_directory_id}/persons:identify",
277+
),
278+
headers=self._headers,
279+
json=request_body,
280+
)
281+
return self._handle_response(response, "identify")
282+
283+
def verify_person(self, person_directory_id: str, person_id: str, data: str):
284+
request_body = {
285+
"faceSource": {"data": data},
286+
}
287+
response = requests.post(
288+
self._get_person_directory_url(
289+
self._endpoint,
290+
self._api_version,
291+
f"{person_directory_id}/persons/{person_id}:verify",
292+
),
293+
headers=self._headers,
294+
json=request_body,
295+
)
296+
return self._handle_response(response, "verify")
297+
298+
def find_similar_faces(self, person_directory_id: str, data: str):
299+
request_body = {
300+
"faceSource": {"data": data},
301+
}
302+
response = requests.post(
303+
self._get_person_directory_url(
304+
self._endpoint,
305+
self._api_version,
306+
f"{person_directory_id}/faces:find",
307+
),
308+
headers=self._headers,
309+
json=request_body,
310+
)
311+
return self._handle_response(response, "find_similar_faces")
312+
313+
def read_file_to_base64(file_path: str) -> str:
314+
with open(file_path, "rb") as f:
315+
file_data = f.read()
316+
return base64.b64encode(file_data).decode("utf-8")

0 commit comments

Comments
 (0)