Skip to content

Commit 3107256

Browse files
committed
style: format inference.py with ruff
- Remove trailing whitespace - Fix inconsistent indentation - Adjust line spacing - Format quotes to double quotes - Add missing blank line at end of file
1 parent c3e876b commit 3107256

File tree

2 files changed

+47
-50
lines changed

2 files changed

+47
-50
lines changed

src/inference.py

Lines changed: 36 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -34,28 +34,27 @@ def forward(self, x):
3434
return x_res.squeeze(0)
3535

3636

37-
3837
def download_checkpoint_from_wandb(artifact_path, project_name="ghost-irim"):
3938
print(f"Downloading checkpoint from W&B: {artifact_path}")
40-
39+
4140
wandb_api_key = os.environ.get("WANDB_API_KEY")
4241
if wandb_api_key:
4342
wandb.login(key=wandb_api_key)
44-
43+
4544
run = wandb.init(project=project_name, job_type="inference")
46-
45+
4746
artifact = run.use_artifact(artifact_path, type="model")
4847
artifact_dir = artifact.download()
49-
48+
5049
artifact_path_obj = Path(artifact_dir)
5150
checkpoint_files = list(artifact_path_obj.glob("*.ckpt"))
52-
51+
5352
if not checkpoint_files:
5453
raise FileNotFoundError(f"No .ckpt file found in artifact directory: {artifact_dir}")
55-
54+
5655
checkpoint_path = checkpoint_files[0]
5756
print(f"Checkpoint downloaded to: {checkpoint_path}")
58-
57+
5958
return checkpoint_path
6059

6160

@@ -69,7 +68,7 @@ def main():
6968
else:
7069
device = "cpu"
7170
print(f"Using device: {device}")
72-
71+
7372
model_name = config.model.name
7473
mask_size = config.inference.get("mask_size", 224)
7574
image_size = 299 if model_name == "inception_v3" else 224
@@ -84,25 +83,19 @@ def main():
8483
test_data = dataset["test"]
8584
test_dataset = ForestDataset(test_data["paths"], test_data["labels"], transform=transforms)
8685

87-
test_loader = torch.utils.data.DataLoader(
88-
test_dataset, batch_size=1, shuffle=False, num_workers=2
89-
)
86+
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2)
9087

9188
num_classes = len(label_map)
9289

9390
# =========================== MODEL LOADING ==================================== #
9491
wandb_artifact = config.inference.get("wandb_artifact", None)
95-
92+
9693
if wandb_artifact:
9794
wandb_project = config.inference.get("wandb_project", "ghost-irim")
9895
checkpoint_path = download_checkpoint_from_wandb(wandb_artifact, wandb_project)
9996
else:
100-
raise FileNotFoundError(
101-
f"Checkpoint not found at {checkpoint_path}. "
102-
"Please set 'wandb_artifact' in config.yaml to download from W&B, "
103-
"or ensure the local checkpoint exists."
104-
)
105-
97+
raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}. Please set 'wandb_artifact' in config.yaml to download from W&B, or ensure the local checkpoint exists.")
98+
10699
print(f"Loading model from: {checkpoint_path}")
107100

108101
classifier = ClassifierModule.load_from_checkpoint(
@@ -116,15 +109,14 @@ def main():
116109
norm_std = [0.5, 0.5, 0.5]
117110

118111
seg_model = SegmentationWrapper(
119-
classifier,
112+
classifier,
120113
mask_size=mask_size,
121-
mean=None, # TODO: fix
122-
std=None, # TODO: fix
123-
input_rescale=True # Expects 0-255 input, scales to 0-1 internally
114+
mean=None, # TODO: fix
115+
std=None, # TODO: fix
116+
input_rescale=True, # Expects 0-255 input, scales to 0-1 internally
124117
).to(device)
125118
seg_model.eval()
126119

127-
128120
# =========================== EXPORT TO ONNX =================================== #
129121
if config.inference.get("export_onnx", False):
130122
dummy_input = torch.randn(1, 3, image_size, image_size, device=device)
@@ -140,25 +132,25 @@ def main():
140132
do_constant_folding=True,
141133
)
142134
print(f"Exported model to {onnx_path.resolve()}")
143-
135+
144136
# Add metadata
145137
model_onnx = onnx.load(onnx_path)
146-
138+
147139
class_names = {v: k for k, v in label_map.items()}
148-
140+
149141
def add_meta(key, value):
150-
meta = model_onnx.metadata_props.add()
151-
meta.key = key
152-
meta.value = json.dumps(value)
142+
meta = model_onnx.metadata_props.add()
143+
meta.key = key
144+
meta.value = json.dumps(value)
153145

