22# Copyright IBM Corp. 2024 - 2024
33# SPDX-License-Identifier: MIT
44#
5+ import logging
56import os
67from collections .abc import Iterable
78from typing import Union
1011import torch
1112import torchvision .transforms as T
1213from PIL import Image
14+ from transformers import RTDetrForObjectDetection , RTDetrImageProcessor
1315
14- MODEL_CHECKPOINT_FN = "model.pt"
15- DEFAULT_NUM_THREADS = 4
16+ _log = logging .getLogger (__name__ )
1617
1718
1819class LayoutPredictor :
19- r """
20- Document layout prediction using torch
20+ """
21+ Document layout prediction using safe tensors
2122 """
2223
2324 def __init__ (
24- self , artifact_path : str , num_threads : int = None , use_cpu_only : bool = False
25+ self ,
26+ artifact_path : str ,
27+ device : str = "cpu" ,
28+ num_threads : int = 4 ,
2529 ):
26- r """
30+ """
2731 Provide the artifact path that contains the LayoutModel file
2832
29- The number of threads is decided, in the following order, by:
30- 1. The init method parameter `num_threads`, if it is set.
31- 2. The envvar "OMP_NUM_THREADS", if it is set.
32- 3. The default value DEFAULT_NUM_THREADS.
33-
34- The execution provided is decided, in the following order:
35- 1. If the init method parameter `cpu_only` is True or the envvar "USE_CPU_ONLY" is set,
36- it uses the "CPUExecutionProvider".
37- 3. Otherwise if the "CUDAExecutionProvider" is present, use:
38- ["CUDAExecutionProvider", "CPUExecutionProvider"]:
39-
4033 Parameters
4134 ----------
4235 artifact_path: Path for the model torch file.
43- num_threads : (Optional) Number of threads to run the inference.
44- use_cpu_only : (Optional) If True, it forces CPU as the execution provider.
36+ device : (Optional) device to run the inference.
37+ num_threads : (Optional) Number of threads to run the inference if device = 'cpu'
4538
4639 Raises
4740 ------
@@ -70,40 +63,51 @@ def __init__(
7063 }
7164
7265 # Blacklisted classes
73- self ._black_classes = set (["Form" , "Key-Value Region" ])
66+ self ._black_classes = set () # ["Form", "Key-Value Region"])
7467
7568 # Set basic params
76- self ._threshold = 0.6 # Score threshold
69+ self ._threshold = 0.3 # Score threshold
7770 self ._image_size = 640
7871 self ._size = np .asarray ([[self ._image_size , self ._image_size ]], dtype = np .int64 )
79- self ._use_cpu_only = use_cpu_only or ("USE_CPU_ONLY" in os .environ )
8072
81- # Model file
82- self ._torch_fn = os .path .join (artifact_path , MODEL_CHECKPOINT_FN )
83- if not os .path .isfile (self ._torch_fn ):
84- raise FileNotFoundError ("Missing torch file: {}" .format (self ._torch_fn ))
85-
86- # Get env vars
87- if num_threads is None :
88- num_threads = int (os .environ .get ("OMP_NUM_THREADS" , DEFAULT_NUM_THREADS ))
73+ # Set number of threads for CPU
74+ self ._device = torch .device (device )
8975 self ._num_threads = num_threads
76+ if device == "cpu" :
77+ torch .set_num_threads (self ._num_threads )
78+
79+ # Model file and configurations
80+ self ._st_fn = os .path .join (artifact_path , "model.safetensors" )
81+ if not os .path .isfile (self ._st_fn ):
82+ raise FileNotFoundError ("Missing safe tensors file: {}" .format (self ._st_fn ))
9083
91- self .model = torch .jit .load (self ._torch_fn )
84+ # Load model and move to device
85+ processor_config = os .path .join (artifact_path , "preprocessor_config.json" )
86+ model_config = os .path .join (artifact_path , "config.json" )
87+ self ._image_processor = RTDetrImageProcessor .from_json_file (processor_config )
88+ self ._model = RTDetrForObjectDetection .from_pretrained (
89+ artifact_path , config = model_config
90+ ).to (self ._device )
91+ self ._model .eval ()
92+
93+ _log .debug ("LayoutPredictor settings: {}" .format (self .info ()))
9294
9395 def info (self ) -> dict :
94- r """
96+ """
9597 Get information about the configuration of LayoutPredictor
9698 """
9799 info = {
98- "torch_file" : self ._torch_fn ,
99- "use_cpu_only" : self ._use_cpu_only ,
100+ "safe_tensors_file" : self ._st_fn ,
101+ "device" : self ._device .type ,
102+ "num_threads" : self ._num_threads ,
100103 "image_size" : self ._image_size ,
101104 "threshold" : self ._threshold ,
102105 }
103106 return info
104107
108+ @torch .inference_mode ()
105109 def predict (self , orig_img : Union [Image .Image , np .ndarray ]) -> Iterable [dict ]:
106- r """
110+ """
107111 Predict bounding boxes for a given image.
108112 The origin (0, 0) is the top-left corner and the predicted bbox coords are provided as:
109113 [left, top, right, bottom]
@@ -128,40 +132,44 @@ def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]:
128132 else :
129133 raise TypeError ("Not supported input image format" )
130134
135+ resize = {"height" : self ._image_size , "width" : self ._image_size }
136+ inputs = self ._image_processor (
137+ images = page_img ,
138+ return_tensors = "pt" ,
139+ size = resize ,
140+ ).to (self ._device )
141+ outputs = self ._model (** inputs )
142+ results = self ._image_processor .post_process_object_detection (
143+ outputs ,
144+ target_sizes = torch .tensor ([page_img .size [::- 1 ]]),
145+ threshold = self ._threshold ,
146+ )
147+
131148 w , h = page_img .size
132- orig_size = torch .tensor ([w , h ])[None ]
133149
134- transforms = T .Compose (
135- [
136- T .Resize ((640 , 640 )),
137- T .ToTensor (),
138- ]
139- )
140- img = transforms (page_img )[None ]
141- # Predict
142- with torch .no_grad ():
143- labels , boxes , scores = self .model (img , orig_size )
150+ result = results [0 ]
151+ for score , label_id , box in zip (
152+ result ["scores" ], result ["labels" ], result ["boxes" ]
153+ ):
154+ score = float (score .item ())
155+
156+ label_id = int (label_id .item ()) + 1 # Advance the label_id
157+ label_str = self ._classes_map [label_id ]
144158
145- # Yield output
146- for label_idx , box , score in zip (labels [0 ], boxes [0 ], scores [0 ]):
147159 # Filter out blacklisted classes
148- label_idx = int (label_idx .item ())
149- score = float (score .item ())
150- label = self ._classes_map [label_idx + 1 ]
151- if label in self ._black_classes :
160+ if label_str in self ._black_classes :
152161 continue
153162
154- # Check against threshold
155- if score > self ._threshold :
156- l = min (w , max (0 , box [0 ]))
157- t = min (h , max (0 , box [1 ]))
158- r = min (w , max (0 , box [2 ]))
159- b = min (h , max (0 , box [3 ]))
160- yield {
161- "l" : l ,
162- "t" : t ,
163- "r" : r ,
164- "b" : b ,
165- "label" : label ,
166- "confidence" : score ,
167- }
163+ bbox_float = [float (b .item ()) for b in box ]
164+ l = min (w , max (0 , bbox_float [0 ]))
165+ t = min (h , max (0 , bbox_float [1 ]))
166+ r = min (w , max (0 , bbox_float [2 ]))
167+ b = min (h , max (0 , bbox_float [3 ]))
168+ yield {
169+ "l" : l ,
170+ "t" : t ,
171+ "r" : r ,
172+ "b" : b ,
173+ "label" : label_str ,
174+ "confidence" : score ,
175+ }
0 commit comments