-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference_estimate_fault_deform.py
More file actions
113 lines (87 loc) · 4.29 KB
/
inference_estimate_fault_deform.py
File metadata and controls
113 lines (87 loc) · 4.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import argparse
import torch
import numpy as np
import time
import random
from src.models.model_loader import load_model
from src.training_inference.evaluate import evaluate_inference_from_estimates
from src.dataloaders.data_loader import get_inference_dataloader
from src.metrics_and_losses.photometric_metrics import realEPE
from src.metrics_and_losses.gradient_metrics import L2Smoothness
from src.configs.load_config import ConfigArgumentParser, YamlConfigAction
def fix_seed(seed):
"""
Fix the random seed for reproducibility across PyTorch, NumPy, and Python's random module.
"""
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.cuda.manual_seed_all(seed)
def parse_args():
"""
Parse command line arguments for the inference script.
"""
parser = argparse.ArgumentParser()
parser = ConfigArgumentParser('inference fault deform')
parser.add_argument('--train_config_name', default='microflow')
# Inference specific parameters
parser.add_argument('--seed', default=1, type=int, help='Seed used for the random number generators')
parser.add_argument('--pretrained_model_filename', help='Filename containing the pretrained weights associated witht the config')
# data loader parameters
parser.add_argument('--split_name', type=str, default="test", help='name of the split')
parser.add_argument('--split_start_idx', type=int, default=0, help='id of the first example')
parser.add_argument('--split_count', type=int, default=10000, help='number of split examples')
parser.add_argument('--split_scaling_factors', nargs='+', default=[0, 1, 2, 3], help="scaling factors")
parser.add_argument('--split_dir', type=str, help='directory containing the dataset splits')
parser.add_argument('--estimation_root_dir', type=str, help='directory containing the estimates')
parser.add_argument('--estimate_model_name', type=str, help='name of the model used to compute the estimates')
parser.add_argument('--dataset_dir', help='path to dataset dir')
parser.add_argument('--dataset_name', default='faultdeform', help='dataset name')
parser.add_argument('--image_size', default=512, type=int, help='Image size')
parser.add_argument('-b', '--batch_size', default=1, type=int, help='batch size')
# edge parameters
parser.add_argument('--near_fault_only', action='store_true', help='')
parser.add_argument('--fault_boundary_dir', default=None, help="dir containing the boundary faults images")
parser.add_argument('--fault_boundary_disk', type=int, default=1, help='number of val examples')
# saving parameters
parser.add_argument('--save_metrics', action='store_true', help='')
parser.add_argument('--eval_crop_size', type=int, default=0, help='when > 0 it is mainly for aligning with cosi-corr')
parser.add_argument('--save_images', action='store_true', help='')
parser.add_argument('--metric_filename', default=None, help="filename where to save the metric results")
parser.add_argument('--save_dir', default=None, help="dir where to save the metrics")
namespace, _ = parser.parse_known_args()
config_filename = os.path.join("data/configs/train_fault_deform/", f'{namespace.train_config_name}.yaml')
parser.add_argument('--config', action=YamlConfigAction, default=[config_filename])
return parser.parse_args()
def run_inference(args):
"""
Run the inference pipeline.
"""
# Fix the seed
fix_seed(args.seed)
# Define device
device = torch.device('cuda')
# Get dataloaders
loader, frame_names, crs_meta_datas, transform_meta_datas = get_inference_dataloader(args)
# Define metrics
photometric_metrics = [realEPE]
smoothness_metrics = [L2Smoothness()]
t_start = time.time()
evaluate_inference_from_estimates(
args=args,
device=device,
loader=loader,
photometric_metrics=photometric_metrics,
smoothness_metrics=smoothness_metrics,
frame_names=frame_names,
crs_meta_datas=crs_meta_datas,
transform_meta_datas=transform_meta_datas
)
t_end = time.time()
print(f'\n## Inference on {args.split_name}##')
print(f"total time {t_end - t_start}")
if __name__ == '__main__':
# Load parameters
args = parse_args()
run_inference(args)