@@ -53,6 +53,7 @@ def __init__(self, num_iters: int = 5)
5353
5454from .cubes import PerplexityStrategy , GreedyStrategy
5555from .model_constructor import init_simple_default_model , create_default_topics
56+ from .rel_toolbox_lite import count_vocab_size , handle_regularizer
5657
5758import artm
5859
@@ -201,7 +202,11 @@ def build_schema_for_regs():
201202 for elem in artm .regularizers .__all__ :
202203 if "Regularizer" in elem :
203204 class_of_object = getattr (artm .regularizers , elem )
204- res = wrap_in_map (build_schema_from_signature (class_of_object ))
205+ res = build_schema_from_signature (class_of_object )
206+ if elem in ["SmoothSparseThetaRegularizer" , "SmoothSparsePhiRegularizer" ,
207+ "DecorrelatorPhiRegularizer" ]:
208+ res [Optional ("relative" , default = None )] = Bool ()
209+ res = wrap_in_map (res )
205210
206211 specific_schema = Map ({class_of_object .__name__ : res })
207212 schemas [class_of_object .__name__ ] = specific_schema
@@ -392,11 +397,44 @@ def build_cube_settings(elemtype, elem_args):
392397 "selection" : elem_args ['selection' ].data }
393398
394399
395- def parse (yaml_string ):
400+ def _add_parsed_scores (parsed , topic_model ):
401+ """ """
402+ for score in parsed .data .get ('scores' , []):
403+ for elemtype , elem_args in score .items ():
404+ is_artm_score = elemtype in artm .scores .__all__
405+ score_object = build_score (elemtype , elem_args , is_artm_score )
406+ if is_artm_score :
407+ topic_model ._model .scores .add (score_object , overwrite = True )
408+ else :
409+ topic_model .custom_scores [elemtype ] = score_object
410+
411+
412+ def _add_parsed_regularizers (
413+ parsed , model , specific_topic_names , background_topic_names , data_stats
414+ ):
415+ """ """
416+ regularizers = []
417+ for stage in parsed .data ['regularizers' ]:
418+ for elemtype , elem_args in stage .items ():
419+ should_be_relative = None
420+ if "relative" in elem_args :
421+ should_be_relative = elem_args ["relative" ]
422+ elem_args .pop ("relative" )
423+
424+ regularizer_object = build_regularizer (
425+ elemtype , elem_args , specific_topic_names , background_topic_names
426+ )
427+ handle_regularizer (should_be_relative , model , regularizer_object , data_stats )
428+ regularizers .append (model .regularizers [regularizer_object .name ])
429+ return regularizers
430+
431+
432+ def parse (yaml_string , force_single_thread = False ):
396433 """
397434 Parameters
398435 ----------
399436 yaml_string : str
437+ force_single_thread : bool
400438
401439 Returns
402440 -------
@@ -418,39 +456,30 @@ def parse(yaml_string):
418456 revalidate_section (parsed , "scores" )
419457
420458 cube_settings = []
421- regularizers = []
422459
423460 dataset = Dataset (parsed .data ["model" ]["dataset_path" ])
461+ modalities_to_use = parsed .data ["model" ]["modalities_to_use" ]
462+
463+ data_stats = count_vocab_size (dataset .get_dictionary (), modalities_to_use )
424464 model = init_simple_default_model (
425465 dataset = dataset ,
426- modalities_to_use = parsed . data [ "model" ][ " modalities_to_use" ] ,
466+ modalities_to_use = modalities_to_use ,
427467 main_modality = parsed .data ["model" ]["main_modality" ],
428468 specific_topics = parsed .data ["topics" ]["specific_topics" ],
429469 background_topics = parsed .data ["topics" ]["background_topics" ],
430470 )
431- for stage in parsed .data ['regularizers' ]:
432- for elemtype , elem_args in stage .items ():
433-
434- regularizer_object = build_regularizer (
435- elemtype , elem_args , specific_topic_names , background_topic_names
436- )
437- regularizers .append (regularizer_object )
438- model .regularizers .add (regularizer_object , overwrite = True )
439471
472+ regularizers = _add_parsed_regularizers (
473+ parsed , model , specific_topic_names , background_topic_names , data_stats
474+ )
440475 topic_model = TopicModel (model )
441-
442- for score in parsed .data .get ('scores' , []):
443- for elemtype , elem_args in score .items ():
444- is_artm_score = elemtype in artm .scores .__all__
445- score_object = build_score (elemtype , elem_args , is_artm_score )
446- if is_artm_score :
447- model .scores .add (score_object , overwrite = True )
448- else :
449- topic_model .custom_scores [elemtype ] = score_object
476+ _add_parsed_scores (parsed , topic_model )
450477
451478 for stage in parsed ['stages' ]:
452479 for elemtype , elem_args in stage .items ():
453480 settings = build_cube_settings (elemtype .data , elem_args )
481+ if force_single_thread :
482+ settings [elemtype ]["separate_thread" ] = False
454483 cube_settings .append (settings )
455484
456485 return cube_settings , regularizers , topic_model , dataset
@@ -486,8 +515,9 @@ def revalidate_section(parsed, section):
486515 stage .revalidate (local_schema )
487516
488517
489- def build_experiment_environment_from_yaml_config (yaml_string , experiment_id , save_path ):
490- settings , regs , model , dataset = parse (yaml_string )
518+ def build_experiment_environment_from_yaml_config (yaml_string , experiment_id ,
519+ save_path , force_single_thread = False ):
520+ settings , regs , model , dataset = parse (yaml_string , force_single_thread )
491521 # TODO: handle dynamic addition of regularizers
492522 experiment = Experiment (experiment_id = experiment_id , save_path = save_path , topic_model = model )
493523 experiment .build (settings )
0 commit comments