Skip to content

Commit bfa57c0

Browse files
committed
Added in hyperparams to components
1 parent 918ced6 commit bfa57c0

File tree

3 files changed

+132
-22
lines changed

3 files changed

+132
-22
lines changed

ngcsimlib/commands/seed.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from ngcsimlib.commands.command import Command
2+
from ngcsimlib.utils import extract_args
3+
4+
5+
class Seed(Command):
6+
"""
7+
In many models there is a need to seed the randomness of a model. While many
8+
components will take seeds in at construction these are not always serializable
9+
or there might be a need to reseed the model after initialization. To solve
10+
this problem ngcsimlib offers the Seed command. This command will simply go
11+
through all the provided components and call see with the specified value.
12+
"""
13+
14+
def __init__(self, components=None, seed_name=None,
15+
command_name=None, **kwargs):
16+
"""
17+
Required calls on Components: ['seed', 'name']
18+
19+
Args:
20+
components: a list of components to call seed on
21+
22+
seed_name: a keyword to bind the input for this command do
23+
24+
command_name: the name of the command on the controller
25+
26+
"""
27+
super().__init__(components=components, command_name=command_name,
28+
required_calls=['seed'])
29+
if seed_name is None:
30+
raise RuntimeError(
31+
self.name + " requires a \'seed_name\' to bind to for construction")
32+
33+
self.seed_name = seed_name
34+
35+
def __call__(self, *args, **kwargs):
36+
try:
37+
vals = extract_args([self.seed_name], *args, **kwargs)
38+
except RuntimeError:
39+
raise RuntimeError(self.name + ", " + str(
40+
self.seed_name) + " is missing from keyword arguments or a positional "
41+
"arguments can be provided")
42+
43+
for component in self.components:
44+
self.components[component].seed(vals[self.seed_name])
45+

ngcsimlib/controller.py

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,21 @@ def make_components(self, path_to_components_file, custom_file_dir=None):
115115
used. (Default: None)
116116
"""
117117
with open(path_to_components_file, 'r') as file:
118-
components = json.load(file)
118+
componentsConfig = json.load(file)
119+
120+
parameterMap = {}
121+
components = componentsConfig["components"]
122+
if "hyperparameters" in componentsConfig.keys():
123+
for component in components:
124+
for pKey, pValue in componentsConfig["hyperparameters"].items():
125+
for cKey, cValue in component.items():
126+
if pKey == cValue:
127+
component[cKey] = pValue
128+
parameterMap[cKey] = pKey
129+
119130
for component in components:
120-
self.add_component(**component, directory=custom_file_dir)
131+
self.add_component(**component, directory=custom_file_dir,
132+
parameterMap=parameterMap)
121133

122134
def make_steps(self, path_to_steps_file):
123135
"""
@@ -197,13 +209,11 @@ def add_component(self, component_type, match_case=False, absolute_path=False, *
197209
raise RuntimeError("Given component type " + str(component_type)
198210
+ " is not callable")
199211

200-
201-
202-
count = call.__code__.co_argcount - 1
203-
named_args = call.__code__.co_varnames[1:count]
204212
try:
205213
component = Component_class(**kwargs)
206214
except TypeError as E:
215+
count = call.__code__.co_argcount - 1
216+
named_args = call.__code__.co_varnames[1:count]
207217
print(E)
208218
raise RuntimeError(str(E) + "\nProvided keyword arguments:\t" + str(list(kwargs.keys())) +
209219
"\nRequired keyword arguments:\t" + str(list(named_args)))
@@ -217,6 +227,9 @@ def add_component(self, component_type, match_case=False, absolute_path=False, *
217227
del obj[key]
218228
print("Failed to serialize \"" + str(key) + "\" in " + component.name)
219229

230+
if "directory" in obj.keys():
231+
del obj["directory"]
232+
220233
self._json_objects['components'].append(obj)
221234

222235
return component
@@ -347,7 +360,45 @@ def save_to_json(self, directory, model_name=None, custom_save=True):
347360
json.dump(self._json_objects['steps'], fp, indent=4)
348361

349362
with open(path + "/components.json", 'w') as fp:
350-
json.dump(self._json_objects['components'], fp, indent=4)
363+
hyperparameters = {}
364+
365+
for idx, component in enumerate(self._json_objects['components']):
366+
if component.get('parameterMap', None) is not None:
367+
for cKey, pKey in component['parameterMap'].items():
368+
pVal = component[cKey]
369+
if pKey not in hyperparameters.keys():
370+
hyperparameters[pKey] = []
371+
hyperparameters[pKey].append((idx, cKey, pVal))
372+
373+
hp = {}
374+
for param in hyperparameters.keys():
375+
matched = True
376+
hp[param] = None
377+
for _, _, pVal in hyperparameters[param]:
378+
if hp[param] is None:
379+
hp[param] = pVal
380+
elif hp[param] != pVal:
381+
del hp[param]
382+
matched = False
383+
break
384+
385+
for idx, cKey, _ in hyperparameters[param]:
386+
if matched:
387+
self._json_objects['components'][idx][cKey] = param
388+
389+
else:
390+
warnings.warn("Unable to extract hyperparameter " + str(param) +
391+
" as it is mismatched between components. Parameter will not be extracted")
392+
393+
for component in self._json_objects['components']:
394+
if "parameterMap" in component.keys():
395+
del component["parameterMap"]
396+
397+
obj = {"components": self._json_objects['components']}
398+
if len(hp.keys()) != 0:
399+
obj["hyperparameters"] = hp
400+
401+
json.dump(obj, fp, indent=4)
351402

352403
with open(path + "/connections.json", 'w') as fp:
353404
json.dump(self._json_objects['connections'], fp, indent=4)
Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,32 @@
11
{
2-
title: "Components",
3-
description: "A list of all components",
4-
type: "array",
5-
items: {
6-
description: "The two required values for creating a components, and then all other keyword arguments needed for to build the components",
7-
type: "object",
8-
required: ["component_type", "name"],
9-
properties: {
10-
component_type: {
11-
type: "string"
12-
},
13-
name: {
14-
type: "string"
15-
},
2+
"$schema": "https://json-schema.org/draft/2020-12/schema",
3+
"title": "Components",
4+
"description": "A collection of components for the model, and the hyperparameters needed to build them",
5+
"type": "object",
6+
"properties": {
7+
"hyperparameters": {
8+
"description": "A mapping of parameter keys to values. To use these simply put the key of the value parameter in for the value and it will automatically be picked up",
9+
"type": "object"
10+
},
11+
"components": {
12+
"description": "A list of all components",
13+
"type": "array",
14+
"items": {
15+
"description": "The two required values for creating a components, and then all other keyword arguments needed for to build the components",
16+
"type": "object",
17+
"required": [
18+
"component_type",
19+
"name"
20+
],
21+
"properties": {
22+
"component_type": {
23+
"type": "string"
24+
},
25+
"name": {
26+
"type": "string"
27+
},
28+
}
29+
}
1630
}
1731
}
18-
}
32+
}

0 commit comments

Comments
 (0)