Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 9 additions & 163 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,177 +3,23 @@ channels:
- pytorch
- nvidia
- conda-forge
- defaults
- nodefaults
dependencies:
- _libgcc_mutex=0.1
- _openmp_mutex=5.1
- anyio=3.6.2
- blas=1.0
- brotli=1.0.9
- brotli-bin=1.0.9
- brotlipy=0.7.0
- bzip2=1.0.8
- c-ares=1.19.0
- ca-certificates=2023.08.22
- certifi=2023.7.22
- cffi=1.15.1
- charset-normalizer=2.0.4
- click=8.0.4
- coloredlogs=15.0.1
- contourpy=1.0.5
- cryptography=39.0.1
- cuda-cudart=11.8.89
- cuda-cupti=11.8.87
- cuda-libraries=11.8.0
- cuda-nvrtc=11.8.89
- cuda-nvtx=11.8.86
- cuda-runtime=11.8.0
- cudatoolkit=11.8.0
- cudnn=8.9.2.26
- cycler=0.11.0
- dbus=1.13.18
- expat=2.4.9
- loguru=0.5.3
- uvicorn=0.20.0
- fastapi=0.95.1
- ffmpeg=4.3
- filelock=3.9.0
- flask=2.2.2
- fontconfig=2.14.1
- fonttools=4.25.0
- freetype=2.12.1
- giflib=5.2.1
- glib=2.69.1
- gmp=6.2.1
- gmpy2=2.1.2
- gnutls=3.6.15
- gst-plugins-base=1.14.1
- gstreamer=1.14.1
- h11=0.12.0
- humanfriendly=10.0
- icu=58.2
- idna=3.4
- intel-openmp=2021.4.0
- itsdangerous=2.0.1
- jinja2=3.1.2
- python-multipart=0.0.6
- janus=1.0.0
- jpeg=9e
- kiwisolver=1.4.4
- krb5=1.19.4
- lame=3.100
- lcms2=2.12
- ld_impl_linux-64=2.38
- lerc=3.0
- libbrotlicommon=1.0.9
- libbrotlidec=1.0.9
- libbrotlienc=1.0.9
- libclang=14.0.6
- libclang13=14.0.6
- libcublas=11.11.3.6
- libcufft=10.9.0.58
- libcufile=1.6.1.9
- libcurand=10.3.2.106
- libcurl=7.88.1
- libcusolver=11.4.1.48
- libcusparse=11.7.5.86
- libdeflate=1.17
- libedit=3.1.20221030
- libev=4.33
- libevent=2.1.12
- libffi=3.4.2
- libgcc-ng=11.2.0
- libgomp=11.2.0
- libiconv=1.16
- libidn2=2.3.2
- libllvm14=14.0.6
- libnghttp2=1.46.0
- libnpp=11.8.0.86
- libnvjpeg=11.9.0.86
- libpng=1.6.39
- libpq=12.9
- libprotobuf=3.20.3
- libssh2=1.10.0
- libstdcxx-ng=11.2.0
- libtasn1=4.19.0
- libtiff=4.5.0
- libunistring=0.9.10
- libuuid=1.41.5
- libwebp=1.2.4
- libwebp-base=1.2.4
- libxcb=1.15
- libxkbcommon=1.0.1
- libxml2=2.10.3
- libxslt=1.1.37
- loguru=0.5.3
- lz4-c=1.9.4
- markupsafe=2.1.1
- matplotlib=3.7.1
- matplotlib-base=3.7.1
- mkl=2021.4.0
- mkl-service=2.4.0
- mkl_fft=1.3.1
- mkl_random=1.2.2
- mpc=1.1.0
- mpfr=4.0.2
- munkres=1.1.4
- ncurses=6.4
- nettle=3.7.3
- networkx=2.8.4
- nspr=4.33
- nss=3.74
- numpy=1.23.5
- numpy-base=1.23.5
- onnx=1.13.0
- onnxruntime=1.12.1
- openh264=2.1.1
- openssl=1.1.1w
- pandas=2.1.4
- packaging=23.0
- pcre=8.45
- pillow=9.4.0
- pip=23.0.1
- ply=3.11
- protobuf=3.20.3
- pycparser=2.21
- pydantic=1.10.2
- pyopenssl=23.0.0
- pyparsing=3.0.9
- pyqt=5.15.7
- pysocks=1.7.1
- python=3.10.11
- python-dateutil=2.8.2
- python-flatbuffers=2.0
- python-multipart=0.0.6
- pytorch=2.0.0
- pytorch-cuda=11.8
- pytorch-mutex=1.0
- qt-main=5.15.2
- qt-webengine=5.15.9
- qtwebkit=5.212
- re2=2022.04.01
- readline=8.2
- requests=2.28.1
- setuptools=66.0.0
- sip=6.6.2
- six=1.16.0
- sniffio=1.3.0
- sqlite=3.41.2
- starlette=0.26.1
- sympy=1.11.1
- tk=8.6.12
- toml=0.10.2
- torchaudio=2.0.0
- torchtriton=2.0.0
- mkl=2021.4.0
- torchvision=0.15.0
- tornado=6.2
- typing-extensions=4.5.0
- typing_extensions=4.5.0
- tzdata=2023c
- urllib3=1.26.15
- uvicorn=0.20.0
- werkzeug=2.2.3
- wheel=0.38.4
- xz=5.2.10
- zlib=1.2.13
- zstd=1.5.5
- onnx=1.13.0
- onnxruntime=1.12.1
- matplotlib=3.7.1
- pip=23.0.1
- pip:
- mpmath==1.2.1
- opencv-python==4.7.0.72
Expand Down
135 changes: 77 additions & 58 deletions sam_service/sam_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@ def worker_loop(worker_id:int, work_queue:janus.SyncQueue[WorkItem]):
counter_lock = Lock()
counter = 0

