Skip to content

Commit 09d84d5

Browse files
committed
Main Logic Setup + Test Scripts + Workflow
1 parent 5839b20 commit 09d84d5

File tree

17 files changed

+293
-1
lines changed

17 files changed

+293
-1
lines changed

.github/workflows/test-fewshot.yml

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
name: Run FewShotLib Tests
2+
3+
on:
4+
workflow_dispatch:
5+
push:
6+
paths:
7+
- '**.py'
8+
- '.github/workflows/test-fewshot.yml'
9+
pull_request:
10+
paths:
11+
- '**.py'
12+
- '.github/workflows/test-fewshot.yml'
13+
14+
jobs:
15+
run-tests:
16+
runs-on: ubuntu-latest
17+
18+
defaults:
19+
run:
20+
shell: bash
21+
22+
steps:
23+
- name: Checkout Repository
24+
uses: actions/checkout@v3
25+
26+
- name: Set up Python
27+
uses: actions/setup-python@v4
28+
with:
29+
python-version: '3.12.7'
30+
31+
- name: Install Dependencies
32+
run: |
33+
python -m pip install --upgrade pip
34+
pip install torch torchvision pillow
35+
36+
- name: Run Feature Extraction Test
37+
run: |
38+
python testing/test1.py
39+
40+
- name: Run Prediction Test
41+
run: |
42+
python testing/test2.py

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,4 +178,5 @@ cython_debug/
178178
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
179179
# refer to https://docs.cursor.com/context/ignore-files
180180
.cursorignore
181-
.cursorindexingignore
181+
.cursorindexingignore
182+
note.txt

__init__.py

Whitespace-only changes.

