Skip to content

Commit cf7d738

Browse files
Update sgn detection training
1 parent 8dc5a4c commit cf7d738

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

scripts/la-vision/train_sgn_detection.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def get_paths(split):
5252

5353
def train():
5454

55-
model_name = "sgn-low-res-detection-v5"
55+
model_name = "sgn-low-res-detection-v6"
5656

5757
train_paths, train_label_paths = get_paths("train")
5858
val_paths, val_label_paths = get_paths("val")
@@ -78,6 +78,10 @@ def train():
7878
f, indent=2, sort_keys=True
7979
)
8080

81+
# For marmoset model
82+
sigma = (0.6, 3, 3)
83+
# For mouse model
84+
# sigma = (1, 4, 4)
8185
supervised_training(
8286
name=model_name,
8387
train_paths=train_paths,
@@ -92,7 +96,7 @@ def train():
9296
out_channels=1,
9397
augmentations=None,
9498
eps=1e-5,
95-
sigma=(1, 4, 4),
99+
sigma=sigma,
96100
lower_bound=None,
97101
upper_bound=None,
98102
test_paths=test_paths,

0 commit comments

Comments
 (0)