Skip to content

Commit c07d9c5

Browse files
committed
Migrated to U2NETP and onnx
1 parent 0398ce0 commit c07d9c5

File tree

11 files changed

+120
-102
lines changed

11 files changed

+120
-102
lines changed

README.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Star reduction in deep sky images
22

3-
Starrem2k13 is a simple tool for removing stars from astronomical images. Starrem2k13 uses a GAN trained on augmented data. It's code was inspired from a [sample at Tensorflow's website](https://www.tensorflow.org/tutorials/generative/pix2pix). The training data consists of only three base images.
3+
Starrem2k13 is a simple tool for removing stars from astronomical images. Starrem2k13 uses a U2NETP model trained on augmented data. The training data consists of only three base images.
44

55
Below are examples of what it can do:
66

@@ -81,4 +81,11 @@ Url: [https://commons.wikimedia.org/wiki/File:The_star_cluster_NGC_3572_and_its_
8181
Direct Link: [https://upload.wikimedia.org/wikipedia/commons/9/95/The_star_cluster_NGC_3572_and_its_dramatic_surroundings.jpg](https://upload.wikimedia.org/wikipedia/commons/9/95/The_star_cluster_NGC_3572_and_its_dramatic_surroundings.jpg)
8282

8383

84-
84+
@InProceedings{Qin_2020_PR,
85+
title = {U2-Net: Going Deeper with Nested U-Structure for Salient Object Detection},
86+
author = {Qin, Xuebin and Zhang, Zichen and Huang, Chenyang and Dehghan, Masood and Zaiane, Osmar and Jagersand, Martin},
87+
journal = {Pattern Recognition},
88+
volume = {106},
89+
pages = {107404},
90+
year = {2020}
91+
}

export_to_onnx.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import tensorflow as tf
2+
import tf2onnx
3+
import model # your model.py
4+
5+
# Build and load weights
6+
G2 = model.Generator()
7+
G2.load_weights("weights/generator_epoch_1000.weights.h5")
8+
9+
# Optional: Run once to build the model (especially if using Functional API or Subclassed model)
10+
G2.predict(tf.random.normal([1, 512, 512])) # Adjust input shape as needed
11+
12+
# Convert to ONNX
13+
spec = (tf.TensorSpec((1, 512, 512), tf.float32),) # shape should match model input
14+
output_path = "generator.onnx"
15+
16+
model_proto, _ = tf2onnx.convert.from_keras(G2, input_signature=spec, opset=13, output_path=output_path)
17+
print(f"ONNX model saved to {output_path}")

model.py

Lines changed: 62 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,62 @@
1-
import tensorflow as tf
2-
3-
IMG_SIZE = 512
4-
OUTPUT_CHANNELS = 1
5-
6-
def downsample(filters, size, apply_batchnorm=True,strides = 2,name=''):
7-
initializer = tf.random_normal_initializer(0., 0.02)
8-
result = tf.keras.Sequential(name=name)
9-
result.add(tf.keras.layers.Conv2D(filters, size, strides=strides, padding='same',kernel_initializer=initializer, use_bias=False))
10-
11-
if apply_batchnorm:
12-
result.add(tf.keras.layers.BatchNormalization())
13-
14-
result.add(tf.keras.layers.LeakyReLU())
15-
return result
16-
17-
def upsample(filters, size, apply_dropout=False,strides = 2,name=''):
18-
initializer = tf.random_normal_initializer(0., 0.02)
19-
result = tf.keras.Sequential(name=name)
20-
result.add(tf.keras.layers.Conv2DTranspose(filters, size, strides=strides,padding='same',
21-
kernel_initializer=initializer,use_bias=False))
22-
23-
result.add(tf.keras.layers.BatchNormalization())
24-
if apply_dropout:
25-
result.add(tf.keras.layers.Dropout(0.5))
26-
result.add(tf.keras.layers.ReLU())
27-
return result
28-
29-
def Generator():
30-
inputs = tf.keras.layers.Input(shape=[IMG_SIZE, IMG_SIZE, OUTPUT_CHANNELS])
31-
32-
down_stack = [
33-
downsample(16, 5, apply_batchnorm=False,strides = 1, name='gd_1'), # (batch_size, 128, 128, 64)
34-
downsample(32, 5,name='gd_2'), # (batch_size, 64, 64, 128)
35-
downsample(64, 5,name='gd_3'), # (batch_size, 32, 32, 256)
36-
downsample(128, 5,name='gd_4'), # (batch_size, 16, 16, 512)
37-
downsample(256, 5,name='gd_5'), # (batch_size, 8, 8, 512)
38-
downsample(256, 5,name='gd_6'), # (batch_size, 4, 4, 512)
39-
downsample(512, 5,name='gd_7'), # (batch_size, 2, 2, 512)
40-
downsample(512, 5,name='gd_8'), # (batch_size, 1, 1, 512)
41-
]
42-
43-
up_stack = [
44-
upsample(512, 5, apply_dropout=True,name='gu_1'), # (batch_size, 2, 2, 1024)
45-
upsample(256, 5, apply_dropout=True,name='gu_2'), # (batch_size, 4, 4, 1024)
46-
upsample(256, 5, apply_dropout=True,name='gu_3'), # (batch_size, 8, 8, 1024)
47-
upsample(128, 5,name='gu_4'), # (batch_size, 16, 16, 1024)
48-
upsample(64, 5,name='gu_5'), # (batch_size, 32, 32, 512)
49-
upsample(32, 5,name='gu_6'), # (batch_size, 64, 64, 256)
50-
upsample(16, 5,name='gu_7'), # (batch_size, 128, 128, 128)
51-
]
52-
53-
initializer = tf.random_normal_initializer(0., 0.02)
54-
last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,strides=1,padding='same',
55-
kernel_initializer=initializer,activation='relu') # (batch_size, 256, 256, 3)
56-
57-
x = inputs
58-
# Downsampling through the model
59-
skips = []
60-
for down in down_stack:
61-
x = down(x)
62-
skips.append(x)
63-
64-
skips = reversed(skips[:-1])
65-
66-
for up, skip in zip(up_stack, skips):
67-
x = up(x)
68-
x = tf.keras.layers.Concatenate()([x, skip])
69-
70-
x = last(x)
71-
72-
return tf.keras.Model(inputs=inputs, outputs=x)
1+
import tensorflow as tf
2+
from tensorflow import keras
3+
from tensorflow.keras import layers
4+
5+
def REBNCONV(x, filters, dilation_rate=1, name_prefix="rebn"):
6+
x = layers.Conv2D(filters, 3, padding='same', dilation_rate=dilation_rate,
7+
kernel_initializer='he_normal', name=f"{name_prefix}_conv")(x)
8+
x = layers.BatchNormalization(name=f"{name_prefix}_bn")(x)
9+
x = layers.ReLU(name=f"{name_prefix}_relu")(x)
10+
return x
11+
12+
def RSU4(x_input, in_filters, mid_filters, out_filters, name_prefix="rsu4"):
13+
hxin = REBNCONV(x_input, out_filters, name_prefix=f"{name_prefix}_in")
14+
15+
hx1 = REBNCONV(hxin, mid_filters, name_prefix=f"{name_prefix}_conv1")
16+
pool1 = layers.MaxPooling2D(pool_size=2)(hx1)
17+
18+
hx2 = REBNCONV(pool1, mid_filters, name_prefix=f"{name_prefix}_conv2")
19+
pool2 = layers.MaxPooling2D(pool_size=2)(hx2)
20+
21+
hx3 = REBNCONV(pool2, mid_filters, name_prefix=f"{name_prefix}_conv3")
22+
hx4 = REBNCONV(hx3, mid_filters, dilation_rate=2, name_prefix=f"{name_prefix}_conv4")
23+
24+
hx3d = REBNCONV(layers.Concatenate()([hx4, hx3]), mid_filters, name_prefix=f"{name_prefix}_conv3d")
25+
hx3d_up = layers.UpSampling2D(size=2, interpolation='bilinear')(hx3d)
26+
27+
hx2d = REBNCONV(layers.Concatenate()([hx3d_up, hx2]), mid_filters, name_prefix=f"{name_prefix}_conv2d")
28+
hx2d_up = layers.UpSampling2D(size=2, interpolation='bilinear')(hx2d)
29+
30+
hx1d = REBNCONV(layers.Concatenate()([hx2d_up, hx1]), out_filters, name_prefix=f"{name_prefix}_conv1d")
31+
32+
return layers.Add(name=f"{name_prefix}_add")([hx1d, hxin])
33+
34+
def Generator(input_shape=(512, 512, 1)):
35+
inputs = keras.Input(shape=input_shape)
36+
37+
stage1 = RSU4(inputs, 1, 16, 64, name_prefix="stage1")
38+
pool12 = layers.MaxPooling2D(pool_size=2)(stage1)
39+
40+
stage2 = RSU4(pool12, 64, 16, 64, name_prefix="stage2")
41+
pool23 = layers.MaxPooling2D(pool_size=2)(stage2)
42+
43+
stage3 = RSU4(pool23, 64, 16, 64, name_prefix="stage3")
44+
pool34 = layers.MaxPooling2D(pool_size=2)(stage3)
45+
46+
stage4 = RSU4(pool34, 64, 16, 64, name_prefix="stage4")
47+
48+
stage3d = RSU4(layers.Concatenate()([
49+
layers.UpSampling2D(size=2, interpolation='bilinear')(stage4), stage3
50+
]), 128, 16, 64, name_prefix="stage3d")
51+
52+
stage2d = RSU4(layers.Concatenate()([
53+
layers.UpSampling2D(size=2, interpolation='bilinear')(stage3d), stage2
54+
]), 128, 16, 64, name_prefix="stage2d")
55+
56+
stage1d = RSU4(layers.Concatenate()([
57+
layers.UpSampling2D(size=2, interpolation='bilinear')(stage2d), stage1
58+
]), 128, 16, 64, name_prefix="stage1d")
59+
60+
output = layers.Conv2D(1, 1, padding='same', activation='sigmoid', name="final_output")(stage1d)
61+
62+
return keras.Model(inputs, output, name="U2NETP_Gray2Gray")

requirements.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
tensorflow
21
numpy
32
pillow
43
tqdm
54
tensorflowjs
6-
pyinstaller
5+
pyinstaller
6+
tf2onnx
7+
onnxruntime

samples_starless/app.py

Whitespace-only changes.

samples_starless/omega_nebula.jpg

52.2 KB
Loading

samples_starless/rim_nebula.jpg

116 KB
Loading

samples_starless/sample1.jpg

-737 KB
Loading

samples_starless/veil_nebula.jpg

-125 KB
Loading

starrem2k13.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,11 @@
11
#!/usr/bin/python3
22

3-
import os
4-
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
5-
import tensorflow as tf
6-
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
7-
8-
import model
93
from PIL import Image
104
import numpy as np
115
import sys
126
from tqdm import tqdm
137
import math
8+
import onnxruntime as ort
149

1510

1611
IMG_SIZE = 512
@@ -27,38 +22,46 @@
2722
print ('Argument List:', str(sys.argv))
2823
exit(1)
2924

30-
def process_tile(channel,i,j,pad_width,model,output_image):
31-
corp_rect = (i*IMG_SIZE,j*IMG_SIZE,i*IMG_SIZE+IMG_SIZE,j*IMG_SIZE+IMG_SIZE)
32-
current_tile = channel.crop(corp_rect)
33-
current_tile = current_tile.convert('L')
34-
blank_image = current_tile.copy()
35-
current_tile = current_tile.resize((IMG_SIZE-pad_width*2,IMG_SIZE-pad_width*2))
36-
blank_image.paste(current_tile,(pad_width,pad_width))
37-
blank_image = blank_image.resize((MODEL_SIZE,MODEL_SIZE))
38-
blank_image = np.asarray(blank_image,dtype="float32").reshape(1,MODEL_SIZE,MODEL_SIZE)/512
39-
predicted_section = model.predict(blank_image,verbose=0)
40-
predicted_section = predicted_section.reshape(MODEL_SIZE,MODEL_SIZE)*512
41-
predicted_section = Image.fromarray(predicted_section).convert('L')
42-
predicted_section = predicted_section.resize((IMG_SIZE,IMG_SIZE))
43-
predicted_section = predicted_section.crop((pad_width,pad_width,IMG_SIZE-pad_width,IMG_SIZE-pad_width))
44-
predicted_section = predicted_section.resize((IMG_SIZE,IMG_SIZE))
45-
output_image.paste(predicted_section, (i*IMG_SIZE,j*IMG_SIZE), mask=None)
25+
def process_tile(channel,i,j,pad_width,output_image):
26+
corp_rect = (i * IMG_SIZE, j * IMG_SIZE, i * IMG_SIZE + IMG_SIZE, j * IMG_SIZE + IMG_SIZE)
27+
current_tile = channel.crop(corp_rect).convert('L')
28+
blank_image = current_tile.copy()
29+
30+
# Resize to remove padding, then paste into padded blank
31+
current_tile = current_tile.resize((IMG_SIZE - pad_width * 2, IMG_SIZE - pad_width * 2))
32+
blank_image.paste(current_tile, (pad_width, pad_width))
33+
34+
# Resize to match model input
35+
blank_image = blank_image.resize((MODEL_SIZE, MODEL_SIZE))
36+
input_array = np.asarray(blank_image, dtype="float32").reshape(1, MODEL_SIZE, MODEL_SIZE)
37+
input_array = input_array /382 # Normalize same as training
38+
39+
# ONNX Inference
40+
output_array = onnx_session.run(None, {input_name: input_array})[0]
41+
output_array = output_array.reshape(MODEL_SIZE, MODEL_SIZE) * 382 # De-normalize
42+
43+
# Convert to image and paste
44+
predicted_section = Image.fromarray(output_array.astype(np.uint8)).convert('L')
45+
predicted_section = predicted_section.resize((IMG_SIZE, IMG_SIZE))
46+
predicted_section = predicted_section.crop((pad_width, pad_width, IMG_SIZE - pad_width, IMG_SIZE - pad_width))
47+
predicted_section = predicted_section.resize((IMG_SIZE, IMG_SIZE))
48+
49+
output_image.paste(predicted_section, (i * IMG_SIZE, j * IMG_SIZE), mask=None)
4650

4751
def process_channel(channel,pad_width,input_image_size):
4852
global progress_bar,step_size,current_progress
4953
output_image = Image.new('L', input_image_size)
5054
for i in range(0,int(channel.size[0]/IMG_SIZE)):
5155
for j in range(0,int(channel.size[1]/IMG_SIZE)):
52-
process_tile(channel,i,j,pad_width,G2,output_image)
56+
process_tile(channel,i,j,pad_width,output_image)
5357
current_progress = current_progress + step_size
5458
if current_progress <= 100:
5559
progress_bar.update(round(step_size,2))
5660
return output_image
5761

5862
try:
59-
G2 = model.Generator()
60-
G2.load_weights("weights/weights")
61-
63+
onnx_session = ort.InferenceSession("weights/model.onnx", providers=["CPUExecutionProvider"])
64+
input_name = onnx_session.get_inputs()[0].name
6265
source_image = Image.open(args[1])
6366
mode = source_image.mode
6467
size = source_image.size

0 commit comments

Comments
 (0)