1- from abc import ABC , abstractmethod
2- from enum import IntEnum
3- import importlib
41import io
52import os
63import json
7- import csv
84import warnings
9- import pickle
105
116import numpy as np
12- import pandas as pd
137from cv2 import imencode
148
15- from .elements import *
16- from .io import load_dataframe
17-
18- __all__ = ["GCVFeatureType" , "GCVAgent" , "TesseractFeatureType" , "TesseractAgent" ]
9+ from .base import BaseOCRAgent , BaseOCRElementType
10+ from ..elements import Layout , TextBlock , Quadrilateral , TextBlock
1911
2012
2113def _cvt_GCV_vertices_to_points (vertices ):
2214 return np .array ([[vertex .x , vertex .y ] for vertex in vertices ])
2315
2416
25- class BaseOCRElementType (IntEnum ):
26- @property
27- @abstractmethod
28- def attr_name (self ):
29- pass
30-
31-
32- class BaseOCRAgent (ABC ):
33- @property
34- @abstractmethod
35- def DEPENDENCIES (self ):
36- """DEPENDENCIES lists all necessary dependencies for the class."""
37- pass
38-
39- @property
40- @abstractmethod
41- def MODULES (self ):
42- """MODULES instructs how to import these necessary libraries.
43-
44- Note:
45- Sometimes a python module have different installation name and module name (e.g.,
46- `pip install tensorflow-gpu` when installing and `import tensorflow` when using
47- ). And sometimes we only need to import a submodule but not whole module. MODULES
48- is designed for this purpose.
49-
50- Returns:
51- :obj: list(dict): A list of dict indicate how the model is imported.
52-
53- Example::
54-
55- [{
56- "import_name": "_vision",
57- "module_path": "google.cloud.vision"
58- }]
59-
60- is equivalent to self._vision = importlib.import_module("google.cloud.vision")
61- """
62- pass
63-
64- @classmethod
65- def _import_module (cls ):
66- for m in cls .MODULES :
67- if importlib .util .find_spec (m ["module_path" ]):
68- setattr (
69- cls , m ["import_name" ], importlib .import_module (m ["module_path" ])
70- )
71- else :
72- raise ModuleNotFoundError (
73- f"\n "
74- f"\n Please install the following libraries to support the class { cls .__name__ } :"
75- f"\n pip install { ' ' .join (cls .DEPENDENCIES )} "
76- f"\n "
77- )
78-
79- def __new__ (cls , * args , ** kwargs ):
80-
81- cls ._import_module ()
82- return super ().__new__ (cls )
83-
84- @abstractmethod
85- def detect (self , image ):
86- pass
87-
88-
8917class GCVFeatureType (BaseOCRElementType ):
9018 """
9119 The element types from Google Cloud Vision API
@@ -341,166 +269,4 @@ def save_response(self, res, file_name):
341269
342270 with open (file_name , "w" ) as f :
343271 json_file = json .loads (res )
344- json .dump (json_file , f )
345-
346-
347- class TesseractFeatureType (BaseOCRElementType ):
348- """
349- The element types for Tesseract Detection API
350- """
351-
352- PAGE = 0
353- BLOCK = 1
354- PARA = 2
355- LINE = 3
356- WORD = 4
357-
358- @property
359- def attr_name (self ):
360- name_cvt = {
361- TesseractFeatureType .PAGE : "page_num" ,
362- TesseractFeatureType .BLOCK : "block_num" ,
363- TesseractFeatureType .PARA : "par_num" ,
364- TesseractFeatureType .LINE : "line_num" ,
365- TesseractFeatureType .WORD : "word_num" ,
366- }
367- return name_cvt [self ]
368-
369- @property
370- def group_levels (self ):
371- levels = ["page_num" , "block_num" , "par_num" , "line_num" , "word_num" ]
372- return levels [: self + 1 ]
373-
374-
375- class TesseractAgent (BaseOCRAgent ):
376- """
377- A wrapper for `Tesseract <https://github.com/tesseract-ocr/tesseract>`_ Text
378- Detection APIs based on `PyTesseract <https://github.com/tesseract-ocr/tesseract>`_.
379- """
380-
381- DEPENDENCIES = ["pytesseract" ]
382- MODULES = [{"import_name" : "_pytesseract" , "module_path" : "pytesseract" }]
383-
384- def __init__ (self , languages = "eng" , ** kwargs ):
385- """Create a Tesseract OCR Agent.
386-
387- Args:
388- languages (:obj:`list` or :obj:`str`, optional):
389- You can specify the language code(s) of the documents to detect to improve
390- accuracy. The supported language and their code can be found on
391- `its github repo <https://github.com/tesseract-ocr/langdata>`_.
392- It supports two formats: 1) you can pass in the languages code as a string
393- of format like `"eng+fra"`, or 2) you can pack them as a list of strings
394- `["eng", "fra"]`.
395- Defaults to 'eng'.
396- """
397- self .lang = languages if isinstance (languages , str ) else "+" .join (languages )
398- self .configs = kwargs
399-
400- @classmethod
401- def with_tesseract_executable (cls , tesseract_cmd_path , ** kwargs ):
402-
403- cls ._pytesseract .pytesseract .tesseract_cmd = tesseract_cmd_path
404- return cls (** kwargs )
405-
406- def _detect (self , img_content ):
407- res = {}
408- res ["text" ] = self ._pytesseract .image_to_string (
409- img_content , lang = self .lang , ** self .configs
410- )
411- _data = self ._pytesseract .image_to_data (
412- img_content , lang = self .lang , ** self .configs
413- )
414- res ["data" ] = pd .read_csv (
415- io .StringIO (_data ), quoting = csv .QUOTE_NONE , encoding = "utf-8" , sep = "\t "
416- )
417- return res
418-
419- def detect (
420- self , image , return_response = False , return_only_text = True , agg_output_level = None
421- ):
422- """Send the input image for OCR.
423-
424- Args:
425- image (:obj:`np.ndarray` or :obj:`str`):
426- The input image array or the name of the image file
427- return_response (:obj:`bool`, optional):
428- Whether directly return all output (string and boxes
429- info) from Tesseract.
430- Defaults to `False`.
431- return_only_text (:obj:`bool`, optional):
432- Whether return only the texts in the OCR results.
433- Defaults to `False`.
434- agg_output_level (:obj:`~TesseractFeatureType`, optional):
435- When set, aggregate the GCV output with respect to the
436- specified aggregation level. Defaults to `None`.
437- """
438-
439- res = self ._detect (image )
440-
441- if return_response :
442- return res
443-
444- if return_only_text :
445- return res ["text" ]
446-
447- if agg_output_level is not None :
448- return self .gather_data (res , agg_output_level )
449-
450- return res ["text" ]
451-
452- @staticmethod
453- def gather_data (response , agg_level ):
454- """
455- Gather the OCR'ed text, bounding boxes, and confidence
456- in a given aggeragation level.
457- """
458- assert isinstance (
459- agg_level , TesseractFeatureType
460- ), f"Invalid agg_level { agg_level } "
461- res = response ["data" ]
462- df = (
463- res [~ res .text .isna ()]
464- .groupby (agg_level .group_levels )
465- .apply (
466- lambda gp : pd .Series (
467- [
468- gp ["left" ].min (),
469- gp ["top" ].min (),
470- gp ["width" ].max (),
471- gp ["height" ].max (),
472- gp ["conf" ].mean (),
473- gp ["text" ].str .cat (sep = " " ),
474- ]
475- )
476- )
477- .reset_index (drop = True )
478- .reset_index ()
479- .rename (
480- columns = {
481- 0 : "x_1" ,
482- 1 : "y_1" ,
483- 2 : "w" ,
484- 3 : "h" ,
485- 4 : "score" ,
486- 5 : "text" ,
487- "index" : "id" ,
488- }
489- )
490- .assign (x_2 = lambda x : x .x_1 + x .w , y_2 = lambda x : x .y_1 + x .h , block_type = "rectangle" )
491- .drop (columns = ["w" , "h" ])
492- )
493-
494- return load_dataframe (df )
495-
496- @staticmethod
497- def load_response (filename ):
498- with open (filename , "rb" ) as fp :
499- res = pickle .load (fp )
500- return res
501-
502- @staticmethod
503- def save_response (res , file_name ):
504-
505- with open (file_name , "wb" ) as fp :
506- pickle .dump (res , fp , protocol = pickle .HIGHEST_PROTOCOL )
272+ json .dump (json_file , f )
0 commit comments