Skip to content

Commit 53c1d7c

Browse files
committed
update TopicNet to v0.4.0
1 parent 22dd19b commit 53c1d7c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+46858
-490
lines changed

topicnet/.flake8

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[flake8]
2+
max-complexity = 10
3+
max-line-length = 100
4+
exclude = __init__.py
5+

topicnet/bitbucket-pipelines.yml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
2+
pipelines:
3+
pull-requests:
4+
'**':
5+
- step:
6+
name: Lint by Flake8
7+
image: python:3.6.0
8+
caches:
9+
- pip
10+
script:
11+
- pip install flake8 mccabe
12+
- flake8 --max-complexity=10 --max-line-length=100 .
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .dataset import Dataset
2+
from .dataset import BaseDataset
23
from .experiment import Experiment
34
from .model_constructor import *

topicnet/cooking_machine/config_parser.py

Lines changed: 95 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,19 +45,20 @@ def __init__(self, num_iters: int = 5)
4545
but it's a work-in-progress currently.
4646
4747
""" # noqa: W291
48-
from .models.topic_model import TopicModel
4948
from .cubes import RegularizersModifierCube, CubeCreator
5049
from .experiment import Experiment
5150
from .dataset import Dataset
52-
51+
from .models import scores as tnscores
52+
from .models import TopicModel
5353

5454
from .cubes import PerplexityStrategy, GreedyStrategy
55-
from .model_constructor import init_simple_default_model
55+
from .model_constructor import init_simple_default_model, create_default_topics
5656

5757
import artm
5858

5959
from inspect import signature, Parameter
60-
from strictyaml import Map, Str, Int, Seq, Any, Optional, Float, EmptyNone, Bool
60+
from strictyaml import Map, Str, Int, Seq, Float, Bool
61+
from strictyaml import Any, Optional, EmptyDict, EmptyNone, EmptyList
6162
from strictyaml import dirty_load
6263

6364
# TODO: use stackoverflow.com/questions/37929851/parse-numpydoc-docstring-and-access-components
@@ -80,15 +81,16 @@ def __init__(self, num_iters: int = 5)
8081
element = Any()
8182
base_schema = Map({
8283
'regularizers': Seq(element),
84+
Optional('scores'): Seq(element),
8385
'stages': Seq(element),
8486
'model': Map({
8587
"dataset_path": Str(),
8688
"modalities_to_use": Seq(Str()),
8789
"main_modality": Str()
8890
}),
8991
'topics': Map({
90-
"background_topics": Seq(Str()),
91-
"specific_topics": Seq(Str()),
92+
"background_topics": Seq(Str()) | Int() | EmptyList(),
93+
"specific_topics": Seq(Str()) | Int() | EmptyList(),
9294
})
9395
})
9496
SUPPORTED_CUBES = [CubeCreator, RegularizersModifierCube]
@@ -150,6 +152,44 @@ def build_schema_from_signature(class_of_object, use_optional=True):
150152
if param.name != 'self'}
151153

152154

155+
def wrap_in_map(dictionary):
156+
could_be_empty = all(isinstance(key, Optional) for key in dictionary)
157+
if could_be_empty:
158+
return Map(dictionary) | EmptyDict()
159+
return Map(dictionary)
160+
161+
162+
def build_schema_for_scores():
163+
"""
164+
Returns
165+
-------
166+
strictyaml.Map
167+
schema used for validation and type-coercion
168+
"""
169+
schemas = {}
170+
for elem in artm.scores.__all__:
171+
if "Score" in elem:
172+
class_of_object = getattr(artm.scores, elem)
173+
# TODO: check if every key is Optional. If it is, then "| EmptyDict()"
174+
# otherwise, just Map()
175+
res = wrap_in_map(build_schema_from_signature(class_of_object))
176+
177+
specific_schema = Map({class_of_object.__name__: res})
178+
schemas[class_of_object.__name__] = specific_schema
179+
180+
for elem in tnscores.__all__:
181+
if "Score" in elem:
182+
class_of_object = getattr(tnscores, elem)
183+
res = build_schema_from_signature(class_of_object)
184+
# res["name"] = Str() # TODO: support custom names
185+
res = wrap_in_map(res)
186+
187+
specific_schema = Map({class_of_object.__name__: res})
188+
schemas[class_of_object.__name__] = specific_schema
189+
190+
return schemas
191+
192+
153193
def build_schema_for_regs():
154194
"""
155195
Returns
@@ -161,7 +201,7 @@ def build_schema_for_regs():
161201
for elem in artm.regularizers.__all__:
162202
if "Regularizer" in elem:
163203
class_of_object = getattr(artm.regularizers, elem)
164-
res = Map(build_schema_from_signature(class_of_object))
204+
res = wrap_in_map(build_schema_from_signature(class_of_object))
165205

