diff --git a/environment.yml b/environment.yml index 262deb9..9d0c6f1 100644 --- a/environment.yml +++ b/environment.yml @@ -3,177 +3,23 @@ channels: - pytorch - nvidia - conda-forge - - defaults + - nodefaults dependencies: - - _libgcc_mutex=0.1 - - _openmp_mutex=5.1 - - anyio=3.6.2 - - blas=1.0 - - brotli=1.0.9 - - brotli-bin=1.0.9 - - brotlipy=0.7.0 - - bzip2=1.0.8 - - c-ares=1.19.0 - - ca-certificates=2023.08.22 - - certifi=2023.7.22 - - cffi=1.15.1 - - charset-normalizer=2.0.4 - - click=8.0.4 - - coloredlogs=15.0.1 - - contourpy=1.0.5 - - cryptography=39.0.1 - - cuda-cudart=11.8.89 - - cuda-cupti=11.8.87 - - cuda-libraries=11.8.0 - - cuda-nvrtc=11.8.89 - - cuda-nvtx=11.8.86 - - cuda-runtime=11.8.0 - - cudatoolkit=11.8.0 - - cudnn=8.9.2.26 - - cycler=0.11.0 - - dbus=1.13.18 - - expat=2.4.9 + - loguru=0.5.3 + - uvicorn=0.20.0 - fastapi=0.95.1 - - ffmpeg=4.3 - - filelock=3.9.0 - - flask=2.2.2 - - fontconfig=2.14.1 - - fonttools=4.25.0 - - freetype=2.12.1 - - giflib=5.2.1 - - glib=2.69.1 - - gmp=6.2.1 - - gmpy2=2.1.2 - - gnutls=3.6.15 - - gst-plugins-base=1.14.1 - - gstreamer=1.14.1 - - h11=0.12.0 - - humanfriendly=10.0 - - icu=58.2 - - idna=3.4 - - intel-openmp=2021.4.0 - - itsdangerous=2.0.1 - - jinja2=3.1.2 + - python-multipart=0.0.6 - janus=1.0.0 - - jpeg=9e - - kiwisolver=1.4.4 - - krb5=1.19.4 - - lame=3.100 - - lcms2=2.12 - - ld_impl_linux-64=2.38 - - lerc=3.0 - - libbrotlicommon=1.0.9 - - libbrotlidec=1.0.9 - - libbrotlienc=1.0.9 - - libclang=14.0.6 - - libclang13=14.0.6 - - libcublas=11.11.3.6 - - libcufft=10.9.0.58 - - libcufile=1.6.1.9 - - libcurand=10.3.2.106 - - libcurl=7.88.1 - - libcusolver=11.4.1.48 - - libcusparse=11.7.5.86 - - libdeflate=1.17 - - libedit=3.1.20221030 - - libev=4.33 - - libevent=2.1.12 - - libffi=3.4.2 - - libgcc-ng=11.2.0 - - libgomp=11.2.0 - - libiconv=1.16 - - libidn2=2.3.2 - - libllvm14=14.0.6 - - libnghttp2=1.46.0 - - libnpp=11.8.0.86 - - libnvjpeg=11.9.0.86 - - libpng=1.6.39 - - libpq=12.9 - - libprotobuf=3.20.3 - - libssh2=1.10.0 - - libstdcxx-ng=11.2.0 - - libtasn1=4.19.0 - - libtiff=4.5.0 - - libunistring=0.9.10 - - libuuid=1.41.5 - - libwebp=1.2.4 - - libwebp-base=1.2.4 - - libxcb=1.15 - - libxkbcommon=1.0.1 - - libxml2=2.10.3 - - libxslt=1.1.37 - - loguru=0.5.3 - - lz4-c=1.9.4 - - markupsafe=2.1.1 - - matplotlib=3.7.1 - - matplotlib-base=3.7.1 - - mkl=2021.4.0 - - mkl-service=2.4.0 - - mkl_fft=1.3.1 - - mkl_random=1.2.2 - - mpc=1.1.0 - - mpfr=4.0.2 - - munkres=1.1.4 - - ncurses=6.4 - - nettle=3.7.3 - - networkx=2.8.4 - - nspr=4.33 - - nss=3.74 - numpy=1.23.5 - - numpy-base=1.23.5 - - onnx=1.13.0 - - onnxruntime=1.12.1 - - openh264=2.1.1 - - openssl=1.1.1w - pandas=2.1.4 - - packaging=23.0 - - pcre=8.45 - - pillow=9.4.0 - - pip=23.0.1 - - ply=3.11 - - protobuf=3.20.3 - - pycparser=2.21 - - pydantic=1.10.2 - - pyopenssl=23.0.0 - - pyparsing=3.0.9 - - pyqt=5.15.7 - - pysocks=1.7.1 - - python=3.10.11 - - python-dateutil=2.8.2 - - python-flatbuffers=2.0 - - python-multipart=0.0.6 - pytorch=2.0.0 - pytorch-cuda=11.8 - - pytorch-mutex=1.0 - - qt-main=5.15.2 - - qt-webengine=5.15.9 - - qtwebkit=5.212 - - re2=2022.04.01 - - readline=8.2 - - requests=2.28.1 - - setuptools=66.0.0 - - sip=6.6.2 - - six=1.16.0 - - sniffio=1.3.0 - - sqlite=3.41.2 - - starlette=0.26.1 - - sympy=1.11.1 - - tk=8.6.12 - - toml=0.10.2 - - torchaudio=2.0.0 - - torchtriton=2.0.0 + - mkl=2021.4.0 - torchvision=0.15.0 - - tornado=6.2 - - typing-extensions=4.5.0 - - typing_extensions=4.5.0 - - tzdata=2023c - - urllib3=1.26.15 - - uvicorn=0.20.0 - - werkzeug=2.2.3 - - wheel=0.38.4 - - xz=5.2.10 - - zlib=1.2.13 - - zstd=1.5.5 + - onnx=1.13.0 + - onnxruntime=1.12.1 + - matplotlib=3.7.1 + - pip=23.0.1 - pip: - mpmath==1.2.1 - opencv-python==4.7.0.72 diff --git a/sam_service/sam_queue.py b/sam_service/sam_queue.py index 5a8e653..0425376 100644 --- a/sam_service/sam_queue.py +++ b/sam_service/sam_queue.py @@ -83,6 +83,14 @@ def worker_loop(worker_id:int, work_queue:janus.SyncQueue[WorkItem]): counter_lock = Lock() counter = 0 +async def get_request_id(): + global counter + request_id = None + async with counter_lock: + request_id = counter + counter += 1 + return request_id + # Start a worker thread for each GPU worker_ids = gpus or [0] logger.info(f"Creating {len(worker_ids)} workers backed by {gpu_count} GPUs") @@ -106,8 +114,56 @@ async def docs_redirect(): @app.get("/new_session_id", response_class=PlainTextResponse) async def new_session_id(): - # uuid1 creates time-based identifiers which guarantee uniqueness - return uuid.uuid1() + """ Create a new session identifier that can be used to identify your client + in other endpoints. + """ + # uuid4 creates random identifiers which are harder to spoof + # than a counter, and easier to distinguish in the logs + return uuid.uuid4() + + +async def cancel_pending_work_items(request_id, session_id): + """ Cancel all pending work items for the given session. + """ + for item in session_dict[session_id]: + logger.debug(f"R{request_id} - Marking {item} as cancelled") + item.cancelled = True + # Send notification to waiting request + await item.result_queue.async_q.put(None) + + +async def get_embedding(img, request_id, session_id): + """ Asynchronously returns the embedding for the given image. + This function creates a work item that is executed on a GPU by a worker thread. + It also handles registering the work item in a global dictionary so that a + session's queued items can be invalidated at any time. + """ + # This result queue is used to communicate the result from the worker thread + # back to this function. We only expect one item to be put on this queue. + result_queue = janus.Queue() + + work_item = WorkItem(request_id=request_id, + session_id=session_id, + work_function=lambda sam: sam.get_box_model(img), + result_queue=result_queue) + + if session_id: + logger.trace(f"R{request_id} - Adding {work_item} to session dict") + session_dict[session_id].append(work_item) + + logger.trace(f"R{request_id} - Putting work function on the work queue") + await work_queue.async_q.put(work_item) + + try: + logger.debug(f"R{request_id} - Waiting for embedding to be completed by worker ...") + return await result_queue.async_q.get() + finally: + if session_id: + logger.trace(f"R{request_id} - Removing {work_item} from session dict") + session_dict[session_id].remove(work_item) + # Remove the session if there are no items left, to avoid memory leaks + if len(session_dict[session_id])==0: + del session_dict[session_id] @app.post("/embedded_model", response_class=PlainTextResponse) @@ -117,65 +173,31 @@ async def embedded_model( cancel_pending: Optional[bool] = Form(False, description="Cancel any pending requests for this session before processing this one"), encoding: str = Query("none", description="compress: Response compressed with gzip"), ): - """Accepts an input image and returns a segment_anything box model + """ Accepts an input image and returns a segment_anything box model. + Optionally also cancel any pending requests from the same session. """ - global counter - request_id = None - async with counter_lock: - request_id = counter - counter += 1 - # Python can handle much larger ints but I can't - if counter==2**32: counter = 0 - + request_id = await get_request_id() logger.debug(f"R{request_id} - Started embedded_model for {session_id}") - # Cancel previous requests if necessary if session_id and cancel_pending: - for item in session_dict[session_id]: - logger.debug(f"R{request_id} - Marking {item} as cancelled") - item.cancelled = True - # Send notification to waiting request - await item.result_queue.async_q.put(None) + await cancel_pending_work_items(request_id, session_id) logger.trace(f"R{request_id} - Reading image") file_data = await image.read() img = utils.buffer_to_image(file_data) - def do_work(sam): - return sam.get_box_model(img) - - result_queue = janus.Queue() - work_item = WorkItem(request_id=request_id, - session_id=session_id, - work_function=do_work, - result_queue=result_queue) - - if session_id: - logger.trace(f"R{request_id} - Adding {work_item} to session dict") - session_dict[session_id].append(work_item) - - logger.trace(f"R{request_id} - Putting work function on the work queue") - await work_queue.async_q.put(work_item) - - logger.debug(f"R{request_id} - Waiting for embedding to be completed by worker ...") - box_model = await result_queue.async_q.get() - - if session_id: - logger.trace(f"R{request_id} - Removing {work_item} from session dict") - session_dict[session_id].remove(work_item) - - headers = {} - if work_item.cancelled: - headers["Cancelled-By-Client"] = "1" + embedding = await get_embedding(img, request_id, session_id) - if box_model is None: + if embedding is None: logger.debug(f"R{request_id} - Returning code 499 Client Closed Request") - return Response(status_code=499, headers=headers) + return Response(status_code=499, headers={ + "Cancelled-By-Client": "1" + }) logger.trace(f"R{request_id} - Computed embedding") - # Serialize the model as base64 string - arr_bytes = box_model.tobytes() + # Serialize the embedding as base64 string + arr_bytes = embedding.tobytes() b64_bytes = base64.b64encode(arr_bytes) b64_string = b64_bytes.decode('utf-8') @@ -187,21 +209,18 @@ def do_work(sam): logger.trace('Compressing embedding ...') compressed_data = gzip.compress(b64_bytes) logger.debug(f"R{request_id} - Returning compressed embedding") - headers["Content-Type"] = "application/gzip" - headers["Content-Encoding"] = "gzip" - return Response(content=compressed_data, headers=headers) + return Response(content=compressed_data, headers={ + "Content-Type": "application/gzip", + "Content-Encoding": "gzip" + }) @app.post("/cancel_pending", response_class=PlainTextResponse) async def cancel_pending( - session_id: Optional[str] = Form(None, description="UUID identifying a session") + session_id: str = Form(..., description="UUID identifying a session") ): """Cancel any pending requests for the given session. """ - # Cancel previous requests if necessary - if session_id: - for item in session_dict[session_id]: - logger.debug(f"Marking {item} as cancelled") - item.cancelled = True - # Send notification to waiting request - await item.result_queue.async_q.put(None) + request_id = await get_request_id() + await cancel_pending_work_items(request_id, session_id) +