Skip to content

Commit d0bdb22

Browse files
fix: Put back in common.py the function read_config(). Extend the unit tests. (#36)
Signed-off-by: Nikos Livathinos <[email protected]>
1 parent 3d88489 commit d0bdb22

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

docling_ibm_models/tableformer/common.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,16 @@ def validate_config(config):
4848
return True
4949

5050

51+
def read_config(config_filename):
52+
with open(config_filename, "r") as fd:
53+
config = json.load(fd)
54+
55+
# Validate the config file
56+
validate_config(config)
57+
58+
return config
59+
60+
5161
def safe_get_parameter(input_dict, index_path, default=None, required=False):
5262
r"""
5363
Safe get parameter from a nested dictionary.

tests/test_common.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22
# Copyright IBM Corp. 2024 - 2024
33
# SPDX-License-Identifier: MIT
44
#
5+
import json
6+
import tempfile
7+
58
import docling_ibm_models.tableformer.common as c
69

10+
711
test_config_a = {
812
"base_dir": "./tests/test_data/",
913
"curr_dir": "./tests/test_data/test_common/",
@@ -70,3 +74,16 @@ def test_config_validation():
7074
assert val, "Valid configuration didn't pass the validation test"
7175
except AssertionError:
7276
assert i == 2, "Configuration validation error"
77+
78+
def test_read_config():
79+
r"""
80+
Testing the read_config() function
81+
"""
82+
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as fp:
83+
# Write a tmp file
84+
json.dump(test_config_a, fp)
85+
fp.close()
86+
87+
# Read the tmp file and extract the config
88+
config = c.read_config(fp.name)
89+
assert isinstance(config, dict)

0 commit comments

Comments
 (0)