Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 156 additions & 44 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
__version__ = "0.3.0"

import os
from collections import OrderedDict
from typing import Any

from aiohttp.web_request import Request

Expand Down Expand Up @@ -247,7 +249,7 @@ def classname_to_wiki(s: str):
with contextlib.suppress(ImportError):
from cachetools import TTLCache

img_cache = TTLCache(maxsize=100, ttl=5) # 1 min TTL
# img_cache = TTLCache(maxsize=100, ttl=5) # 1 min TTL
prompt_cache = TTLCache(maxsize=100, ttl=5) # 1 min TTL

node_dependency_mapping = get_node_dependencies()
Expand Down Expand Up @@ -362,29 +364,132 @@ async def get_home(request: Request):

import asyncio
import os
import time
from asyncio import Semaphore
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from io import BytesIO

from aiohttp import web
from PIL import Image

image_thread_pool = ThreadPoolExecutor(
max_workers=4, thread_name_prefix="img_worker"
)

@asynccontextmanager
async def get_image_with_timeout(
file_path, preview_params=None, channel=None, timeout=10
):
try:
result = await asyncio.wait_for(
asyncio.get_event_loop().run_in_executor(
image_thread_pool,
get_cached_image,
file_path,
preview_params,
channel,
),
timeout=timeout,
)
yield result
except asyncio.TimeoutError:
print(f"Image processing timed out for {file_path}")
raise
except Exception as e:
print(f"Error processing image {file_path}: {str(e)}")
raise

async def get_image_response(
file, filename: str, preview_info=None, channel=None
):
try:
async with get_image_with_timeout(
file, preview_info, channel
) as img:
return web.Response(
body=img,
content_type="image/webp" if preview_info else "image/png",
headers={"Content-Disposition": f'filename="{filename}"'},
)
except asyncio.TimeoutError:
return web.Response(status=504, text="Image processing timed out")
except Exception as e:
return web.Response(status=500, text=str(e))

class LRUCache:
def __init__(self, capacity: int):
self.cache = OrderedDict()
self.capacity = capacity

def get(self, key) -> Any:
if key not in self.cache:
return None
self.cache.move_to_end(key)
return self.cache[key]

def put(self, key, value: Any) -> None:
if key in self.cache:
self.cache.move_to_end(key)
self.cache[key] = value
if len(self.cache) > self.capacity:
self.cache.popitem(last=False)

img_cache = LRUCache(capacity=100)

def get_cached_image(file_path: str, preview_params=None, channel=None):
cache_key = (file_path, preview_params, channel)
if img_cache and (cache_key in img_cache):
return img_cache[cache_key]

with Image.open(file_path) as img:
info = img.info
if preview_params:
img = process_preview(img, preview_params)
if channel:
img = process_channel(img, channel)
if prompt_cache:
prompt_cache[cache_key] = info
try:
if img_cache:
img_cache[cache_key] = img.getvalue()
return img_cache[cache_key]
cached_value = img_cache.get(cache_key)
if cached_value is not None:
return cached_value

with Image.open(file_path) as img:
info = img.info
if preview_params:
img = process_preview(img, preview_params)
if channel:
img = process_channel(img, channel)

result = img.getvalue()

try:
if prompt_cache:
prompt_cache[cache_key] = info
if img_cache:
img_cache.put(cache_key, result)
except Exception as e:
print(
f"Warning: Failed to cache image {file_path}: {str(e)}"
)

return result
except Exception as e:
print(f"Error processing image {file_path}: {str(e)}")
raise

class RateLimiter:
def __init__(self, requests_per_second):
self.requests_per_second = requests_per_second
self.semaphore = Semaphore(requests_per_second)
self.timestamps = []

async def acquire(self):
await self.semaphore.acquire()
now = time.time()
self.timestamps.append(now)

return img.getvalue()
# Remove old timestamps
self.timestamps = [t for t in self.timestamps if now - t < 1.0]

if len(self.timestamps) >= self.requests_per_second:
await asyncio.sleep(1.0)

def release(self):
self.semaphore.release()

rate_limiter = RateLimiter(requests_per_second=10)

