Skip to content

Commit 867d89e

Browse files
committed
Cleaned up some files
1 parent 1a75606 commit 867d89e

File tree

7 files changed

+63
-52
lines changed

7 files changed

+63
-52
lines changed

ngcsimlib/_src/compartment/compartment.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
from ngcsimlib._src.global_state.manager import global_state_manager as gState
33
from ngcsimlib._src.logger import warn
44
import ast
5-
from typing import TypeVar, Union
5+
from typing import TypeVar, Union, Set
66
from ngcsimlib._src.operations.BaseOp import BaseOp
77
from ngcsimlib._src.context.context_manager import global_context_manager as gcm
88

99
T = TypeVar('T')
1010

11+
1112
class Compartment(metaclass=CompartmentMeta):
1213
"""
1314
Compartments exist as a layer between the global state of models and the
@@ -19,17 +20,29 @@ class Compartment(metaclass=CompartmentMeta):
1920
nothing will access it unless the user manually goes looking for it). The
2021
compartment should be reflected in the global state immediately after it is
2122
initialized. Compartments can be flagged as fixed which means that they
22-
exist in the global state and can be used in compiled methods but they can
23+
exist in the global state and can be used in compiled methods, but they can
2324
not be changed after creation.
2425
2526
Args
2627
initial_value: the initial value to set in the global state
2728
2829
fixed (default=False): sets the flag for if this compartment is fixed.
30+
31+
display_name (default=None): sets the display name of the compartment
32+
33+
units (default=None): sets the units of the compartment
34+
35+
plot_method (default=None): sets the plot method of the compartment,
36+
this method is to be used by the processes when monitoring this
37+
compartment to integrate with the plotting system.
2938
"""
30-
def __init__(self, initial_value: T, fixed: bool = False,
31-
display_name=None, units=None, plot_method=None):
32-
self._initial_value = initial_value
39+
def __init__(self, initial_value: T,
40+
fixed: bool = False,
41+
display_name: str | None = None,
42+
units: str | None = None,
43+
plot_method: str | None = None):
44+
45+
self._initial_value: T = initial_value
3346

3447
self.name = None
3548
self._root_target = None
@@ -41,7 +54,7 @@ def __init__(self, initial_value: T, fixed: bool = False,
4154
self.plot_method = plot_method
4255

4356
@property
44-
def root(self):
57+
def root(self) -> str | None:
4558
return self._root_target
4659

4760
@property
@@ -91,7 +104,11 @@ def get(self) -> T:
91104
"""
92105
return self._get_value()
93106

94-
def get_needed_keys(self):
107+
def get_needed_keys(self) -> Set[str]:
108+
"""
109+
Returns: Returns a set of compartment paths that are needed to compute
110+
the value of this compartment
111+
"""
95112
if isinstance(self.target, BaseOp):
96113
return self.target.get_needed_keys()
97114
return set(self.target)
@@ -130,7 +147,7 @@ def __rshift__(self, other):
130147
other.__rrshift__(self)
131148

132149
@property
133-
def target(self):
150+
def target(self) -> Union["BaseOp", str]:
134151
"""
135152
Returns: the current target of the compartment
136153
"""
@@ -156,4 +173,4 @@ def target(self, value: Union["Compartment", "BaseOp", str]):
156173
if isinstance(value, str):
157174
self._target = value
158175

159-
raise ValueError("Invalid compartment target ", value)
176+
raise ValueError("Invalid compartment target ", value)

ngcsimlib/_src/context/context.py

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class Context(object):
4040
with block. This means that in order to use any of the compiled methods or
4141
processes defined the with block must first be left.
4242
"""
43+
4344
def __new__(cls, name: str, *args, **kwargs):
4445
targetPath = gcm.append_path(addition=name)
4546
if gcm.exists(targetPath):
@@ -59,8 +60,7 @@ def __init__(self, name: str):
5960

6061
self.name = name
6162
self.objects = {}
62-
self._connections = {}
63-
63+
self._connections: Dict[str: Union["Compartment", "BaseOp"]] = {}
6464

6565
def __enter__(self):
6666
self.__previous_path = gcm.current_path
@@ -72,7 +72,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
7272
gcm.step_to(self.__previous_path)
7373
self.__previous_path = None
7474

75-
def recompile(self):
75+
def recompile(self) -> None:
7676
"""
7777
Recompiles all the context aware objects inside the context based on
7878
their priority. The higher the priority is, the sooner it will happen.
@@ -86,38 +86,39 @@ def recompile(self):
8686
"""
8787
priorities = {}
8888

89-
9089
for objectType in self.objects.keys():
9190
_objs = self.get_objects_by_type(objectType)
9291
for objName, obj in _objs.items():
9392
if getattr(obj, "_is_compilable", False):
94-
p = getattr(obj, "_priority", None)
95-
p = 0 if p is None else p
93+
p = getattr(obj, "_priority", None) or 0
9694

9795
if p not in priorities:
9896
priorities[p] = []
9997

10098
priorities[p].append(obj)
10199

102-
103100
keys = sorted(priorities.keys(), reverse=True)
104101
for key in keys:
105102
for obj in priorities[key]:
106103
obj.compile()
107104

108-
def registerObj(self, obj: "ContextAwareObjectMeta"):
105+
def registerObj(self, obj: "ContextAwareObjectMeta") -> bool:
109106
"""
110107
Registers an object in the context. The context automatically sorts the
111-
objects by type through the "_type" field set on the object/class. It
112-
expects to be given a "ContextObjectType" but it is not a requirement.
108+
objects by type through the "_type" field set on the object/class.
109+
Standard practice is to use the predefined decorators or superclasses
110+
found in this library to set this field, but it is not a requirement.
113111
If an unknown type is provided to the context it will still sort it
114112
into a bin with other objects of the same type. (Note: _type can be
115113
either a string or ContextObjectTypes.TYPE, both will be grouped
116114
together)
117115
118116
Args:
119-
obj: The object to register, requires the "_type" field to be
120-
defined and not null
117+
obj: The object to register in the context
118+
119+
Returns:
120+
boolean: marks if the object was successfully registered in the
121+
context
121122
"""
122123
_type = getattr(obj, "_type", None)
123124
if _type is None:
@@ -126,10 +127,12 @@ def registerObj(self, obj: "ContextAwareObjectMeta"):
126127
f"object will be limited. Please use one of the provided "
127128
f"context object types or define your own to ensure "
128129
f"compatability")
129-
return
130+
return False
130131

