77from typing import Union
88
99import numpy as np
10- import onnxruntime as ort
10+ import torch
11+ import torchvision .transforms as T
1112from PIL import Image
1213
1314MODEL_CHECKPOINT_FN = "model.pt"
1617
1718class LayoutPredictor :
1819 r"""
19- Document layout prediction using ONNX
20+ Document layout prediction using torch
2021 """
2122
2223 def __init__ (
2324 self , artifact_path : str , num_threads : int = None , use_cpu_only : bool = False
2425 ):
2526 r"""
26- Provide the artifact path that contains the LayoutModel ONNX file
27+ Provide the artifact path that contains the LayoutModel file
2728
2829 The number of threads is decided, in the following order, by:
2930 1. The init method parameter `num_threads`, if it is set.
@@ -38,13 +39,13 @@ def __init__(
3839
3940 Parameters
4041 ----------
41- artifact_path: Path for the model ONNX file.
42+ artifact_path: Path for the model torch file.
4243 num_threads: (Optional) Number of threads to run the inference.
4344 use_cpu_only: (Optional) If True, it forces CPU as the execution provider.
4445
4546 Raises
4647 ------
47- FileNotFoundError when the model's ONNX file is missing
48+ FileNotFoundError when the model's torch file is missing
4849 """
4950 # Initialize classes map:
5051 self ._classes_map = {
@@ -75,46 +76,27 @@ def __init__(
7576 self ._threshold = 0.6 # Score threshold
7677 self ._image_size = 640
7778 self ._size = np .asarray ([[self ._image_size , self ._image_size ]], dtype = np .int64 )
79+ self ._use_cpu_only = use_cpu_only or ("USE_CPU_ONLY" in os .environ )
80+
81+ # Model file
82+ self ._torch_fn = os .path .join (artifact_path , MODEL_CHECKPOINT_FN )
83+ if not os .path .isfile (self ._torch_fn ):
84+ raise FileNotFoundError ("Missing torch file: {}" .format (self ._torch_fn ))
7885
7986 # Get env vars
80- self ._use_cpu_only = use_cpu_only or ("USE_CPU_ONLY" in os .environ )
8187 if num_threads is None :
8288 num_threads = int (os .environ .get ("OMP_NUM_THREADS" , DEFAULT_NUM_THREADS ))
8389 self ._num_threads = num_threads
8490
85- # Decide the execution providers
86- if (
87- not self ._use_cpu_only
88- and "CUDAExecutionProvider" in ort .get_available_providers ()
89- ):
90- providers = ["CUDAExecutionProvider" , "CPUExecutionProvider" ]
91- else :
92- providers = ["CPUExecutionProvider" ]
93- self ._providers = providers
94-
95- # Model ONNX file
96- self ._onnx_fn = os .path .join (artifact_path , MODEL_CHECKPOINT_FN )
97- if not os .path .isfile (self ._onnx_fn ):
98- raise FileNotFoundError ("Missing ONNX file: {}" .format (self ._onnx_fn ))
99-
100- # ONNX options
101- self ._options = ort .SessionOptions ()
102- self ._options .intra_op_num_threads = self ._num_threads
103- self .sess = ort .InferenceSession (
104- self ._onnx_fn ,
105- sess_options = self ._options ,
106- providers = self ._providers ,
107- )
91+ self .model = torch .jit .load (self ._torch_fn )
10892
10993 def info (self ) -> dict :
11094 r"""
11195 Get information about the configuration of LayoutPredictor
11296 """
11397 info = {
114- "onnx_file" : self ._onnx_fn ,
115- "intra_op_num_threads" : self ._num_threads ,
98+ "torch_file" : self ._torch_fn ,
11699 "use_cpu_only" : self ._use_cpu_only ,
117- "providers" : self ._providers ,
118100 "image_size" : self ._image_size ,
119101 "threshold" : self ._threshold ,
120102 }
@@ -147,33 +129,35 @@ def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]:
147129 raise TypeError ("Not supported input image format" )
148130
149131 w , h = page_img .size
150- page_img = page_img .resize ((self ._image_size , self ._image_size ))
151- page_data = np .array (page_img , dtype = np .uint8 ) / np .float32 (255.0 )
152- page_data = np .expand_dims (np .transpose (page_data , axes = [2 , 0 , 1 ]), axis = 0 )
132+ orig_size = torch .tensor ([w , h ])[None ]
153133
154- # Predict
155- labels , boxes , scores = self .sess .run (
156- output_names = None ,
157- input_feed = {
158- "images" : page_data ,
159- "orig_target_sizes" : self ._size ,
160- },
134+ transforms = T .Compose (
135+ [
136+ T .Resize ((640 , 640 )),
137+ T .ToTensor (),
138+ ]
161139 )
140+ img = transforms (page_img )[None ]
141+ # Predict
142+ with torch .no_grad ():
143+ labels , boxes , scores = self .model (img , orig_size )
162144
163145 # Yield output
164146 for label_idx , box , score in zip (labels [0 ], boxes [0 ], scores [0 ]):
165147 # Filter out blacklisted classes
166- label = self ._classes_map [label_idx ]
148+ label_idx = int (label_idx .item ())
149+ score = float (score .item ())
150+ label = self ._classes_map [label_idx + 1 ]
167151 if label in self ._black_classes :
168152 continue
169153
170154 # Check against threshold
171155 if score > self ._threshold :
172156 yield {
173- "l" : box [0 ] / self . _image_size * w ,
174- "t" : box [1 ] / self . _image_size * h ,
175- "r" : box [2 ] / self . _image_size * w ,
176- "b" : box [3 ] / self . _image_size * h ,
157+ "l" : box [0 ],
158+ "t" : box [1 ],
159+ "r" : box [2 ],
160+ "b" : box [3 ],
177161 "label" : label ,
178162 "confidence" : score ,
179163 }
0 commit comments