Skip to content

Commit cba4e86

Browse files
committed
backup code
1 parent 9a6716c commit cba4e86

File tree

2 files changed

+261
-0
lines changed

2 files changed

+261
-0
lines changed

graph_net/constraint_util.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from graph_net.dynamic_dim_constraints import DynamicDimConstraints
2+
from typing import Callable
3+
import copy
4+
5+
6+
def symbolize_data_input_dims(
7+
dyn_dim_cstr: DynamicDimConstraints,
8+
is_data_input: Callable[["input_var_name:str"], bool],
9+
is_input_shape_valid: Callable[[DynamicDimConstraints], bool],
10+
) -> DynamicDimConstraints | None:
11+
"""
12+
Symbolizes data input dimensions as much as possible.
13+
Returns new DynamicDimConstraints if success.
14+
Returns None if no symbolicable dim .
15+
"""
16+
unqiue_dims = set()
17+
18+
def dumpy_filter_fn(input_name, input_idx, axis, dim):
19+
unqiue_dims.add(dim)
20+
# No symbolization because of returning True
21+
return False
22+
23+
# Collect input dimensions into `unqiue_dims`
24+
assert dyn_dim_cstr.symbolize(dumpy_filter_fn) is None
25+
for picked_dim in unqiue_dims:
26+
tmp_dyn_dim_cstr = copy.deepcopy(dyn_dim_cstr)
27+
28+
def filter_fn(input_name, input_idx, axis, dim):
29+
return is_data_input(input_name) and dim == picked_dim
30+
31+
symbol = tmp_dyn_dim_cstr.symbolize(filter_fn)
32+
if symbol is None:
33+
continue
34+
sym2example_value = {symbol: picked_dim + 1}
35+
if not tmp_dyn_dim_cstr.check_delta_symbol2example_value(sym2example_value):
36+
continue
37+
tmp_dyn_dim_cstr.update_symbol2example_value(sym2example_value)
38+
if not is_input_shape_valid(tmp_dyn_dim_cstr):
39+
continue
40+
dyn_dim_cstr = tmp_dyn_dim_cstr
41+
return dyn_dim_cstr
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
import sys
2+
import sympy
3+
import importlib.util as imp
4+
from dataclasses import dataclass
5+
import copy
6+
from typing import Callable
7+
from collections import namedtuple
8+
9+
10+
@dataclass
11+
class DynamicDimConstraints:
12+
kSymbolVarNamePrefix = "s"
13+
14+
symbols: list[sympy.Symbol]
15+
kSymbols = "dynamic_dim_constraint_symbols"
16+
17+
# len(symbol2example_value) == len(symbols)
18+
symbol2example_value: dict[sympy.Symbol, int]
19+
kSymbol2ExampleValue = "dynamic_dim_constraint_symbol2example_value"
20+
21+
relations: list[sympy.Rel]
22+
kRelations = "dynamic_dim_constraint_relations"
23+
24+
# len(input_shapes) equals number of Model.forward arguments
25+
input_shapes: list[("var-name", tuple[sympy.Expr | int])]
26+
kInputShapes = "dynamic_dim_constraint_input_shapes"
27+
28+
# len(input_shapes) equals number of Model.forward arguments
29+
input_max_values: list[("var-name", tuple[sympy.Expr | int | None])]
30+
kInputMaxValues = "dynamic_dim_constraint_input_max_values"
31+
32+
def symbolize(
33+
self,
34+
filter_fn=Callable[
35+
["input_name:str", "input_idx:int", "axis:int", "dim:int"], bool
36+
],
37+
) -> sympy.Symbol | None:
38+
"""
39+
Returns created symbol.
40+
"""
41+
InputDim = namedtuple("InputDim", ["input_idx", "axis", "dim"])
42+
input_dims = [
43+
InputDim(input_idx, axis, dim)
44+
for input_idx, namedshape in enumerate(self.input_shapes)
45+
for input_name, shape in [namedshape]
46+
for axis, dim in enumerate(shape)
47+
if isinstance(dim, int)
48+
if filter_fn(input_name, input_idx, axis, dim)
49+
]
50+
if len(input_dims) == 0:
51+
return None
52+
unique_dims = set(dim for input_dix, axis, dim in input_dims)
53+
assert len(unique_dims) == 1
54+
dim = list(unique_dims)[0]
55+
new_sym = self._new_symbol(example_value=dim)
56+
for input_dix, axis, _ in input_dims:
57+
self.input_shapes[input_dix][1][axis] = new_sym
58+
return new_sym
59+
60+
def _new_symbol(self, example_value):
61+
max_existed_seq_no = max(
62+
[
63+
-1,
64+
*(
65+
seq_no
66+
for symbol in self.symbols
67+
for seq_no in [int(symbol.name[1:])]
68+
),
69+
]
70+
)
71+
seq_no = max_existed_seq_no + 1
72+
symbol = sympy.Symbol(f"{self.kSymbolVarNamePrefix}{seq_no}")
73+
self.symbol2example_value[symbol] = example_value
74+
self.symbols.append(symbol)
75+
return symbol
76+
77+
def update_symbol2example_value(self, sym2example_value: dict):
78+
self.symbol2example_value = self.merge_symbol2example_value(sym2example_value)
79+
return self
80+
81+
def merge_symbol2example_value(self, sym2example_value: dict):
82+
return {
83+
k: v
84+
for k, v in [*self.symbol2example_value.items(), *sym2example_value.items()]
85+
}
86+
87+
def check_delta_symbol2example_value(self, sym2example_value: dict):
88+
if len(sym2example_value) == 0:
89+
return True
90+
91+
sym2example_value = self.merge_symbol2example_value(sym2example_value)
92+
93+
sym_exprs = [
94+
*self._get_sym_exprs_from_input_shapes(),
95+
*self._get_sym_exprs_from_input_max_values(),
96+
]
97+
98+
relations = [*self.relations, *(sym_expr > 0 for sym_expr in sym_exprs)]
99+
100+
return all(
101+
relation.subs(sym2example_value) == sympy.true for relation in relations
102+
)
103+
104+
def _get_sym_exprs_from_input_shapes(self):
105+
yield from (
106+
sym_dim
107+
for name, shape in self.input_shapes
108+
for sym_dim in shape
109+
if isinstance(sym_dim, sympy.Expr)
110+
)
111+
112+
def _get_sym_exprs_from_input_max_values(self):
113+
yield from (
114+
sym_value
115+
for name, sym_value in self.input_max_values
116+
if isinstance(sym_value, sympy.Expr)
117+
)
118+
119+
def serialize_to_py_str(self):
120+
symbols_definition = "\n".join(
121+
f"{name} = Symbol('{name}')"
122+
for symbol in self.symbols
123+
for name in [symbol.name]
124+
)
125+
return f"""
126+
from sympy import Symbol, Expr, Rel, Eq
127+
128+
{symbols_definition}
129+
130+
{self.kSymbols} = {self.symbols}
131+
132+
{self.kSymbol2ExampleValue} = {self.symbol2example_value}
133+
134+
{self.kRelations} = {self.relations}
135+
136+
{self.kInputShapes} = {self.input_shapes}
137+
138+
{self.kInputMaxValues} = {self.input_max_values}
139+
"""
140+
141+
@classmethod
142+
def unserialize_from_py_file(cls, filepath):
143+
module = cls.load_module(filepath)
144+
return cls(
145+
symbols=cls.module_symbols(module),
146+
symbol2example_value=cls.module_symbol2example_value(module),
147+
relations=cls.module_relations(module),
148+
input_shapes=cls.module_input_shapes(module),
149+
input_max_values=cls.module_input_max_values(module),
150+
)
151+
152+
@classmethod
153+
def module_symbols(cls, module):
154+
return cls.get_module_list_attr(module, cls.kSymbols)
155+
156+
@classmethod
157+
def module_symbol2example_value(cls, module):
158+
return cls.get_module_dict_attr(module, cls.kSymbol2ExampleValue)
159+
160+
@classmethod
161+
def module_relations(cls, module):
162+
return cls.get_module_list_attr(module, cls.kRelations)
163+
164+
@classmethod
165+
def module_input_shapes(cls, module):
166+
return cls.get_module_list_attr(module, cls.kInputShapes)
167+
168+
@classmethod
169+
def module_input_max_values(cls, module):
170+
return cls.get_module_list_attr(module, cls.kInputMaxValues)
171+
172+
@classmethod
173+
def get_module_list_attr(cls, module, attr):
174+
return cls.get_module_attr(module, attr, default=[])
175+
176+
@classmethod
177+
def get_module_dict_attr(cls, module, attr):
178+
return cls.get_module_attr(module, attr, default={})
179+
180+
@classmethod
181+
def get_module_attr(cls, module, attr, default):
182+
return getattr(module, attr) if hasattr(module, attr) else default
183+
184+
@classmethod
185+
def load_module(cls, path, name="unamed"):
186+
spec = imp.spec_from_file_location(name, path)
187+
module = imp.module_from_spec(spec)
188+
spec.loader.exec_module(module)
189+
return module
190+
191+
192+
if __name__ == "__main__":
193+
cstr_code = """
194+
import sympy
195+
196+
x = sympy.Symbol('x')
197+
y = sympy.Symbol('y')
198+
199+
dynamic_dim_constraint_symbol2example_value = [(x, 2)]
200+
201+
dynamic_dim_constraint_symbols = [x, y]
202+
203+
dynamic_dim_constraint_relations = [x > 0]
204+
"""
205+
206+
import tempfile
207+
208+
with tempfile.NamedTemporaryFile(suffix=".py", mode="w+", encoding="utf-8") as tmp:
209+
tmp.write(cstr_code)
210+
tmp.flush()
211+
cstr = DynamicDimConstraints.unserialize_from_py_file(tmp.name)
212+
print(cstr)
213+
print(cstr.serialize_to_py_str())
214+
215+
with tempfile.NamedTemporaryFile(suffix=".py", mode="w+", encoding="utf-8") as tmp:
216+
tmp.write(cstr.serialize_to_py_str())
217+
tmp.flush()
218+
cstr = DynamicDimConstraints.unserialize_from_py_file(tmp.name)
219+
print(cstr)
220+
print(cstr.serialize_to_py_str())

0 commit comments

Comments
 (0)