Skip to content

Commit b4db598

Browse files
committed
Refactor modello helpers and extend tests
1 parent 0fa21c8 commit b4db598

File tree

2 files changed

+149
-58
lines changed

2 files changed

+149
-58
lines changed

modello.py

Lines changed: 80 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class BoundInstanceDummy(InstanceDummy):
2929

3030

3131
class ModelloMetaNamespace(dict):
32-
"""This is so that Modello class definitions implicitly define symbols."""
32+
"""Namespace used when building :class:`Modello` subclasses."""
3333

3434
def __init__(self, name: str, bases: typing.Tuple[type, ...]) -> None:
3535
"""Create a namespace for a Modello class to use."""
@@ -98,6 +98,13 @@ def __setitem__(self, key: str, value: object) -> None:
9898
proxy_dummy = InstanceDummy(f"{key}_{class_dummy.name}", **class_dummy.assumptions0)
9999
setattr(proxy, attr_name, proxy_dummy)
100100
dummy_map[class_dummy] = proxy_dummy
101+
# include proxies for child models so expressions can be substituted
102+
for _nested_attr, (_, child_map) in model_cls._modello_nested_models.items():
103+
for child_dummy, child_proxy in child_map.items():
104+
proxy_dummy = InstanceDummy(
105+
f"{key}_{child_proxy.name}", **child_proxy.assumptions0
106+
)
107+
dummy_map[child_proxy] = proxy_dummy
101108
for attr_name, expr in model_cls._modello_namespace.attrs.items():
102109
if attr_name not in model_cls._modello_namespace.dummies:
103110
setattr(proxy, attr_name, expr.subs(dummy_map))
@@ -148,44 +155,56 @@ class Modello(ModelloSentinelClass, metaclass=ModelloMeta):
148155
# ------------------------------------------------------------------
149156
# Private helpers used by ``__init__``
150157
# ------------------------------------------------------------------
151-
def _parse_nested_values(self, value_map: typing.Dict[str, Basic]) -> typing.Dict[str, typing.Union["Modello", typing.Dict[str, Basic]]]:
152-
"""Extract nested model data from ``value_map``.
158+
@classmethod
159+
def _parse_nested_values(
160+
cls, value_map: typing.Mapping[str, Basic]
161+
) -> tuple[
162+
typing.Dict[str, typing.Union["Modello", typing.Dict[str, Basic]]],
163+
typing.Dict[str, Basic],
164+
]:
165+
"""Split ``value_map`` into nested and local values.
153166
154-
Unknown or ``None`` values are replaced with an empty mapping.
155-
The extracted values are removed from ``value_map``.
167+
Attributes matching ``cls._modello_nested_models`` are returned in the
168+
``nested`` dictionary. Unknown or ``None`` values produce an empty mapping
169+
entry. The returned ``values`` mapping contains only attributes local to
170+
this model.
156171
"""
157172

158-
nested_values: typing.Dict[str, typing.Union["Modello", typing.Dict[str, Basic]]] = {}
159-
for attr in self._modello_nested_models:
160-
raw = value_map.pop(attr, {})
161-
if isinstance(raw, Modello):
162-
nested_values[attr] = raw
163-
elif isinstance(raw, dict):
164-
nested_values[attr] = raw
173+
nested: dict[str, typing.Union["Modello", typing.Dict[str, Basic]]] = {}
174+
remaining: dict[str, Basic] = {}
175+
for key, val in value_map.items():
176+
if key in cls._modello_nested_models:
177+
if isinstance(val, Modello):
178+
nested[key] = val
179+
elif isinstance(val, dict):
180+
nested[key] = val
181+
else:
182+
nested[key] = {}
165183
else:
166-
nested_values[attr] = {}
167-
return nested_values
184+
remaining[key] = val
185+
return nested, remaining
168186

187+
@classmethod
169188
def _create_instance_dummies(
170-
self, name: str
189+
cls, name: str
171190
) -> tuple[
172191
typing.Dict[InstanceDummy, BoundInstanceDummy],
173192
typing.Dict[str, typing.Dict[InstanceDummy, BoundInstanceDummy]],
174193
]:
175194
"""Bind all ``InstanceDummy`` objects to this instance.
176195
177-
Returns a tuple of ``instance_dummies`` for this model and a mapping for
178-
each nested model that relates the child's class dummies to their bound
179-
counterparts.
196+
Returns a tuple containing a mapping of this model's dummies to their
197+
bound counterparts and a mapping for each nested model that relates the
198+
child's class dummies to their bound dummies.
180199
"""
181200

182201
instance_dummies: dict[InstanceDummy, BoundInstanceDummy] = {
183202
class_dummy: class_dummy.bound(name)
184-
for class_dummy in self._modello_namespace.dummies.values()
203+
for class_dummy in cls._modello_namespace.dummies.values()
185204
}
186205

187206
nested_dummy_map: dict[str, dict[InstanceDummy, BoundInstanceDummy]] = {}
188-
for attr, (_, mapping) in self._modello_nested_models.items():
207+
for attr, (_, mapping) in cls._modello_nested_models.items():
189208
dummy_map: dict[InstanceDummy, BoundInstanceDummy] = {}
190209
for class_dummy, proxy_dummy in mapping.items():
191210
bound = proxy_dummy.bound(name)
@@ -195,25 +214,24 @@ def _create_instance_dummies(
195214

196215
return instance_dummies, nested_dummy_map
197216

217+
@classmethod
198218
def _collect_instance_constraints(
199-
self,
200-
value_map: typing.Dict[str, Basic],
201-
nested_values: typing.Dict[str, typing.Union["Modello", typing.Dict[str, Basic]]],
219+
cls,
220+
value_map: typing.Mapping[str, Basic],
221+
nested_values: typing.Mapping[str, typing.Union["Modello", typing.Dict[str, Basic]]],
202222
instance_dummies: typing.Dict[InstanceDummy, BoundInstanceDummy],
203-
nested_dummy_map: typing.Dict[str, typing.Dict[InstanceDummy, BoundInstanceDummy]],
204223
) -> typing.Dict[BoundInstanceDummy, Basic]:
205224
"""Convert provided values into instance constraints."""
206225

207226
constraints: dict[BoundInstanceDummy, Basic] = {}
208227

209228
for attr, value in value_map.items():
210229
simplified = simplify(value).subs(instance_dummies)
211-
value_map[attr] = simplified
212-
class_dummy = getattr(self, attr)
230+
class_dummy = getattr(cls, attr)
213231
constraints[instance_dummies[class_dummy]] = simplified
214232

215233
for attr, data in nested_values.items():
216-
model_cls, mapping = self._modello_nested_models[attr]
234+
model_cls, mapping = cls._modello_nested_models[attr]
217235
if isinstance(data, Modello):
218236
for child_attr, class_dummy in model_cls._modello_namespace.dummies.items():
219237
proxy_dummy = mapping[class_dummy]
@@ -227,28 +245,28 @@ def _collect_instance_constraints(
227245

228246
return constraints
229247

248+
@classmethod
230249
def _build_constraints(
231-
self,
250+
cls,
232251
instance_dummies: typing.Dict[InstanceDummy, BoundInstanceDummy],
233252
instance_constraints: typing.Dict[BoundInstanceDummy, Basic],
234-
nested_dummy_map: typing.Dict[str, typing.Dict[InstanceDummy, BoundInstanceDummy]],
235253
) -> list[Eq]:
236254
"""Compile a list of equations representing this instance."""
237255

238256
constraints = [
239257
Eq(instance_dummies[class_dummy], value.subs(instance_dummies))
240-
for class_dummy, value in self._modello_class_constraints.items()
258+
for class_dummy, value in cls._modello_class_constraints.items()
241259
]
242-
243-
for attr, (model_cls, mapping) in self._modello_nested_models.items():
260+
for attr, (model_cls, mapping) in cls._modello_nested_models.items():
244261
for class_dummy, expr in model_cls._modello_class_constraints.items():
245262
proxy_dummy = mapping[class_dummy]
246263
constraints.append(Eq(instance_dummies[proxy_dummy], expr.subs(mapping).subs(instance_dummies)))
247264

248265
constraints.extend(Eq(d, v) for d, v in instance_constraints.items())
249266
return constraints
250267

251-
def _solve(self, constraints: list[Eq]) -> dict[BoundInstanceDummy, Basic]:
268+
@staticmethod
269+
def _solve(constraints: list[Eq]) -> dict[BoundInstanceDummy, Basic]:
252270
"""Solve ``constraints`` and return a mapping of dummies to values."""
253271

254272
if not constraints:
@@ -259,37 +277,41 @@ def _solve(self, constraints: list[Eq]) -> dict[BoundInstanceDummy, Basic]:
259277
raise ValueError(f"{len(solutions)} solutions")
260278
return solutions[0]
261279

280+
@classmethod
262281
def _assign_local_values(
263-
self,
282+
cls,
264283
solution: dict[BoundInstanceDummy, Basic],
265284
instance_dummies: typing.Dict[InstanceDummy, BoundInstanceDummy],
266285
instance_constraints: typing.Dict[BoundInstanceDummy, Basic],
267-
) -> None:
268-
"""Set resolved values on this instance's own attributes."""
286+
) -> typing.Dict[str, Basic]:
287+
"""Return resolved values for this model's own attributes."""
269288

