|
2 | 2 | import shutil |
3 | 3 | import pickle |
4 | 4 |
|
5 | | -import cv2 |
6 | | -import face_recognition |
7 | 5 | from flask import request |
8 | 6 | from flask_restful import Resource |
9 | 7 | from flask_uploads import UploadNotAllowed |
10 | 8 | from flask_jwt_extended import jwt_required |
11 | 9 |
|
12 | 10 | from src.db import Session |
13 | 11 | from src.libs import image_helper |
| 12 | +from src.libs.train_classifier import TrainClassifier |
14 | 13 | from src.libs.strings import gettext |
15 | 14 | from src.models import StudentModel |
16 | 15 | from src.schemas import StudentSchema, ImageSchema |
17 | | -from src.settings import DATASET_PATH, ENCODINGS_FILE, DLIB_MODEL |
| 16 | +from src.settings import DATASET_PATH, ENCODINGS_FILE |
18 | 17 |
|
19 | 18 |
|
20 | 19 | student_schema = StudentSchema() |
@@ -101,68 +100,10 @@ def post(cls, student_id: int): |
101 | 100 | save_as_filename = filename + extension |
102 | 101 | image_path = image_helper.save_image(data["image"], folder=folder, name=save_as_filename) |
103 | 102 | basename = image_helper.get_basename(image_path) |
104 | | - return {"message": gettext("image_uploaded").format(basename)}, 201 |
105 | 103 | except UploadNotAllowed: |
106 | 104 | extension = image_helper.get_extension(data["image"]) |
107 | 105 | return {"message": gettext("image_illegal_extension").format(extension)}, 400 |
108 | 106 |
|
109 | | - |
110 | | -class TrainClassifier(Resource): |
111 | | - """Train KNN Classifier by storing results in `files/encodings.pickle` file""" |
112 | | - # TODO: Store encodings in SQL database rather than `files/encodings.pickle` file |
113 | | - @classmethod |
114 | | - @jwt_required |
115 | | - def get(cls): |
116 | | - try: |
117 | | - print("[INFO] loading encodings...") |
118 | | - data = pickle.loads(open(ENCODINGS_FILE, "rb").read()) |
119 | | - # initialize the list of known encodings and known names |
120 | | - known_encodings = data["encodings"] |
121 | | - known_ids = data["ids"] |
122 | | - except FileNotFoundError: |
123 | | - # initialize the list of known encodings and known names |
124 | | - known_encodings = [] |
125 | | - known_ids = [] |
126 | | - |
127 | | - # get single unique ids by converting into set |
128 | | - # for each _id convert it into int |
129 | | - unique_ids = [int(_id) for _id in set(known_ids)] |
130 | | - |
131 | | - # get all id_paths and join the path of the parent folder to each id_path |
132 | | - id_paths = [os.path.join(DATASET_PATH, f) for f in os.listdir(DATASET_PATH)] |
133 | | - # print(">>> ID paths:", id_paths) |
134 | | - |
135 | | - # now looping through all the id_paths and loading the images in that id_path |
136 | | - for id_path in id_paths: |
137 | | - # getting the ID from the image |
138 | | - _id = int(os.path.split(id_path)[1]) |
139 | | - if _id in unique_ids: |
140 | | - continue |
141 | | - # grab the paths to the input images of that ID |
142 | | - image_paths = [os.path.join(id_path, f) for f in os.listdir(id_path)] |
143 | | - for i, image_path in enumerate(image_paths): |
144 | | - print(f"[INFO] ID: {_id}, processing image {i + 1}/{len(image_paths)}") |
145 | | - # load the input image and convert it from RGB (OpenCV ordering) |
146 | | - # to dlib ordering (RGB) |
147 | | - image = cv2.imread(image_path) |
148 | | - rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
149 | | - |
150 | | - # detect the (x, y)-coordinates of the bounding boxes |
151 | | - # corresponding to each face in the input frame, then compute |
152 | | - # the facial embeddings for each face |
153 | | - boxes = face_recognition.face_locations(rgb, model=DLIB_MODEL) |
154 | | - # compute the facial embedding for the face |
155 | | - encodings = face_recognition.face_encodings(rgb, boxes) |
156 | | - # loop over the encodings |
157 | | - for encoding in encodings: |
158 | | - # add each encoding + name to our set of known names and |
159 | | - # encodings |
160 | | - known_encodings.append(encoding) |
161 | | - known_ids.append(_id) |
162 | | - |
163 | | - # dump the facial encodings + names to disk |
164 | | - print("[INFO] serializing encodings...") |
165 | | - data = {"encodings": known_encodings, "ids": known_ids} |
166 | | - f = open(ENCODINGS_FILE, "wb") |
167 | | - f.write(pickle.dumps(data)) |
168 | | - f.close() |
| 107 | + # train images when submitted successfully |
| 108 | + TrainClassifier.train() |
| 109 | + return {"message": gettext("image_uploaded").format(basename)}, 201 |
0 commit comments