131-
if not isinstance(_type, ContextObjectTypes) and not (
132-
isinstance(_type, str) and _type in ContextObjectTypes.__members__):
132+
if (not isinstance(_type, ContextObjectTypes) and
133+
_type not in self.objects.keys() and
134+
not (isinstance(_type, str) and _type in ContextObjectTypes.__members__)
135+
):
133136
warn(
134137
f"Context object type {_type} is not known to this context. It will "
135138
f"be stored and tracked but some functionality will be "
@@ -146,9 +149,10 @@ def registerObj(self, obj: "ContextAwareObjectMeta"):
146149
warn(f"Trying to register context object with the same name "
147150
f"({obj.name}) as another object in this context. Aborting "
148151
f"registration!")
149-
return
152+
return False
150153

151154
self.objects[_type][obj.name] = obj
155+
return True
152156

153157
def get_objects_by_type(self, objectType: ContextObjectTypes | str) -> Dict[
154158
str, "ContextAwareObjectMeta"]:
@@ -213,20 +217,18 @@ def get_objects(self, *object_names: str,
213217
return _objs[0]
214218
return _objs
215219

216-
217220
def get_components(self, *component_names: str, unwrap: bool = True) -> \
218221
Union[None, "ContextAwareObjectMeta", List[Union[
219222
"ContextAwareObjectMeta", None]]]:
220223
return self.get_objects(*component_names,
221224
objectType=ContextObjectTypes.component,
222225
unwrap=unwrap)
223226

224-
def add_connection(self, source, destination):
227+
def add_connection(self, source: Union["Compartment", "BaseOp"], destination: "Compartment"):
225228
self._connections[destination.root] = source
226229

227-
228230
def save_to_json(self, directory: str, model_name: Union[str, None] = None,
229-
custom_save: bool = True, overwrite: bool = False):
231+
custom_save: bool = True, overwrite: bool = False) -> None:
230232
"""
231233
Saves the context to a collection fo JSON files.
232234
@@ -256,7 +258,6 @@ def save_to_json(self, directory: str, model_name: Union[str, None] = None,
256258
print('Failed to delete %s. Reason: %s' % (file_path, e))
257259
shutil.rmtree(directory + "/" + model_name)
258260

259-
260261
path = make_unique_path(directory, model_name)
261262

262263
contextMeta = {"types": list(self.objects.keys()),
@@ -276,7 +277,8 @@ def save_to_json(self, directory: str, model_name: Union[str, None] = None,
276277

277278
for obj_name, obj in _objs.items():
278279
objData = {}
279-
if hasattr(obj, "to_json") and callable(getattr(obj, "to_json")):
280+
if hasattr(obj, "to_json") and callable(
281+
getattr(obj, "to_json")):
280282
objData.update(obj.to_json())
281283

282284
objData["modulePath"] = modManager.resolve_public_import(obj)
@@ -303,10 +305,8 @@ def save_to_json(self, directory: str, model_name: Union[str, None] = None,
303305
with open(f"{path}/connections.json", "w") as fp:
304306
json.dump(connections, fp, indent=4)
305307

306-
307-
308308
@staticmethod
309-
def load(directory: str, module_name: str):
309+
def load(directory: str, module_name: str) -> "Context":
310310
if gcm.exists(gcm.append_path(module_name)):
311311
warn("Trying to load a context that already exists, returning "
312312
"existing context")
@@ -330,17 +330,20 @@ def load(directory: str, module_name: str):
330330
kwargs = objData["kwargs"]
331331
newObj = objKlass(*args, **kwargs)
332332

333-
delayed_load.append((getattr(newObj, "_priority", 0), newObj, objData, type_path))
333+
delayed_load.append((
334+
getattr(newObj, "_priority", 0), newObj,
335+
objData, type_path))
334336

335-
delayed_load = sorted(delayed_load, key=lambda x: x[0], reverse=True)
337+
delayed_load = sorted(delayed_load, key=lambda x: x[0],
338+
reverse=True)
336339
for _, obj, data, type_path in delayed_load:
337-
if hasattr(obj, "from_json") and callable(getattr(obj, "from_json")):
340+
if hasattr(obj, "from_json") and callable(
341+
getattr(obj, "from_json")):
338342
obj.from_json(objData)
339343

340344
if hasattr(obj, "load") and callable(getattr(obj, "load")):
341345
obj.load(f"{type_path}/custom")
342346

343-
344347
with open(f"{path}/connections.json", "r") as fp:
345348
connectionData = json.load(fp)
346349
for connectionRoot, target in connectionData.items():
@@ -350,8 +353,4 @@ def load(directory: str, module_name: str):
350353
else:
351354
dest.target = BaseOp.load_op(target)
352355

353-
354356
return ctx
355-
356-
357-

ngcsimlib/_src/context/contextAwareObject.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def to_json(self) -> Dict[str, Any]:
4444
"kwargs": safe_kwargs}
4545
return data
4646

47-
def compile(self):
47+
def compile(self) -> None:
4848
"""
4949
A wrapper to compile this object, unless a custom compiler is being used
5050
do not modify this method.

ngcsimlib/_src/context/contextAwareObjectMeta.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .context_manager import global_context_manager as gcm
55
from collections.abc import Iterable
66

7+
78
def extract_name(cls, args, kwargs):
89
init = cls.__init__
910
sig = inspect.signature(init)
@@ -14,6 +15,7 @@ def extract_name(cls, args, kwargs):
1415
return bound.arguments["name"]
1516
return None
1617

18+
1719
class ContextAwareObjectMeta(type):
1820
def __new__(cls, name, bases, attrs):
1921
if '__enter__' not in attrs:

ngcsimlib/_src/context/contextObjectDecorators.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,6 @@ def process(cls):
1515
cls._type = ContextObjectTypes.process
1616
return cls
1717

18-
# @staticmethod
19-
# def _operation(cls):
20-
# cls._type = ContextObjectTypes.operation
21-
# return cls
22-
23-
2418

2519
component = ContextObjectDecorators.component
2620
process = ContextObjectDecorators.process
27-
# operation = ContextObjectDecorators._operation

ngcsimlib/_src/parser/contextTransformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def visit_Call(self, node):
106106
subAttr = getattr(attr, node.func.attr)
107107
if not hasattr(subAttr, "compiled"):
108108
error("Attempting to use a method of a subcomponent that is not compiled/compilable")
109-
109+
110110
method_id = f"{attr.context_path.replace(':', '_')}_{node.func.attr}"
111111
subAst = subAttr.compiled.ast
112112
subAst.body[0].body = subAst.body[0].body[:-1]

ngcsimlib/_src/parser/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class CompiledMethod:
2020
def __init__(self, fn, fn_ast, auxiliary_ast, namespace, extra_globals):
2121
self._fn = fn
2222
self._fn_ast = fn_ast
23-
self._auxiliary_ast = auxiliary_ast if auxiliary_ast is not None else {}
23+
self._auxiliary_ast = auxiliary_ast or {}
2424
self._namespace = namespace
2525
self._extra_globals = extra_globals
2626

0 commit comments

Comments
 (0)