154-
add_meta('model_type', 'Segmentor')
155-
add_meta('class_names', class_names)
156-
add_meta('resolution', 20)
157-
add_meta('tiles_size', image_size)
158-
add_meta('tiles_overlap', 0)
146+
add_meta("model_type", "Segmentor")
147+
add_meta("class_names", class_names)
148+
add_meta("resolution", 20)
149+
add_meta("tiles_size", image_size)
150+
add_meta("tiles_overlap", 0)
159151

160152
onnx.save(model_onnx, onnx_path)
161-
153+
162154
if wandb.run is not None:
163155
onnx_artifact = wandb.Artifact(
164156
name=f"segmentation-model-{model_name}",
@@ -170,29 +162,29 @@ def add_meta(key, value):
170162
"image_size": image_size,
171163
"format": "onnx",
172164
"opset_version": 17,
173-
}
165+
},
174166
)
175167
onnx_artifact.add_file(str(onnx_path))
176168
wandb.log_artifact(onnx_artifact)
177169
print(f"ONNX model uploaded to W&B artifacts as 'segmentation-model-{model_name}'")
178170
else:
179171
print("Warning: W&B run not initialized. ONNX model not uploaded to artifacts.")
180-
172+
181173
# =========================== INFERENCE LOOP =================================== #
182174
print(f"Running inference on {len(test_loader)} samples...")
183175
all_preds = []
184176
all_targets = []
185177

186178
with torch.no_grad():
187-
for i, batch in enumerate(tqdm(test_loader)):
179+
for batch in tqdm(test_loader):
188180
imgs, labels = batch
189181
imgs = imgs.to(device)
190182
labels = labels.to(device)
191183

192184
masks = seg_model(imgs)
193-
185+
194186
probs = masks[:, :, 0, 0]
195-
187+
196188
all_preds.append(probs)
197189
all_targets.append(labels)
198190

@@ -202,7 +194,7 @@ def add_meta(key, value):
202194
# =========================== METRICS & LOGGING ================================ #
203195
if wandb.run is not None:
204196
print("Calculating and logging metrics...")
205-
197+
206198
metrics_per_experiment = count_metrics(all_targets, all_preds)
207199
print(f"Test Metrics: {metrics_per_experiment}")
208200
for key, value in metrics_per_experiment.items():
@@ -222,7 +214,7 @@ def add_meta(key, value):
222214

223215
plots_dir = Path("src/plots")
224216
plots_dir.mkdir(exist_ok=True, parents=True)
225-
217+
226218
get_confusion_matrix(all_preds, all_targets, class_names=list(label_map.keys()))
227219
get_roc_auc_curve(all_preds, all_targets, class_names=list(label_map.keys()))
228220
get_precision_recall_curve(all_preds, all_targets, class_names=list(label_map.keys()))
@@ -234,5 +226,6 @@ def add_meta(key, value):
234226
else:
235227
print("W&B run not active. Skipping metrics logging.")
236228

229+
237230
if __name__ == "__main__":
238231
main()

src/models/segmentation_wrapper.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,38 @@
11
import torch
22
import torch.nn as nn
33

4+
45
class SegmentationWrapper(nn.Module):
6+
mean: torch.Tensor
7+
std: torch.Tensor
8+
59
def __init__(self, classifier: nn.Module, mask_size: int = 224, mean=None, std=None, input_rescale=False):
610
super().__init__()
711
self.classifier = classifier.eval()
812
self.mask_size = mask_size
913
self.input_rescale = input_rescale
10-
14+
1115
if mean is None:
1216
mean = [0.0, 0.0, 0.0]
1317
if std is None:
1418
std = [1.0, 1.0, 1.0]
15-
19+
1620
self.register_buffer("mean", torch.tensor(mean).view(1, 3, 1, 1))
1721
self.register_buffer("std", torch.tensor(std).view(1, 3, 1, 1))
1822

1923
@torch.no_grad()
2024
def forward(self, x: torch.Tensor) -> torch.Tensor:
2125
if self.input_rescale:
2226
x = x / 255.0
23-
27+
2428
x = (x - self.mean) / self.std
25-
29+
2630
logits = self.classifier(x)
27-
31+
2832
# Ensure logits is 2D [batch, num_classes]
2933
if logits.dim() == 1:
3034
logits = logits.unsqueeze(0)
31-
35+
3236
probs = torch.softmax(logits, dim=1)
33-
mask = probs[:, :, None, None].expand(-1, -1, self.mask_size, self.mask_size) # (B, C, H, W)
37+
mask = probs[:, :, None, None].expand(-1, -1, self.mask_size, self.mask_size) # (B, C, H, W)
3438
return mask

0 commit comments

Comments
 (0)