async def get_request_id():
global counter
request_id = None
async with counter_lock:
request_id = counter
counter += 1
return request_id

# Start a worker thread for each GPU
worker_ids = gpus or [0]
logger.info(f"Creating {len(worker_ids)} workers backed by {gpu_count} GPUs")
Expand All @@ -106,8 +114,56 @@ async def docs_redirect():

@app.get("/new_session_id", response_class=PlainTextResponse)
async def new_session_id():
# uuid1 creates time-based identifiers which guarantee uniqueness
return uuid.uuid1()
""" Create a new session identifier that can be used to identify your client
in other endpoints.
"""
# uuid4 creates random identifiers which are harder to spoof
# than a counter, and easier to distinguish in the logs
return uuid.uuid4()


async def cancel_pending_work_items(request_id, session_id):
""" Cancel all pending work items for the given session.
"""
for item in session_dict[session_id]:
logger.debug(f"R{request_id} - Marking {item} as cancelled")
item.cancelled = True
# Send notification to waiting request
await item.result_queue.async_q.put(None)


async def get_embedding(img, request_id, session_id):
""" Asynchronously returns the embedding for the given image.
This function creates a work item that is executed on a GPU by a worker thread.
It also handles registering the work item in a global dictionary so that a
session's queued items can be invalidated at any time.
"""
# This result queue is used to communicate the result from the worker thread
# back to this function. We only expect one item to be put on this queue.
result_queue = janus.Queue()

work_item = WorkItem(request_id=request_id,
session_id=session_id,
work_function=lambda sam: sam.get_box_model(img),
result_queue=result_queue)

if session_id:
logger.trace(f"R{request_id} - Adding {work_item} to session dict")
session_dict[session_id].append(work_item)

logger.trace(f"R{request_id} - Putting work function on the work queue")
await work_queue.async_q.put(work_item)

