Skip to content

Commit b56c74e

Browse files
committed
Support sending/receiving image tiles (crop regions) in custom workflows #2230
* requires a KritaSelection node which defines the crop bbox * applies to the selection mask and KritaCanvas * does not affect KritaImageLayer or KritaMaskLayer * KritaOutput can have arbitrary offsets
1 parent 9b01f6a commit b56c74e

File tree

11 files changed

+179
-31
lines changed

11 files changed

+179
-31
lines changed

ai_diffusion/api.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,13 @@ class InpaintMode(Enum):
123123
custom = 6
124124

125125

126+
class InpaintContext(Enum):
127+
automatic = 0
128+
mask_bounds = 1
129+
entire_image = 2
130+
layer_bounds = 3
131+
132+
126133
class FillMode(Enum):
127134
none = 0
128135
neutral = 1

ai_diffusion/client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from .api import WorkflowInput
1010
from .comfy_workflow import ComfyObjectInfo
11-
from .image import ImageCollection
11+
from .image import ImageCollection, Point
1212
from .properties import Property, ObservableProperties
1313
from .files import FileLibrary, FileFormat
1414
from .style import Style
@@ -49,6 +49,7 @@ class OutputBatchMode(Enum):
4949

5050
class JobInfoOutput(NamedTuple):
5151
name: str = ""
52+
offset: Point = Point(0, 0)
5253
batch_mode: OutputBatchMode = OutputBatchMode.default
5354
resize_canvas: bool = False
5455

ai_diffusion/comfy_client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@
1616
from .client import filter_supported_styles, loras_to_upload
1717
from .comfy_workflow import ComfyObjectInfo
1818
from .files import FileFormat
19-
from .image import Image, ImageCollection
19+
from .image import Image, ImageCollection, Point
2020
from .network import RequestManager, NetworkError
2121
from .websockets.src import websockets
2222
from .style import Styles
2323
from .resources import ControlMode, ResourceId, ResourceKind, Arch
2424
from .resources import CustomNode, UpscalerName, resource_id
2525
from .settings import PerformanceSettings, settings
2626
from .localization import translate as _
27-
from .util import client_logger as log
27+
from .util import client_logger as log, parse_enum
2828
from .workflow import create as create_workflow
2929
from . import platform_tools, resources, util
3030

@@ -906,7 +906,8 @@ def _extract_job_info_output(job_id: str, msg: dict):
906906
if isinstance(info, dict):
907907
result = JobInfoOutput(
908908
name=info.get("name", ""),
909-
batch_mode=OutputBatchMode[info.get("batch_mode", "default")],
909+
offset=Point(info.get("offset_x", 0), info.get("offset_y", 0)),
910+
batch_mode=parse_enum(OutputBatchMode, info.get("batch_mode", "default")),
910911
resize_canvas=info.get("resize_canvas", False),
911912
)
912913
return ClientMessage(ClientEvent.output, job_id, result=result)

ai_diffusion/comfy_workflow.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1352,6 +1352,12 @@ def inputs(self, node_name: str, category="") -> dict[str, list] | None:
13521352
result.update(inputs.get("optional", {}))
13531353
return result
13541354

1355+
def outputs(self, node_name: str) -> list[str]:
1356+
node = self.nodes.get(node_name)
1357+
if node is None:
1358+
return []
1359+
return node.get("output_name", [])
1360+
13551361

13561362
def _convert_ui_workflow(w: dict, node_inputs: ComfyObjectInfo):
13571363
version = w.get("version")

ai_diffusion/custom_workflow.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@
1010
from PyQt5.QtCore import Qt, QObject, QUuid, QAbstractListModel, QSortFilterProxyModel, QModelIndex
1111
from PyQt5.QtCore import QMetaObject, QTimer, pyqtSignal
1212

13-
from .api import WorkflowInput
13+
from .api import WorkflowInput, InpaintContext
1414
from .client import OutputBatchMode, TextOutput, ClientOutput, JobInfoOutput
1515
from .comfy_workflow import ComfyWorkflow, ComfyNode
1616
from .connection import Connection, ConnectionState
17-
from .image import Bounds, Image
17+
from .image import Bounds, Image, Mask
1818
from .jobs import Job, JobParams, JobQueue, JobKind
1919
from .properties import Property, ObservableProperties
2020
from .style import Styles
21-
from .util import base_type_match, user_data_dir, client_logger as log
21+
from .util import base_type_match, parse_enum, user_data_dir, client_logger as log
2222
from .ui import theme
2323
from . import eventloop
2424

