Skip to content

Commit 8a97b71

Browse files
committed
updated version 0.4.1
1 parent d78d856 commit 8a97b71

File tree

12 files changed

+2566
-53
lines changed

12 files changed

+2566
-53
lines changed

topicnet/cooking_machine/config_parser.py

Lines changed: 53 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(self, num_iters: int = 5)
5353

5454
from .cubes import PerplexityStrategy, GreedyStrategy
5555
from .model_constructor import init_simple_default_model, create_default_topics
56+
from .rel_toolbox_lite import count_vocab_size, handle_regularizer
5657

5758
import artm
5859

@@ -201,7 +202,11 @@ def build_schema_for_regs():
201202
for elem in artm.regularizers.__all__:
202203
if "Regularizer" in elem:
203204
class_of_object = getattr(artm.regularizers, elem)
204-
res = wrap_in_map(build_schema_from_signature(class_of_object))
205+
res = build_schema_from_signature(class_of_object)
206+
if elem in ["SmoothSparseThetaRegularizer", "SmoothSparsePhiRegularizer",
207+
"DecorrelatorPhiRegularizer"]:
208+
res[Optional("relative", default=None)] = Bool()
209+
res = wrap_in_map(res)
205210

206211
specific_schema = Map({class_of_object.__name__: res})
207212
schemas[class_of_object.__name__] = specific_schema
@@ -392,11 +397,44 @@ def build_cube_settings(elemtype, elem_args):
392397
"selection": elem_args['selection'].data}
393398

394399

