-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathdemo_PixelEmb_train.py
More file actions
35 lines (30 loc) · 1.01 KB
/
demo_PixelEmb_train.py
File metadata and controls
35 lines (30 loc) · 1.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import instSeg
from skimage.io import imread, imsave
import numpy as np
import os
import csv
import time
import cv2
from random import shuffle
from tensorflow import keras
import csv
config = instSeg.ConfigParallel(image_channel=3)
config.modules = ['embedding', 'edt']
config.loss_embedding = 'cos'
config.edt_loss = 'mse'
config.embedding_include_bg = True
config.train_learning_rate = 1e-4
config.lr_decay_rate = 0.9
config.lr_decay_period = 20000
config.backbone = 'uNet'
# X_train: input images, N x H x W x C
# y_train: output masks, N x H x W x 1
X_train = np.array(X_train[:-val_split])
y_train = np.expand_dims(np.array(y_train[:-val_split]), axis=-1)
ds_train = {'image': X_train, 'instance': y_train}
X_val = np.array(X_train[-val_split:])
y_val = np.expand_dims(np.array(y_train[-val_split:]), axis=-1)
ds_val = {'image': X_val, 'instance': y_val}
# create model and train
model = instSeg.InstSegParallel(config=config, model_dir=model_dir)
model.train(ds_train, ds_val, batch_size=2, epochs=epoches, augmentation=False)