Skip to content

Commit 27e45b1

Browse files
committed
support model-path-prefix in splitting positions
1 parent c013a64 commit 27e45b1

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

graph_net/test/typical_sequence_decomposer_test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ model_list="$GRAPH_NET_ROOT/graph_net/test/dev_model_list/validation_error_model
1111

1212
python3 -m graph_net.torch.typical_sequence_split_points \
1313
--model-list "$model_list" \
14+
--model-path-prefix "$GRAPH_NET_ROOT" \
1415
--device "cuda" \
1516
--window-size 10 \
1617
--fold-policy default \

graph_net/torch/typical_sequence_split_points.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,14 @@ def get_len(tid):
150150
get_len(sym_id)
151151
return token2len
152152

153-
def analyze(self, model_paths_file: str, device: str) -> Dict[str, Dict]:
153+
def analyze(
154+
self, model_path_prefix: str, model_paths_file: str, device: str
155+
) -> Dict[str, Dict]:
154156
input_file = Path(model_paths_file)
155157

156158
with open(input_file, "r") as f:
157159
model_paths = [
158-
Path(line.strip())
160+
Path(model_path_prefix) / line.strip()
159161
for line in f
160162
if line.strip() and not line.startswith("#")
161163
]
@@ -264,7 +266,7 @@ def main(args):
264266
fold_policy=args.fold_policy,
265267
fold_times=args.fold_times,
266268
)
267-
results = analyzer.analyze(args.model_list, args.device)
269+
results = analyzer.analyze(args.model_path_prefix, args.model_list, args.device)
268270
if args.output_json:
269271
with open(args.output_json, "w") as f:
270272
json.dump(results, f, indent=4)
@@ -280,6 +282,12 @@ def main(args):
280282
required=True,
281283
help="Path to a text file containing paths to models (one per line).",
282284
)
285+
parser.add_argument(
286+
"--model-path-prefix",
287+
type=str,
288+
default="./",
289+
help="Prefix to add to each model path in the list.",
290+
)
283291
parser.add_argument(
284292
"--device",
285293
type=str,

0 commit comments

Comments
 (0)