395-
def parse(yaml_string):
400+
def _add_parsed_scores(parsed, topic_model):
401+
""" """
402+
for score in parsed.data.get('scores', []):
403+
for elemtype, elem_args in score.items():
404+
is_artm_score = elemtype in artm.scores.__all__
405+
score_object = build_score(elemtype, elem_args, is_artm_score)
406+
if is_artm_score:
407+
topic_model._model.scores.add(score_object, overwrite=True)
408+
else:
409+
topic_model.custom_scores[elemtype] = score_object
410+
411+
412+
def _add_parsed_regularizers(
413+
parsed, model, specific_topic_names, background_topic_names, data_stats
414+
):
415+
""" """
416+
regularizers = []
417+
for stage in parsed.data['regularizers']:
418+
for elemtype, elem_args in stage.items():
419+
should_be_relative = None
420+
if "relative" in elem_args:
421+
should_be_relative = elem_args["relative"]
422+
elem_args.pop("relative")
423+
424+
regularizer_object = build_regularizer(
425+
elemtype, elem_args, specific_topic_names, background_topic_names
426+
)
427+
handle_regularizer(should_be_relative, model, regularizer_object, data_stats)
428+
regularizers.append(model.regularizers[regularizer_object.name])
429+
return regularizers
430+
431+
432+
def parse(yaml_string, force_single_thread=False):
396433
"""
397434
Parameters
398435
----------
399436
yaml_string : str
437+
force_single_thread : bool
400438
401439
Returns
402440
-------
@@ -418,39 +456,30 @@ def parse(yaml_string):
418456
revalidate_section(parsed, "scores")
419457

420458
cube_settings = []
421-
regularizers = []
422459

423460
dataset = Dataset(parsed.data["model"]["dataset_path"])
461+
modalities_to_use = parsed.data["model"]["modalities_to_use"]
462+
463+
data_stats = count_vocab_size(dataset.get_dictionary(), modalities_to_use)
424464
model = init_simple_default_model(
425465
dataset=dataset,
426-
modalities_to_use=parsed.data["model"]["modalities_to_use"],
466+
modalities_to_use=modalities_to_use,
427467
main_modality=parsed.data["model"]["main_modality"],
428468
specific_topics=parsed.data["topics"]["specific_topics"],
429469
background_topics=parsed.data["topics"]["background_topics"],
430470
)
431-
for stage in parsed.data['regularizers']:
432-
for elemtype, elem_args in stage.items():
433-
434-
regularizer_object = build_regularizer(
435-
elemtype, elem_args, specific_topic_names, background_topic_names
436-
)
437-
regularizers.append(regularizer_object)
438-
model.regularizers.add(regularizer_object, overwrite=True)
439471

472+
regularizers = _add_parsed_regularizers(
473+
parsed, model, specific_topic_names, background_topic_names, data_stats
474+
)
440475
topic_model = TopicModel(model)
441-
442-
for score in parsed.data.get('scores', []):
443-
for elemtype, elem_args in score.items():
444-
is_artm_score = elemtype in artm.scores.__all__
445-
score_object = build_score(elemtype, elem_args, is_artm_score)
446-
if is_artm_score:
447-
model.scores.add(score_object, overwrite=True)
448-
else:
449-
topic_model.custom_scores[elemtype] = score_object
476+
_add_parsed_scores(parsed, topic_model)
450477

451478
for stage in parsed['stages']:
452479
for elemtype, elem_args in stage.items():
453480
settings = build_cube_settings(elemtype.data, elem_args)
481+
if force_single_thread:
482+
settings[elemtype]["separate_thread"] = False
454483
cube_settings.append(settings)
455484

456485
return cube_settings, regularizers, topic_model, dataset
@@ -486,8 +515,9 @@ def revalidate_section(parsed, section):
486515
stage.revalidate(local_schema)
487516

488517

489-
def build_experiment_environment_from_yaml_config(yaml_string, experiment_id, save_path):
490-
settings, regs, model, dataset = parse(yaml_string)
518+
def build_experiment_environment_from_yaml_config(yaml_string, experiment_id,
519+
save_path, force_single_thread=False):
520+
settings, regs, model, dataset = parse(yaml_string, force_single_thread)
491521
# TODO: handle dynamic addition of regularizers
492522
experiment = Experiment(experiment_id=experiment_id, save_path=save_path, topic_model=model)
493523
experiment.build(settings)

topicnet/cooking_machine/cubes/controller_cube.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,6 @@ def apply(self, topic_model, one_model_parameter, dictionary=None, model_id=None
371371
handle_regularizer(
372372
self._relative,
373373
new_model,
374-
modalities,
375374
new_regularizer,
376375
self.data_stats,
377376
)

topicnet/cooking_machine/cubes/regularizer_cube.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,14 @@ def apply(self, topic_model, one_model_parameter, dictionary=None, model_id=None
144144
regularizer_type = str(type(regularizer))
145145
if isinstance(regularizer, dict):
146146
if regularizer['name'] in new_model.regularizers.data:
147-
setattr(new_model.regularizers[regularizer['name']],
148-
field_name,
149-
params)
147+
new_regularizer = deepcopy(new_model.regularizers[regularizer['name']])
148+
new_regularizer._tau = params
149+
handle_regularizer(
150+
self._relative,
151+
new_model,
152+
new_regularizer,
153+
self.data_stats,
154+
)
150155
else:
151156
error_msg = (f"Regularizer {regularizer['name']} does not exist. "
152157
f"Cannot be modified.")
@@ -157,7 +162,6 @@ def apply(self, topic_model, one_model_parameter, dictionary=None, model_id=None
157162
handle_regularizer(
158163
self._relative,
159164
new_model,
160-
modalities,
161165
new_regularizer,
162166
self.data_stats,
163167
)

topicnet/cooking_machine/models/topic_model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,3 +574,17 @@ def regularizers(self):
574574
def class_ids(self):
575575
""" """
576576
return self._model.class_ids
577+
578+
def describe_scores(self):
579+
data = []
580+
for score_name, score in self.scores.items():
581+
data.append([self.model_id, score_name, score[-1]])
582+
result = pd.DataFrame(columns=["model_id", "score_name", "last_value"], data=data)
583+
return result.set_index(["model_id", "score_name"])
584+
585+
def describe_regularizers(self):
586+
data = []
587+
for reg_name, reg in self.regularizers._data.items():
588+
data.append([self.model_id, reg_name, reg.tau, reg.gamma])
589+
result = pd.DataFrame(columns=["model_id", "regularizer_name", "tau", "gamma"], data=data)
590+
return result.set_index(["model_id", "regularizer_name"])

topicnet/cooking_machine/recipes/ARTM_baseline.yml

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,44 +24,34 @@ regularizers:
2424
topic_names: background_topics
2525
class_ids: {modality_list}
2626
tau: 0.1
27+
relative: true
2728
- SmoothSparseThetaRegularizer:
2829
name: smooth_theta_bcg
2930
topic_names: background_topics
3031
tau: 0.1
32+
relative: true
3133
scores:
3234
- BleiLaffertyScore:
33-
num_top_tokens: 15
35+
num_top_tokens: 30
3436
model:
3537
dataset_path: {dataset_path}
3638
modalities_to_use: {modality_list}
3739
main_modality: '{main_modality}'
3840

3941
stages:
40-
- RegularizersModifierCube:
41-
num_iter: 1
42-
reg_search: grid
43-
regularizer_parameters:
44-
- name: smooth_phi_bcg
45-
tau_grid: [0.1]
46-
- name: smooth_theta_bcg
47-
tau_grid: [0.1]
48-
selection:
49-
- COLLECT 1
50-
verbose: false
51-
relative_coefficients: True
5242
- RegularizersModifierCube:
5343
num_iter: 20
5444
reg_search: add
5545
regularizer_parameters:
5646
name: decorrelation_phi
5747
selection:
58-
- PerplexityScore@all < 1.01 * MINIMUM(PerplexityScore@all) and BleiLaffertyScore -> max
48+
- PerplexityScore@all < 1.05 * MINIMUM(PerplexityScore@all) and BleiLaffertyScore -> max
5949
strategy: PerplexityStrategy
6050
# parameters of this strategy are intended for revision
6151
strategy_params:
6252
start_point: 0
63-
step: 1000
64-
max_len: 10000
53+
step: 0.01
54+
max_len: 50
6555
tracked_score_function: PerplexityScore@all
6656
verbose: false
67-
relative_coefficients: false
57+
relative_coefficients: true

topicnet/cooking_machine/rel_toolbox_lite.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def modality_weight_rel2abs(tokens_data, weights, default_modality):
160160
return taus
161161

162162

163-
def handle_regularizer(use_relative_coefficients, model, modalities, regularizer, data_stats):
163+
def handle_regularizer(use_relative_coefficients, model, regularizer, data_stats):
164164
"""
165165
Handles the case of various regularizers that
166166
contain 'Regularizer' in their name, namely all artm regularizers
@@ -171,8 +171,6 @@ def handle_regularizer(use_relative_coefficients, model, modalities, regularizer
171171
indicates whether regularizer should be altered
172172
model : TopicModel or artm.ARTM
173173
to be changed in place
174-
modalities : dict
175-
modalities used in the model
176174
regularizer : an instance of Regularizer from artm library
177175
data_stats : dict
178176
collection-specific data
@@ -195,7 +193,7 @@ def handle_regularizer(use_relative_coefficients, model, modalities, regularizer
195193
regularizer = transform_regularizer(
196194
data_stats,
197195
regularizer,
198-
modalities,
196+
model.class_ids,
199197
n_topics,
200198
)
201199

0 commit comments

Comments
 (0)