Skip to content

Commit 206b32d

Browse files
guyleafmarkowanga
andauthored
chore: Merge latest upstream changes (#11)
* Fix PadToSize impl to follow Transform API after torchvision 0.21 (lyuwenyu#629) (duplicate #5) * Add update config params for pytorch_v2 all tool scripts (lyuwenyu#633) --------- Co-authored-by: Marcin Wątroba <[email protected]>
1 parent 9839f94 commit 206b32d

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

rtdetrv2_pytorch/tools/export_onnx.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,27 @@
77

88
import torch
99
import torch.nn as nn
10-
from rtdetrv2.core import YAMLConfig
10+
from rtdetrv2.core import YAMLConfig, yaml_utils
1111
from rtdetrv2.misc import import_modules
1212

1313

1414
def main(
1515
args,
1616
):
1717
"""main"""
18-
cfg = YAMLConfig(args.config, resume=args.resume)
18+
update_dict = yaml_utils.parse_cli(args.update) if args.update else {}
19+
update_dict.update(
20+
{
21+
k: v
22+
for k, v in args.__dict__.items()
23+
if k
24+
not in [
25+
"update",
26+
]
27+
and v is not None
28+
}
29+
)
30+
cfg = YAMLConfig(args.config, **update_dict)
1931

2032
if args.resume:
2133
checkpoint = torch.load(args.resume, map_location="cpu")
@@ -125,6 +137,7 @@ def forward(self, images, orig_target_sizes):
125137
default=[],
126138
help="preload modules before execution",
127139
)
140+
parser.add_argument("--update", "-u", nargs="+", help="update yaml config")
128141

129142
args = parser.parse_args()
130143

rtdetrv2_pytorch/tools/run_profile.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
1212
from typing import Any, Dict, List, Optional
1313

14-
from rtdetrv2.core import YAMLConfig
14+
from rtdetrv2.core import YAMLConfig, yaml_utils
1515
from rtdetrv2.misc import import_modules
1616

1717
__all__ = ["profile_stats"]
@@ -98,11 +98,26 @@ def trace_handler(prof):
9898
default=[],
9999
help="preload modules before execution",
100100
)
101+
parser.add_argument(
102+
"-u", "--update", nargs="+", help="Update yaml config from command line."
103+
)
101104
args = parser.parse_args()
102105

103106
import_modules(args.preloads)
104107

105-
cfg = YAMLConfig(args.config, device=args.device)
108+
update_dict = yaml_utils.parse_cli(args.update) if args.update else {}
109+
update_dict.update(
110+
{
111+
k: v
112+
for k, v in args.__dict__.items()
113+
if k
114+
not in [
115+
"update",
116+
]
117+
and v is not None
118+
}
119+
)
120+
cfg = YAMLConfig(args.config, **update_dict)
106121
model = cfg.model.to(args.device)
107122

108123
profile_stats(model, verbose=True)

0 commit comments

Comments
 (0)