Skip to content

Commit 0cbadd7

Browse files
authored
[MaskRCNN/PyT] Update AMP API for inference (#810)
1 parent 2badf6e commit 0cbadd7

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

PyTorch/Segmentation/MaskRCNN/pytorch/tools/test_net.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,14 @@ def main():
9696
model = build_detection_model(cfg)
9797
model.to(cfg.MODEL.DEVICE)
9898

99-
# Initialize mixed-precision if necessary
99+
# Initialize mixed-precision
100100
if args.fp16:
101101
use_mixed_precision = True
102102
else:
103103
use_mixed_precision = cfg.DTYPE == "float16"
104-
amp_handle = amp.init(enabled=use_mixed_precision, verbose=cfg.AMP_VERBOSE)
105-
104+
amp_opt_level = 'O1' if use_mixed_precision else 'O0'
105+
model = amp.initialize(model, opt_level=amp_opt_level)
106+
106107
output_dir = cfg.OUTPUT_DIR
107108
checkpointer = DetectronCheckpointer(cfg, model, save_dir=output_dir)
108109
_ = checkpointer.load(cfg.MODEL.WEIGHT)

0 commit comments

Comments
 (0)