Skip to content

Commit 47ab6ab

Browse files
committed
Add cuda_gc in typical_sequence_split_points
1 parent fb7ecd6 commit 47ab6ab

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

graph_net/torch/typical_sequence_split_points.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
from typing import Any, Dict, List
66
import torch
77
import torch.nn as nn
8+
89
from graph_net.torch.rp_expr.rp_expr_parser import RpExprParser
910
from graph_net.torch.rp_expr.rp_expr_util import (
1011
MakeNestedIndexRangeFromLetsListTokenRpExpr,
1112
)
1213
from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs
1314
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module_without_varify
15+
from graph_net.torch.decompose_util import cuda_gc
1416

1517

1618
class TypicalSequenceExtractor:
@@ -85,10 +87,12 @@ def _get_output_path(self, rel_model_path: str):
8587

8688
def _extract_ops(self, model_path: str) -> List[str]:
8789
extractor = TypicalSequenceExtractor()
88-
model, inputs = get_torch_module_and_inputs(model_path)
89-
compiled_model, _ = parse_sole_graph_module_without_varify(model, inputs)
90-
extractor.extract_compiler(compiled_model, inputs)
91-
ops_info = extractor.extract_node
90+
with cuda_gc():
91+
model, inputs = get_torch_module_and_inputs(model_path)
92+
compiled_model, _ = parse_sole_graph_module_without_varify(model, inputs)
93+
extractor.extract_compiler(compiled_model, inputs)
94+
ops_info = extractor.extract_node
95+
del model, inputs, compiled_model
9296

9397
return [op["target_name"] for op in ops_info]
9498

0 commit comments

Comments
 (0)