|
1 | 1 | from flask_restplus import Namespace, Resource, reqparse |
2 | 2 | from werkzeug.datastructures import FileStorage |
3 | | -from imantics import Image as ImanticsImage |
| 3 | +from imantics import Mask |
4 | 4 | from flask_login import login_required |
5 | 5 | from ..config import Config |
6 | 6 | from PIL import Image |
| 7 | +from ..models import ImageModel |
7 | 8 |
|
8 | 9 | import os |
9 | 10 |
|
10 | | -MODEL_LOADED = len(Config.MASK_RCNN_FILE) != 0 and os.path.isfile(Config.MASK_RCNN_FILE) |
11 | | - |
12 | | -if MODEL_LOADED: |
| 11 | +MASKRCNN_LOADED = os.path.isfile(Config.MASK_RCNN_FILE) |
| 12 | +if MASKRCNN_LOADED: |
13 | 13 | from ..util.mask_rcnn import model as maskrcnn |
14 | 14 | else: |
15 | 15 | print("MaskRCNN model is disabled.", flush=True) |
16 | 16 |
|
| 17 | +DEXTR_LOADED = os.path.isfile(Config.DEXTR_FILE) |
| 18 | +if DEXTR_LOADED: |
| 19 | + from ..util.dextr import model as dextr |
| 20 | +else: |
| 21 | + print("DEXTR model is disabled.", flush=True) |
| 22 | + |
17 | 23 | api = Namespace('model', description='Model related operations') |
18 | 24 |
|
19 | 25 |
|
20 | 26 | image_upload = reqparse.RequestParser() |
21 | 27 | image_upload.add_argument('image', location='files', type=FileStorage, required=True, help='Image') |
22 | 28 |
|
| 29 | +dextr_args = reqparse.RequestParser() |
| 30 | +dextr_args.add_argument('points', location='json', type=list, required=True) |
| 31 | +dextr_args.add_argument('padding', location='json', type=int, default=50) |
| 32 | +dextr_args.add_argument('threshold', location='json', type=int, default=80) |
| 33 | + |
| 34 | + |
| 35 | +@api.route('/dextr/<int:image_id>') |
| 36 | +class MaskRCNN(Resource): |
| 37 | + |
| 38 | + @login_required |
| 39 | + @api.expect(dextr_args) |
| 40 | + def post(self, image_id): |
| 41 | + """ COCO data test """ |
| 42 | + |
| 43 | + if not DEXTR_LOADED: |
| 44 | + return {"disabled": True, "message": "DEXTR is disabled"}, 400 |
| 45 | + |
| 46 | + args = dextr_args.parse_args() |
| 47 | + points = args.get('points') |
| 48 | + padding = args.get('padding') |
| 49 | + threshold = args.get('threshold') |
| 50 | + |
| 51 | + if len(points) != 4: |
| 52 | + return {"message": "Invalid points entered"}, 400 |
| 53 | + |
| 54 | + image_model = ImageModel.objects(id=image_id).first() |
| 55 | + if not image_model: |
| 56 | + return {"message": "Invalid image ID"}, 400 |
| 57 | + |
| 58 | + image = Image.open(image_model.path) |
| 59 | + result = dextr.predict_mask(image, points) |
| 60 | + |
| 61 | + return { "segmentaiton": Mask(result).polygons().segmentation } |
| 62 | + |
| 63 | + |
23 | 64 | @api.route('/maskrcnn') |
24 | 65 | class MaskRCNN(Resource): |
25 | 66 |
|
26 | 67 | @login_required |
27 | 68 | @api.expect(image_upload) |
28 | 69 | def post(self): |
29 | 70 | """ COCO data test """ |
30 | | - if not MODEL_LOADED: |
| 71 | + if not MASKRCNN_LOADED: |
31 | 72 | return {"disabled": True, "coco": {}} |
32 | 73 |
|
33 | 74 | args = image_upload.parse_args() |
|
0 commit comments