Skip to content

Commit cf1fd64

Browse files
committed
remove graph
1 parent b1de6a0 commit cf1fd64

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

graph_net/torch/backend/range_decomposer_validator_backend.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,11 @@
99

1010

1111
class ComposedModel(nn.Module):
12-
def __init__(self, graph: nn.Module, subgraph: List[nn.Module]):
12+
def __init__(self, subgraph: List[nn.Module]):
1313
super().__init__()
14-
self.graph = graph
1514
self.subgraphs = nn.ModuleList(subgraph)
1615

1716
def forward(self, **kwargs):
18-
self.graph(**kwargs)
19-
2017
subgraph_intput = {
2118
key.replace("L", "l_l", 1): value
2219
for key, value in kwargs.items()
@@ -61,7 +58,6 @@ def __call__(self, model: torch.nn.Module) -> torch.nn.Module:
6158
)
6259

6360
device = model.__class__.__graph_net_device__
64-
graph_instances = self._load_model_instance(model_dir, device)
6561
subgraph_instances = []
6662

6763
for path in subgraph_paths:
@@ -72,7 +68,7 @@ def __call__(self, model: torch.nn.Module) -> torch.nn.Module:
7268
f"[RangeDecomposerValidatorBackend] Loaded and instantiated '{dir_name}'"
7369
)
7470

75-
composed_model = ComposedModel(graph_instances, subgraph_instances)
71+
composed_model = ComposedModel(subgraph_instances)
7672
return composed_model.eval()
7773

7874
def synchronize(self):

0 commit comments

Comments
 (0)