11from typing import Union
22from PIL import Image
33import numpy as np
4+ import warnings
45
5- from .catalog import PathManager , LABEL_MAP_CATALOG
6+ from .catalog import MODEL_CATALOG , PathManager , LABEL_MAP_CATALOG
67from ..base_layoutmodel import BaseLayoutModel
78from ...elements import Rectangle , TextBlock , Layout
89from ...file_utils import is_torch_cuda_available , is_detectron2_available
@@ -30,9 +31,9 @@ class Detectron2LayoutModel(BaseLayoutModel):
3031 word labels (strings). If the config is from one of the supported
3132 datasets, Layout Parser will automatically initialize the label_map.
3233 Defaults to `None`.
33- enforce_cpu (:obj:`bool `, optional):
34- When set to `True`, it will enforce using cpu even if it is on a CUDA
35- available device.
34+ device (:obj:`str `, optional):
35+ Whether to use cuda or cpu devices. If not set, LayoutParser will
36+ automatically determine the device to initialize the models on .
3637 extra_config (:obj:`list`, optional):
3738 Extra configuration passed to the Detectron2 model
3839 configuration. The argument will be used in the `merge_from_list
@@ -49,70 +50,55 @@ class Detectron2LayoutModel(BaseLayoutModel):
4950
5051 DEPENDENCIES = ["detectron2" ]
5152 DETECTOR_NAME = "detectron2"
53+ MODEL_CATALOG = MODEL_CATALOG
5254
5355 def __init__ (
5456 self ,
5557 config_path ,
5658 model_path = None ,
5759 label_map = None ,
5860 extra_config = None ,
59- enforce_cpu = False ,
61+ enforce_cpu = None ,
62+ device = None ,
6063 ):
6164
65+ if enforce_cpu is not None :
66+ warnings .warn (
67+ "Setting enforce_cpu is deprecated. Please set `device` instead." ,
68+ DeprecationWarning ,
69+ )
70+
6271 if extra_config is None :
6372 extra_config = []
6473
65- if config_path .startswith ("lp://" ) and label_map is None :
66- dataset_name = config_path .lstrip ("lp://" ).split ("/" )[0 ]
67- label_map = LABEL_MAP_CATALOG [dataset_name ]
68-
69- if enforce_cpu :
70- extra_config .extend (["MODEL.DEVICE" , "cpu" ])
74+ config_path , model_path = self .config_parser (
75+ config_path , model_path , allow_empty_path = True
76+ )
77+ config_path = PathManager .get_local_path (config_path )
7178
7279 cfg = detectron2 .config .get_cfg ()
73- config_path = self ._reconstruct_path_with_detector_name (config_path )
74- config_path = PathManager .get_local_path (config_path )
7580 cfg .merge_from_file (config_path )
7681 cfg .merge_from_list (extra_config )
7782
7883 if model_path is not None :
79- model_path = self ._reconstruct_path_with_detector_name (model_path )
84+ model_path = PathManager .get_local_path (model_path )
85+ # Because it will be forwarded to the detectron2 paths
8086 cfg .MODEL .WEIGHTS = model_path
81-
82- if not enforce_cpu :
83- cfg .MODEL .DEVICE = "cuda" if is_torch_cuda_available () else "cpu"
87+
88+ if is_torch_cuda_available ():
89+ if device is None :
90+ device = "cuda"
91+ else :
92+ device = "cpu"
93+ cfg .MODEL .DEVICE = device
8494
8595 self .cfg = cfg
8696
8797 self .label_map = label_map
8898 self ._create_model ()
8999
90- def _reconstruct_path_with_detector_name (self , path : str ) -> str :
91- """This function will add the detector name (detectron2) into the
92- lp model config path to get the "canonical" model name.
93-
94- For example, for a given config_path `lp://HJDataset/faster_rcnn_R_50_FPN_3x/config`,
95- it will transform it into `lp://detectron2/HJDataset/faster_rcnn_R_50_FPN_3x/config`.
96- However, if the config_path already contains the detector name, we won't change it.
97-
98- This function is a general step to support multiple backends in the layout-parser
99- library.
100-
101- Args:
102- path (str): The given input path that might or might not contain the detector name.
103-
104- Returns:
105- str: a modified path that contains the detector name.
106- """
107- if path .startswith ("lp://" ): # TODO: Move "lp://" to a constant
108- model_name = path [len ("lp://" ) :]
109- model_name_segments = model_name .split ("/" )
110- if (
111- len (model_name_segments ) == 3
112- and self .DETECTOR_NAME not in model_name_segments
113- ):
114- return "lp://" + self .DETECTOR_NAME + "/" + path [len ("lp://" ) :]
115- return path
100+ def _create_model (self ):
101+ self .model = detectron2 .engine .DefaultPredictor (self .cfg )
116102
117103 def gather_output (self , outputs ):
118104
@@ -136,9 +122,6 @@ def gather_output(self, outputs):
136122
137123 return layout
138124
139- def _create_model (self ):
140- self .model = detectron2 .engine .DefaultPredictor (self .cfg )
141-
142125 def detect (self , image ):
143126 """Detect the layout of a given image.
144127
0 commit comments