def process_preview(img: Image.Image, preview_params):
image_format, quality, width = preview_params
Expand Down Expand Up @@ -437,41 +542,44 @@ async def get_image_response(
# to load workflows in the sidebar
@PromptServer.instance.routes.get("/mtb/view")
async def view_image(request: Request):
import folder_paths
try:
import folder_paths

filename = request.rel_url.query.get("filename")
if not filename:
return web.Response(status=404)
await rate_limiter.acquire()

filename, output_dir = folder_paths.annotated_filepath(filename)
if filename[0] == "/" or ".." in filename:
return web.Response(status=400)
filename = request.rel_url.query.get("filename")
if not filename:
return web.Response(status=404)

if output_dir is None:
rtype = request.rel_url.query.get("type", "output")
output_dir = folder_paths.get_directory_by_type(rtype)
filename, output_dir = folder_paths.annotated_filepath(filename)
if filename[0] == "/" or ".." in filename:
return web.Response(status=400)

if output_dir is None:
return web.Response(status=400)
if output_dir is None:
rtype = request.rel_url.query.get("type", "output")
output_dir = folder_paths.get_directory_by_type(rtype)

if "subfolder" in request.rel_url.query:
full_output_dir = os.path.join(
output_dir, request.rel_url.query["subfolder"]
)
if (
os.path.commonpath(
(os.path.abspath(full_output_dir), output_dir)
if output_dir is None:
return web.Response(status=400)

if "subfolder" in request.rel_url.query:
full_output_dir = os.path.join(
output_dir, request.rel_url.query["subfolder"]
)
!= output_dir
):
return web.Response(status=403)
output_dir = full_output_dir
if (
os.path.commonpath(
(os.path.abspath(full_output_dir), output_dir)
)
!= output_dir
):
return web.Response(status=403)
output_dir = full_output_dir

filename = os.path.basename(filename)
file = os.path.join(output_dir, filename)
filename = os.path.basename(filename)
file = os.path.join(output_dir, filename)

if not os.path.isfile(file):
return web.Response(status=404)
if not os.path.isfile(file):
return web.Response(status=404)

ret_workflow = request.rel_url.query.get("workflow")

Expand Down Expand Up @@ -509,9 +617,13 @@ async def view_image(request: Request):
width = request.rel_url.query.get("width")
preview_info = (image_format, quality, width)

channel = request.rel_url.query.get("channel")
channel = request.rel_url.query.get("channel")

return await get_image_response(file, filename, preview_info, channel)
return await get_image_response(
file, filename, preview_info, channel
)
finally:
rate_limiter.release()

@PromptServer.instance.routes.get("/mtb/server-info")
async def get_debug(request: Request):
Expand Down
7 changes: 6 additions & 1 deletion endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,12 @@ def ACTIONS_getUserImageFolders():
input_subdirs = [x.name for x in input_dir.iterdir() if x.is_dir()]
output_subdirs = [x.name for x in output_dir.iterdir() if x.is_dir()]

return {"input": input_subdirs, "output": output_subdirs}
return {
"input_root": input_dir.as_posix(),
"input": input_subdirs,
"output": output_subdirs,
"output_root": output_dir.as_posix(),
}


def ACTIONS_getUserVideos(
Expand Down
69 changes: 69 additions & 0 deletions web/comfy_shared.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import { app } from '../../scripts/app.js'
import { api } from '../../scripts/api.js'

if (!window.MTB) {
window.MTB = {}
}
// #region base utils

// - crude uuid
Expand Down Expand Up @@ -276,6 +279,10 @@ export const getNamedWidget = (node, ...names) => {
* @returns {{to:LGraphNode, from:LGraphNode, type:'error' | 'incoming' | 'outgoing'}}
*/
export const nodesFromLink = (node, link) => {
if (typeof link === 'number') {
console.log('Resolving link from id', link)
link = app.graph.links[link]
}
const fromNode = app.graph.getNodeById(link.origin_id)
const toNode = app.graph.getNodeById(link.target_id)

Expand Down Expand Up @@ -1068,6 +1075,66 @@ export const addDocumentation = (

// #endregion

// #region canvas / drawing

// calculate convex hull (Graham)
export function getConvexHull(points) {
if (points.length < 3) return points

// find the bottommost point (and leftmost if tied)
let bottom = 0
for (let i = 1; i < points.length; i++) {
if (
points[i][1] < points[bottom][1] ||
(points[i][1] === points[bottom][1] && points[i][0] < points[bottom][0])
) {
bottom = i
}
}
// swap bottom point to first position
;[points[0], points[bottom]] = [points[bottom], points[0]]

// sort points by polar angle with respect to base point
const basePoint = points[0]
points.sort((a, b) => {
if (a === basePoint) return -1
if (b === basePoint) return 1

const angleA = Math.atan2(a[1] - basePoint[1], a[0] - basePoint[0])
const angleB = Math.atan2(b[1] - basePoint[1], b[0] - basePoint[0])

if (angleA < angleB) return -1
if (angleA > angleB) return 1

// if angles are equal, sort by distance
const distA = (a[0] - basePoint[0]) ** 2 + (a[1] - basePoint[1]) ** 2
const distB = (b[0] - basePoint[0]) ** 2 + (b[1] - basePoint[1]) ** 2
return distA - distB
})

// build convex hull
const stack = [points[0], points[1]]
for (let i = 2; i < points.length; i++) {
while (
stack.length > 1 &&
!isLeftTurn(stack[stack.length - 2], stack[stack.length - 1], points[i])
) {
stack.pop()
}
stack.push(points[i])
}

return stack
}

function isLeftTurn(p1, p2, p3) {
return (
(p2[0] - p1[0]) * (p3[1] - p1[1]) - (p2[1] - p1[1]) * (p3[0] - p1[0]) > 0
)
}

// #endregion

// #region node extensions

/**
Expand Down Expand Up @@ -1142,6 +1209,8 @@ export const runAction = async (name, ...args) => {
const res = await req.json()
return res.result
}

window.MTB.run = runAction
export const getServerInfo = async () => {
const res = await api.fetchApi('/mtb/server-info')
return await res.json()
Expand Down
Loading