166206
specific_schema = Map({class_of_object.__name__: res})
167207
schemas[class_of_object.__name__] = specific_schema
@@ -280,7 +320,28 @@ def handle_special_cases(elem_args, kwargs):
280320
kwargs["strategy"] = strategy # or None if failed to identify it
281321

282322

283-
def build_regularizer(elemtype, elem_args, parsed):
323+
def build_score(elemtype, elem_args, is_artm_score):
324+
"""
325+
Parameters
326+
----------
327+
elemtype : str
328+
name of score
329+
elem_args: dict
330+
is_artm_score: bool
331+
332+
Returns
333+
-------
334+
instance of artm.Score or topicnet.BaseScore
335+
"""
336+
module = artm.scores if is_artm_score else tnscores
337+
class_of_object = getattr(module, elemtype)
338+
kwargs = {name: value
339+
for name, value in elem_args.items()}
340+
341+
return class_of_object(**kwargs)
342+
343+
344+
def build_regularizer(elemtype, elem_args, specific_topic_names, background_topic_names):
284345
"""
285346
Parameters
286347
----------
@@ -299,9 +360,9 @@ def build_regularizer(elemtype, elem_args, parsed):
299360
# special case: shortcut for topic_names
300361
if "topic_names" in kwargs:
301362
if kwargs["topic_names"] == "background_topics":
302-
kwargs["topic_names"] = parsed.data["topics"]["background_topics"]
363+
kwargs["topic_names"] = background_topic_names
303364
if kwargs["topic_names"] == "specific_topics":
304-
kwargs["topic_names"] = parsed.data["topics"]["specific_topics"]
365+
kwargs["topic_names"] = specific_topic_names
305366

306367
return class_of_object(**kwargs)
307368

@@ -345,8 +406,16 @@ def parse(yaml_string):
345406
dataset: Dataset
346407
"""
347408
parsed = dirty_load(yaml_string, base_schema, allow_flow_style=True)
409+
410+
specific_topic_names, background_topic_names = create_default_topics(
411+
parsed.data["topics"]["specific_topics"],
412+
parsed.data["topics"]["background_topics"]
413+
)
414+
348415
revalidate_section(parsed, "stages")
349416
revalidate_section(parsed, "regularizers")
417+
if "scores" in parsed:
418+
revalidate_section(parsed, "scores")
350419

351420
cube_settings = []
352421
regularizers = []
@@ -362,12 +431,23 @@ def parse(yaml_string):
362431
for stage in parsed.data['regularizers']:
363432
for elemtype, elem_args in stage.items():
364433

365-
regularizer_object = build_regularizer(elemtype, elem_args, parsed)
366-
434+
regularizer_object = build_regularizer(
435+
elemtype, elem_args, specific_topic_names, background_topic_names
436+
)
367437
regularizers.append(regularizer_object)
368438
model.regularizers.add(regularizer_object, overwrite=True)
369439

370440
topic_model = TopicModel(model)
441+
442+
for score in parsed.data.get('scores', []):
443+
for elemtype, elem_args in score.items():
444+
is_artm_score = elemtype in artm.scores.__all__
445+
score_object = build_score(elemtype, elem_args, is_artm_score)
446+
if is_artm_score:
447+
model.scores.add(score_object, overwrite=True)
448+
else:
449+
topic_model.custom_scores[elemtype] = score_object
450+
371451
for stage in parsed['stages']:
372452
for elemtype, elem_args in stage.items():
373453
settings = build_cube_settings(elemtype.data, elem_args)
@@ -390,6 +470,8 @@ def revalidate_section(parsed, section):
390470
schemas = build_schema_for_cubes()
391471
elif section == "regularizers":
392472
schemas = build_schema_for_regs()
473+
elif section == "scores":
474+
schemas = build_schema_for_scores()
393475
else:
394476
raise ValueError(f"Unknown section name '{section}'")
395477

@@ -398,7 +480,7 @@ def revalidate_section(parsed, section):
398480
name = list(stage.data)[0]
399481

400482
if name not in schemas:
401-
raise ValueError(f"Unsupported stage ID: {name} at line {stage.start_line}")
483+
raise ValueError(f"Unsupported {section} value: {name} at line {stage.start_line}")
402484
local_schema = schemas[name]
403485

404486
stage.revalidate(local_schema)

topicnet/cooking_machine/cubes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .base_cube import BaseCube, retrieve_score_for_strategy
22
from .regularizer_cube import RegularizersModifierCube
3+
from .controller_cube import RegularizationControllerCube
34
from .cube_creator import CubeCreator
45
from .perplexity_strategy import PerplexityStrategy
56
from .greedy_strategy import GreedyStrategy

0 commit comments

Comments
 (0)