Skip to content

Commit eeeda67

Browse files
committed
Add RenamedDataInputPredicator to support generate unittest for renamed subgraphs.
1 parent 96f2658 commit eeeda67

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

graph_net/torch/constraint_util.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@ def __call__(self, model_path, input_var_name: str) -> bool:
1616
)
1717

1818

19+
class RenamedDataInputPredicator:
20+
def __init__(self, config):
21+
self.config = config
22+
23+
def __call__(self, model_path, input_var_name: str) -> bool:
24+
return input_var_name.startswith("in_")
25+
26+
1927
class ModelRunnablePredicator:
2028
def __init__(self, config):
2129
if config is None:

graph_net/torch/sample_passes/agent_unittest_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def __init__(
143143

144144
def generate(self):
145145
model_name = "".join(
146-
word.capitalize() for word in re.split(r"[_-]", self.model_path.name)
146+
word.capitalize() for word in re.split(r"[_.-]", self.model_path.name)
147147
)
148148
graph_module = load_class_from_file(
149149
self.model_path / "model.py", class_name="GraphModule"

0 commit comments

Comments
 (0)