Skip to content

Commit 8577741

Browse files
committed
feat: support tensor composition
1 parent ccc08da commit 8577741

File tree

4 files changed

+68
-49
lines changed

4 files changed

+68
-49
lines changed

src/opvious/modeling/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
total,
2323
)
2424
from .definitions import (
25+
ComposeFunction,
2526
Constraint,
2627
Dimension,
2728
Image,
@@ -31,7 +32,9 @@
3132
TensorLike,
3233
Variable,
3334
alias,
35+
compose,
3436
constraint,
37+
infer_quantifiables,
3538
interval,
3639
objective,
3740
)
@@ -81,5 +84,8 @@
8184
# Fragments
8285
"fragments",
8386
# Utilities
87+
"ComposeFunction",
88+
"compose",
89+
"infer_quantifiables",
8490
"method_decorator",
8591
]

src/opvious/modeling/definitions.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,9 +214,6 @@ def total_product_count(self):
214214
The number of arguments must match the tensor's quantification.
215215
"""
216216

217-
# TODO: Add map method, which appends to _mappers array of transformations.
218-
# Once implemented, remove the negate arguments to transformations.
219-
220217
def __init__(
221218
self,
222219
image: Image,
@@ -387,6 +384,46 @@ def render_statement(self, label: Label) -> str | None:
387384
return s
388385

389386

387+
type ComposeFunction = Callable[[*tuple[Expression, ...]], Expression]
388+
389+
390+
def compose(fn: ComposeFunction, *tensors: TensorLike) -> TensorLike:
391+
"""Maps a function over one or more tensors
392+
393+
This is particularly useful when creating model fragments which use
394+
tensor-likes as arguments. For example it can be used to flip the indicator
395+
in an :class:`~opvious.modeling.fragments.ActivatedVariable`.
396+
"""
397+
if not tensors:
398+
raise ValueError("No tensors")
399+
return _Composed(fn, tensors)
400+
401+
402+
@dataclasses.dataclass(frozen=True)
403+
class _Composed:
404+
function: ComposeFunction
405+
tensors: Sequence[TensorLike]
406+
407+
def __call__(self, *subscripts: ExpressionLike) -> Expression:
408+
exprs = [tensor(*subscripts) for tensor in self.tensors]
409+
return self.function(*exprs)
410+
411+
412+
def infer_quantifiables(tensor: TensorLike) -> tuple[Quantifiable, ...]:
413+
"""Infers a tensor-like's underlying quantification
414+
415+
Args:
416+
tensor: A tensor or composed tensor with only unary functions
417+
"""
418+
match tensor:
419+
case Tensor():
420+
return tensor.quantifiables()
421+
case _Composed(_, tensors) if len(tensors) == 1:
422+
return infer_quantifiables(tensors[0])
423+
case _:
424+
raise Exception(f"Unable to infer quantifiables from {tensor}")
425+
426+
390427
def _is_simple_domain(d: Domain) -> bool:
391428
if d.mask is not None:
392429
return False

src/opvious/modeling/fragments.py

Lines changed: 16 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
Variable,
3333
alias,
3434
constraint,
35+
infer_quantifiables,
3536
interval,
3637
)
3738
from .identifiers import Name
@@ -157,9 +158,9 @@ def __init__(
157158
lower_bound: bool = True,
158159
upper_bound: bool = True,
159160
) -> None:
161+
if not quantifiables:
162+
quantifiables = infer_quantifiables(tensor)
160163
if isinstance(tensor, Tensor):
161-
if not quantifiables:
162-
quantifiables = tensor.quantifiables()
163164
if not image:
164165
image = tensor.image
165166
if lower_bound and tensor.image.lower_bound == 0:
@@ -246,9 +247,9 @@ class ActivationVariable(ModelFragment):
246247
used, if `False` no activation constraint will be added.
247248
lower_bound: Value of the lower bound used in the deactivation
248249
constraint. If `True` the variable's image's lower bound will be
249-
used, if `False` no deactivation constraint will be added.
250+
used, if `False` no deactivation constraint will be added. Can only
251+
be used if the tensor is non-negative.
250252
name: Name of the generated activation variable
251-
negate: Negate the returned indicator variable.
252253
projection: Mask used to project the variable's quantification. When
253254
this is set, the indicator variable will be set to 1 iff at least
254255
one of the projected tensor values is positive.
@@ -263,11 +264,10 @@ def __new__(
263264
upper_bound: ExpressionLike | TensorLike | bool = True,
264265
lower_bound: ExpressionLike | TensorLike | bool = False,
265266
name: Name | None = None,
266-
negate: bool = False,
267267
projection: Projection = -1,
268268
) -> ActivationVariable:
269-
if not quantifiables and isinstance(tensor, Tensor):
270-
quantifiables = tensor.quantifiables()
269+
if not quantifiables:
270+
quantifiables = infer_quantifiables(tensor)
271271
domains = tuple(domain(q) for q in quantifiables)
272272

273273
def quantification(
@@ -299,8 +299,7 @@ def activates(self) -> Quantified:
299299
bound = bound(*cp.lifted)
300300
elif bound is True:
301301
bound = tensor_image().upper_bound
302-
value = 1 - self.value(*cp) if negate else self.value(*cp)
303-
yield bound * value >= tensor(*cp.lifted)
302+
yield bound * self.value(*cp) >= tensor(*cp.lifted)
304303

305304
@constraint(disabled=lower_bound is False)
306305
def deactivates(self) -> Quantified:
@@ -317,8 +316,7 @@ def deactivates(self) -> Quantified:
317316
bound = bound(*cp)
318317
elif bound is True:
319318
bound = tensor_image().lower_bound
320-
value = 1 - self.value(*cp) if negate else self.value(*cp)
321-
yield bound * value <= term
319+
yield bound * self.value(*cp) <= term
322320

323321
return _Fragment()
324322

@@ -362,7 +360,6 @@ def activation_variable(
362360
upper_bound: ExpressionLike | TensorLike | bool = True,
363361
lower_bound: ExpressionLike | TensorLike | bool = False,
364362
name: Name | None = None,
365-
negate: bool = False,
366363
projection: Projection = -1,
367364
) -> Callable[[TensorLike], ActivationVariable]:
368365
"""Transforms a method into an :class:`ActivationVariable` fragment
@@ -379,7 +376,6 @@ def wrapper(fn: TensorLike) -> ActivationVariable:
379376
lower_bound=lower_bound,
380377
upper_bound=upper_bound,
381378
name=name,
382-
negate=negate,
383379
projection=projection,
384380
)
385381

