diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..980524a --- /dev/null +++ b/.travis.yml @@ -0,0 +1,9 @@ +language: python +python: + - "3.6" +# command to install dependencies +install: + - python setup.py develop +# command to run tests +script: + - pytest --doctest-module diff --git a/README.md b/README.md index b1d036c..7c471bd 100644 --- a/README.md +++ b/README.md @@ -29,25 +29,23 @@ Instead of running these commands manually you can run the ./setup.sh script whi Or if you want to run the commands manually... -``` -# From the root of the repo go to ./fullEndToEndDemo -cd ./fullEndToEndDemo - -# Grab all the dependencies, this install is pretty huge -sudo apt-get update -sudo apt-get install git cmake g++ redis-server libboost-all-dev libopencv-dev python-opencv python-numpy python-scipy -y - -#Make it -cmake . -make - -# This step is optional. It removes a pointless annoying error opencv spits out -# About: https://stackoverflow.com/questions/12689304/ctypes-error-libdc1394-error-failed-to-initialize-libdc1394 -sudo ln /dev/null /dev/raw1394 - -# Then run either ./runDemo1.sh or ./runDemo2.sh to run the demo - - +```console +$ # From the root of the repo go to ./fullEndToEndDemo +$ cd ./fullEndToEndDemo +$ +$ # Grab all the dependencies, this install is pretty huge +$ sudo apt-get update +$ sudo apt-get install git cmake g++ redis-server libboost-all-dev libopencv-dev python-opencv python-numpy python-scipy -y +$ +$ #Make it +$ cmake . +$ make +$ +$ # This step is optional. It removes a pointless annoying error opencv spits out +$ # About: https://stackoverflow.com/questions/12689304/ctypes-error-libdc1394-error-failed-to-initialize-libdc1394 +$ sudo ln /dev/null /dev/raw1394 +$ +$ # Then run either ./runDemo1.sh or ./runDemo2.sh to run the demo ``` # Python setup @@ -58,19 +56,19 @@ This setup was tested on a newly deployed vm on Ubuntu 18.04 LTS, YMMV on differ To use python package, do the following: -``` -sudo apt-get update -sudo apt-get install python3-pip python3-opencv redis-server -y - -# On some systems this path is missing -# read more here: https://github.com/pypa/pip/issues/3813 -PATH="$PATH:~/.local/bin" - -#cd to project directory -pip3 install . +```console +$ sudo apt-get update +$ sudo apt-get install python3-pip python3-opencv redis-server -y +$ +$ # On some systems this path is missing +$ # read more here: https://github.com/pypa/pip/issues/3813 +$ PATH="$PATH:~/.local/bin" +$ +$ # cd to project directory +$ pip3 install . ``` -You also need install redis. +if you got permission error, install it under virtual env or use `--user` flag. # Demo 1 @@ -94,8 +92,8 @@ The demo takes 2 minutes (1 minute 38 seconds*) to run on a quad core VM *Thanks to [meowcoder](https://github.com/meowcoder) for the speed up! -``` -user@instance-1:~/transformationInvariantImageSearch/fullEndToEndDemo$ time ./runDemo1.sh +```console +$ time ./fullEndToEndDemo/runDemo1.sh Loading image: inputImages/cat1.png ... done Added 46725 image fragments to DB Loading image: inputImages/cat2.png ... done @@ -135,8 +133,8 @@ sys 0m6.592s python example ```console -$ time transformation-invariant-image-search insert fullEndToEndDemo/inputImages/cat* && \ - time transformation-invariant-image-search lookup fullEndToEndDemo/inputImages/cat_original.png +$ time transformation-invariant-image-search insert fullEndToEndDemo/inputImages/cat* && \ +$ time transformation-invariant-image-search lookup fullEndToEndDemo/inputImages/cat_original.png loading fullEndToEndDemo/inputImages/cat1.png 100%|██| 3/3 [00:07<00:00, 2.66s/it] @@ -219,8 +217,8 @@ Here the two images mona.jpg and van_gogh.jpg are inserted into the database and *Thanks to [meowcoder](https://github.com/meowcoder) for the speed up! -``` -user@instance-1:~/transformationInvariantImageSearch/fullEndToEndDemo$ time ./runDemo2.sh +```console +$ time ./fullEndToEndDemo/runDemo2.sh Loading image: ./inputImages/mona.jpg ... done Added 26991 image fragments to DB Loading image: ./inputImages/van_gogh.jpg ... done @@ -239,8 +237,9 @@ sys 0m18.224s python example ```console -$ time transformation-invariant-image-search insert ./fullEndToEndDemo/inputImages/mona.jpg ./fullEndToEndDemo/inputImages/van_gogh.jpg && \ - time transformation-invariant-image-search lookup ./fullEndToEndDemo/inputImages/monaComposite.jpg +$ time transformation-invariant-image-search insert \ +$ ./fullEndToEndDemo/inputImages/mona.jpg ./fullEndToEndDemo/inputImages/van_gogh.jpg && \ +$ time transformation-invariant-image-search lookup ./fullEndToEndDemo/inputImages/monaComposite.jpg loading ./fullEndToEndDemo/inputImages/mona.jpg 100%|███| 3/3 [00:03<00:00, 1.24s/it] diff --git a/setup.py b/setup.py index 3d56dba..2330d6d 100644 --- a/setup.py +++ b/setup.py @@ -1,10 +1,6 @@ #!/usr/bin/env python from setuptools import setup, find_packages -""" -TODO -- copy or link `python` folder to `transformation_invariant_image_search` -""" def readme(): with open('README.md') as f: @@ -29,16 +25,29 @@ def readme(): zip_safe=False, python_requires='>=3.6', install_requires=[ + 'appdirs>=1.4.3', + 'Flask-Admin==1.5.3', + 'Flask-SQLAlchemy>=2.3.2', + 'Flask>=1.0.2', 'hiredis', 'numpy', + 'opencv-python>=4.0.0.21', + 'Pillow>=5.4.1', 'redis', 'scikit-learn', 'scipy', + 'SQLAlchemy-Utils>=0.33.11', 'tqdm>=4.29.1', ], + extras_require={ + 'dev': [ + 'docutils==0.14', + 'pytest==4.2.0', + ], + }, entry_points={ 'console_scripts': [ - 'transformation-invariant-image-search = transformation_invariant_image_search.main:main'] + 'transformation-invariant-image-search = transformation_invariant_image_search.main:cli'] }, classifiers=[ 'Development Status :: 3 - Alpha', diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..b7980d3 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,98 @@ +import json +import os +import shutil +import tempfile + +from click.testing import CliRunner +from flask import current_app +import click +import pytest + +from transformation_invariant_image_search import main + + +@pytest.fixture +def client(): + db_fd, config_db = tempfile.mkstemp() + image_fd = tempfile.mkdtemp() + db_uri = 'sqlite:///{}'.format(config_db) + app = main.create_app(db_uri=db_uri, image_dir=image_fd) + app.config['DATABASE'] = config_db + app.config['TESTING'] = True + client = app.test_client() + + yield client + + os.close(db_fd) + os.unlink(app.config['DATABASE']) + shutil.rmtree(image_fd) + + +def test_empty_db(client): + """Start with a blank database.""" + rv = client.get('/') + assert b'Home - Transformation Invariant Image Search' in rv.data + + +def test_checksum_get(client): + """test checksum with a blank database.""" + url = '/api/checksum' + rv = client.get(url) + assert rv.get_json() == [] + + +def test_checksum_post(client): + """Start with a blank database.""" + csm_value = '54abb6e1eb59cccf61ae356aff7e491894c5ca606dfda4240d86743424c65faf' + url = '/api/checksum' + exp_dict = dict(value=csm_value, id=1, ext='png', trash=False) + rv = client.post(url, data=dict(value=csm_value, ext='png')) + assert rv.get_json() == exp_dict + rv = client.get(url) + assert rv.get_json() == [exp_dict] + + +def test_image_post(client): + url = '/api/image' + filename = 'fullEndToEndDemo/inputImages/cat_original.png' + csm_value = '54abb6e1eb59cccf61ae356aff7e491894c5ca606dfda4240d86743424c65faf' + ext = 'png' + exp_dict = dict(id=1, value=csm_value, ext=ext, trash=False) + rv = client.post(url) + assert rv.get_json()['error'] + file_data = {'file': open(filename, 'rb')} + rv = client.post(url, data=file_data) + post_exp_dict = exp_dict.copy() + post_exp_dict['url'] = ['http://localhost/i/{}.{}'.format(csm_value, ext)] + assert rv.get_json() == post_exp_dict + image_dir = client.application.config.get('IMAGE_DIR') + exp_dst_file = os.path.join(image_dir, csm_value[:2], '{}.{}'.format(csm_value, ext)) + assert os.path.isfile(exp_dst_file) + rv = client.get(url) + assert rv.get_json() == [exp_dict] + + +def test_upload_api(client): + filename = 'fullEndToEndDemo/inputImages/cat1.png' + upload_url = '/api/image' + rv = client.post(upload_url, data={'file': open(filename, 'rb')}) + assert rv.get_json() == { + 'ext': 'png', 'id': 1, 'trash': False, + 'url': ['http://localhost/i/4aba099f752d609aad2ed4c28f972ae96d02ad2579d0dd3f16b1ac29a88caf6d.png'], + 'value': '4aba099f752d609aad2ed4c28f972ae96d02ad2579d0dd3f16b1ac29a88caf6d' + } + + +@pytest.mark.parametrize( + 'args,word', + [ + ('--help', 'Usage:'), + ('--version', 'Transformation Invariant Image Search') + ] +) +def test_cli(args, word): + runner = CliRunner() + result = runner.invoke(main.cli, [args]) + assert result.exit_code == 0 + if word is not None: + assert word in result.output diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..2125daf --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,13 @@ +def test_checksum(): + from transformation_invariant_image_search import models, main + app = main.create_app(db_uri='sqlite://') + csm_value = '54abb6e1eb59cccf61ae356aff7e491894c5ca606dfda4240d86743424c65faf' + with app.app_context(): + models.DB.create_all() + m = models.Checksum(value=csm_value, ext='png') + models.DB.session.add(m) + models.DB.session.commit() + assert m.id == 1 + + res = models.DB.session.query(models.Checksum).filter_by(id=1).first() + assert res.value == csm_value diff --git a/transformation_invariant_image_search/keypoints.py b/transformation_invariant_image_search/keypoints.py index c46716a..6407b7e 100644 --- a/transformation_invariant_image_search/keypoints.py +++ b/transformation_invariant_image_search/keypoints.py @@ -56,6 +56,16 @@ def recolour(img, gauss_width=41): def compute_keypoints(img): + """Compute keypoints. + + >>> filename = 'fullEndToEndDemo/inputImages/cat_original.png' + >>> img = cv2.imread(filename) + >>> res = compute_keypoints(img) + >>> len(res) == 50 + True + >>> sorted(res)[0] + (1.0, 26.0) + """ gauss_width = 21 img = recolour(img, gauss_width) b, _, _ = cv2.split(img) diff --git a/transformation_invariant_image_search/main.py b/transformation_invariant_image_search/main.py index 84d2097..e5b8b00 100644 --- a/transformation_invariant_image_search/main.py +++ b/transformation_invariant_image_search/main.py @@ -2,20 +2,72 @@ Usage: main.py lookup ... main.py insert ... """ -import sys -import multiprocessing from collections import Counter from os import cpu_count +import hashlib +import multiprocessing +import os +import platform +import shutil +import sys +import tempfile +import pathlib +import logging +from appdirs import user_data_dir +from flask.cli import FlaskGroup +from flask_admin import Admin, AdminIndexView +from flask_sqlalchemy import SQLAlchemy +from PIL import Image +from sqlalchemy_utils import database_exists, create_database +import click import cv2 -import redis +import flask import numpy as np +import redis +import tqdm +from flask import ( + current_app, + Flask, + jsonify, + request, + send_from_directory, + url_for, +) +from . import models from .keypoints import compute_keypoints -from .phash import triangles_from_keypoints, hash_triangles +from .models import ( + DB, + Checksum, + DATA_DIR, + DEFAULT_IMAGE_DIR +) +from .phash import ( + triangles_from_keypoints, + hash_triangles, + TRIANGLE_LOWER, + TRIANGLE_UPPER, +) + + +__version__ = '0.0.1' +DEFAULT_DB_URI = 'sqlite:///{}'.format(os.path.join(DATA_DIR, 'tiis.db')) def phash_triangles(img, triangles, batch_size=None): + """Get phash from triangles. + + >>> filename = 'fullEndToEndDemo/inputImages/cat_original.png' + >>> img = cv2.imread(filename) + >>> keypoints = compute_keypoints(img) + >>> triangles = triangles_from_keypoints(keypoints) + >>> res = phash_triangles(img, triangles) + >>> len(res) + 34770 + >>> sorted(res)[0] + '0000563b8d730d07' + """ n = len(triangles) if batch_size is None: @@ -32,6 +84,104 @@ def phash_triangles(img, triangles, batch_size=None): return results +def get_duplicate( + session, filename=None, csm_m=None, img_dir=DEFAULT_IMAGE_DIR, + triangle_lower=TRIANGLE_LOWER, triangle_upper=TRIANGLE_UPPER): + """Get duplicate data. + >>> import tempfile + >>> from . import main + >>> filename1 = 'fullEndToEndDemo/inputImages/cat_original.png' + >>> filename2 = 'fullEndToEndDemo/inputImages/cat1.png' + >>> filename3 = 'fullEndToEndDemo/inputImages/mona.jpg' + >>> image_fd = tempfile.mkdtemp() + >>> app = main.create_app(db_uri='sqlite://') + >>> app.app_context().push() + >>> DB.create_all() + >>> triangle_lower = 100 + >>> triangle_upper = 300 + >>> # Get duplicate from image filename + >>> get_duplicate( + ... DB.session, filename1, + ... triangle_lower=triangle_lower, triangle_upper=triangle_upper) + [] + >>> # Get duplicate from checksum model + >>> m = DB.session.query(Checksum).filter_by(id=1).first() + >>> get_duplicate( + ... DB.session, csm_m=m, + ... triangle_lower=triangle_lower, triangle_upper=triangle_upper) + [] + >>> len(m.phashes) > 0 + True + >>> get_duplicate( + ... DB.session, filename2, + ... triangle_lower=triangle_lower, triangle_upper=triangle_upper) + [] + >>> get_duplicate(DB.session, csm_m=m, triangle_lower=triangle_lower) + [] + >>> get_duplicate( + ... DB.session, filename3, + ... triangle_lower=triangle_lower, triangle_upper=triangle_upper) + [] + """ + if csm_m is not None and filename is not None: + raise ValueError('Only either checksum model or filename is required') + if csm_m: + m, created = csm_m, False + else: + m, created = models.get_or_create_checksum_model( + session, filename, img_dir=img_dir) + res = [] + if created: + session.add(m) + session.commit() + hash_list = None + if not m.phashes: + if filename: + img = cv2.imread(filename) + else: + img = cv2.imread( + models.get_image_path(m.value, m.ext, img_dir)) + keypoints = compute_keypoints(img) + triangles = triangles_from_keypoints( + keypoints, lower=triangle_lower, upper=triangle_upper) + hash_list = set(phash_triangles(img, triangles)) + hash_list_ms = [] + logging.debug('getting existing phash on db') + for hash_group in tqdm.tqdm( + list(models.grouper(hash_list, 999))): + hash_list_ms.extend( + session.query(models.Phash) + .filter(models.Phash.value.in_(filter(lambda x: x, hash_group))) + .all()) + hash_list_ms_values = [x.value for x in hash_list_ms] + not_in_db_hash_list = \ + [x for x in hash_list if x not in hash_list_ms_values] + if not_in_db_hash_list: + logging.debug('insert phash') + for hash_group in tqdm.tqdm( + list(models.grouper(not_in_db_hash_list, 1000))): + session.add_all( + [models.Phash(value=i) for i in hash_group if i]) + session.flush + session.commit() + logging.debug('getting rest of phash') + for hash_group in tqdm.tqdm( + list(models.grouper(not_in_db_hash_list, 999))): + hash_list_ms.extend( + session.query(models.Phash) \ + .filter(models.Phash.value.in_(filter(lambda x: x, hash_group))) \ + .all()) + m.phashes.extend(hash_list_ms) + session.add(m) + session.commit() + if session.query(Checksum).count() > 1: + res = session.query(Checksum).join(models.Phash.checksums) \ + .distinct(Checksum.id) \ + .filter(models.Phash.checksums.any(Checksum.value == m.value)) \ + .filter(Checksum.id != m.id).all() + return res + + def pipeline(r, data, chunk_size): npartitions = len(data) // chunk_size pipe = r.pipeline() @@ -40,7 +190,7 @@ def pipeline(r, data, chunk_size): yield pipe, chunk -def insert(chunks, filename): +def insert_(chunks, filename): n = 0 for pipe, keys in chunks: @@ -52,7 +202,7 @@ def insert(chunks, filename): print(f'added {n} fragments for {filename}') -def lookup(chunks, filename): +def lookup_(chunks, filename): count = Counter() for pipe, keys in chunks: @@ -68,13 +218,178 @@ def lookup(chunks, filename): print(f'{num:<10d} {key.decode("utf-8")}') -def main(): - if len(sys.argv) < 3: - print(__doc__) - exit(1) +def create_app(script_info=None, db_uri=DEFAULT_DB_URI, image_dir=DEFAULT_IMAGE_DIR): + """create app.""" + app = Flask(__name__) + app.config['SQLALCHEMY_DATABASE_URI'] = db_uri # NOQA + app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False + app.config['SECRET_KEY'] = os.getenv('TIIS_SECRET_KEY') or os.urandom(24) + app.config['WTF_CSRF_ENABLED'] = False + app.config['IMAGE_DIR'] = image_dir + DB.init_app(app) + if not database_exists(db_uri): + create_database(db_uri) + with app.app_context(): + DB.create_all() + + @app.shell_context_processor + def shell_context(): + return {'app': app, 'db': DB, 'models': models, 'session': DB.session} + + # Migrate(app, DB) + # flask-admin + app_admin = Admin( + app, name='Transformation Invariant Image Search', template_mode='bootstrap3', + index_view=AdminIndexView( + name='Home', + template='tiis/index.html', + url='/' + ) + ) + # index_view=views.HomeView(name='Home', template='transformation_invariant_image_search/index.html', url='/')) # NOQA + app.add_url_rule('/api/checksum', 'checksum_list', checksum_list, methods=['GET', 'POST']) + app.add_url_rule('/api/checksum//duplicate', 'checksum_duplicate', checksum_duplicate) + app.add_url_rule('/api/image', 'image_list', image_list, methods=['GET', 'POST']) + app.add_url_rule('/i/', 'image_url', image_url) + return app + + +def image_url(filename): + img_dir = current_app.config.get('IMAGE_DIR') + return send_from_directory( + img_dir, os.path.join(filename[:2], filename)) + + +def checksum_duplicate(cid): + m = DB.session.query(Checksum).filter_by(id=cid).first_or_404() + res = get_duplicate( + DB.session, csm_m=m, triangle_lower=100, triangle_upper=300) + dict_list = [x.to_dict() for x in res] + list(map( + lambda x: x.update({'url': url_for( + '.image_url', _external=True, + filename='{}.{}'.format(x['value'], x['ext']))}), + dict_list + )) + return jsonify(dict_list) + + +def checksum_list(): + if request.method == 'POST': + csm_value = request.form.get('value', None) + if not csm_value: + return jsonify({}) + m = DB.session.query(Checksum).filter_by(value=csm_value).first() + if m is None: + kwargs = dict(value=csm_value) + ext = request.form.get('ext', None) + if ext is not None: + kwargs['ext'] = ext + trash = request.form.get('trash', None) + if trash is not None: + kwargs['trash'] = trash + m = Checksum(**kwargs) + DB.session.add(m) + DB.session.commit() + return jsonify(m.to_dict()) + ms = DB.session.query(Checksum).paginate(1, 10).items + return jsonify([x.to_dict() for x in ms]) + + +def image_list(): + if request.method == 'POST': + # check if the post request has the file part + if 'file' not in request.files: + return jsonify({'error': 'No file part'}) + file_ = request.files['file'] + # if user does not select file, browser also + # submit an empty part without filename + if file_.filename == '': + return jsonify({'error': 'No selected file'}) + with tempfile.NamedTemporaryFile(delete=False) as f: + file_.save(f.name) + pil_img = Image.open(f.name) + sha256 = hashlib.sha256() + with open(f.name, 'rb') as f: + for block in iter(lambda: f.read(128*1024), b''): + sha256.update(block) + sha256_csum = sha256.hexdigest() + image_dir = current_app.config.get('IMAGE_DIR', None) + if image_dir is None: + return jsonify({'error': 'Image dir is not specified'}) + ext = pil_img.format.lower() + dst_file = os.path.join( + image_dir, sha256_csum[:2], '{}.{}'.format(sha256_csum, ext)) + m = models.get_or_create(DB.session, Checksum, value=sha256_csum)[0] + m.ext = ext + m.trash = False + pathlib.Path(os.path.dirname(dst_file)).mkdir(parents=True, exist_ok=True) + shutil.move(f.name, dst_file) + DB.session.add(m) + DB.session.commit() + dict_res = m.to_dict() + dict_res['url'] = url_for( + '.image_url', _external=True, + filename='{}.{}'.format(m.value, m.ext)), + return jsonify(dict_res) + ms = DB.session.query(Checksum).filter_by(trash=False).paginate(1, 10).items + return jsonify([x.to_dict() for x in ms]) + + +def get_custom_version(ctx, param, value): + """Output modified --version flag result. + + Modified from: + https://github.com/pallets/flask/blob/master/flask/cli.py + """ + if not value or ctx.resilient_parsing: + return + import werkzeug + message = ( + '%(app_name)s %(app_version)s\n' + 'Python %(python)s\n' + 'Flask %(flask)s\n' + 'Werkzeug %(werkzeug)s' + ) + click.echo(message % { + 'app_name': 'Transformation Invariant Image Search', + 'app_version': __version__, + 'python': platform.python_version(), + 'flask': flask.__version__, + 'werkzeug': werkzeug.__version__, + }, color=ctx.color) + ctx.exit() + + +class CustomFlaskGroup(FlaskGroup): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.params[0].help = 'Show the program version.' + self.params[0].callback = get_custom_version + + +@click.group(cls=CustomFlaskGroup, create_app=create_app) +def cli(): + """CLI interface for Transformation Invariant Image Search.""" + pass + + +@cli.command() +@click.argument('image', nargs=-1) +def insert(image): + """Insert image's triangle phashes to database.""" + main('insert', image) + + +@cli.command() +@click.argument('image', nargs=-1) +def lookup(image): + """Lookup image's triangle phashes in database.""" + main('lookup', image) + - command, *filenames = sys.argv[1:] - command = insert if command == 'insert' else lookup +def main(command, filenames): + command = insert_ if command == 'insert' else lookup_ r = redis.StrictRedis(host='localhost', port=6379, db=0) try: @@ -97,4 +412,4 @@ def main(): if __name__ == '__main__': - main() + cli() diff --git a/transformation_invariant_image_search/models.py b/transformation_invariant_image_search/models.py new file mode 100644 index 0000000..0fe58d9 --- /dev/null +++ b/transformation_invariant_image_search/models.py @@ -0,0 +1,124 @@ +from itertools import zip_longest +import hashlib +import os +import pathlib +import shutil + +from appdirs import user_data_dir +from flask import Flask +from flask_sqlalchemy import SQLAlchemy +from PIL import Image +import cv2 +import tqdm + +from .keypoints import compute_keypoints +from .phash import triangles_from_keypoints, hash_triangles, TRIANGLE_LOWER, TRIANGLE_UPPER + + +DB = SQLAlchemy() +DATA_DIR = user_data_dir('transformation_invariant_image_search', 'Tom Murphy') +pathlib.Path(DATA_DIR).mkdir(parents=True, exist_ok=True) +DEFAULT_IMAGE_DIR = os.path.join(DATA_DIR, 'image') + +checksum_phashes = DB.Table( + 'checksum_phashes', + DB.Column('checksum_id', DB.Integer, DB.ForeignKey('checksum.id'), primary_key=True), + DB.Column('phash_id', DB.Integer, DB.ForeignKey('phash.id'), primary_key=True)) + + +class Base(DB.Model): + __abstract__ = True + id = DB.Column(DB.Integer, primary_key=True) + + +class Checksum(Base): + value = DB.Column(DB.String(), unique=True, nullable=False) + trash = DB.Column(DB.Boolean(), default=False) + ext = DB.Column(DB.String(), nullable=False) + phashes = DB.relationship('Phash', secondary=checksum_phashes, lazy='subquery', + backref=DB.backref('checksums', lazy=True)) + + + def __repr__(self): + templ = '' + return templ.format(self, self.value[:7]) + + def to_dict(self): + keys = ['value', 'trash', 'ext', 'id'] + return {k: getattr(self, k) for k in keys} + + +class Phash(Base): + value = DB.Column(DB.String(), unique=True, nullable=False) + + def __repr__(self): + templ = '' + return templ.format(self) + + +def get_or_create(session, model, **kwargs): + """Creates an object or returns the object if exists.""" + instance = session.query(model).filter_by(**kwargs).first() + created = False + if not instance: + instance = model(**kwargs) + session.add(instance) + created = True + return instance, created + + +def get_image_path(checksum_value, ext, img_dir=DEFAULT_IMAGE_DIR): + """Get image path. + >>> import tempfile + >>> image_fd = tempfile.mkdtemp() + >>> get_image_path( + ... '54abb6e1eb59cccf61ae356aff7e491894c5ca606dfda4240d86743424c65faf', + ... 'png', image_fd) + '.../54/54abb6e1eb59cccf61ae356aff7e491894c5ca606dfda4240d86743424c65faf.png' + """ + return os.path.join(img_dir, checksum_value[:2], '{}.{}'.format(checksum_value, ext)) + + +def get_or_create_checksum_model(session, filename, img_dir=DEFAULT_IMAGE_DIR): + """Get or create checksum model. + >>> import tempfile + >>> from . import main + >>> filename = 'fullEndToEndDemo/inputImages/cat_original.png' + >>> image_fd = tempfile.mkdtemp() + >>> app = main.create_app(db_uri='sqlite://') + >>> app.app_context().push() + >>> DB.create_all() + >>> _ = Checksum.query.delete() + >>> get_or_create_checksum_model(DB.session, filename, image_fd) + (, True) + >>> res = get_or_create_checksum_model(DB.session, filename, image_fd) + >>> res + (, False) + >>> m = res[0] + >>> os.path.isfile(get_image_path(m.value, m.ext, image_fd)) + True + """ + pil_img = Image.open(filename) + sha256 = hashlib.sha256() + with open(filename, 'rb') as f: + for block in iter(lambda: f.read(128*1024), b''): + sha256.update(block) + sha256_csum = sha256.hexdigest() + m, created = get_or_create(session, Checksum, value=sha256_csum) + m.ext = pil_img.format.lower() + m.trash = False + dst_file = get_image_path(m.value, m.ext, img_dir) + pathlib.Path(os.path.dirname(dst_file)).mkdir(parents=True, exist_ok=True) + shutil.copy(filename, dst_file) + return m, created + + +def grouper(iterable, n, fillvalue=None): + """Collect data into fixed-length chunks or blocks. + taken from: + https://docs.python.org/3/library/itertools.html#itertools.zip_longest + >>> list(grouper('ABCDEFG', 3, 'x')) + [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')] + """ + args = [iter(iterable)] * n + return zip_longest(*args, fillvalue=fillvalue) diff --git a/transformation_invariant_image_search/phash.py b/transformation_invariant_image_search/phash.py index fd63bb6..811854e 100644 --- a/transformation_invariant_image_search/phash.py +++ b/transformation_invariant_image_search/phash.py @@ -6,6 +6,8 @@ HEX_STRINGS = np.array([f'{x:02x}' for x in range(256)]) BIN_POWERS = 2 ** np.arange(8) +TRIANGLE_LOWER = 50 +TRIANGLE_UPPER = 400 def phash(image, hash_size=8, highfreq_factor=4): @@ -26,6 +28,19 @@ def hash_to_hex(a): def hash_triangles(img, triangles): + """Get hash triangles. + >>> from .keypoints import compute_keypoints + >>> filename = 'fullEndToEndDemo/inputImages/cat_original.png' + >>> img = cv2.imread(filename) + >>> keypoints = compute_keypoints(img) + >>> triangles = triangles_from_keypoints(keypoints) + >>> res = hash_triangles(img, triangles) + >>> len(res), sorted(res)[0] + (34770, '0000563b8d730d07') + >>> res = hash_triangles(img, [triangles[0]]) + >>> len(res), sorted(res) + (3, ['709a3765dd04b0f3', 'b8dd5c4e7a352cea', 'de433036010bb391']) + """ n = len(triangles) triangles = np.asarray(triangles) @@ -51,7 +66,7 @@ def hash_triangles(img, triangles): # rotate triangles 3 times, one for each edge of the triangle rotations = (0, 1, 2), (1, 2, 0), (2, 0, 1) - for i, rotation in enumerate(tqdm.tqdm(rotations)): + for i, rotation in enumerate(rotations): p = triangles[:, rotation, :] p0 = p[:, 0] @@ -71,7 +86,8 @@ def hash_triangles(img, triangles): transform = target_points @ input_points_inverse @ transpose_m transform = transform[:, :2, :] - for k in tqdm.tqdm(range(n)): + range_list = tqdm.tqdm(range(n)) if len(range(n)) > 1 else range(n) + for k in range_list: image = cv2.warpAffine(img, transform[k], size) # calculate dct for perceptual hash @@ -89,7 +105,25 @@ def hash_triangles(img, triangles): return hash_to_hex(hashes) -def triangles_from_keypoints(keypoints, lower=50, upper=400): +def triangles_from_keypoints(keypoints, lower=TRIANGLE_LOWER, upper=TRIANGLE_UPPER): + """Get Triangles from keypoints. + + >>> from .keypoints import compute_keypoints + >>> filename = 'fullEndToEndDemo/inputImages/cat_original.png' + >>> img = cv2.imread(filename) + >>> keypoints = compute_keypoints(img) + >>> res = triangles_from_keypoints(keypoints) + >>> len(res) + 11590 + >>> print(list(map(lambda x: x.tolist(), res[0]))) + [[162.0, 203.0], [261.0, 76.0], [131.0, 63.0]] + >>> res2 = triangles_from_keypoints(keypoints, lower=10) + >>> len(res2) + 14238 + >>> res3 = triangles_from_keypoints(keypoints, upper=100) + >>> len(res3) + 315 + """ keypoints = np.asarray(keypoints, dtype=float) tree = BallTree(keypoints, leaf_size=10) diff --git a/transformation_invariant_image_search/templates/tiis/index.html b/transformation_invariant_image_search/templates/tiis/index.html new file mode 100644 index 0000000..e2eac28 --- /dev/null +++ b/transformation_invariant_image_search/templates/tiis/index.html @@ -0,0 +1,79 @@ +{% extends 'admin/master.html' %} + +{% block head %} +{{ super() }} + +{% endblock %} + +{% block body %} +
+
+
+ + +
+
+ +
+ +
+
+
+
+
+ +
+ +{% endblock %}