|
1 | 1 | import json |
2 | 2 | import os |
3 | 3 | import boto3 |
4 | | -ssm = boto3.client('ssm') |
| 4 | +import uuid |
5 | 5 | sns = boto3.client('sns') |
6 | | -rekognition = boto3.client('rekognition') |
| 6 | +a2i = boto3.client('sagemaker-a2i-runtime') |
| 7 | +ssm = boto3.client('ssm') |
7 | 8 | dynamodb = boto3.client('dynamodb') |
8 | | -def get_parameters(): |
9 | | - response = ssm.get_parameters_by_path( |
10 | | - Path=os.environ['parameter_store_path'], |
11 | | - Recursive=True |
12 | | - ) |
13 | | - parameter_store = {} |
14 | | - for parameter in response['Parameters']: |
15 | | - parameter_name = (parameter['Name'].split('/'))[-1] |
16 | | - parameter_store[parameter_name] = parameter['Value'] |
17 | | - return parameter_store |
18 | | -def store_detection_results(dynamodb_table, detectlabel_request_id, detectlabel_date, s3_bucket_name, s3_object_key, s3_object_eTag, project_version, detected_label, confidence_level, minimum_confidence_level, A2I): |
19 | | - response = dynamodb.put_item( |
20 | | - TableName=dynamodb_table, |
21 | | - Item={ |
22 | | - 'DetectLabelRequestId': {'S':detectlabel_request_id}, |
23 | | - 'DetectLabelDate': {'S':detectlabel_date}, |
24 | | - 'S3Bucket': {'S':s3_bucket_name}, |
25 | | - 'S3ObjectKey': {'S':s3_object_key}, |
26 | | - 'S3ObjectEtag': {'S':s3_object_eTag}, |
27 | | - 'ProjectVersionArn': {'S':project_version}, |
28 | | - 'DectectedLabel': {'S':detected_label}, |
29 | | - 'DetectedConfidenceLevel': {'N':str(confidence_level)}, |
30 | | - 'MinimumConfidenceLevel': {'N':str(minimum_confidence_level)}, |
31 | | - 'A2IEnabled': {'BOOL':A2I} |
32 | | - } |
33 | | - ) |
| 9 | +def get_parameter(parameter_name): |
| 10 | + response = ssm.get_parameter(Name=parameter_name) |
| 11 | + parameter_value = json.loads(response['Parameter']['Value']) |
| 12 | + return parameter_value |
34 | 13 | def publish_message(sns_subject, sns_message, topic_arn): |
35 | 14 | response = sns.publish( |
36 | 15 | TopicArn=topic_arn, |
37 | 16 | Message=sns_message, |
38 | 17 | Subject=sns_subject |
39 | 18 | ) |
| 19 | +def append_a2i_request(dynamodb_table, detectlabel_request_id, human_loop_name, humanloop_request_id): |
| 20 | + response = dynamodb.update_item( |
| 21 | + TableName=dynamodb_table, |
| 22 | + Key={ |
| 23 | + 'DetectLabelRequestId': {'S':detectlabel_request_id} |
| 24 | + }, |
| 25 | + ExpressionAttributeNames={ |
| 26 | + '#HLN':'HumanLoopName', |
| 27 | + '#HLRI':'HumanLoopRequestId' |
| 28 | + }, |
| 29 | + ExpressionAttributeValues={ |
| 30 | + ':n': {'S':human_loop_name}, |
| 31 | + ':i': {'S':humanloop_request_id} |
| 32 | + }, |
| 33 | + UpdateExpression='SET #HLN=:n, #HLRI=:i' |
| 34 | + ) |
40 | 35 | def handler(event, context): |
41 | | - parameter_store = get_parameters() |
42 | | - a2i = True if parameter_store['Enable-A2I-Workflow'].lower().capitalize() == 'True' else False |
43 | | - sys_vars = json.loads(parameter_store['For-System-Use-Only']) |
44 | | - project_version = sys_vars['rekognition_project_version_arn'] |
| 36 | + parameter_name = os.environ['parameter_store_path'] + 'For-System-Use-Only' |
| 37 | + sys_vars = get_parameter(parameter_name) |
45 | 38 | dynamodb_table = sys_vars['dynamodb_table'] |
46 | | - minimum_confidence_level=parameter_store['Minimum-Label-Detection-Confidence'] |
47 | | - s3_bucket_name=event['s3event']['s3']['bucket']['name'] |
48 | | - s3_object_key=event['s3event']['s3']['object']['key'] |
49 | | - s3_object_eTag=event['s3event']['s3']['object']['eTag'] |
50 | | - s3_event_time=event['s3event']['eventTime'] |
51 | | - response = rekognition.detect_custom_labels( |
52 | | - ProjectVersionArn=project_version, |
53 | | - Image={ |
54 | | - 'S3Object': { |
55 | | - 'Bucket': s3_bucket_name, |
56 | | - 'Name': s3_object_key |
57 | | - } |
58 | | - }, |
59 | | - MaxResults=1, |
60 | | - MinConfidence=0 |
| 39 | + detectlabel_request_id=event['message']['ResponseMetadata']['RequestId'] |
| 40 | + s3_bucket_name=event['message']['s3event']['s3']['bucket']['name'] |
| 41 | + s3_object_key=event['message']['s3event']['s3']['object']['key'] |
| 42 | + detected_label=event['message']['CustomLabels'][0]['Name'] |
| 43 | + confidence_level=event['message']['CustomLabels'][0]['Confidence'] |
| 44 | + human_loop_name=str(uuid.uuid4()) |
| 45 | + response = a2i.start_human_loop( |
| 46 | + HumanLoopName = human_loop_name, |
| 47 | + FlowDefinitionArn = sys_vars['flow_definition_arn'], |
| 48 | + HumanLoopInput = { |
| 49 | + 'InputContent': json.dumps({ |
| 50 | + 'initialValue': confidence_level, |
| 51 | + 'detectLabelRequestId': detectlabel_request_id, |
| 52 | + 'taskObject': 's3://'+s3_bucket_name+'/'+s3_object_key |
| 53 | + }) |
| 54 | + } |
61 | 55 | ) |
62 | | - confidence_level = response['CustomLabels'][0]['Confidence'] |
63 | | - detected_label = response['CustomLabels'][0]['Name'] |
64 | | - detectlabel_request_id = response['ResponseMetadata']['RequestId'] |
65 | | - if confidence_level < float(parameter_store['Minimum-Label-Detection-Confidence']): |
66 | | - response['A2I'] = True if a2i else False |
67 | | - store_detection_results(dynamodb_table, detectlabel_request_id, s3_event_time, s3_bucket_name, s3_object_key, s3_object_eTag, project_version, detected_label, confidence_level, minimum_confidence_level, a2i) |
68 | | - publish_message('Rekgnition Custom Labels Detection Invoked', json.dumps(response), sys_vars['sns-topic']) |
69 | | - response['s3event'] = event['s3event'] |
| 56 | + humanloop_request_id=response['ResponseMetadata']['RequestId'] |
| 57 | + append_a2i_request(dynamodb_table, detectlabel_request_id, human_loop_name, humanloop_request_id) |
| 58 | + response['s3event'] = event['message']['s3event'] |
| 59 | + publish_message('A2I Human Loop Initiated', json.dumps(response), sys_vars['sns-topic']) |
70 | 60 | return { |
71 | 61 | 'message': response |
72 | 62 | } |
0 commit comments