Skip to content

Commit 730ccef

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
let LazyCall support dataclasses
Summary: it was unsupported due to omegaconf Reviewed By: zhanghang1989 Differential Revision: D30753898 fbshipit-source-id: d47b014bdc806023ebc95364359d8695865cd3c7
1 parent a8b8aa3 commit 730ccef

File tree

4 files changed

+30
-7
lines changed

4 files changed

+30
-7
lines changed

detectron2/config/lazy.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from collections import abc
1010
from contextlib import contextmanager
1111
from copy import deepcopy
12+
from dataclasses import is_dataclass
1213
from typing import List, Tuple, Union
1314
import cloudpickle
1415
import yaml
@@ -45,7 +46,14 @@ def __init__(self, target):
4546
self._target = target
4647

4748
def __call__(self, **kwargs):
48-
kwargs["_target_"] = self._target
49+
if is_dataclass(self._target):
50+
# omegaconf object cannot hold dataclass type
51+
# https://github.com/omry/omegaconf/issues/784
52+
target = _convert_target_to_string(self._target)
53+
else:
54+
target = self._target
55+
kwargs["_target_"] = target
56+
4957
return DictConfig(content=kwargs, flags={"allow_objects": True})
5058

5159

detectron2/evaluation/coco_evaluation.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,12 @@ def __init__(
118118

119119
self._metadata = MetadataCatalog.get(dataset_name)
120120
if not hasattr(self._metadata, "json_file"):
121-
self._logger.info(
122-
f"'{dataset_name}' is not registered by `register_coco_instances`."
123-
" Therefore trying to convert it to COCO format ..."
124-
)
121+
if output_dir is None:
122+
raise ValueError(
123+
"output_dir must be provided to COCOEvaluator "
124+
"for datasets not in COCO format."
125+
)
126+
self._logger.info(f"Trying to convert '{dataset_name}' to COCO format ...")
125127

126128
cache_path = os.path.join(output_dir, f"{dataset_name}_coco_format.json")
127129
self._metadata.json_file = cache_path

dev/linter.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
# Run this script at project root by "./dev/linter.sh" before you commit
55

66
{
7-
black --version | grep -E "21.4b2" > /dev/null
7+
black --version | grep -E "21\." > /dev/null
88
} || {
9-
echo "Linter requires 'black==21.4b2' !"
9+
echo "Linter requires 'black==21.*' !"
1010
exit 1
1111
}
1212

tests/config/test_instantiate_config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import yaml
77
from omegaconf import OmegaConf
88
from omegaconf import __version__ as oc_version
9+
from dataclasses import dataclass
910

1011
from detectron2.config import instantiate, LazyCall as L
1112
from detectron2.layers import ShapeSpec
@@ -24,6 +25,12 @@ def __call__(self, call_arg):
2425
return call_arg + self.int_arg
2526

2627

28+
@dataclass
29+
class TestDataClass:
30+
x: int
31+
y: str
32+
33+
2734
@unittest.skipIf(OC_VERSION < (2, 1), "omegaconf version too old")
2835
class TestConstruction(unittest.TestCase):
2936
def test_basic_construct(self):
@@ -85,3 +92,9 @@ def test_instantiate_namedtuple(self):
8592
def test_bad_lazycall(self):
8693
with self.assertRaises(Exception):
8794
L(3)
95+
96+
def test_instantiate_dataclass(self):
97+
a = L(TestDataClass)(x=1, y="s")
98+
a = instantiate(a)
99+
self.assertEqual(a.x, 1)
100+
self.assertEqual(a.y, "s")

0 commit comments

Comments
 (0)