@@ -531,6 +531,30 @@ def collect_parameters(self, layers: "LayerManager", bounds: Bounds, animation=F
531531

532532
return params
533533

534+
def prepare_mask(
535+
self,
536+
selection_node: ComfyNode,
537+
mask: Mask | None,
538+
mask_bounds: Bounds | None,
539+
canvas_bounds: Bounds,
540+
):
541+
ctx = selection_node.input("context", "entire_image").replace(" ", "_")
542+
pad = selection_node.input("padding", 0)
543+
if mask and mask_bounds:
544+
match parse_enum(InpaintContext, ctx):
545+
case InpaintContext.entire_image:
546+
bounds = canvas_bounds
547+
case InpaintContext.automatic:
548+
bounds = Bounds.pad(mask.bounds, pad)
549+
case InpaintContext.mask_bounds:
550+
bounds = Bounds.pad(mask_bounds, pad)
551+
case _:
552+
raise ValueError(f"Invalid inpaint context: {ctx}")
553+
bounds = Bounds.clamp(bounds, canvas_bounds.extent)
554+
mask.bounds = mask.bounds.relative_to(bounds)
555+
return mask, bounds
556+
return None, canvas_bounds
557+
534558
def switch_to_web_workflow(self):
535559
self._switch_workflow_bind = self._workflows.rowsInserted.connect(self._set_workflow_index)
536560
self._switch_workflow_timer = QTimer()
@@ -559,6 +583,7 @@ def handle_output(self, job: Job, output: ClientOutput | None):
559583
self.outputs_changed.emit(self.outputs)
560584
elif isinstance(output, JobInfoOutput):
561585
job.params.resize_canvas = output.resize_canvas
586+
job.params.bounds = Bounds(*output.offset, *job.params.bounds.extent)
562587
if output.name:
563588
job.params.name = output.name
564589
match output.batch_mode:

ai_diffusion/image.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ def from_qsize(qsize: QSize):
7373
def largest(a: "Extent", b: "Extent"):
7474
return a if a.width * a.height > b.width * b.height else b
7575

76+
@staticmethod
77+
def min(a: "Extent", b: "Extent"):
78+
return Extent(min(a.width, b.width), min(a.height, b.height))
79+
7680
@staticmethod
7781
def ratio(a: "Extent", b: "Extent"):
7882
return sqrt(a.pixel_count / b.pixel_count)

ai_diffusion/model.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717
from . import eventloop, workflow, util
1818
from .api import ConditioningInput, ControlInput, WorkflowKind, WorkflowInput, SamplingInput
19-
from .api import InpaintMode, InpaintParams, FillMode, ImageInput, CustomWorkflowInput, UpscaleInput
19+
from .api import FillMode, ImageInput, CustomWorkflowInput, UpscaleInput
20+
from .api import InpaintMode, InpaintContext, InpaintParams
2021
from .localization import translate as _
2122
from .util import clamp, ensure, trim_text, client_logger as log
2223
from .settings import ApplyBehavior, ApplyRegionBehavior, GenerationFinishedAction, ImageFileFormat
@@ -472,23 +473,28 @@ async def _generate_custom(self, previous_input: WorkflowInput | None):
472473

473474
try:
474475
wf = ensure(self.custom.graph)
475-
bounds = Bounds(0, 0, *self._doc.extent)
476-
img_input = ImageInput.from_extent(bounds.extent)
477-
img_input.initial_image = self._get_current_image(bounds)
478476
is_live = self.custom.mode is CustomGenerationMode.live
479477
is_anim = self.custom.mode is CustomGenerationMode.animation
480478
seed = self.seed if is_live or self.fixed_seed else workflow.generate_seed()
479+
canvas_bounds = Bounds(0, 0, *self._doc.extent)
480+
bounds = canvas_bounds
481+
mask = None
481482

482-
if next(wf.find(type="ETN_KritaSelection"), None):
483-
mask, _ = self._doc.create_mask_from_selection()
484-
if mask:
485-
img_input.hires_mask = mask.to_image(bounds.extent)
483+
if selection_node := next(wf.find(type="ETN_KritaSelection"), None):
484+
mods = get_selection_modifiers(InpaintMode.fill, self.strength, is_live)
485+
mask, select_bounds = self._doc.create_mask_from_selection(mods.padding, 8, 256)
486+
mask, bounds = self.custom.prepare_mask(selection_node, mask, select_bounds, bounds)
486487

487-
params = self.custom.collect_parameters(self.layers, bounds, is_anim)
488+
img_input = ImageInput.from_extent(bounds.extent)
489+
img_input.initial_image = self._get_current_image(bounds)
490+
img_input.hires_mask = mask.to_image(bounds.extent) if mask else None
491+
492+
params = self.custom.collect_parameters(self.layers, canvas_bounds, is_anim)
488493
input = WorkflowInput(
489494
WorkflowKind.custom,
490495
img_input,
491496
sampling=SamplingInput("custom", "custom", 1, 1000, seed=seed),
497+
inpaint=InpaintParams(InpaintMode.fill, bounds),
492498
custom_workflow=CustomWorkflowInput(wf.root, params),
493499
)
494500
job_params = JobParams(bounds, self.custom.job_name, metadata=self.custom.params)
@@ -646,7 +652,7 @@ def show_preview(self, job_id: str, index: int, name_prefix="Preview"):
646652
image = job.results[index]
647653
bounds = job.params.bounds
648654
if image.extent != bounds.extent:
649-
image = Image.crop(image, Bounds(0, 0, *bounds.extent))
655+
image = Image.crop(image, Bounds(0, 0, *Extent.min(bounds.extent, image.extent)))
650656
if self._layer and self._layer.was_removed:
651657
self._layer = None # layer was removed by user
652658
if self._layer is not None:
@@ -950,13 +956,6 @@ def edit_style(self) -> Style | None:
950956
return None
951957

952958

953-
class InpaintContext(Enum):
954-
automatic = 0
955-
mask_bounds = 1
956-
entire_image = 2
957-
layer_bounds = 3
958-
959-
960959
class CustomInpaint(QObject, ObservableProperties):
961960
mode = Property(InpaintMode.automatic, persist=True)
962961
fill = Property(FillMode.neutral, persist=True)

ai_diffusion/util.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
T = TypeVar("T")
1818
R = TypeVar("R")
19+
E = TypeVar("E", bound=Enum)
1920
QOBJECT = TypeVar("QOBJECT", bound=QObject)
2021

2122
plugin_dir = dir = Path(__file__).parent
@@ -92,6 +93,15 @@ def ensure(value: Optional[T], msg="") -> T:
9293
return value
9394

9495

96+
def parse_enum(enum_class: type[E], value: str, default: E | None = None) -> E:
97+
try:
98+
return enum_class[value]
99+
except KeyError:
100+
if default is not None:
101+
return default
102+
raise ValueError(f"Invalid value '{value}' for enum {enum_class.__name__}")
103+
104+
95105
def maybe(func: Callable[[T], R], value: Optional[T]) -> Optional[R]:
96106
if value is not None:
97107
return func(value)

ai_diffusion/workflow.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1328,6 +1328,7 @@ def expand_custom(
13281328
w: ComfyWorkflow,
13291329
input: CustomWorkflowInput,
13301330
images: ImageInput,
1331+
bounds: Bounds,
13311332
seed: int,
13321333
models: ClientModels,
13331334
):
@@ -1373,6 +1374,8 @@ def get_param(node: ComfyNode, expected_type: type | tuple[type, type] | None =
13731374
image = ensure(images.initial_image)
13741375
outputs[node.output(0)] = w.solid_mask(image.extent, 1.0)
13751376
outputs[node.output(1)] = images.hires_mask is not None
1377+
outputs[node.output(2)] = bounds.x
1378+
outputs[node.output(3)] = bounds.y
13761379
case "ETN_Parameter":
13771380
outputs[node.output(0)] = get_param(node)
13781381
case "ETN_KritaImageLayer":
@@ -1689,7 +1692,14 @@ def create(i: WorkflowInput, models: ClientModels, comfy_mode=ComfyRunMode.serve
16891692
)
16901693
elif i.kind is WorkflowKind.custom:
16911694
seed = ensure(i.sampling).seed
1692-
return expand_custom(workflow, ensure(i.custom_workflow), ensure(i.images), seed, models)
1695+
return expand_custom(
1696+
workflow,
1697+
ensure(i.custom_workflow),
1698+
ensure(i.images),
1699+
ensure(i.inpaint).target_bounds,
1700+
seed,
1701+
models,
1702+
)
16931703
else:
16941704
raise ValueError(f"Unsupported workflow kind: {i.kind}")
16951705

tests/test_custom_workflow.py

Lines changed: 89 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,43 @@ def test_parameter_order():
349349
]
350350

351351

352+
def test_prepare_mask():
353+
connection_workflows = {"connection1": make_dummy_graph(42)}
354+
connection = create_mock_connection(connection_workflows)
355+
workflows = WorkflowCollection(connection)
356+
357+
jobs = JobQueue()
358+
workspace = CustomWorkspace(workflows, dummy_generate, jobs)
359+
360+
mask = Mask.rectangle(Bounds(10, 10, 40, 40), 0)
361+
canvas_bounds = Bounds(0, 0, 100, 100)
362+
selection_bounds = Bounds(12, 12, 34, 34)
363+
selection_node = ComfyNode(0, "ETN_Selection", {"context": "automatic", "padding": 3})
364+
365+
prepared_mask, bounds = workspace.prepare_mask(
366+
selection_node, copy(mask), selection_bounds, canvas_bounds
367+
)
368+
assert bounds == Bounds(6, 6, 48, 48) # mask.bounds + padding // multiple of 8
369+
assert prepared_mask is not None
370+
assert prepared_mask.bounds == Bounds(4, 4, 40, 40)
371+
372+
selection_node.inputs["context"] = "mask_bounds"
373+
prepared_mask, bounds = workspace.prepare_mask(
374+
selection_node, copy(mask), selection_bounds, canvas_bounds
375+
)
376+
assert bounds == Bounds(9, 9, 40, 40) # selection_bounds + padding // multiple of 8
377+
assert prepared_mask is not None
378+
assert prepared_mask.bounds == Bounds(1, 1, 40, 40)
379+
380+
selection_node.inputs["context"] = "entire_image"
381+
prepared_mask, bounds = workspace.prepare_mask(
382+
selection_node, copy(mask), selection_bounds, canvas_bounds
383+
)
384+
assert bounds == canvas_bounds
385+
assert prepared_mask is not None
386+
assert prepared_mask.bounds == mask.bounds
387+
388+
352389
def test_text_output():
353390
connection_workflows = {"connection1": make_dummy_graph(42)}
354391
connection = create_mock_connection(connection_workflows, ComfyObjectInfo({}))
@@ -490,10 +527,7 @@ def test_expand():
490527
}
491528

