Skip to content

Commit bc2a6f5

Browse files
authored
fix issue #125 (#139)
1 parent 0be8d15 commit bc2a6f5

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

clip_benchmark/metrics/linear_probe.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,13 @@ def evaluate(model, train_dataloader, dataloader, fewshot_k, batch_size, num_wor
277277
peak_idx = find_peak(wd_list, [left, peak_idx, right], feature_train_loader, feature_val_loader, input_shape, output_shape, lr, epochs, autocast, device, verbose, seed)
278278
step_span //= 2
279279
best_wd = wd_list[peak_idx]
280-
train_loader = feature_train_val_loader
280+
if fewshot_k < 0:
281+
# if we are doing full training, we use the full training set (train+val)
282+
train_loader = feature_train_val_loader
283+
else:
284+
# if we are doing few-shot learning, we use the few-shot training set only
285+
# as adding the validation set will train on more data than intended
286+
train_loader = feature_train_loader
281287
else:
282288
best_wd = 0
283289
train_loader = feature_train_loader

0 commit comments

Comments
 (0)