Skip to content

Commit 3a8a8fb

Browse files
committed
reorder symbol names
1 parent ebe9968 commit 3a8a8fb

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

graph_net/constraint_util.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,13 @@ def symbolize_data_input_dims(
143143
Returns new DynamicDimConstraints if success.
144144
Returns None if no symbolicable dim .
145145
"""
146-
unqiue_dims = set()
146+
unqiue_dims = []
147147

148148
def dumpy_filter_fn(input_name, input_idx, axis, dim):
149149
if is_data_input(input_name):
150150
print("data_input", input_name, input_idx, axis, dim)
151-
unqiue_dims.add(dim)
151+
if dim not in unqiue_dims:
152+
unqiue_dims.append(dim)
152153
# No symbolization because of returning True
153154
return False
154155

graph_net/tensor_meta.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
@dataclass
77
class TensorMeta:
8+
record_class_name: str
89
name: str
910
original_name: str | None
1011
shape: list[int]
@@ -20,6 +21,7 @@ class TensorMeta:
2021
def unserialize_from_py_file(cls, file_path) -> list["TensorMeta"]:
2122
return [
2223
TensorMeta(
24+
record_class_name=attrs.get("record_class_name"),
2325
name=attrs.get("name"),
2426
original_name=attrs.get("original_name", None),
2527
shape=attrs.get("shape", []),
@@ -37,11 +39,13 @@ def unserialize_from_py_file(cls, file_path) -> list["TensorMeta"]:
3739

3840
@classmethod
3941
def _convert_cls_to_attrs(cls, tensor_meta_cls):
40-
return {
42+
attrs = {
4143
k: v
4244
for k, v in tensor_meta_cls.__dict__.items()
4345
if not k.startswith("__") and not callable(v)
4446
}
47+
attrs["record_class_name"] = tensor_meta_cls.__name__
48+
return attrs
4549

4650
@classmethod
4751
def _get_classes(cls, file_path, name="unamed"):
@@ -50,10 +54,9 @@ def _get_classes(cls, file_path, name="unamed"):
5054
spec.loader.exec_module(unnamed)
5155
yield from inspect.getmembers(unnamed, inspect.isclass)
5256

53-
def serialize_to_py_str(self, cls_name_prefix: str = "example_input") -> str:
54-
uid = f"{cls_name_prefix}_tensor_meta_{self.name}"
57+
def serialize_to_py_str(self) -> str:
5558
lines = [
56-
(f"class {uid}:"),
59+
(f"class {self.record_class_name}:"),
5760
(f'\tname = "{self.name}"'),
5861
(f"\tshape = {self.shape}"),
5962
(f'\tdtype = "{self.dtype}"'),

samples/timm/resnet18/input_tensor_constraints.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
dynamic_dim_constraint_symbols = [S0, S1]
77

8-
dynamic_dim_constraint_symbol2example_value = {S0: 224, S1: 1}
8+
dynamic_dim_constraint_symbol2example_value = {S0: 1, S1: 224}
99

1010
dynamic_dim_constraint_relations = []
1111

@@ -205,7 +205,7 @@
205205
[512, 512, 3, 3],
206206
"L_self_modules_layer4_modules_1_modules_conv2_parameters_weight_",
207207
),
208-
([S1, 3, S0, S0], "L_x_"),
208+
([S0, 3, S1, S1], "L_x_"),
209209
([], "s1"),
210210
]
211211

0 commit comments

Comments
 (0)