Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions modelcenter/SAM/APP/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import gradio as gr
import numpy as np
import cv2

import utils
from predict import build_predictor

ID_PHOTO_IMAGE_DEMO = "./images/cityscapes_demo.png"

generator = build_predictor()


def clear_image_all():
utils.delete_result()
return None, None, None, None


def get_id_photo_output(img):
"""
Get the special size and background photo.

Args:
img(numpy:ndarray): The image array.
size(str): The size user specified.
bg(str): The background color user specified.
download_size(str): The size for image saving.

"""
predictor = generator
masks = predictor.generate(img)
pred_result, pseudo_map = utils.masks2pseudomap(masks) # PIL Image
added_pseudo_map = utils.visualize(
img, pred_result, color_map=utils.get_color_map_list(256))
res_download = utils.download(pseudo_map)

return pseudo_map, added_pseudo_map, res_download


with gr.Blocks() as demo:
gr.Markdown("""# Segment Anything (PaddleSeg) """)
with gr.Tab("InputImage"):
image_in = gr.Image(value=ID_PHOTO_IMAGE_DEMO, label="Input image")

with gr.Row():
image_clear_btn = gr.Button("Clear")
image_submit_btn = gr.Button("Submit")

with gr.Row():
img_out1 = gr.Image(
label="Output image", interactive=False).style(height=300)
img_out2 = gr.Image(
label="Output image with mask",
interactive=False).style(height=300)
downloaded_img = gr.File(label='Image download').style(height=50)

image_clear_btn.click(
fn=clear_image_all,
inputs=None,
outputs=[image_in, img_out1, img_out2, downloaded_img])

image_submit_btn.click(
fn=get_id_photo_output,
inputs=[image_in, ],
outputs=[img_out1, img_out2, downloaded_img])

gr.Markdown(
"""<font color=Gray>Tips: You can try segment the default image OR upload any images you want to segment by click on the clear button first.</font>"""
)

gr.Markdown(
"""<font color=Gray>This is Segment Anything build with PaddlePaddle.
We refer to the [SAM](https://github.com/facebookresearch/segment-anything) for code strucure and model architecture.
If you have any question or feature request, welcome to raise issues on [GitHub](https://github.com/PaddlePaddle/PaddleSeg/issues). </font>"""
)

gr.Button.style(1)

demo.launch(server_name="0.0.0.0", server_port=8021, share=True)
11 changes: 11 additions & 0 deletions modelcenter/SAM/APP/app.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
【Segment-Anything-App-YAML】

APP_Info:
title: Segment-Anything-App
colorFrom: blue
colorTo: yellow
sdk: gradio
sdk_version: 3.4.1
app_file: app.py
license: apache-2.0
device: cpu
55 changes: 55 additions & 0 deletions modelcenter/SAM/APP/download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
import sys
import time

import requests
import zipfile

FLUSH_INTERVAL = 0.1
lasttime = time.time()


def progress(str, end=False):
global lasttime
if end:
str += "\n"
lasttime = 0
if time.time() - lasttime >= FLUSH_INTERVAL:
sys.stdout.write("\r%s" % str)
lasttime = time.time()
sys.stdout.flush()


def download_file(url, savepath, print_progress=True):
if print_progress:
print("Connecting to {}".format(url))
r = requests.get(url, stream=True, timeout=15)
total_length = r.headers.get('content-length')

if total_length is None:
with open(savepath, 'wb') as f:
shutil.copyfileobj(r.raw, f)
else:
with open(savepath, 'wb') as f:
dl = 0
total_length = int(total_length)
starttime = time.time()
if print_progress:
print("Downloading %s" % os.path.basename(savepath))
for data in r.iter_content(chunk_size=4096):
dl += len(data)
f.write(data)
if print_progress:
done = int(50 * dl / total_length)
progress("[%-50s] %.2f%%" %
('=' * done, float(100 * dl) / total_length))
if print_progress:
progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)


def uncompress(path):
files = zipfile.ZipFile(path, 'r')
filelist = files.namelist()
rootpath = filelist[0]
for file in filelist:
files.extract(file, './')
Binary file added modelcenter/SAM/APP/images/cityscapes_demo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
25 changes: 25 additions & 0 deletions modelcenter/SAM/APP/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import paddle
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry

model_link = {
'vit_h':
"https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_h/model.pdparams",
'vit_l':
"https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_l/model.pdparams",
'vit_b':
"https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_b/model.pdparams"
}


def build_predictor():
print("Loading model...")

if paddle.is_compiled_with_cuda():
paddle.set_device("gpu")
else:
paddle.set_device("cpu")

sam = sam_model_registry["vit_b"](checkpoint=model_link["vit_b"])
generator = SamAutomaticMaskGenerator(sam, output_mode="binary_mask")

return generator
7 changes: 7 additions & 0 deletions modelcenter/SAM/APP/requirement.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
gradio
paddlepaddle
opencv-python
pyyaml >= 5.1
PIL
numpy
time
24 changes: 24 additions & 0 deletions modelcenter/SAM/APP/segment_anything/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This implementation refers to: https://github.com/facebookresearch/segment-anything

from .build_sam import (
build_sam,
build_sam_vit_h,
build_sam_vit_l,
build_sam_vit_b,
sam_model_registry, )
from .predictor import SamPredictor
from .automatic_mask_generator import SamAutomaticMaskGenerator
Loading