-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference_fault_deform.py
More file actions
115 lines (88 loc) · 4.17 KB
/
inference_fault_deform.py
File metadata and controls
115 lines (88 loc) · 4.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import argparse
import torch
import numpy as np
import time
import random
from src.models.model_loader import load_model
from src.training_inference.evaluate import evaluate_inference
from src.dataloaders.data_loader import get_inference_dataloader
from src.metrics_and_losses.photometric_metrics import realEPE
from src.metrics_and_losses.gradient_metrics import L2Smoothness
from src.configs.load_config import ConfigArgumentParser, YamlConfigAction
def fix_seed(seed):
"""
Fix the random seed for reproducibility across PyTorch, NumPy, and Python's random module.
"""
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.cuda.manual_seed_all(seed)
def parse_args():
"""
Parse command line arguments for the inference script.
"""
parser = argparse.ArgumentParser()
parser = ConfigArgumentParser('inference fault deform')
parser.add_argument('--train_config_name', default='microflow')
# Inference specific parameters
parser.add_argument('--seed', default=1, type=int, help='Seed used for the random number generators')
parser.add_argument('--pretrained_model_filename', help='Filename containing the pretrained weights associated witht the config')
# data loader parameters
parser.add_argument('--split_name', type=str, default="test", help='name of the split')
parser.add_argument('--split_start_idx', type=int, default=0, help='id of the first example')
parser.add_argument('--split_count', type=int, default=10000, help='number of split examples')
parser.add_argument('--split_scaling_factors', nargs='+', default=[0, 1, 2, 3], help="scaling factors")
parser.add_argument('--split_dir', type=str, help='directory containing the dataset splits')
parser.add_argument('--dataset_dir', help='path to dataset dir')
parser.add_argument('--dataset_name', default='faultdeform', help='dataset name')
parser.add_argument('--image_size', default=512, type=int, help='Image size')
parser.add_argument('-b', '--batch_size', default=1, type=int, help='batch size')
# edge parameters
parser.add_argument('--near_fault_only', action='store_true', help='')
parser.add_argument('--fault_boundary_dir', default=None, help="dir containing the boundary faults images")
parser.add_argument('--fault_boundary_disk', type=int, default=1, help='number of val examples')
# saving parameters
parser.add_argument('--save_metrics', action='store_true', help='')
parser.add_argument('--eval_crop_size', type=int, default=0, help='when > 0 it is mainly for aligning with cosi-corr')
parser.add_argument('--save_images', action='store_true', help='')
parser.add_argument('--metric_filename', default=None, help="filename where to save the metric results")
parser.add_argument('--save_dir', default=None, help="dir where to save the estimated images")
namespace, _ = parser.parse_known_args()
config_filename = os.path.join("data/configs/train_fault_deform/", f'{namespace.train_config_name.lower()}.yaml')
parser.add_argument('--config', action=YamlConfigAction, default=[config_filename])
return parser.parse_args()
def run_inference(args):
"""
Run the inference pipeline.
"""
# Fix the seed
fix_seed(args.seed)
# Define device
device = torch.device('cuda')
# Load models
optical_flow_model = load_model(args, device)
# Get dataloaders
loader, frame_names, crs_meta_datas, transform_meta_datas = get_inference_dataloader(args)
# Define metrics
photometric_metrics = [realEPE]
smoothness_metrics = [L2Smoothness()]
t_start = time.time()
evaluate_inference(
args=args,
device=device,
model=optical_flow_model,
loader=loader,
photometric_metrics=photometric_metrics,
smoothness_metrics=smoothness_metrics,
frame_names=frame_names,
crs_meta_datas=crs_meta_datas,
transform_meta_datas=transform_meta_datas
)
t_end = time.time()
print(f'\n## Inference on {args.split_name}##')
print(f"total time {t_end - t_start}")
if __name__ == '__main__':
# Load parameters
args = parse_args()
run_inference(args)