270-
for attr, class_dummy in self._modello_namespace.dummies.items():
289+
values: dict[str, Basic] = {}
290+
for attr, class_dummy in cls._modello_namespace.dummies.items():
271291
instance_dummy = instance_dummies[class_dummy]
272292
if instance_dummy in solution:
273293
value = solution[instance_dummy]
274294
elif instance_dummy in instance_constraints:
275295
value = instance_constraints[instance_dummy]
276-
elif class_dummy in self._modello_class_constraints:
277-
value = self._modello_class_constraints[class_dummy].subs(instance_dummies)
296+
elif class_dummy in cls._modello_class_constraints:
297+
value = cls._modello_class_constraints[class_dummy].subs(instance_dummies)
278298
else:
279299
value = instance_dummy
280-
setattr(self, attr, value)
300+
values[attr] = value
301+
return values
281302

303+
@classmethod
282304
def _instantiate_nested_models(
283-
self,
305+
cls,
284306
name: str,
285307
solution: dict[BoundInstanceDummy, Basic],
286308
instance_dummies: typing.Dict[InstanceDummy, BoundInstanceDummy],
287309
instance_constraints: typing.Dict[BoundInstanceDummy, Basic],
288-
nested_dummy_map: typing.Dict[str, typing.Dict[InstanceDummy, BoundInstanceDummy]],
289-
) -> None:
290-
"""Create nested model instances using solved values."""
310+
) -> typing.Dict[str, "Modello"]:
311+
"""Instantiate nested models and return them."""
291312

