Skip to content

Commit b6ba0c7

Browse files
fix: Remove left-over code which is not needed for prediction (#35)
Reduce the calls of opencv to the minimum that is actually used. Signed-off-by: Nikos Livathinos <[email protected]>
1 parent beea804 commit b6ba0c7

File tree

10 files changed

+8
-2823
lines changed

10 files changed

+8
-2823
lines changed

docling_ibm_models/tableformer/common.py

Lines changed: 0 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -48,32 +48,6 @@ def validate_config(config):
4848
return True
4949

5050

51-
def parse_arguments():
52-
r"""
53-
Parse the input arguments
54-
A ValueError exception will be thrown in case the config file is invalid
55-
"""
56-
parser = argparse.ArgumentParser(description="Train the TableModel")
57-
parser.add_argument(
58-
"-c", "--config", required=True, default=None, help="configuration file (JSON)"
59-
)
60-
args = parser.parse_args()
61-
config_filename = args.config
62-
63-
assert os.path.isfile(config_filename), "FAILURE: Config file not found."
64-
return read_config(config_filename)
65-
66-
67-
def read_config(config_filename):
68-
with open(config_filename, "r") as fd:
69-
config = json.load(fd)
70-
71-
# Validate the config file
72-
validate_config(config)
73-
74-
return config
75-
76-
7751
def safe_get_parameter(input_dict, index_path, default=None, required=False):
7852
r"""
7953
Safe get parameter from a nested dictionary.
@@ -130,71 +104,3 @@ def get_prepared_data_filename(prepared_data_part, dataset_name):
130104
if "<POSTFIX>" in template:
131105
template = template.replace("<POSTFIX>", dataset_name)
132106
return template
133-
134-
135-
def create_dataset_and_model(config, purpose, fixed_padding=False):
136-
r"""
137-
Gets a model from configuration
138-
139-
Parameters
140-
---------
141-
config : Dictionary
142-
The configuration of the model
143-
purpose : string
144-
One of "train", "eval", "predict"
145-
fixed_padding : bool
146-
Parameter passed to the constructor of the DataLoader
147-
148-
Returns
149-
-------
150-
In case a Model cannot be initialized return None, None, None. Otherwise:
151-
152-
device : selected device
153-
dataset : Instance of the DataLoader
154-
model : Instance of the model
155-
"""
156-
from docling_ibm_models.tableformer.data_management.tf_dataset import TFDataset
157-
158-
model_type = config["model"]["type"]
159-
model = None
160-
161-
# Get env vars:
162-
use_cpu_only = os.environ.get("USE_CPU_ONLY", False)
163-
use_cuda_only = not use_cpu_only
164-
165-
# Use the cpu for the evaluation
166-
device = "cpu" # Default, run on CPU
167-
num_gpus = torch.cuda.device_count() # Check if GPU is available
168-
if use_cuda_only:
169-
device = "cuda:0" if num_gpus > 0 else "cpu" # Run on first available GPU
170-
else:
171-
device = "cpu"
172-
173-
# Create the DataLoader
174-
# loader = DataLoader(config, purpose, fixed_padding=fixed_padding)
175-
dataset = TFDataset(config, purpose, fixed_padding=fixed_padding)
176-
dataset.set_device(device)
177-
dataset_val = None
178-
if config["train"]["validation"] and purpose == "train":
179-
dataset_val = TFDataset(config, "val", fixed_padding=fixed_padding)
180-
dataset_val.set_device(device)
181-
if model_type == "TableModel04_rs":
182-
from docling_ibm_models.tableformer.models.table04_rs.tablemodel04_rs import ( # noqa: F401
183-
TableModel04_rs,
184-
)
185-
# Find the model class and create an instance of it
186-
for candidate in BaseModel.__subclasses__():
187-
if candidate.__name__ == model_type:
188-
init_data = dataset.get_init_data()
189-
model = candidate(config, init_data, purpose, device)
190-
191-
if model is None:
192-
logger.warn("Not found model: " + str(model_type))
193-
return None, None, None
194-
195-
logger.info("Found model: " + str(model_type))
196-
197-
if purpose == s.PREDICT_PURPOSE:
198-
return device, dataset, model
199-
else:
200-
return device, dataset, dataset_val, model

0 commit comments

Comments
 (0)