try:
logger.debug(f"R{request_id} - Waiting for embedding to be completed by worker ...")
return await result_queue.async_q.get()
finally:
if session_id:
logger.trace(f"R{request_id} - Removing {work_item} from session dict")
session_dict[session_id].remove(work_item)
# Remove the session if there are no items left, to avoid memory leaks
if len(session_dict[session_id])==0:
del session_dict[session_id]


@app.post("/embedded_model", response_class=PlainTextResponse)
Expand All @@ -117,65 +173,31 @@ async def embedded_model(
cancel_pending: Optional[bool] = Form(False, description="Cancel any pending requests for this session before processing this one"),
encoding: str = Query("none", description="compress: Response compressed with gzip"),
):
"""Accepts an input image and returns a segment_anything box model
""" Accepts an input image and returns a segment_anything box model.
Optionally also cancel any pending requests from the same session.
"""
global counter
request_id = None
async with counter_lock:
request_id = counter
counter += 1
# Python can handle much larger ints but I can't
if counter==2**32: counter = 0

request_id = await get_request_id()
logger.debug(f"R{request_id} - Started embedded_model for {session_id}")

# Cancel previous requests if necessary
if session_id and cancel_pending:
for item in session_dict[session_id]:
logger.debug(f"R{request_id} - Marking {item} as cancelled")
item.cancelled = True
# Send notification to waiting request
await item.result_queue.async_q.put(None)
await cancel_pending_work_items(request_id, session_id)

logger.trace(f"R{request_id} - Reading image")
file_data = await image.read()
img = utils.buffer_to_image(file_data)

def do_work(sam):
return sam.get_box_model(img)

result_queue = janus.Queue()
work_item = WorkItem(request_id=request_id,
session_id=session_id,
work_function=do_work,
result_queue=result_queue)

if session_id:
logger.trace(f"R{request_id} - Adding {work_item} to session dict")
session_dict[session_id].append(work_item)

logger.trace(f"R{request_id} - Putting work function on the work queue")
await work_queue.async_q.put(work_item)

logger.debug(f"R{request_id} - Waiting for embedding to be completed by worker ...")
box_model = await result_queue.async_q.get()

if session_id:
logger.trace(f"R{request_id} - Removing {work_item} from session dict")
session_dict[session_id].remove(work_item)

headers = {}
if work_item.cancelled:
headers["Cancelled-By-Client"] = "1"
embedding = await get_embedding(img, request_id, session_id)

if box_model is None:
if embedding is None:
logger.debug(f"R{request_id} - Returning code 499 Client Closed Request")
return Response(status_code=499, headers=headers)
return Response(status_code=499, headers={
"Cancelled-By-Client": "1"
})

logger.trace(f"R{request_id} - Computed embedding")

# Serialize the model as base64 string
arr_bytes = box_model.tobytes()
# Serialize the embedding as base64 string
arr_bytes = embedding.tobytes()
b64_bytes = base64.b64encode(arr_bytes)
b64_string = b64_bytes.decode('utf-8')

Expand All @@ -187,21 +209,18 @@ def do_work(sam):
logger.trace('Compressing embedding ...')
compressed_data = gzip.compress(b64_bytes)
logger.debug(f"R{request_id} - Returning compressed embedding")
headers["Content-Type"] = "application/gzip"
headers["Content-Encoding"] = "gzip"
return Response(content=compressed_data, headers=headers)
return Response(content=compressed_data, headers={
"Content-Type": "application/gzip",
"Content-Encoding": "gzip"
})


@app.post("/cancel_pending", response_class=PlainTextResponse)
async def cancel_pending(
session_id: Optional[str] = Form(None, description="UUID identifying a session")
session_id: str = Form(..., description="UUID identifying a session")
):
"""Cancel any pending requests for the given session.
"""
# Cancel previous requests if necessary
if session_id:
for item in session_dict[session_id]:
logger.debug(f"Marking {item} as cancelled")
item.cancelled = True
# Send notification to waiting request
await item.result_queue.async_q.put(None)
request_id = await get_request_id()
await cancel_pending_work_items(request_id, session_id)