Skip to content

Commit b5e7fe1

Browse files
authored
[DLMED] add allow_missing_reference (#4079)
Signed-off-by: Nic Ma <[email protected]>
1 parent b229e9c commit b5e7fe1

File tree

2 files changed

+40
-3
lines changed

2 files changed

+40
-3
lines changed

monai/bundle/reference_resolver.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
import os
1213
import re
14+
import warnings
1315
from typing import Any, Dict, Optional, Sequence, Set
1416

1517
from monai.bundle.config_item import ConfigComponent, ConfigExpression, ConfigItem
@@ -50,6 +52,8 @@ class ReferenceResolver:
5052
ref = ID_REF_KEY # reference prefix
5153
# match a reference string, e.g. "@id#key", "@id#key#0", "@_target_#key"
5254
id_matcher = re.compile(rf"{ref}(?:\w*)(?:{sep}\w*)*")
55+
# if `allow_missing_reference` and can't find a reference ID, will just raise a warning and don't update the config
56+
allow_missing_reference = False if os.environ.get("MONAI_ALLOW_MISSING_REFERENCE", "0") == "0" else True
5357

5458
def __init__(self, items: Optional[Sequence[ConfigItem]] = None):
5559
# save the items in a dictionary with the `ConfigItem.id` as key
@@ -140,7 +144,12 @@ def _resolve_one_item(self, id: str, waiting_list: Optional[Set[str]] = None, **
140144
try:
141145
look_up_option(d, self.items, print_all_options=False)
142146
except ValueError as err:
143-
raise ValueError(f"the referring item `@{d}` is not defined in the config content.") from err
147+
msg = f"the referring item `@{d}` is not defined in the config content."
148+
if self.allow_missing_reference:
149+
warnings.warn(msg)
150+
continue
151+
else:
152+
raise ValueError(msg) from err
144153
# recursively resolve the reference first
145154
self._resolve_one_item(id=d, waiting_list=waiting_list, **kwargs)
146155
waiting_list.discard(d)
@@ -210,7 +219,12 @@ def update_refs_pattern(cls, value: str, refs: Dict) -> str:
210219
for item in result:
211220
ref_id = item[len(cls.ref) :] # remove the ref prefix "@"
212221
if ref_id not in refs:
213-
raise KeyError(f"can not find expected ID '{ref_id}' in the references.")
222+
msg = f"can not find expected ID '{ref_id}' in the references."
223+
if cls.allow_missing_reference:
224+
warnings.warn(msg)
225+
continue
226+
else:
227+
raise KeyError(msg)
214228
if value_is_expr:
215229
# replace with local code, will be used in the `evaluate` logic with `locals={"refs": ...}`
216230
value = value.replace(item, f"{cls._vars}['{ref_id}']")

tests/test_config_parser.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from parameterized import parameterized
1818

19-
from monai.bundle.config_parser import ConfigParser
19+
from monai.bundle import ConfigParser, ReferenceResolver
2020
from monai.data import DataLoader, Dataset
2121
from monai.transforms import Compose, LoadImaged, RandTorchVisiond
2222
from monai.utils import min_version, optional_import
@@ -86,6 +86,8 @@ def __call__(self, a, b):
8686
}
8787
]
8888

89+
TEST_CASE_4 = [{"A": 1, "B": "@A", "C": "@D", "E": "$'test' + '@F'"}]
90+
8991

9092
class TestConfigParser(unittest.TestCase):
9193
def test_config_content(self):
@@ -154,6 +156,27 @@ def test_macro_replace(self):
154156
parser.resolve_macro_and_relative_ids()
155157
self.assertEqual(str(parser.get()), str({"A": {"B": 1, "C": 2}, "D": [3, 1, 3, 4]}))
156158

159+
@parameterized.expand([TEST_CASE_4])
160+
def test_allow_missing_reference(self, config):
161+
default = ReferenceResolver.allow_missing_reference
162+
ReferenceResolver.allow_missing_reference = True
163+
parser = ConfigParser(config=config)
164+
165+
for id in config:
166+
item = parser.get_parsed_content(id=id)
167+
if id in ("A", "B"):
168+
self.assertEqual(item, 1)
169+
elif id == "C":
170+
self.assertEqual(item, "@D")
171+
elif id == "E":
172+
self.assertEqual(item, "test@F")
173+
174+
# restore the default value
175+
ReferenceResolver.allow_missing_reference = default
176+
with self.assertRaises(ValueError):
177+
parser.parse()
178+
parser.get_parsed_content(id="E")
179+
157180

158181
if __name__ == "__main__":
159182
unittest.main()

0 commit comments

Comments
 (0)