|
5 | 5 | from random import choice, shuffle
|
6 | 6 | from typing import List, Union
|
7 | 7 | from warnings import warn
|
| 8 | +from copy import deepcopy |
8 | 9 |
|
9 | 10 | import numpy as np
|
10 | 11 | from constraint import (
|
@@ -92,10 +93,11 @@ def __init__(
|
92 | 93 | self._tensorspace_param_config_structure = []
|
93 | 94 | self._map_tensor_to_param = {}
|
94 | 95 | self._map_param_to_tensor = {}
|
95 |
| - self.restrictions = restrictions.copy() if hasattr(restrictions, "copy") else restrictions |
96 |
| - self.original_restrictions = restrictions.copy() if hasattr(restrictions, "copy") else restrictions |
| 96 | + restrictions = list(restrictions) if not isinstance(restrictions, (list, tuple)) else restrictions |
| 97 | + self.restrictions = deepcopy(restrictions) |
| 98 | + self.original_restrictions = deepcopy(restrictions) # keep the original restrictions, so that the searchspace can be modified later |
97 | 99 | # the searchspace can add commonly used constraints (e.g. maxprod(blocks) <= maxthreads)
|
98 |
| - self._modified_restrictions = restrictions.copy() if hasattr(restrictions, "copy") else restrictions |
| 100 | + self._modified_restrictions = deepcopy(restrictions) |
99 | 101 | self.param_names = list(self.tune_params.keys())
|
100 | 102 | self.params_values = tuple(tuple(param_vals) for param_vals in self.tune_params.values())
|
101 | 103 | self.params_values_indices = None
|
@@ -479,8 +481,9 @@ def __build_searchspace(self, block_size_names: list, max_threads: int, solver:
|
479 | 481 |
|
480 | 482 | def __add_restrictions(self, parameter_space: Problem) -> Problem:
|
481 | 483 | """Add the user-specified restrictions as constraints on the parameter space."""
|
482 |
| - if isinstance(self.restrictions, list): |
483 |
| - for restriction in self.restrictions: |
| 484 | + restrictions = deepcopy(self.restrictions) |
| 485 | + if isinstance(restrictions, list): |
| 486 | + for restriction in restrictions: |
484 | 487 | required_params = self.param_names
|
485 | 488 |
|
486 | 489 | # (un)wrap where necessary
|
@@ -510,14 +513,14 @@ def __add_restrictions(self, parameter_space: Problem) -> Problem:
|
510 | 513 | raise ValueError(f"Unrecognized restriction type {type(restriction)} ({restriction})")
|
511 | 514 |
|
512 | 515 | # if the restrictions are the old monolithic function, apply them directly (only for backwards compatibility, likely slower than well-specified constraints!)
|
513 |
| - elif callable(self.restrictions): |
| 516 | + elif callable(restrictions): |
514 | 517 |
|
515 | 518 | def restrictions_wrapper(*args):
|
516 |
| - return check_instance_restrictions(self.restrictions, dict(zip(self.param_names, args)), False) |
| 519 | + return check_instance_restrictions(restrictions, dict(zip(self.param_names, args)), False) |
517 | 520 |
|
518 | 521 | parameter_space.addConstraint(FunctionConstraint(restrictions_wrapper), self.param_names)
|
519 |
| - elif self.restrictions is not None: |
520 |
| - raise ValueError(f"The restrictions are of unsupported type {type(self.restrictions)}") |
| 522 | + elif restrictions is not None: |
| 523 | + raise ValueError(f"The restrictions are of unsupported type {type(restrictions)}") |
521 | 524 | return parameter_space
|
522 | 525 |
|
523 | 526 | def __parse_restrictions_pysmt(self, restrictions: list, tune_params: dict, symbols: dict):
|
|
0 commit comments