Skip to content

Commit 48de38c

Browse files
authored
fix amp autocast warnings (#138)
* fix amp autocast warnings * use autocast correctly
1 parent a555474 commit 48de38c

File tree

4 files changed

+18
-22
lines changed

4 files changed

+18
-22
lines changed

clip_benchmark/metrics/image_caption_selection.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ def evaluate(model, dataloader, tokenizer, device, amp=True):
3434
3535
dict of accuracy metrics
3636
"""
37-
autocast = torch.cuda.amp.autocast if amp else suppress
3837
image_score = []
3938
text_score = []
4039
score = []
@@ -52,7 +51,7 @@ def evaluate(model, dataloader, tokenizer, device, amp=True):
5251
# tokenize all texts in the batch
5352
batch_texts_tok_ = tokenizer([text for i, texts in enumerate(batch_texts) for text in texts]).to(device)
5453
# compute the embedding of images and texts
55-
with torch.no_grad(), autocast():
54+
with torch.no_grad(), torch.autocast(device, enabled=amp):
5655
batch_images_emb = F.normalize(model.encode_image(batch_images_), dim=-1).view(B, nim, -1)
5756
batch_texts_emb = F.normalize(model.encode_text(batch_texts_tok_), dim=-1).view(B, nt, -1)
5857
gt = torch.arange(min(nim, nt)).to(device)

clip_benchmark/metrics/linear_probe.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __getitem__(self, i):
5656
return self.features[i], self.targets[i]
5757

5858

59-
def train(dataloader, input_shape, output_shape, weight_decay, lr, epochs, autocast, device, seed):
59+
def train(dataloader, input_shape, output_shape, weight_decay, lr, epochs, amp, device, seed):
6060
torch.manual_seed(seed)
6161
model = torch.nn.Linear(input_shape, output_shape)
6262
devices = [x for x in range(torch.cuda.device_count())]
@@ -81,7 +81,7 @@ def train(dataloader, input_shape, output_shape, weight_decay, lr, epochs, autoc
8181
scheduler(step)
8282

8383
optimizer.zero_grad()
84-
with autocast():
84+
with torch.autocast(device, enabled=amp):
8585
pred = model(x)
8686
loss = criterion(pred, y)
8787

@@ -107,14 +107,14 @@ def train(dataloader, input_shape, output_shape, weight_decay, lr, epochs, autoc
107107
return model
108108

109109

110-
def infer(model, dataloader, autocast, device):
110+
def infer(model, dataloader, amp, device):
111111
true, pred = [], []
112112
with torch.no_grad():
113113
for x, y in tqdm(dataloader):
114114
x = x.to(device)
115115
y = y.to(device)
116116

117-
with autocast():
117+
with torch.autocast(device, enabled=amp):
118118
logits = model(x)
119119

120120
pred.append(logits.cpu())
@@ -125,12 +125,12 @@ def infer(model, dataloader, autocast, device):
125125
return logits, target
126126

127127

128-
def find_peak(wd_list, idxs, train_loader, val_loader, input_shape, output_shape, lr, epochs, autocast, device, verbose, seed):
128+
def find_peak(wd_list, idxs, train_loader, val_loader, input_shape, output_shape, lr, epochs, amp, device, verbose, seed):
129129
best_wd_idx, max_acc = 0, 0
130130
for idx in idxs:
131131
weight_decay = wd_list[idx]
132-
model = train(train_loader, input_shape, output_shape, weight_decay, lr, epochs, autocast, device, seed)
133-
logits, target = infer(model, val_loader, autocast, device)
132+
model = train(train_loader, input_shape, output_shape, weight_decay, lr, epochs, amp, device, seed)
133+
logits, target = infer(model, val_loader, amp, device)
134134
acc1, = accuracy(logits.float(), target.float(), topk=(1,))
135135
if verbose:
136136
print(f"Valid accuracy with weight_decay {weight_decay}: {acc1}")
@@ -150,7 +150,6 @@ def evaluate(model, train_dataloader, dataloader, fewshot_k, batch_size, num_wor
150150
os.mkdir(feature_dir)
151151

152152
featurizer = Featurizer(model, normalize).cuda()
153-
autocast = torch.cuda.amp.autocast if amp else suppress
154153
if not os.path.exists(os.path.join(feature_dir, 'targets_train.pt')):
155154
# now we have to cache the features
156155
devices = [x for x in range(torch.cuda.device_count())]
@@ -168,7 +167,7 @@ def evaluate(model, train_dataloader, dataloader, fewshot_k, batch_size, num_wor
168167
for images, target in tqdm(loader):
169168
images = images.to(device)
170169

171-
with autocast():
170+
with torch.autocast(device, enabled=amp):
172171
feature = featurizer(images)
173172

174173
features.append(feature.cpu())
@@ -270,11 +269,11 @@ def evaluate(model, train_dataloader, dataloader, fewshot_k, batch_size, num_wor
270269
wd_list = np.logspace(-6, 2, num=97).tolist()
271270
wd_list_init = np.logspace(-6, 2, num=7).tolist()
272271
wd_init_idx = [i for i, val in enumerate(wd_list) if val in wd_list_init]
273-
peak_idx = find_peak(wd_list, wd_init_idx, feature_train_loader, feature_val_loader, input_shape, output_shape, lr, epochs, autocast, device, verbose, seed)
272+
peak_idx = find_peak(wd_list, wd_init_idx, feature_train_loader, feature_val_loader, input_shape, output_shape, lr, epochs, amp, device, verbose, seed)
274273
step_span = 8
275274
while step_span > 0:
276275
left, right = max(peak_idx - step_span, 0), min(peak_idx + step_span, len(wd_list)-1)
277-
peak_idx = find_peak(wd_list, [left, peak_idx, right], feature_train_loader, feature_val_loader, input_shape, output_shape, lr, epochs, autocast, device, verbose, seed)
276+
peak_idx = find_peak(wd_list, [left, peak_idx, right], feature_train_loader, feature_val_loader, input_shape, output_shape, lr, epochs, amp, device, verbose, seed)
278277
step_span //= 2
279278
best_wd = wd_list[peak_idx]
280279
if fewshot_k < 0:
@@ -288,8 +287,8 @@ def evaluate(model, train_dataloader, dataloader, fewshot_k, batch_size, num_wor
288287
best_wd = 0
289288
train_loader = feature_train_loader
290289

291-
final_model = train(train_loader, input_shape, output_shape, best_wd, lr, epochs, autocast, device, seed)
292-
logits, target = infer(final_model, feature_test_loader, autocast, device)
290+
final_model = train(train_loader, input_shape, output_shape, best_wd, lr, epochs, amp, device, seed)
291+
logits, target = infer(final_model, feature_test_loader, amp, device)
293292
pred = logits.argmax(axis=1)
294293

295294
# measure accuracy

clip_benchmark/metrics/zeroshot_classification.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from sklearn.metrics import classification_report, balanced_accuracy_score
1313

1414

15+
1516
def zero_shot_classifier(model, tokenizer, classnames, templates, device, amp=True):
1617
"""
1718
This function returns zero-shot vectors for each class in order
@@ -36,8 +37,7 @@ def zero_shot_classifier(model, tokenizer, classnames, templates, device, amp=Tr
3637
torch.Tensor of shape (N,C) where N is the number
3738
of templates, and C is the number of classes.
3839
"""
39-
autocast = torch.cuda.amp.autocast if amp else suppress
40-
with torch.no_grad(), autocast():
40+
with torch.no_grad(), torch.autocast(device, enabled=amp):
4141
zeroshot_weights = []
4242
for classname in tqdm(classnames):
4343
if type(templates) == dict:
@@ -100,7 +100,6 @@ def run_classification(model, classifier, dataloader, device, amp=True):
100100
- pred (N, C) are the logits
101101
- true (N,) are the actual classes
102102
"""
103-
autocast = torch.cuda.amp.autocast if amp else suppress
104103
pred = []
105104
true = []
106105
nb = 0
@@ -109,7 +108,7 @@ def run_classification(model, classifier, dataloader, device, amp=True):
109108
images = images.to(device)
110109
target = target.to(device)
111110

112-
with autocast():
111+
with torch.autocast(device, enabled=amp):
113112
# predict
114113
image_features = model.encode_image(images)
115114
image_features = F.normalize(image_features, dim=-1)

clip_benchmark/metrics/zeroshot_retrieval.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ def evaluate(model, dataloader, tokenizer, device, amp=True, recall_k_list=[5])
3939
batch_texts_emb_list = []
4040
# for each text, we collect the corresponding image index, as each image can have multiple corresponding texts
4141
texts_image_index = []
42-
dataloader = dataloader_with_indices(dataloader)
43-
autocast = torch.cuda.amp.autocast if amp else suppress
42+
dataloader = dataloader_with_indices(dataloader)
4443
for batch_images, batch_texts, inds in tqdm(dataloader):
4544
batch_images = batch_images.to(device)
4645
# tokenize all texts in the batch
@@ -49,7 +48,7 @@ def evaluate(model, dataloader, tokenizer, device, amp=True, recall_k_list=[5])
4948
batch_texts_image_index = [ind for ind, texts in zip(inds, batch_texts) for text in texts]
5049

5150
# compute the embedding of images and texts
52-
with torch.no_grad(), autocast():
51+
with torch.no_grad(), torch.autocast(device, enabled=amp):
5352
batch_images_emb = F.normalize(model.encode_image(batch_images), dim=-1)
5453
batch_texts_emb = F.normalize(model.encode_text(batch_texts_tok), dim=-1)
5554

0 commit comments

Comments
 (0)