diff --git a/.travis.yml b/.travis.yml
new file mode 100644
index 0000000..980524a
--- /dev/null
+++ b/.travis.yml
@@ -0,0 +1,9 @@
+language: python
+python:
+ - "3.6"
+# command to install dependencies
+install:
+ - python setup.py develop
+# command to run tests
+script:
+ - pytest --doctest-module
diff --git a/README.md b/README.md
index b1d036c..7c471bd 100644
--- a/README.md
+++ b/README.md
@@ -29,25 +29,23 @@ Instead of running these commands manually you can run the ./setup.sh script whi
Or if you want to run the commands manually...
-```
-# From the root of the repo go to ./fullEndToEndDemo
-cd ./fullEndToEndDemo
-
-# Grab all the dependencies, this install is pretty huge
-sudo apt-get update
-sudo apt-get install git cmake g++ redis-server libboost-all-dev libopencv-dev python-opencv python-numpy python-scipy -y
-
-#Make it
-cmake .
-make
-
-# This step is optional. It removes a pointless annoying error opencv spits out
-# About: https://stackoverflow.com/questions/12689304/ctypes-error-libdc1394-error-failed-to-initialize-libdc1394
-sudo ln /dev/null /dev/raw1394
-
-# Then run either ./runDemo1.sh or ./runDemo2.sh to run the demo
-
-
+```console
+$ # From the root of the repo go to ./fullEndToEndDemo
+$ cd ./fullEndToEndDemo
+$
+$ # Grab all the dependencies, this install is pretty huge
+$ sudo apt-get update
+$ sudo apt-get install git cmake g++ redis-server libboost-all-dev libopencv-dev python-opencv python-numpy python-scipy -y
+$
+$ #Make it
+$ cmake .
+$ make
+$
+$ # This step is optional. It removes a pointless annoying error opencv spits out
+$ # About: https://stackoverflow.com/questions/12689304/ctypes-error-libdc1394-error-failed-to-initialize-libdc1394
+$ sudo ln /dev/null /dev/raw1394
+$
+$ # Then run either ./runDemo1.sh or ./runDemo2.sh to run the demo
```
# Python setup
@@ -58,19 +56,19 @@ This setup was tested on a newly deployed vm on Ubuntu 18.04 LTS, YMMV on differ
To use python package, do the following:
-```
-sudo apt-get update
-sudo apt-get install python3-pip python3-opencv redis-server -y
-
-# On some systems this path is missing
-# read more here: https://github.com/pypa/pip/issues/3813
-PATH="$PATH:~/.local/bin"
-
-#cd to project directory
-pip3 install .
+```console
+$ sudo apt-get update
+$ sudo apt-get install python3-pip python3-opencv redis-server -y
+$
+$ # On some systems this path is missing
+$ # read more here: https://github.com/pypa/pip/issues/3813
+$ PATH="$PATH:~/.local/bin"
+$
+$ # cd to project directory
+$ pip3 install .
```
-You also need install redis.
+if you got permission error, install it under virtual env or use `--user` flag.
# Demo 1
@@ -94,8 +92,8 @@ The demo takes 2 minutes (1 minute 38 seconds*) to run on a quad core VM
*Thanks to [meowcoder](https://github.com/meowcoder) for the speed up!
-```
-user@instance-1:~/transformationInvariantImageSearch/fullEndToEndDemo$ time ./runDemo1.sh
+```console
+$ time ./fullEndToEndDemo/runDemo1.sh
Loading image: inputImages/cat1.png ... done
Added 46725 image fragments to DB
Loading image: inputImages/cat2.png ... done
@@ -135,8 +133,8 @@ sys 0m6.592s
python example
```console
-$ time transformation-invariant-image-search insert fullEndToEndDemo/inputImages/cat* && \
- time transformation-invariant-image-search lookup fullEndToEndDemo/inputImages/cat_original.png
+$ time transformation-invariant-image-search insert fullEndToEndDemo/inputImages/cat* && \
+$ time transformation-invariant-image-search lookup fullEndToEndDemo/inputImages/cat_original.png
loading fullEndToEndDemo/inputImages/cat1.png
100%|██| 3/3 [00:07<00:00, 2.66s/it]
@@ -219,8 +217,8 @@ Here the two images mona.jpg and van_gogh.jpg are inserted into the database and
*Thanks to [meowcoder](https://github.com/meowcoder) for the speed up!
-```
-user@instance-1:~/transformationInvariantImageSearch/fullEndToEndDemo$ time ./runDemo2.sh
+```console
+$ time ./fullEndToEndDemo/runDemo2.sh
Loading image: ./inputImages/mona.jpg ... done
Added 26991 image fragments to DB
Loading image: ./inputImages/van_gogh.jpg ... done
@@ -239,8 +237,9 @@ sys 0m18.224s
python example
```console
-$ time transformation-invariant-image-search insert ./fullEndToEndDemo/inputImages/mona.jpg ./fullEndToEndDemo/inputImages/van_gogh.jpg && \
- time transformation-invariant-image-search lookup ./fullEndToEndDemo/inputImages/monaComposite.jpg
+$ time transformation-invariant-image-search insert \
+$ ./fullEndToEndDemo/inputImages/mona.jpg ./fullEndToEndDemo/inputImages/van_gogh.jpg && \
+$ time transformation-invariant-image-search lookup ./fullEndToEndDemo/inputImages/monaComposite.jpg
loading ./fullEndToEndDemo/inputImages/mona.jpg
100%|███| 3/3 [00:03<00:00, 1.24s/it]
diff --git a/setup.py b/setup.py
index 3d56dba..2330d6d 100644
--- a/setup.py
+++ b/setup.py
@@ -1,10 +1,6 @@
#!/usr/bin/env python
from setuptools import setup, find_packages
-"""
-TODO
-- copy or link `python` folder to `transformation_invariant_image_search`
-"""
def readme():
with open('README.md') as f:
@@ -29,16 +25,29 @@ def readme():
zip_safe=False,
python_requires='>=3.6',
install_requires=[
+ 'appdirs>=1.4.3',
+ 'Flask-Admin==1.5.3',
+ 'Flask-SQLAlchemy>=2.3.2',
+ 'Flask>=1.0.2',
'hiredis',
'numpy',
+ 'opencv-python>=4.0.0.21',
+ 'Pillow>=5.4.1',
'redis',
'scikit-learn',
'scipy',
+ 'SQLAlchemy-Utils>=0.33.11',
'tqdm>=4.29.1',
],
+ extras_require={
+ 'dev': [
+ 'docutils==0.14',
+ 'pytest==4.2.0',
+ ],
+ },
entry_points={
'console_scripts': [
- 'transformation-invariant-image-search = transformation_invariant_image_search.main:main']
+ 'transformation-invariant-image-search = transformation_invariant_image_search.main:cli']
},
classifiers=[
'Development Status :: 3 - Alpha',
diff --git a/tests/test_main.py b/tests/test_main.py
new file mode 100644
index 0000000..b7980d3
--- /dev/null
+++ b/tests/test_main.py
@@ -0,0 +1,98 @@
+import json
+import os
+import shutil
+import tempfile
+
+from click.testing import CliRunner
+from flask import current_app
+import click
+import pytest
+
+from transformation_invariant_image_search import main
+
+
+@pytest.fixture
+def client():
+ db_fd, config_db = tempfile.mkstemp()
+ image_fd = tempfile.mkdtemp()
+ db_uri = 'sqlite:///{}'.format(config_db)
+ app = main.create_app(db_uri=db_uri, image_dir=image_fd)
+ app.config['DATABASE'] = config_db
+ app.config['TESTING'] = True
+ client = app.test_client()
+
+ yield client
+
+ os.close(db_fd)
+ os.unlink(app.config['DATABASE'])
+ shutil.rmtree(image_fd)
+
+
+def test_empty_db(client):
+ """Start with a blank database."""
+ rv = client.get('/')
+ assert b'Home - Transformation Invariant Image Search' in rv.data
+
+
+def test_checksum_get(client):
+ """test checksum with a blank database."""
+ url = '/api/checksum'
+ rv = client.get(url)
+ assert rv.get_json() == []
+
+
+def test_checksum_post(client):
+ """Start with a blank database."""
+ csm_value = '54abb6e1eb59cccf61ae356aff7e491894c5ca606dfda4240d86743424c65faf'
+ url = '/api/checksum'
+ exp_dict = dict(value=csm_value, id=1, ext='png', trash=False)
+ rv = client.post(url, data=dict(value=csm_value, ext='png'))
+ assert rv.get_json() == exp_dict
+ rv = client.get(url)
+ assert rv.get_json() == [exp_dict]
+
+
+def test_image_post(client):
+ url = '/api/image'
+ filename = 'fullEndToEndDemo/inputImages/cat_original.png'
+ csm_value = '54abb6e1eb59cccf61ae356aff7e491894c5ca606dfda4240d86743424c65faf'
+ ext = 'png'
+ exp_dict = dict(id=1, value=csm_value, ext=ext, trash=False)
+ rv = client.post(url)
+ assert rv.get_json()['error']
+ file_data = {'file': open(filename, 'rb')}
+ rv = client.post(url, data=file_data)
+ post_exp_dict = exp_dict.copy()
+ post_exp_dict['url'] = ['http://localhost/i/{}.{}'.format(csm_value, ext)]
+ assert rv.get_json() == post_exp_dict
+ image_dir = client.application.config.get('IMAGE_DIR')
+ exp_dst_file = os.path.join(image_dir, csm_value[:2], '{}.{}'.format(csm_value, ext))
+ assert os.path.isfile(exp_dst_file)
+ rv = client.get(url)
+ assert rv.get_json() == [exp_dict]
+
+
+def test_upload_api(client):
+ filename = 'fullEndToEndDemo/inputImages/cat1.png'
+ upload_url = '/api/image'
+ rv = client.post(upload_url, data={'file': open(filename, 'rb')})
+ assert rv.get_json() == {
+ 'ext': 'png', 'id': 1, 'trash': False,
+ 'url': ['http://localhost/i/4aba099f752d609aad2ed4c28f972ae96d02ad2579d0dd3f16b1ac29a88caf6d.png'],
+ 'value': '4aba099f752d609aad2ed4c28f972ae96d02ad2579d0dd3f16b1ac29a88caf6d'
+ }
+
+
+@pytest.mark.parametrize(
+ 'args,word',
+ [
+ ('--help', 'Usage:'),
+ ('--version', 'Transformation Invariant Image Search')
+ ]
+)
+def test_cli(args, word):
+ runner = CliRunner()
+ result = runner.invoke(main.cli, [args])
+ assert result.exit_code == 0
+ if word is not None:
+ assert word in result.output
diff --git a/tests/test_models.py b/tests/test_models.py
new file mode 100644
index 0000000..2125daf
--- /dev/null
+++ b/tests/test_models.py
@@ -0,0 +1,13 @@
+def test_checksum():
+ from transformation_invariant_image_search import models, main
+ app = main.create_app(db_uri='sqlite://')
+ csm_value = '54abb6e1eb59cccf61ae356aff7e491894c5ca606dfda4240d86743424c65faf'
+ with app.app_context():
+ models.DB.create_all()
+ m = models.Checksum(value=csm_value, ext='png')
+ models.DB.session.add(m)
+ models.DB.session.commit()
+ assert m.id == 1
+
+ res = models.DB.session.query(models.Checksum).filter_by(id=1).first()
+ assert res.value == csm_value
diff --git a/transformation_invariant_image_search/keypoints.py b/transformation_invariant_image_search/keypoints.py
index c46716a..6407b7e 100644
--- a/transformation_invariant_image_search/keypoints.py
+++ b/transformation_invariant_image_search/keypoints.py
@@ -56,6 +56,16 @@ def recolour(img, gauss_width=41):
def compute_keypoints(img):
+ """Compute keypoints.
+
+ >>> filename = 'fullEndToEndDemo/inputImages/cat_original.png'
+ >>> img = cv2.imread(filename)
+ >>> res = compute_keypoints(img)
+ >>> len(res) == 50
+ True
+ >>> sorted(res)[0]
+ (1.0, 26.0)
+ """
gauss_width = 21
img = recolour(img, gauss_width)
b, _, _ = cv2.split(img)
diff --git a/transformation_invariant_image_search/main.py b/transformation_invariant_image_search/main.py
index 84d2097..e5b8b00 100644
--- a/transformation_invariant_image_search/main.py
+++ b/transformation_invariant_image_search/main.py
@@ -2,20 +2,72 @@
Usage: main.py lookup ...
main.py insert ...
"""
-import sys
-import multiprocessing
from collections import Counter
from os import cpu_count
+import hashlib
+import multiprocessing
+import os
+import platform
+import shutil
+import sys
+import tempfile
+import pathlib
+import logging
+from appdirs import user_data_dir
+from flask.cli import FlaskGroup
+from flask_admin import Admin, AdminIndexView
+from flask_sqlalchemy import SQLAlchemy
+from PIL import Image
+from sqlalchemy_utils import database_exists, create_database
+import click
import cv2
-import redis
+import flask
import numpy as np
+import redis
+import tqdm
+from flask import (
+ current_app,
+ Flask,
+ jsonify,
+ request,
+ send_from_directory,
+ url_for,
+)
+from . import models
from .keypoints import compute_keypoints
-from .phash import triangles_from_keypoints, hash_triangles
+from .models import (
+ DB,
+ Checksum,
+ DATA_DIR,
+ DEFAULT_IMAGE_DIR
+)
+from .phash import (
+ triangles_from_keypoints,
+ hash_triangles,
+ TRIANGLE_LOWER,
+ TRIANGLE_UPPER,
+)
+
+
+__version__ = '0.0.1'
+DEFAULT_DB_URI = 'sqlite:///{}'.format(os.path.join(DATA_DIR, 'tiis.db'))
def phash_triangles(img, triangles, batch_size=None):
+ """Get phash from triangles.
+
+ >>> filename = 'fullEndToEndDemo/inputImages/cat_original.png'
+ >>> img = cv2.imread(filename)
+ >>> keypoints = compute_keypoints(img)
+ >>> triangles = triangles_from_keypoints(keypoints)
+ >>> res = phash_triangles(img, triangles)
+ >>> len(res)
+ 34770
+ >>> sorted(res)[0]
+ '0000563b8d730d07'
+ """
n = len(triangles)
if batch_size is None:
@@ -32,6 +84,104 @@ def phash_triangles(img, triangles, batch_size=None):
return results
+def get_duplicate(
+ session, filename=None, csm_m=None, img_dir=DEFAULT_IMAGE_DIR,
+ triangle_lower=TRIANGLE_LOWER, triangle_upper=TRIANGLE_UPPER):
+ """Get duplicate data.
+ >>> import tempfile
+ >>> from . import main
+ >>> filename1 = 'fullEndToEndDemo/inputImages/cat_original.png'
+ >>> filename2 = 'fullEndToEndDemo/inputImages/cat1.png'
+ >>> filename3 = 'fullEndToEndDemo/inputImages/mona.jpg'
+ >>> image_fd = tempfile.mkdtemp()
+ >>> app = main.create_app(db_uri='sqlite://')
+ >>> app.app_context().push()
+ >>> DB.create_all()
+ >>> triangle_lower = 100
+ >>> triangle_upper = 300
+ >>> # Get duplicate from image filename
+ >>> get_duplicate(
+ ... DB.session, filename1,
+ ... triangle_lower=triangle_lower, triangle_upper=triangle_upper)
+ []
+ >>> # Get duplicate from checksum model
+ >>> m = DB.session.query(Checksum).filter_by(id=1).first()
+ >>> get_duplicate(
+ ... DB.session, csm_m=m,
+ ... triangle_lower=triangle_lower, triangle_upper=triangle_upper)
+ []
+ >>> len(m.phashes) > 0
+ True
+ >>> get_duplicate(
+ ... DB.session, filename2,
+ ... triangle_lower=triangle_lower, triangle_upper=triangle_upper)
+ []
+ >>> get_duplicate(DB.session, csm_m=m, triangle_lower=triangle_lower)
+ []
+ >>> get_duplicate(
+ ... DB.session, filename3,
+ ... triangle_lower=triangle_lower, triangle_upper=triangle_upper)
+ []
+ """
+ if csm_m is not None and filename is not None:
+ raise ValueError('Only either checksum model or filename is required')
+ if csm_m:
+ m, created = csm_m, False
+ else:
+ m, created = models.get_or_create_checksum_model(
+ session, filename, img_dir=img_dir)
+ res = []
+ if created:
+ session.add(m)
+ session.commit()
+ hash_list = None
+ if not m.phashes:
+ if filename:
+ img = cv2.imread(filename)
+ else:
+ img = cv2.imread(
+ models.get_image_path(m.value, m.ext, img_dir))
+ keypoints = compute_keypoints(img)
+ triangles = triangles_from_keypoints(
+ keypoints, lower=triangle_lower, upper=triangle_upper)
+ hash_list = set(phash_triangles(img, triangles))
+ hash_list_ms = []
+ logging.debug('getting existing phash on db')
+ for hash_group in tqdm.tqdm(
+ list(models.grouper(hash_list, 999))):
+ hash_list_ms.extend(
+ session.query(models.Phash)
+ .filter(models.Phash.value.in_(filter(lambda x: x, hash_group)))
+ .all())
+ hash_list_ms_values = [x.value for x in hash_list_ms]
+ not_in_db_hash_list = \
+ [x for x in hash_list if x not in hash_list_ms_values]
+ if not_in_db_hash_list:
+ logging.debug('insert phash')
+ for hash_group in tqdm.tqdm(
+ list(models.grouper(not_in_db_hash_list, 1000))):
+ session.add_all(
+ [models.Phash(value=i) for i in hash_group if i])
+ session.flush
+ session.commit()
+ logging.debug('getting rest of phash')
+ for hash_group in tqdm.tqdm(
+ list(models.grouper(not_in_db_hash_list, 999))):
+ hash_list_ms.extend(
+ session.query(models.Phash) \
+ .filter(models.Phash.value.in_(filter(lambda x: x, hash_group))) \
+ .all())
+ m.phashes.extend(hash_list_ms)
+ session.add(m)
+ session.commit()
+ if session.query(Checksum).count() > 1:
+ res = session.query(Checksum).join(models.Phash.checksums) \
+ .distinct(Checksum.id) \
+ .filter(models.Phash.checksums.any(Checksum.value == m.value)) \
+ .filter(Checksum.id != m.id).all()
+ return res
+
+
def pipeline(r, data, chunk_size):
npartitions = len(data) // chunk_size
pipe = r.pipeline()
@@ -40,7 +190,7 @@ def pipeline(r, data, chunk_size):
yield pipe, chunk
-def insert(chunks, filename):
+def insert_(chunks, filename):
n = 0
for pipe, keys in chunks:
@@ -52,7 +202,7 @@ def insert(chunks, filename):
print(f'added {n} fragments for {filename}')
-def lookup(chunks, filename):
+def lookup_(chunks, filename):
count = Counter()
for pipe, keys in chunks:
@@ -68,13 +218,178 @@ def lookup(chunks, filename):
print(f'{num:<10d} {key.decode("utf-8")}')
-def main():
- if len(sys.argv) < 3:
- print(__doc__)
- exit(1)
+def create_app(script_info=None, db_uri=DEFAULT_DB_URI, image_dir=DEFAULT_IMAGE_DIR):
+ """create app."""
+ app = Flask(__name__)
+ app.config['SQLALCHEMY_DATABASE_URI'] = db_uri # NOQA
+ app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
+ app.config['SECRET_KEY'] = os.getenv('TIIS_SECRET_KEY') or os.urandom(24)
+ app.config['WTF_CSRF_ENABLED'] = False
+ app.config['IMAGE_DIR'] = image_dir
+ DB.init_app(app)
+ if not database_exists(db_uri):
+ create_database(db_uri)
+ with app.app_context():
+ DB.create_all()
+
+ @app.shell_context_processor
+ def shell_context():
+ return {'app': app, 'db': DB, 'models': models, 'session': DB.session}
+
+ # Migrate(app, DB)
+ # flask-admin
+ app_admin = Admin(
+ app, name='Transformation Invariant Image Search', template_mode='bootstrap3',
+ index_view=AdminIndexView(
+ name='Home',
+ template='tiis/index.html',
+ url='/'
+ )
+ )
+ # index_view=views.HomeView(name='Home', template='transformation_invariant_image_search/index.html', url='/')) # NOQA
+ app.add_url_rule('/api/checksum', 'checksum_list', checksum_list, methods=['GET', 'POST'])
+ app.add_url_rule('/api/checksum//duplicate', 'checksum_duplicate', checksum_duplicate)
+ app.add_url_rule('/api/image', 'image_list', image_list, methods=['GET', 'POST'])
+ app.add_url_rule('/i/', 'image_url', image_url)
+ return app
+
+
+def image_url(filename):
+ img_dir = current_app.config.get('IMAGE_DIR')
+ return send_from_directory(
+ img_dir, os.path.join(filename[:2], filename))
+
+
+def checksum_duplicate(cid):
+ m = DB.session.query(Checksum).filter_by(id=cid).first_or_404()
+ res = get_duplicate(
+ DB.session, csm_m=m, triangle_lower=100, triangle_upper=300)
+ dict_list = [x.to_dict() for x in res]
+ list(map(
+ lambda x: x.update({'url': url_for(
+ '.image_url', _external=True,
+ filename='{}.{}'.format(x['value'], x['ext']))}),
+ dict_list
+ ))
+ return jsonify(dict_list)
+
+
+def checksum_list():
+ if request.method == 'POST':
+ csm_value = request.form.get('value', None)
+ if not csm_value:
+ return jsonify({})
+ m = DB.session.query(Checksum).filter_by(value=csm_value).first()
+ if m is None:
+ kwargs = dict(value=csm_value)
+ ext = request.form.get('ext', None)
+ if ext is not None:
+ kwargs['ext'] = ext
+ trash = request.form.get('trash', None)
+ if trash is not None:
+ kwargs['trash'] = trash
+ m = Checksum(**kwargs)
+ DB.session.add(m)
+ DB.session.commit()
+ return jsonify(m.to_dict())
+ ms = DB.session.query(Checksum).paginate(1, 10).items
+ return jsonify([x.to_dict() for x in ms])
+
+
+def image_list():
+ if request.method == 'POST':
+ # check if the post request has the file part
+ if 'file' not in request.files:
+ return jsonify({'error': 'No file part'})
+ file_ = request.files['file']
+ # if user does not select file, browser also
+ # submit an empty part without filename
+ if file_.filename == '':
+ return jsonify({'error': 'No selected file'})
+ with tempfile.NamedTemporaryFile(delete=False) as f:
+ file_.save(f.name)
+ pil_img = Image.open(f.name)
+ sha256 = hashlib.sha256()
+ with open(f.name, 'rb') as f:
+ for block in iter(lambda: f.read(128*1024), b''):
+ sha256.update(block)
+ sha256_csum = sha256.hexdigest()
+ image_dir = current_app.config.get('IMAGE_DIR', None)
+ if image_dir is None:
+ return jsonify({'error': 'Image dir is not specified'})
+ ext = pil_img.format.lower()
+ dst_file = os.path.join(
+ image_dir, sha256_csum[:2], '{}.{}'.format(sha256_csum, ext))
+ m = models.get_or_create(DB.session, Checksum, value=sha256_csum)[0]
+ m.ext = ext
+ m.trash = False
+ pathlib.Path(os.path.dirname(dst_file)).mkdir(parents=True, exist_ok=True)
+ shutil.move(f.name, dst_file)
+ DB.session.add(m)
+ DB.session.commit()
+ dict_res = m.to_dict()
+ dict_res['url'] = url_for(
+ '.image_url', _external=True,
+ filename='{}.{}'.format(m.value, m.ext)),
+ return jsonify(dict_res)
+ ms = DB.session.query(Checksum).filter_by(trash=False).paginate(1, 10).items
+ return jsonify([x.to_dict() for x in ms])
+
+
+def get_custom_version(ctx, param, value):
+ """Output modified --version flag result.
+
+ Modified from:
+ https://github.com/pallets/flask/blob/master/flask/cli.py
+ """
+ if not value or ctx.resilient_parsing:
+ return
+ import werkzeug
+ message = (
+ '%(app_name)s %(app_version)s\n'
+ 'Python %(python)s\n'
+ 'Flask %(flask)s\n'
+ 'Werkzeug %(werkzeug)s'
+ )
+ click.echo(message % {
+ 'app_name': 'Transformation Invariant Image Search',
+ 'app_version': __version__,
+ 'python': platform.python_version(),
+ 'flask': flask.__version__,
+ 'werkzeug': werkzeug.__version__,
+ }, color=ctx.color)
+ ctx.exit()
+
+
+class CustomFlaskGroup(FlaskGroup):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.params[0].help = 'Show the program version.'
+ self.params[0].callback = get_custom_version
+
+
+@click.group(cls=CustomFlaskGroup, create_app=create_app)
+def cli():
+ """CLI interface for Transformation Invariant Image Search."""
+ pass
+
+
+@cli.command()
+@click.argument('image', nargs=-1)
+def insert(image):
+ """Insert image's triangle phashes to database."""
+ main('insert', image)
+
+
+@cli.command()
+@click.argument('image', nargs=-1)
+def lookup(image):
+ """Lookup image's triangle phashes in database."""
+ main('lookup', image)
+
- command, *filenames = sys.argv[1:]
- command = insert if command == 'insert' else lookup
+def main(command, filenames):
+ command = insert_ if command == 'insert' else lookup_
r = redis.StrictRedis(host='localhost', port=6379, db=0)
try:
@@ -97,4 +412,4 @@ def main():
if __name__ == '__main__':
- main()
+ cli()
diff --git a/transformation_invariant_image_search/models.py b/transformation_invariant_image_search/models.py
new file mode 100644
index 0000000..0fe58d9
--- /dev/null
+++ b/transformation_invariant_image_search/models.py
@@ -0,0 +1,124 @@
+from itertools import zip_longest
+import hashlib
+import os
+import pathlib
+import shutil
+
+from appdirs import user_data_dir
+from flask import Flask
+from flask_sqlalchemy import SQLAlchemy
+from PIL import Image
+import cv2
+import tqdm
+
+from .keypoints import compute_keypoints
+from .phash import triangles_from_keypoints, hash_triangles, TRIANGLE_LOWER, TRIANGLE_UPPER
+
+
+DB = SQLAlchemy()
+DATA_DIR = user_data_dir('transformation_invariant_image_search', 'Tom Murphy')
+pathlib.Path(DATA_DIR).mkdir(parents=True, exist_ok=True)
+DEFAULT_IMAGE_DIR = os.path.join(DATA_DIR, 'image')
+
+checksum_phashes = DB.Table(
+ 'checksum_phashes',
+ DB.Column('checksum_id', DB.Integer, DB.ForeignKey('checksum.id'), primary_key=True),
+ DB.Column('phash_id', DB.Integer, DB.ForeignKey('phash.id'), primary_key=True))
+
+
+class Base(DB.Model):
+ __abstract__ = True
+ id = DB.Column(DB.Integer, primary_key=True)
+
+
+class Checksum(Base):
+ value = DB.Column(DB.String(), unique=True, nullable=False)
+ trash = DB.Column(DB.Boolean(), default=False)
+ ext = DB.Column(DB.String(), nullable=False)
+ phashes = DB.relationship('Phash', secondary=checksum_phashes, lazy='subquery',
+ backref=DB.backref('checksums', lazy=True))
+
+
+ def __repr__(self):
+ templ = ''
+ return templ.format(self, self.value[:7])
+
+ def to_dict(self):
+ keys = ['value', 'trash', 'ext', 'id']
+ return {k: getattr(self, k) for k in keys}
+
+
+class Phash(Base):
+ value = DB.Column(DB.String(), unique=True, nullable=False)
+
+ def __repr__(self):
+ templ = ''
+ return templ.format(self)
+
+
+def get_or_create(session, model, **kwargs):
+ """Creates an object or returns the object if exists."""
+ instance = session.query(model).filter_by(**kwargs).first()
+ created = False
+ if not instance:
+ instance = model(**kwargs)
+ session.add(instance)
+ created = True
+ return instance, created
+
+
+def get_image_path(checksum_value, ext, img_dir=DEFAULT_IMAGE_DIR):
+ """Get image path.
+ >>> import tempfile
+ >>> image_fd = tempfile.mkdtemp()
+ >>> get_image_path(
+ ... '54abb6e1eb59cccf61ae356aff7e491894c5ca606dfda4240d86743424c65faf',
+ ... 'png', image_fd)
+ '.../54/54abb6e1eb59cccf61ae356aff7e491894c5ca606dfda4240d86743424c65faf.png'
+ """
+ return os.path.join(img_dir, checksum_value[:2], '{}.{}'.format(checksum_value, ext))
+
+
+def get_or_create_checksum_model(session, filename, img_dir=DEFAULT_IMAGE_DIR):
+ """Get or create checksum model.
+ >>> import tempfile
+ >>> from . import main
+ >>> filename = 'fullEndToEndDemo/inputImages/cat_original.png'
+ >>> image_fd = tempfile.mkdtemp()
+ >>> app = main.create_app(db_uri='sqlite://')
+ >>> app.app_context().push()
+ >>> DB.create_all()
+ >>> _ = Checksum.query.delete()
+ >>> get_or_create_checksum_model(DB.session, filename, image_fd)
+ (, True)
+ >>> res = get_or_create_checksum_model(DB.session, filename, image_fd)
+ >>> res
+ (, False)
+ >>> m = res[0]
+ >>> os.path.isfile(get_image_path(m.value, m.ext, image_fd))
+ True
+ """
+ pil_img = Image.open(filename)
+ sha256 = hashlib.sha256()
+ with open(filename, 'rb') as f:
+ for block in iter(lambda: f.read(128*1024), b''):
+ sha256.update(block)
+ sha256_csum = sha256.hexdigest()
+ m, created = get_or_create(session, Checksum, value=sha256_csum)
+ m.ext = pil_img.format.lower()
+ m.trash = False
+ dst_file = get_image_path(m.value, m.ext, img_dir)
+ pathlib.Path(os.path.dirname(dst_file)).mkdir(parents=True, exist_ok=True)
+ shutil.copy(filename, dst_file)
+ return m, created
+
+
+def grouper(iterable, n, fillvalue=None):
+ """Collect data into fixed-length chunks or blocks.
+ taken from:
+ https://docs.python.org/3/library/itertools.html#itertools.zip_longest
+ >>> list(grouper('ABCDEFG', 3, 'x'))
+ [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')]
+ """
+ args = [iter(iterable)] * n
+ return zip_longest(*args, fillvalue=fillvalue)
diff --git a/transformation_invariant_image_search/phash.py b/transformation_invariant_image_search/phash.py
index fd63bb6..811854e 100644
--- a/transformation_invariant_image_search/phash.py
+++ b/transformation_invariant_image_search/phash.py
@@ -6,6 +6,8 @@
HEX_STRINGS = np.array([f'{x:02x}' for x in range(256)])
BIN_POWERS = 2 ** np.arange(8)
+TRIANGLE_LOWER = 50
+TRIANGLE_UPPER = 400
def phash(image, hash_size=8, highfreq_factor=4):
@@ -26,6 +28,19 @@ def hash_to_hex(a):
def hash_triangles(img, triangles):
+ """Get hash triangles.
+ >>> from .keypoints import compute_keypoints
+ >>> filename = 'fullEndToEndDemo/inputImages/cat_original.png'
+ >>> img = cv2.imread(filename)
+ >>> keypoints = compute_keypoints(img)
+ >>> triangles = triangles_from_keypoints(keypoints)
+ >>> res = hash_triangles(img, triangles)
+ >>> len(res), sorted(res)[0]
+ (34770, '0000563b8d730d07')
+ >>> res = hash_triangles(img, [triangles[0]])
+ >>> len(res), sorted(res)
+ (3, ['709a3765dd04b0f3', 'b8dd5c4e7a352cea', 'de433036010bb391'])
+ """
n = len(triangles)
triangles = np.asarray(triangles)
@@ -51,7 +66,7 @@ def hash_triangles(img, triangles):
# rotate triangles 3 times, one for each edge of the triangle
rotations = (0, 1, 2), (1, 2, 0), (2, 0, 1)
- for i, rotation in enumerate(tqdm.tqdm(rotations)):
+ for i, rotation in enumerate(rotations):
p = triangles[:, rotation, :]
p0 = p[:, 0]
@@ -71,7 +86,8 @@ def hash_triangles(img, triangles):
transform = target_points @ input_points_inverse @ transpose_m
transform = transform[:, :2, :]
- for k in tqdm.tqdm(range(n)):
+ range_list = tqdm.tqdm(range(n)) if len(range(n)) > 1 else range(n)
+ for k in range_list:
image = cv2.warpAffine(img, transform[k], size)
# calculate dct for perceptual hash
@@ -89,7 +105,25 @@ def hash_triangles(img, triangles):
return hash_to_hex(hashes)
-def triangles_from_keypoints(keypoints, lower=50, upper=400):
+def triangles_from_keypoints(keypoints, lower=TRIANGLE_LOWER, upper=TRIANGLE_UPPER):
+ """Get Triangles from keypoints.
+
+ >>> from .keypoints import compute_keypoints
+ >>> filename = 'fullEndToEndDemo/inputImages/cat_original.png'
+ >>> img = cv2.imread(filename)
+ >>> keypoints = compute_keypoints(img)
+ >>> res = triangles_from_keypoints(keypoints)
+ >>> len(res)
+ 11590
+ >>> print(list(map(lambda x: x.tolist(), res[0])))
+ [[162.0, 203.0], [261.0, 76.0], [131.0, 63.0]]
+ >>> res2 = triangles_from_keypoints(keypoints, lower=10)
+ >>> len(res2)
+ 14238
+ >>> res3 = triangles_from_keypoints(keypoints, upper=100)
+ >>> len(res3)
+ 315
+ """
keypoints = np.asarray(keypoints, dtype=float)
tree = BallTree(keypoints, leaf_size=10)
diff --git a/transformation_invariant_image_search/templates/tiis/index.html b/transformation_invariant_image_search/templates/tiis/index.html
new file mode 100644
index 0000000..e2eac28
--- /dev/null
+++ b/transformation_invariant_image_search/templates/tiis/index.html
@@ -0,0 +1,79 @@
+{% extends 'admin/master.html' %}
+
+{% block head %}
+{{ super() }}
+
+{% endblock %}
+
+{% block body %}
+
+
+
+
![]()
+
+
+{% endblock %}