@@ -48,32 +48,6 @@ def validate_config(config):
4848 return True
4949
5050
51- def parse_arguments ():
52- r"""
53- Parse the input arguments
54- A ValueError exception will be thrown in case the config file is invalid
55- """
56- parser = argparse .ArgumentParser (description = "Train the TableModel" )
57- parser .add_argument (
58- "-c" , "--config" , required = True , default = None , help = "configuration file (JSON)"
59- )
60- args = parser .parse_args ()
61- config_filename = args .config
62-
63- assert os .path .isfile (config_filename ), "FAILURE: Config file not found."
64- return read_config (config_filename )
65-
66-
67- def read_config (config_filename ):
68- with open (config_filename , "r" ) as fd :
69- config = json .load (fd )
70-
71- # Validate the config file
72- validate_config (config )
73-
74- return config
75-
76-
7751def safe_get_parameter (input_dict , index_path , default = None , required = False ):
7852 r"""
7953 Safe get parameter from a nested dictionary.
@@ -130,71 +104,3 @@ def get_prepared_data_filename(prepared_data_part, dataset_name):
130104 if "<POSTFIX>" in template :
131105 template = template .replace ("<POSTFIX>" , dataset_name )
132106 return template
133-
134-
135- def create_dataset_and_model (config , purpose , fixed_padding = False ):
136- r"""
137- Gets a model from configuration
138-
139- Parameters
140- ---------
141- config : Dictionary
142- The configuration of the model
143- purpose : string
144- One of "train", "eval", "predict"
145- fixed_padding : bool
146- Parameter passed to the constructor of the DataLoader
147-
148- Returns
149- -------
150- In case a Model cannot be initialized return None, None, None. Otherwise:
151-
152- device : selected device
153- dataset : Instance of the DataLoader
154- model : Instance of the model
155- """
156- from docling_ibm_models .tableformer .data_management .tf_dataset import TFDataset
157-
158- model_type = config ["model" ]["type" ]
159- model = None
160-
161- # Get env vars:
162- use_cpu_only = os .environ .get ("USE_CPU_ONLY" , False )
163- use_cuda_only = not use_cpu_only
164-
165- # Use the cpu for the evaluation
166- device = "cpu" # Default, run on CPU
167- num_gpus = torch .cuda .device_count () # Check if GPU is available
168- if use_cuda_only :
169- device = "cuda:0" if num_gpus > 0 else "cpu" # Run on first available GPU
170- else :
171- device = "cpu"
172-
173- # Create the DataLoader
174- # loader = DataLoader(config, purpose, fixed_padding=fixed_padding)
175- dataset = TFDataset (config , purpose , fixed_padding = fixed_padding )
176- dataset .set_device (device )
177- dataset_val = None
178- if config ["train" ]["validation" ] and purpose == "train" :
179- dataset_val = TFDataset (config , "val" , fixed_padding = fixed_padding )
180- dataset_val .set_device (device )
181- if model_type == "TableModel04_rs" :
182- from docling_ibm_models .tableformer .models .table04_rs .tablemodel04_rs import ( # noqa: F401
183- TableModel04_rs ,
184- )
185- # Find the model class and create an instance of it
186- for candidate in BaseModel .__subclasses__ ():
187- if candidate .__name__ == model_type :
188- init_data = dataset .get_init_data ()
189- model = candidate (config , init_data , purpose , device )
190-
191- if model is None :
192- logger .warn ("Not found model: " + str (model_type ))
193- return None , None , None
194-
195- logger .info ("Found model: " + str (model_type ))
196-
197- if purpose == s .PREDICT_PURPOSE :
198- return device , dataset , model
199- else :
200- return device , dataset , dataset_val , model
0 commit comments