@@ -416,8 +412,8 @@ def __init__(
416412
raise NotImplementedError() # TODO: Implement.
417413

418414
self._tensor = tensor
419-
if not quantifiables and isinstance(tensor, Tensor):
420-
quantifiables = tensor.quantifiables()
415+
if not quantifiables:
416+
quantifiables = infer_quantifiables(tensor)
421417
self._domains = tuple(domain(q) for q in quantifiables)
422418

423419
self._piece_count = Parameter.discrete(
@@ -528,7 +524,6 @@ class ActivatedVariable(ModelFragment):
528524
subscripts
529525
upper_bound: Tensor upper bound, can be omitted if `tensor` is a
530526
:class:`~opvious.modeling.Tensor` instance
531-
negate: Negate the input indicator
532527
force_activation: Add constraint to ensure that the derived variable is
533528
at least equal to `tensor` when `indicator` is non-zero. You may
534529
choose to omit this if the variable is already pushed up via other
@@ -549,13 +544,11 @@ def __init__(
549544
upper_bound: ExpressionLike | None = None,
550545
force_activation: bool = True,
551546
force_deactivation: bool = True,
552-
negate: bool = False,
553547
name: Name | None = None,
554548
) -> None:
555549
self._tensor = tensor
556550
self._indicator = indicator
557551
self._indicator_projection = indicator_projection
558-
self._negate = negate
559552
self._force_activation = force_activation
560553
self._force_deactivation = force_deactivation
561554

@@ -593,12 +586,9 @@ def deactivates(self) -> Quantified:
593586
This constraint will be omitted if `force_deactivation` is false.
594587
"""
595588
for cp in self._quantification():
596-
toggle = (
597-
1 - self._indicator(*cp)
598-
if self._negate
599-
else self._indicator(*cp)
600-
)
601-
yield self.value(*cp.lifted) <= self._upper_bound * toggle
589+
yield self.value(
590+
*cp.lifted
591+
) <= self._upper_bound * self._indicator(*cp)
602592

603593
@constraint(lambda init, self: init(disabled=not self._force_activation))
604594
def activates(self) -> Quantified:
@@ -607,14 +597,10 @@ def activates(self) -> Quantified:
607597
This constraint will be omitted if `force_activation` is false.
608598
"""
609599
for cp in self._quantification():
610-
toggle = (
611-
self._indicator(*cp)
612-
if self._negate
613-
else 1 - self._indicator(*cp)
614-
)
615600
yield (
616601
self.value(*cp.lifted)
617-
>= self._tensor(*cp.lifted) - self._upper_bound * toggle
602+
>= self._tensor(*cp.lifted)
603+
- self._upper_bound * (1 - self._indicator(*cp))
618604
)
619605

620606

@@ -624,7 +610,6 @@ def activated_variable(
624610
indicator: Tensor,
625611
indicator_projection: Projection = -1,
626612
upper_bound: ExpressionLike | None = None,
627-
negate: bool = False,
628613
force_activation: bool = True,
629614
force_deactivation: bool = True,
630615
name: Name | None = None,
@@ -641,7 +626,6 @@ def wrapper(fn: TensorLike) -> ActivatedVariable:
641626
indicator=indicator,
642627
indicator_projection=indicator_projection,
643628
upper_bound=upper_bound,
644-
negate=negate,
645629
force_activation=force_activation,
646630
force_deactivation=force_deactivation,
647631
name=name,

tests/modeling_test.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -259,29 +259,21 @@ def child_starts_after_parent_ends(self):
259259
@om.fragments.activation_variable(
260260
lambda init, self: init(
261261
self.competing_tasks,
262-
negate=True,
263262
upper_bound=self.duration.total(),
264263
)
265264
)
266-
def must_start_after(self, t1, t2):
265+
# t2 ends after t1 starts => 1
266+
# 0 => t2 ends before t1 starts
267+
def may_start_before_end(self, t1, t2):
267268
return self.task_end(t2) - self.task_start(t1)
268269

269-
@om.fragments.activation_variable(
270-
lambda init, self: init(
271-
self.competing_tasks,
272-
negate=True,
273-
upper_bound=self.duration.total(),
274-
)
275-
)
276-
def must_end_before(self, t1, t2):
277-
return self.task_end(t1) - self.task_start(t2)
278-
279270
@om.constraint
280271
def one_active_task_per_machine(self):
281272
for t1, t2 in self.competing_tasks:
282273
yield (
283-
self.must_end_before(t1, t2) + self.must_start_after(t1, t2)
284-
>= 1
274+
self.may_start_before_end(t1, t2) +
275+
self.may_start_before_end(t2, t1)
276+
<= 1
285277
)
286278

287279
@om.objective

0 commit comments

Comments
 (0)