Skip to content

Commit 246233f

Browse files
committed
update
1 parent d163b15 commit 246233f

File tree

13 files changed

+1374
-0
lines changed

13 files changed

+1374
-0
lines changed

IndexSearch.py

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
import faiss
2+
import torch
3+
import numpy as np
4+
import os
5+
import argparse
6+
import pandas as pd
7+
import ast
8+
import itertools
9+
from PIL import Image
10+
from geopy.distance import geodesic
11+
from transformers import CLIPImageProcessor, CLIPModel
12+
from utils.utils import MP16Dataset, im2gps3kDataset, yfcc4kDataset
13+
from torch.utils.data import DataLoader
14+
from tqdm import tqdm
15+
from torch.utils.data import Dataset, DataLoader
16+
from datetime import datetime
17+
18+
def build_index(args):
19+
if args.index == 'g3':
20+
model = torch.load('./checkpoints/g3.pth', map_location='cuda:0')
21+
model.requires_grad_(False)
22+
vision_processor = model.vision_processor
23+
dataset = MP16Dataset(vision_processor = model.vision_processor, text_processor = None)
24+
index_flat = faiss.IndexFlatIP(768*3)
25+
dataloader = DataLoader(dataset, batch_size=1024, shuffle=False, num_workers=16, pin_memory=True, prefetch_factor=3)
26+
model.eval()
27+
t= tqdm(dataloader)
28+
for i, (images, texts, longitude, latitude) in enumerate(t):
29+
images = images.to(args.device)
30+
vision_output = model.vision_model(images)[1]
31+
image_embeds = model.vision_projection(vision_output)
32+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
33+
34+
image_text_embeds = model.vision_projection_else_1(model.vision_projection(vision_output))
35+
image_text_embeds = image_text_embeds / image_text_embeds.norm(p=2, dim=-1, keepdim=True)
36+
37+
image_location_embeds = model.vision_projection_else_2(model.vision_projection(vision_output))
38+
image_location_embeds = image_location_embeds / image_location_embeds.norm(p=2, dim=-1, keepdim=True)
39+
40+
image_embeds = torch.cat([image_embeds, image_text_embeds, image_location_embeds], dim=1)
41+
index_flat.add(image_embeds.cpu().detach().numpy())
42+
43+
faiss.write_index(index_flat, f'./index/{args.index}.index')
44+
45+
def search_index(args, index, topk):
46+
print('start searching...')
47+
if args.dataset == 'im2gps3k':
48+
if args.index == 'g3':
49+
model = torch.load('./checkpoints/g3.pth', map_location='cuda:0')
50+
model.requires_grad_(False)
51+
vision_processor = model.vision_processor
52+
dataset = im2gps3kDataset(vision_processor = vision_processor, text_processor = None)
53+
dataloader = DataLoader(dataset, batch_size=256, shuffle=False, num_workers=16, pin_memory=True, prefetch_factor=5)
54+
test_images_embeds = np.empty((0, 768*3))
55+
model.eval()
56+
print('generating embeds...')
57+
t = tqdm(dataloader)
58+
for i, (images, texts, longitude, latitude) in enumerate(t):
59+
images = images.to(args.device)
60+
vision_output = model.vision_model(images)[1]
61+
image_embeds = model.vision_projection(vision_output)
62+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
63+
64+
image_text_embeds = model.vision_projection_else_1(model.vision_projection(vision_output))
65+
image_text_embeds = image_text_embeds / image_text_embeds.norm(p=2, dim=-1, keepdim=True)
66+
67+
image_location_embeds = model.vision_projection_else_2(model.vision_projection(vision_output))
68+
image_location_embeds = image_location_embeds / image_location_embeds.norm(p=2, dim=-1, keepdim=True)
69+
70+
image_embeds = torch.cat([image_embeds, image_text_embeds, image_location_embeds], dim=1)
71+
test_images_embeds = np.concatenate([test_images_embeds, image_embeds.cpu().detach().numpy()], axis=0)
72+
print(test_images_embeds.shape)
73+
test_images_embeds = test_images_embeds.reshape(-1, 768*3)
74+
print('start searching NN...')
75+
D, I = index.search(test_images_embeds, topk)
76+
print(I)
77+
return D, I
78+
elif args.dataset == 'yfcc4k':
79+
if args.index == 'g3':
80+
model = torch.load('./checkpoints/g3.pth', map_location='cuda:0')
81+
model.requires_grad_(False)
82+
vision_processor = model.vision_processor
83+
dataset = yfcc4kDataset(vision_processor = vision_processor, text_processor = None)
84+
dataloader = DataLoader(dataset, batch_size=256, shuffle=False, num_workers=16, pin_memory=True, prefetch_factor=5)
85+
test_images_embeds = np.empty((0, 768*3))
86+
model.eval()
87+
print('generating embeds...')
88+
t = tqdm(dataloader)
89+
for i, (images, texts, longitude, latitude) in enumerate(t):
90+
images = images.to(args.device)
91+
vision_output = model.vision_model(images)[1]
92+
image_embeds = model.vision_projection(vision_output)
93+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
94+
95+
image_text_embeds = model.vision_projection_else_1(model.vision_projection(vision_output))
96+
image_text_embeds = image_text_embeds / image_text_embeds.norm(p=2, dim=-1, keepdim=True)
97+
98+
image_location_embeds = model.vision_projection_else_2(model.vision_projection(vision_output))
99+
image_location_embeds = image_location_embeds / image_location_embeds.norm(p=2, dim=-1, keepdim=True)
100+
101+
image_embeds = torch.cat([image_embeds, image_text_embeds, image_location_embeds], dim=1)
102+
test_images_embeds = np.concatenate([test_images_embeds, image_embeds.cpu().detach().numpy()], axis=0)
103+
print(test_images_embeds.shape)
104+
test_images_embeds = test_images_embeds.reshape(-1, 768*3)
105+
print('start searching NN...')
106+
D, I = index.search(test_images_embeds, topk)
107+
return D, I
108+
109+
class GeoImageDataset(Dataset):
110+
def __init__(self, dataframe, img_folder, topn, vision_processor, database_df, I):
111+
self.dataframe = dataframe
112+
self.img_folder = img_folder
113+
self.topn = topn
114+
self.vision_processor = vision_processor
115+
self.database_df = database_df
116+
self.I = I
117+
118+
def __len__(self):
119+
return len(self.dataframe)
120+
121+
def __getitem__(self, idx):
122+
img_path = f'{self.img_folder}/{self.dataframe.loc[idx, "IMG_ID"]}'
123+
image = Image.open(img_path).convert('RGB')
124+
image = self.vision_processor(images=image, return_tensors='pt')['pixel_values'].reshape(3,224,224)
125+
126+
gps_data = []
127+
search_top1_latitude, search_top1_longitude = self.database_df.loc[self.I[idx][0], ['LAT', 'LON']].values
128+
rag_5, rag_10, rag_15, zs = [],[],[],[]
129+
for j in range(self.topn):
130+
gps_data.extend([
131+
float(self.dataframe.loc[idx, f'5_rag_{j}_latitude']),
132+
float(self.dataframe.loc[idx, f'5_rag_{j}_longitude']),
133+
float(self.dataframe.loc[idx, f'10_rag_{j}_latitude']),
134+
float(self.dataframe.loc[idx, f'10_rag_{j}_longitude']),
135+
float(self.dataframe.loc[idx, f'15_rag_{j}_latitude']),
136+
float(self.dataframe.loc[idx, f'15_rag_{j}_longitude']),
137+
float(self.dataframe.loc[idx, f'zs_{j}_latitude']),
138+
float(self.dataframe.loc[idx, f'zs_{j}_longitude']),
139+
search_top1_latitude,
140+
search_top1_longitude
141+
])
142+
143+
gps_data = np.array(gps_data).reshape(-1, 2)
144+
return image, gps_data, idx
145+
146+
def evaluate(args, I):
147+
print('start evaluation')
148+
if args.database == 'mp16':
149+
database = args.database_df
150+
df = args.dataset_df
151+
df['NN_idx'] = I[:, 0]
152+
df['LAT_pred'] = df.apply(lambda x: database.loc[x['NN_idx'],'LAT'], axis=1)
153+
df['LON_pred'] = df.apply(lambda x: database.loc[x['NN_idx'],'LON'], axis=1)
154+
155+
df_llm = pd.read_csv(f'./data/{args.dataset}/{args.dataset}_prediction.csv')
156+
model = torch.load('./checkpoints/g3.pth', map_location='cuda:0')
157+
topn = 5 # number of candidates
158+
159+
dataset = GeoImageDataset(df_llm, f'./data/{args.dataset}/images', topn, vision_processor=model.vision_processor, database_df=database, I=I)
160+
data_loader = DataLoader(dataset, batch_size=256, shuffle=False, num_workers=16, pin_memory=True)
161+
162+
for images, gps_batch, indices in tqdm(data_loader):
163+
images = images.to(args.device)
164+
image_embeds = model.vision_projection_else_2(model.vision_projection(model.vision_model(images)[1]))
165+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) # b, 768
166+
167+
gps_batch = gps_batch.to(args.device)
168+
gps_input = gps_batch.clone().detach()
169+
b, c, _ = gps_input.shape
170+
gps_input = gps_input.reshape(b*c, 2)
171+
location_embeds = model.location_encoder(gps_input)
172+
location_embeds = model.location_projection_else(location_embeds.reshape(b*c, -1))
173+
location_embeds = location_embeds / location_embeds.norm(p=2, dim=-1, keepdim=True)
174+
location_embeds = location_embeds.reshape(b, c, -1) # b, c, 768
175+
176+
similarity = torch.matmul(image_embeds.unsqueeze(1), location_embeds.permute(0, 2, 1)) # b, 1, c
177+
similarity = similarity.squeeze(1).cpu().detach().numpy()
178+
max_idxs = np.argmax(similarity, axis=1)
179+
180+
# update DataFrame
181+
for i, max_idx in enumerate(max_idxs):
182+
final_idx = indices[i]
183+
final_idx = final_idx.item()
184+
final_latitude, final_longitude = gps_batch[i][max_idx]
185+
final_latitude, final_longitude = final_latitude.item(), final_longitude.item()
186+
if final_latitude < -90 or final_latitude > 90:
187+
final_latitude = 0
188+
if final_longitude < -180 or final_longitude > 180:
189+
final_longitude = 0
190+
df.loc[final_idx, 'LAT_pred'] = final_latitude
191+
df.loc[final_idx, 'LON_pred'] = final_longitude
192+
193+
df['geodesic'] = df.apply(lambda x: geodesic((x['LAT'], x['LON']), (x['LAT_pred'], x['LON_pred'])).km, axis=1)
194+
print(df.head())
195+
df.to_csv(f'./data/{args.dataset}_{args.index}_results.csv', index=False)
196+
197+
# 1, 25, 200, 750, 2500 km level
198+
print('2500km level: ', df[df['geodesic'] < 2500].shape[0] / df.shape[0])
199+
print('750km level: ', df[df['geodesic'] < 750].shape[0] / df.shape[0])
200+
print('200km level: ', df[df['geodesic'] < 200].shape[0] / df.shape[0])
201+
print('25km level: ', df[df['geodesic'] < 25].shape[0] / df.shape[0])
202+
print('1km level: ', df[df['geodesic'] < 1].shape[0] / df.shape[0])
203+
204+
if __name__ == '__main__':
205+
206+
res = faiss.StandardGpuResources()
207+
208+
parser = argparse.ArgumentParser()
209+
parser.add_argument('--index', type=str, default='g3')
210+
parser.add_argument('--dataset', type=str, default='im2gps3k')
211+
parser.add_argument('--database', type=str, default='mp16')
212+
args = parser.parse_args()
213+
if args.dataset == 'im2gps3k':
214+
args.dataset_df = pd.read_csv('./data/im2gps3k/im2gps3k_places365.csv')
215+
elif args.dataset == 'yfcc4k':
216+
args.dataset_df = pd.read_csv('./data/yfcc4k/yfcc4k_places365.csv')
217+
218+
if args.database == 'mp16':
219+
args.database_df = pd.read_csv('./data/MP16_Pro_filtered.csv')
220+
221+
args.device = "cuda" if torch.cuda.is_available() else "cpu"
222+
223+
if not os.path.exists(f'./index'): os.makedirs(f'./index')
224+
if not os.path.exists(f'./index/{args.index}.index'):
225+
build_index(args)
226+
else:
227+
# gpu_index_flat = faiss.index_cpu_to_gpu(res, 0, index)
228+
if not os.path.exists(f'./index/I_{args.index}_{args.dataset}.npy'):
229+
index = faiss.read_index(f'./index/{args.index}.index')
230+
print('read index success')
231+
D,I = search_index(args, index, 20)
232+
np.save(f'./index/D_{args.index}_{args.dataset}.npy', D)
233+
np.save(f'./index/I_{args.index}_{args.dataset}.npy', I)
234+
else:
235+
D = np.load(f'./index/D_{args.index}_{args.dataset}.npy')
236+
I = np.load(f'./index/I_{args.index}_{args.dataset}.npy')
237+
evaluate(args, I)
238+

