File tree Expand file tree Collapse file tree 4 files changed +30
-7
lines changed
Expand file tree Collapse file tree 4 files changed +30
-7
lines changed Original file line number Diff line number Diff line change 99from collections import abc
1010from contextlib import contextmanager
1111from copy import deepcopy
12+ from dataclasses import is_dataclass
1213from typing import List , Tuple , Union
1314import cloudpickle
1415import yaml
@@ -45,7 +46,14 @@ def __init__(self, target):
4546 self ._target = target
4647
4748 def __call__ (self , ** kwargs ):
48- kwargs ["_target_" ] = self ._target
49+ if is_dataclass (self ._target ):
50+ # omegaconf object cannot hold dataclass type
51+ # https://github.com/omry/omegaconf/issues/784
52+ target = _convert_target_to_string (self ._target )
53+ else :
54+ target = self ._target
55+ kwargs ["_target_" ] = target
56+
4957 return DictConfig (content = kwargs , flags = {"allow_objects" : True })
5058
5159
Original file line number Diff line number Diff line change @@ -118,10 +118,12 @@ def __init__(
118118
119119 self ._metadata = MetadataCatalog .get (dataset_name )
120120 if not hasattr (self ._metadata , "json_file" ):
121- self ._logger .info (
122- f"'{ dataset_name } ' is not registered by `register_coco_instances`."
123- " Therefore trying to convert it to COCO format ..."
124- )
121+ if output_dir is None :
122+ raise ValueError (
123+ "output_dir must be provided to COCOEvaluator "
124+ "for datasets not in COCO format."
125+ )
126+ self ._logger .info (f"Trying to convert '{ dataset_name } ' to COCO format ..." )
125127
126128 cache_path = os .path .join (output_dir , f"{ dataset_name } _coco_format.json" )
127129 self ._metadata .json_file = cache_path
Original file line number Diff line number Diff line change 44# Run this script at project root by "./dev/linter.sh" before you commit
55
66{
7- black --version | grep -E " 21.4b2 " > /dev/null
7+ black --version | grep -E " 21\. " > /dev/null
88} || {
9- echo " Linter requires 'black==21.4b2 ' !"
9+ echo " Linter requires 'black==21.* ' !"
1010 exit 1
1111}
1212
Original file line number Diff line number Diff line change 66import yaml
77from omegaconf import OmegaConf
88from omegaconf import __version__ as oc_version
9+ from dataclasses import dataclass
910
1011from detectron2 .config import instantiate , LazyCall as L
1112from detectron2 .layers import ShapeSpec
@@ -24,6 +25,12 @@ def __call__(self, call_arg):
2425 return call_arg + self .int_arg
2526
2627
28+ @dataclass
29+ class TestDataClass :
30+ x : int
31+ y : str
32+
33+
2734@unittest .skipIf (OC_VERSION < (2 , 1 ), "omegaconf version too old" )
2835class TestConstruction (unittest .TestCase ):
2936 def test_basic_construct (self ):
@@ -85,3 +92,9 @@ def test_instantiate_namedtuple(self):
8592 def test_bad_lazycall (self ):
8693 with self .assertRaises (Exception ):
8794 L (3 )
95+
96+ def test_instantiate_dataclass (self ):
97+ a = L (TestDataClass )(x = 1 , y = "s" )
98+ a = instantiate (a )
99+ self .assertEqual (a .x , 1 )
100+ self .assertEqual (a .y , "s" )
You can’t perform that action at this time.
0 commit comments