Skip to content

Commit 4b1a78a

Browse files
committed
Add CV/NLP SymDim Reifier matching error msg.
1 parent 8b97b17 commit 4b1a78a

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

graph_net/torch/sym_dim_reifiers/naive_cv_sym_dim_reifier.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,14 @@ def match(self) -> bool:
1616
if self.dyn_dim_cstrs is None:
1717
return False
1818
sym_shapes_str = self.dyn_dim_cstrs.serialize_symbolic_input_shapes_to_str()
19-
return sym_shapes_str in self._get_map_cv_sym_shapes_str2reifier()
19+
if sym_shapes_str == "[]":
20+
return False
21+
elif sym_shapes_str not in self._get_map_cv_sym_shapes_str2reifier():
22+
print(
23+
f"[CV SymDim Reify] No reifier matched symbolic shapes:{sym_shapes_str} \nPlease add a reify rule to _get_map_cv_sym_shapes_str2reifier()"
24+
)
25+
return False
26+
return True
2027

2128
def reify(self):
2229
assert self.match()

graph_net/torch/sym_dim_reifiers/naive_nlp_sym_dim_reifier.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,14 @@ def match(self) -> bool:
1616
if self.dyn_dim_cstrs is None:
1717
return False
1818
sym_shapes_str = self.dyn_dim_cstrs.serialize_symbolic_input_shapes_to_str()
19-
return sym_shapes_str in self._get_map_nlp_sym_shapes_str2reifier()
19+
if sym_shapes_str == "[]":
20+
return False
21+
elif sym_shapes_str not in self._get_map_nlp_sym_shapes_str2reifier():
22+
print(
23+
f"[NLP SymDim Reify] No reifier matched symbolic shapes:{sym_shapes_str} \nPlease add a reify rule to _get_map_nlp_sym_shapes_str2reifier()"
24+
)
25+
return False
26+
return True
2027

2128
def reify(self):
2229
assert self.match()

0 commit comments

Comments
 (0)