fewshotlib.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
import torch
2+
import torch.nn as nn
3+
import torchvision.models as models
4+
import torchvision.transforms as T
5+
from PIL import Image
6+
import os
7+
8+
class FewShotClassifier:
9+
def __init__(self, model_path, transform=None, USE_GPU=True):
10+
if USE_GPU and torch.cuda.is_available():
11+
self.device = torch.device("cuda")
12+
else:
13+
self.device = torch.device("cpu")
14+
15+
if not os.path.exists(model_path):
16+
raise FileNotFoundError(f"Model file not found at path: {model_path}")
17+
18+
try:
19+
checkpoint = torch.load(model_path, map_location=self.device)
20+
except Exception as e:
21+
raise RuntimeError(f"Failed to load the model: {e}")
22+
23+
if "backbone" not in checkpoint:
24+
raise KeyError("Missing 'backbone' in the checkpoint.")
25+
if "prototypes" not in checkpoint:
26+
raise KeyError("Missing 'prototypes' in the checkpoint.")
27+
28+
self.backbone = checkpoint["backbone"]
29+
self.image_format = checkpoint.get("image_format", "RGB")
30+
self.encoder = get_encoder(self.backbone, self.image_format).to(self.device)
31+
32+
self.prototypes = checkpoint["prototypes"].to(self.device)
33+
self.labels = checkpoint.get("labels", None)
34+
35+
self.transform = transform
36+
if self.transform is None:
37+
if "transform" in checkpoint:
38+
try:
39+
self.transform = checkpoint["transform"]
40+
except Exception as e:
41+
raise RuntimeError(f"Failed to load transform from checkpoint: {e}")
42+
else:
43+
self.transform = get_default_transform(self.image_format)
44+
45+
def _load_and_preprocess(self, img_path):
46+
if not os.path.exists(img_path):
47+
raise FileNotFoundError(f"Image file not found: {img_path}")
48+
49+
try:
50+
image = Image.open(img_path).convert(self.image_format)
51+
except Exception as e:
52+
raise RuntimeError(f"Failed to load image: {img_path}. Error: {e}")
53+
54+
try:
55+
img_tensor = self.transform(image)
56+
except Exception as e:
57+
raise RuntimeError(f"Transform failed for image {img_path}: {e}")
58+
59+
# Handle grayscale if model expects 3 channels
60+
if img_tensor.ndim == 2:
61+
img_tensor = img_tensor.unsqueeze(0)
62+
elif img_tensor.ndim == 3 and img_tensor.shape[0] == 1 and self.image_format == "RGB":
63+
img_tensor = img_tensor.repeat(3, 1, 1)
64+
elif img_tensor.ndim == 3 and img_tensor.shape[0] == 3 and self.image_format == "L":
65+
img_tensor = img_tensor[0:1]
66+
67+
return img_tensor
68+
69+
def predict(self, img_paths):
70+
single_input = False
71+
if isinstance(img_paths, str):
72+
img_paths = [img_paths]
73+
single_input = True
74+
elif not isinstance(img_paths, list):
75+
raise ValueError("img_paths must be a string or a list of strings")
76+
77+
try:
78+
imgs = [self._load_and_preprocess(p) for p in img_paths]
79+
except Exception as e:
80+
raise RuntimeError(f"Image preprocessing failed: {e}")
81+
82+
batch = torch.stack(imgs).to(self.device)
83+
84+
with torch.no_grad():
85+
try:
86+
features = self.encoder(batch)
87+
except Exception as e:
88+
raise RuntimeError(f"Encoder inference failed: {e}")
89+
90+
if features.ndim == 4:
91+
features = torch.nn.functional.adaptive_avg_pool2d(features, (1, 1))
92+
features = features.view(features.size(0), -1)
93+
94+
features = torch.nn.functional.normalize(features, dim=1)
95+
prototypes = torch.nn.functional.normalize(self.prototypes, dim=1)
96+
97+
sim = torch.matmul(features, prototypes.T)
98+
preds = sim.argmax(dim=1).tolist()
99+
100+
results = []
101+
for i, idx in enumerate(preds):
102+
label = self.labels[idx] if self.labels and idx < len(self.labels) else idx
103+
results.append({
104+
"file": img_paths[i],
105+
"index": idx,
106+
"label": label
107+
})
108+
109+
return results[0] if single_input else results
110+
111+
112+
def get_default_transform(image_format):
113+
if image_format == "L":
114+
return T.Compose([
115+
T.Resize((224, 224)),
116+
T.Grayscale(num_output_channels=1),
117+
T.ToTensor(),
118+
])
119+
else:
120+
return T.Compose([
121+
T.Resize((224, 224)),
122+
T.ToTensor(),
123+
])
124+
125+
126+
def get_encoder(backbone_name, image_format):
127+
if backbone_name == 'resnet18':
128+
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
129+
if image_format == "L":
130+
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
131+
model = nn.Sequential(*list(model.children())[:-1])
132+
133+
elif backbone_name == 'resnet34':
134+
model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)
135+
if image_format == "L":
136+
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
137+
model = nn.Sequential(*list(model.children())[:-1])
138+
139+
elif backbone_name == 'resnet50':
140+
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
141+
if image_format == "L":
142+
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
143+
model = nn.Sequential(*list(model.children())[:-1])
144+
145+
elif backbone_name == 'mobilenet_v2':
146+
model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
147+
if image_format == "L":
148+
model.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False)
149+
model = model.features
150+
151+
elif backbone_name == 'mobilenet_v3_small':
152+
model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1)
153+
if image_format == "L":
154+
model.features[0][0] = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1, bias=False)
155+
model = model.features
156+
157+
elif backbone_name == 'mobilenet_v3_large':
158+
model = models.mobilenet_v3_large(weights=models.MobileNet_V3_Large_Weights.IMAGENET1K_V1)
159+
if image_format == "L":
160+
model.features[0][0] = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1, bias=False)
161+
model = model.features
162+
163+
elif backbone_name == 'efficientnet_b0':
164+
model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
165+
if image_format == "L":
166+
model.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False)
167+
model = model.features
168+
169+
elif backbone_name == 'efficientnet_b1':
170+
model = models.efficientnet_b1(weights=models.EfficientNet_B1_Weights.IMAGENET1K_V1)
171+
if image_format == "L":
172+
model.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False)
173+
model = model.features
174+
175+
elif backbone_name == 'densenet121':
176+
model = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
177+
if image_format == "L":
178+
model.features.conv0 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
179+
model = nn.Sequential(*list(model.features.children()))
180+
181+
elif backbone_name == 'densenet169':
182+
model = models.densenet169(weights=models.DenseNet169_Weights.IMAGENET1K_V1)
183+
if image_format == "L":
184+
model.features.conv0 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
185+
model = nn.Sequential(*list(model.features.children()))
186+
187+
else:
188+
raise ValueError(f"Backbone '{backbone_name}' not supported!")
189+
190+
return model.eval()

testing/__init__.py

Whitespace-only changes.

testing/dataset/1.jpg

1.99 MB
Loading

testing/dataset/2.jpg

3.5 MB
Loading

testing/dataset/3.jpg

8.3 MB
Loading

testing/dataset/4.jpg

3.34 MB
Loading

testing/dataset/5.jpg

1010 KB
Loading

0 commit comments

Comments
 (0)