292-
for attr, (model_cls, mapping) in self._modello_nested_models.items():
313+
nested_instances: dict[str, Modello] = {}
314+
for attr, (model_cls, mapping) in cls._modello_nested_models.items():
293315
value_kwargs: dict[str, Basic] = {}
294316
for child_attr, class_dummy in model_cls._modello_namespace.dummies.items():
295317
proxy_dummy = mapping[class_dummy]
@@ -303,30 +325,35 @@ def _instantiate_nested_models(
303325
else:
304326
val = inst_dummy
305327
value_kwargs[child_attr] = val
306-
nested_instance = model_cls(f"{name}_{attr}", **value_kwargs)
307-
setattr(self, attr, nested_instance)
328+
nested_instances[attr] = model_cls(f"{name}_{attr}", **value_kwargs)
329+
return nested_instances
308330

309331
# ------------------------------------------------------------------
310332
# ``__init__`` orchestrates the above helpers
311333
# ------------------------------------------------------------------
312334
def __init__(self, name: str, **value_map: Basic) -> None:
313335
"""Initialise a model instance and solve for all attributes."""
314336

315-
nested_values = self._parse_nested_values(value_map)
337+
nested_values, local_values = self._parse_nested_values(value_map)
316338
instance_dummies, nested_dummy_map = self._create_instance_dummies(name)
317339
self._modello_instance_dummies = instance_dummies
318340

319341
instance_constraints = self._collect_instance_constraints(
320-
value_map, nested_values, instance_dummies, nested_dummy_map
342+
local_values, nested_values, instance_dummies
321343
)
322344
self._modello_instance_constraints = instance_constraints
323345

324-
constraints = self._build_constraints(instance_dummies, instance_constraints, nested_dummy_map)
346+
constraints = self._build_constraints(instance_dummies, instance_constraints)
325347
self._modello_constraints = constraints
326348

327349
solution = self._solve(constraints)
328350

329-
self._assign_local_values(solution, instance_dummies, instance_constraints)
330-
self._instantiate_nested_models(
331-
name, solution, instance_dummies, instance_constraints, nested_dummy_map
351+
values = self._assign_local_values(solution, instance_dummies, instance_constraints)
352+
for attr, val in values.items():
353+
setattr(self, attr, val)
354+
355+
nested_instances = self._instantiate_nested_models(
356+
name, solution, instance_dummies, instance_constraints
332357
)
358+
for attr, instance in nested_instances.items():
359+
setattr(self, attr, instance)

test_modello.py

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,10 @@ class Parent(Modello):
106106
child = Child
107107
y = InstanceDummy("y")
108108

109-
instance = object.__new__(Parent)
110109
values = {"child": {"x": 5}, "y": 3}
111-
nested = instance._parse_nested_values(values)
110+
nested, remaining = Parent._parse_nested_values(values)
112111
assert nested == {"child": {"x": 5}}
113-
assert values == {"y": 3}
112+
assert remaining == {"y": 3}
114113

115114

116115
def test_helper_create_instance_dummies():
@@ -123,9 +122,74 @@ class Parent(Modello):
123122
child = Child
124123
y = InstanceDummy("y")
125124

126-
instance = object.__new__(Parent)
127-
dummies, nested = instance._create_instance_dummies("X")
125+
dummies, nested = Parent._create_instance_dummies("X")
128126
assert dummies[Parent._modello_namespace.dummies["y"]].name.startswith("X_")
129127
child_dummy = Child._modello_namespace.dummies["x"]
130128
assert child_dummy in nested["child"]
131129
assert nested["child"][child_dummy].name.startswith("X_")
130+
131+
132+
def test_helper_collect_instance_constraints():
133+
"""_collect_instance_constraints handles nested dictionaries."""
134+
135+
class Child(Modello):
136+
a = InstanceDummy("a")
137+
b = InstanceDummy("b")
138+
139+
class Parent(Modello):
140+
child = Child
141+
c = InstanceDummy("c")
142+
143+
dummies, nested_map = Parent._create_instance_dummies("P")
144+
constraints = Parent._collect_instance_constraints(
145+
{"c": 4}, {"child": {"a": 1}}, dummies
146+
)
147+
assert constraints[dummies[Parent._modello_namespace.dummies["c"]]] == 4
148+
child_dummy = Child._modello_namespace.dummies["a"]
149+
proxy_dummy = nested_map["child"][child_dummy]
150+
assert constraints[proxy_dummy] == 1
151+
152+
153+
def test_meta_setitem_creates_proxy():
154+
"""Assigning a model class creates proxy dummies."""
155+
156+
class Child(Modello):
157+
x = InstanceDummy("x")
158+
159+
class Parent(Modello):
160+
child = Child
161+
162+
proxy = Parent._modello_namespace.other_attrs["child"]
163+
assert type(proxy).__name__ == "ChildProxy"
164+
assert Parent._modello_namespace.nested_models["child"][0] is Child
165+
for dummy in Child._modello_namespace.dummies.values():
166+
proxy_dummy = Parent._modello_namespace.nested_models["child"][1][dummy]
167+
assert getattr(proxy, dummy.name) is proxy_dummy
168+
169+
170+
def test_nested_multiple_levels():
171+
"""Nested models work across more than one level."""
172+
173+
class Leaf(Modello):
174+
a = InstanceDummy("a")
175+
b = InstanceDummy("b")
176+
total = a + b
177+
178+
class Branch(Modello):
179+
leaf = Leaf
180+
offset = InstanceDummy("offset")
181+
total = leaf.total + offset
182+
183+
class Tree(Modello):
184+
left = Branch
185+
right = Branch
186+
whole = left.total + right.total
187+
188+
left = Branch("L", leaf={"a": 1, "b": 2}, offset=1)
189+
right = Branch("R", leaf={"a": 2, "b": 3}, offset=2)
190+
tree = Tree("T", left=left, right=right)
191+
assert tree.left.leaf.total == 3
192+
assert tree.left.total == 4
193+
assert tree.right.leaf.total == 5
194+
assert tree.right.total == 7
195+
assert tree.whole == 11

0 commit comments

Comments
 (0)