-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathurl_list_infer.py
More file actions
77 lines (62 loc) · 2.4 KB
/
url_list_infer.py
File metadata and controls
77 lines (62 loc) · 2.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import requests
import numpy as np
import pathlib
import pandas as pd
from io import BytesIO
from PIL import Image
from torch.utils.data import Dataset
from transformers import pipeline
from tqdm import tqdm
from multiprocessing import Manager
from typing import List
from metadata_utils import get_metadata_img
MODEL_NAME = 'MITLL/LADI-v2-classifier-small'
labels = ['trees_any',
'water_any',
'trees_damage',
'debris_any',
'roads_any',
'flooding_any',
'buildings_any',
'buildings_affected_or_greater',
'bridges_any',
'flooding_structures',
'roads_damage']
class URLListDataset(Dataset):
def __init__(self, urls: List[str]):
self.urls = urls
self.manager = Manager()
self.metadata_map = self.manager.dict()
def __len__(self):
return len(self.urls)
def __getitem__(self, idx):
url = self.urls[idx]
response = requests.get(url)
if response.status_code == 200:
img = Image.open(BytesIO(response.content))
self.metadata_map[url] = get_metadata_img(img)
return img
else:
raise Exception(f"Failed to download image at: {url}")
def postprocess_output(infer_output):
output_dict = {}
for response in infer_output:
if response['label'] in labels:
output_dict[response['label']] = response['score']
return dict(sorted(output_dict.items()))
if __name__ == "__main__":
pipe = pipeline(model=MODEL_NAME,
task='image-classification',
function_to_apply='sigmoid',
device=0,
num_workers=40)
urls = ["http://s3.amazonaws.com/fema-cap-imagery/Images/CAP_-_Spring_Storms_2024/Source/24-1-5100_OKWG/A0003_AerialOblique/2415100A0003_Marietta_Area__310.JPG", "http://s3.amazonaws.com/fema-cap-imagery/Images/CAP_-_Spring_Storms_2024/Source/24-1-5100_OKWG/A0003_AerialOblique/2415100A0003_Marietta_Area__301.JPG"]
ds = URLListDataset(urls)
outputs = []
for i, output in tqdm(enumerate(pipe(ds, batch_size=12, top_k=20))):
classes = postprocess_output(output)
curr_filename = urls[i]
img_metadata = ds.metadata_map[curr_filename]
outputs.append({'file_path': curr_filename, **classes, **img_metadata})
df = pd.DataFrame(data=outputs)
df.to_csv('outputs.csv', index=False)