@@ -45,19 +45,20 @@ def __init__(self, num_iters: int = 5)
4545but it's a work-in-progress currently.
4646
4747""" # noqa: W291
48- from .models .topic_model import TopicModel
4948from .cubes import RegularizersModifierCube , CubeCreator
5049from .experiment import Experiment
5150from .dataset import Dataset
52-
51+ from .models import scores as tnscores
52+ from .models import TopicModel
5353
5454from .cubes import PerplexityStrategy , GreedyStrategy
55- from .model_constructor import init_simple_default_model
55+ from .model_constructor import init_simple_default_model , create_default_topics
5656
5757import artm
5858
5959from inspect import signature , Parameter
60- from strictyaml import Map , Str , Int , Seq , Any , Optional , Float , EmptyNone , Bool
60+ from strictyaml import Map , Str , Int , Seq , Float , Bool
61+ from strictyaml import Any , Optional , EmptyDict , EmptyNone , EmptyList
6162from strictyaml import dirty_load
6263
6364# TODO: use stackoverflow.com/questions/37929851/parse-numpydoc-docstring-and-access-components
@@ -80,15 +81,16 @@ def __init__(self, num_iters: int = 5)
8081element = Any ()
8182base_schema = Map ({
8283 'regularizers' : Seq (element ),
84+ Optional ('scores' ): Seq (element ),
8385 'stages' : Seq (element ),
8486 'model' : Map ({
8587 "dataset_path" : Str (),
8688 "modalities_to_use" : Seq (Str ()),
8789 "main_modality" : Str ()
8890 }),
8991 'topics' : Map ({
90- "background_topics" : Seq (Str ()),
91- "specific_topics" : Seq (Str ()),
92+ "background_topics" : Seq (Str ()) | Int () | EmptyList () ,
93+ "specific_topics" : Seq (Str ()) | Int () | EmptyList () ,
9294 })
9395})
9496SUPPORTED_CUBES = [CubeCreator , RegularizersModifierCube ]
@@ -150,6 +152,44 @@ def build_schema_from_signature(class_of_object, use_optional=True):
150152 if param .name != 'self' }
151153
152154
155+ def wrap_in_map (dictionary ):
156+ could_be_empty = all (isinstance (key , Optional ) for key in dictionary )
157+ if could_be_empty :
158+ return Map (dictionary ) | EmptyDict ()
159+ return Map (dictionary )
160+
161+
162+ def build_schema_for_scores ():
163+ """
164+ Returns
165+ -------
166+ strictyaml.Map
167+ schema used for validation and type-coercion
168+ """
169+ schemas = {}
170+ for elem in artm .scores .__all__ :
171+ if "Score" in elem :
172+ class_of_object = getattr (artm .scores , elem )
173+ # TODO: check if every key is Optional. If it is, then "| EmptyDict()"
174+ # otherwise, just Map()
175+ res = wrap_in_map (build_schema_from_signature (class_of_object ))
176+
177+ specific_schema = Map ({class_of_object .__name__ : res })
178+ schemas [class_of_object .__name__ ] = specific_schema
179+
180+ for elem in tnscores .__all__ :
181+ if "Score" in elem :
182+ class_of_object = getattr (tnscores , elem )
183+ res = build_schema_from_signature (class_of_object )
184+ # res["name"] = Str() # TODO: support custom names
185+ res = wrap_in_map (res )
186+
187+ specific_schema = Map ({class_of_object .__name__ : res })
188+ schemas [class_of_object .__name__ ] = specific_schema
189+
190+ return schemas
191+
192+
153193def build_schema_for_regs ():
154194 """
155195 Returns
@@ -161,7 +201,7 @@ def build_schema_for_regs():
161201 for elem in artm .regularizers .__all__ :
162202 if "Regularizer" in elem :
163203 class_of_object = getattr (artm .regularizers , elem )
164- res = Map (build_schema_from_signature (class_of_object ))
204+ res = wrap_in_map (build_schema_from_signature (class_of_object ))
165205
166206 specific_schema = Map ({class_of_object .__name__ : res })
167207 schemas [class_of_object .__name__ ] = specific_schema
@@ -280,7 +320,28 @@ def handle_special_cases(elem_args, kwargs):
280320 kwargs ["strategy" ] = strategy # or None if failed to identify it
281321
282322
283- def build_regularizer (elemtype , elem_args , parsed ):
323+ def build_score (elemtype , elem_args , is_artm_score ):
324+ """
325+ Parameters
326+ ----------
327+ elemtype : str
328+ name of score
329+ elem_args: dict
330+ is_artm_score: bool
331+
332+ Returns
333+ -------
334+ instance of artm.Score or topicnet.BaseScore
335+ """
336+ module = artm .scores if is_artm_score else tnscores
337+ class_of_object = getattr (module , elemtype )
338+ kwargs = {name : value
339+ for name , value in elem_args .items ()}
340+
341+ return class_of_object (** kwargs )
342+
343+
344+ def build_regularizer (elemtype , elem_args , specific_topic_names , background_topic_names ):
284345 """
285346 Parameters
286347 ----------
@@ -299,9 +360,9 @@ def build_regularizer(elemtype, elem_args, parsed):
299360 # special case: shortcut for topic_names
300361 if "topic_names" in kwargs :
301362 if kwargs ["topic_names" ] == "background_topics" :
302- kwargs ["topic_names" ] = parsed . data [ "topics" ][ "background_topics" ]
363+ kwargs ["topic_names" ] = background_topic_names
303364 if kwargs ["topic_names" ] == "specific_topics" :
304- kwargs ["topic_names" ] = parsed . data [ "topics" ][ "specific_topics" ]
365+ kwargs ["topic_names" ] = specific_topic_names
305366
306367 return class_of_object (** kwargs )
307368
@@ -345,8 +406,16 @@ def parse(yaml_string):
345406 dataset: Dataset
346407 """
347408 parsed = dirty_load (yaml_string , base_schema , allow_flow_style = True )
409+
410+ specific_topic_names , background_topic_names = create_default_topics (
411+ parsed .data ["topics" ]["specific_topics" ],
412+ parsed .data ["topics" ]["background_topics" ]
413+ )
414+
348415 revalidate_section (parsed , "stages" )
349416 revalidate_section (parsed , "regularizers" )
417+ if "scores" in parsed :
418+ revalidate_section (parsed , "scores" )
350419
351420 cube_settings = []
352421 regularizers = []
@@ -362,12 +431,23 @@ def parse(yaml_string):
362431 for stage in parsed .data ['regularizers' ]:
363432 for elemtype , elem_args in stage .items ():
364433
365- regularizer_object = build_regularizer (elemtype , elem_args , parsed )
366-
434+ regularizer_object = build_regularizer (
435+ elemtype , elem_args , specific_topic_names , background_topic_names
436+ )
367437 regularizers .append (regularizer_object )
368438 model .regularizers .add (regularizer_object , overwrite = True )
369439
370440 topic_model = TopicModel (model )
441+
442+ for score in parsed .data .get ('scores' , []):
443+ for elemtype , elem_args in score .items ():
444+ is_artm_score = elemtype in artm .scores .__all__
445+ score_object = build_score (elemtype , elem_args , is_artm_score )
446+ if is_artm_score :
447+ model .scores .add (score_object , overwrite = True )
448+ else :
449+ topic_model .custom_scores [elemtype ] = score_object
450+
371451 for stage in parsed ['stages' ]:
372452 for elemtype , elem_args in stage .items ():
373453 settings = build_cube_settings (elemtype .data , elem_args )
@@ -390,6 +470,8 @@ def revalidate_section(parsed, section):
390470 schemas = build_schema_for_cubes ()
391471 elif section == "regularizers" :
392472 schemas = build_schema_for_regs ()
473+ elif section == "scores" :
474+ schemas = build_schema_for_scores ()
393475 else :
394476 raise ValueError (f"Unknown section name '{ section } '" )
395477
@@ -398,7 +480,7 @@ def revalidate_section(parsed, section):
398480 name = list (stage .data )[0 ]
399481
400482 if name not in schemas :
401- raise ValueError (f"Unsupported stage ID : { name } at line { stage .start_line } " )
483+ raise ValueError (f"Unsupported { section } value : { name } at line { stage .start_line } " )
402484 local_schema = schemas [name ]
403485
404486 stage .revalidate (local_schema )
0 commit comments