492529
w = ComfyWorkflow()
493-
w = workflow.expand_custom(w, input, images, 123, models)
494-
495-
def find_img_id(image: Image):
496-
return next((id for id, img in w.images.items() if img == image), "not-found")
530+
w = workflow.expand_custom(w, input, images, Bounds(0, 0, 4, 4), 123, models)
497531

498532
expected = [
499533
ComfyNode(1, "ETN_LoadImageCache", {"id": img_id(images.initial_image)}),
@@ -554,7 +588,7 @@ def test_expand_animation():
554588
models = ClientModels()
555589

556590
w = ComfyWorkflow()
557-
w = workflow.expand_custom(w, input, images, 123, models)
591+
w = workflow.expand_custom(w, input, images, Bounds(0, 0, 4, 4), 123, models)
558592

559593
expected = [
560594
ComfyNode(1, "ETN_LoadImageCache", {"id": img_id(in_images[0])}),
@@ -582,3 +616,53 @@ def test_expand_animation():
582616
]
583617
for node in expected:
584618
assert node in w, f"Node {node} not found in\n{json.dumps(w.root, indent=2)}"
619+
620+
621+
def test_expand_selection():
622+
ext = ComfyWorkflow()
623+
select, select_active, off_x, off_y = ext.add(
624+
"ETN_KritaSelection", 4, context="automatic", padding=2
625+
)
626+
canvas, width, height, seed = ext.add("ETN_KritaCanvas", 4)
627+
ext.add(
628+
"Sink",
629+
1,
630+
image=canvas,
631+
width=width,
632+
height=height,
633+
mask=select,
634+
has_selection=select_active,
635+
offset_x=off_x,
636+
offset_y=off_y,
637+
)
638+
639+
params = {}
640+
input = CustomWorkflowInput(workflow=ext.root, params=params)
641+
images = ImageInput.from_extent(Extent(8, 16))
642+
images.initial_image = Image.create(Extent(8, 16), Qt.GlobalColor.red)
643+
images.hires_mask = Image.create(Extent(8, 16), Qt.GlobalColor.green)
644+
bounds = Bounds(2, 3, 8, 16) # selection from (2,2) to (6,6)
645+
models = ClientModels()
646+
647+
w = ComfyWorkflow()
648+
w = workflow.expand_custom(w, input, images, bounds, 123, models)
649+
650+
expected = [
651+
ComfyNode(1, "ETN_LoadImageCache", {"id": img_id(images.hires_mask)}),
652+
ComfyNode(2, "ETN_LoadImageCache", {"id": img_id(images.initial_image)}),
653+
ComfyNode(
654+
3,
655+
"Sink",
656+
{
657+
"image": Output(2, 0),
658+
"width": 8,
659+
"height": 16,
660+
"mask": Output(1, 1),
661+
"has_selection": True,
662+
"offset_x": 2,
663+
"offset_y": 3,
664+
},
665+
),
666+
]
667+
for node in expected:
668+
assert node in w, f"Node {node} not found in\n{json.dumps(w.root, indent=2)}"

0 commit comments

Comments
 (0)