README.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,46 @@ This is the code repository for paper "G3: An Effective and Adaptive Framework f
44

55
You can download the images and metadata of MP16-Pro from huggingface: [Jia-py/MP16-Pro](https://huggingface.co/datasets/Jia-py/MP16-Pro/tree/main)
66

7+
# Data
8+
9+
IM2GPS3K: [images](http://www.mediafire.com/file/7ht7sn78q27o9we/im2gps3ktest.zip) | [metadata](https://raw.githubusercontent.com/TIBHannover/GeoEstimation/original_tf/meta/im2gps3k_places365.csv)
10+
11+
YFCC4K: [images](http://www.mediafire.com/file/3og8y3o6c9de3ye/yfcc4k.zip) | [metadata](https://github.com/TIBHannover/GeoEstimation/releases/download/pytorch/yfcc25600_places365.csv)
12+
713
# Environment Setting
814

15+
```bash
16+
# test on cuda12.0
17+
conda create -n g3 python=3.9
18+
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu121
19+
pip install transformers accelerate huggingface_hub pandas
20+
```
21+
922
# Running samples
1023

24+
1. Geo-alignment
25+
26+
You can run `python run_G3.py` to train the model.
27+
28+
2. Geo-diversification
29+
30+
First, you need to build the index file using `python IndexSearch.py`.
31+
32+
Parameters in IndexSearch.py
33+
- index name --> which model you want to use for embedding
34+
- dataset --> im2gps3k or yfcc4k
35+
- database --> default mp16
36+
37+
Then, you also need to construct index for negative samples by modifying images_embeds to -1 * images_embeds
38+
39+
Then, you can run `llm_predict_hf.py` or `llm_predict.py` to generate llm predictions.
40+
41+
After that, `running aggregate_llm_predictions.py` to aggregate the predictions.
42+
43+
3. Geo-verification
44+
45+
`python IndexSearch.py --index=g3 --dataset=im2gps3k or yfcc4k` to verificate predictions and evaluate.
46+
1147
# Citation
1248

1349
```bib

aggregate_llm_predictions.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import pandas as pd
2+
import re
3+
import ast
4+
from tqdm import tqdm
5+
6+
df_raw = pd.read_csv('./data/im2gps3k/im2gps3k_places365.csv')
7+
zs_df = pd.read_csv('./data/im2gps3k/llm_predict_results_zs.csv')
8+
rag_5_df = pd.read_csv('./data/im2gps3k/5_llm_predict_results_rag.csv')
9+
rag_10_df = pd.read_csv('./data/im2gps3k/10_llm_predict_results_rag.csv')
10+
rag_15_df = pd.read_csv('./data/im2gps3k/15_llm_predict_results_rag.csv')
11+
12+
pattern = r'[-+]?\d+\.\d+'
13+
14+
for i in tqdm(range(zs_df.shape[0])):
15+
response = zs_df.loc[i, 'response']
16+
response = ast.literal_eval(response)
17+
for idx, content in enumerate(response):
18+
try:
19+
match = re.findall(pattern, content)
20+
latitude = match[0]
21+
longitude = match[1]
22+
df_raw.loc[i, f'zs_{idx}_latitude'] = latitude
23+
df_raw.loc[i, f'zs_{idx}_longitude'] = longitude
24+
except:
25+
df_raw.loc[i, f'zs_{idx}_latitude'] = '0.0'
26+
df_raw.loc[i, f'zs_{idx}_longitude'] = '0.0'
27+
28+
for i in tqdm(range(df_raw.shape[0])):
29+
response = rag_5_df.loc[i, 'rag_response']
30+
response = ast.literal_eval(response)
31+
for idx, content in enumerate(response):
32+
try:
33+
match = re.findall(pattern, content)
34+
latitude = match[0]
35+
longitude = match[1]
36+
df_raw.loc[i, f'5_rag_{idx}_latitude'] = latitude
37+
df_raw.loc[i, f'5_rag_{idx}_longitude'] = longitude
38+
except:
39+
df_raw.loc[i, f'5_rag_{idx}_latitude'] = '0.0'
40+
df_raw.loc[i, f'5_rag_{idx}_longitude'] = '0.0'
41+
42+
for i in tqdm(range(df_raw.shape[0])):
43+
response = rag_10_df.loc[i, 'rag_response']
44+
response = ast.literal_eval(response)
45+
for idx, content in enumerate(response):
46+
try:
47+
match = re.findall(pattern, content)
48+
latitude = match[0]
49+
longitude = match[1]
50+
df_raw.loc[i, f'10_rag_{idx}_latitude'] = latitude
51+
df_raw.loc[i, f'10_rag_{idx}_longitude'] = longitude
52+
except:
53+
df_raw.loc[i, f'10_rag_{idx}_latitude'] = '0.0'
54+
df_raw.loc[i, f'10_rag_{idx}_longitude'] = '0.0'
55+
56+
for i in tqdm(range(df_raw.shape[0])):
57+
response = rag_15_df.loc[i, 'rag_response']
58+
response = ast.literal_eval(response)
59+
for idx, content in enumerate(response):
60+
try:
61+
match = re.findall(pattern, content)
62+
latitude = match[0]
63+
longitude = match[1]
64+
df_raw.loc[i, f'15_rag_{idx}_latitude'] = latitude
65+
df_raw.loc[i, f'15_rag_{idx}_longitude'] = longitude
66+
except:
67+
df_raw.loc[i, f'15_rag_{idx}_latitude'] = '0.0'
68+
df_raw.loc[i, f'15_rag_{idx}_longitude'] = '0.0'
69+
70+
df_raw.to_csv('./data/im2gps3k/im2gps3k_prediction.csv', index=False)

data/im2gps3k/put_data_here.txt

Whitespace-only changes.

data/yfcc4k/put_data_here.txt

Whitespace-only changes.

index/put_index_here.txt

Whitespace-only changes.

0 commit comments

Comments
 (0)