Skip to content

Commit ce3c559

Browse files
authored
Allow custom workflows to resize the canvas (#2217)
1 parent 2fb70dc commit ce3c559

File tree

6 files changed

+64
-5
lines changed

6 files changed

+64
-5
lines changed

ai_diffusion/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
" https://github.com/Acly/krita-ai-diffusion/releases"
1313
)
1414

15-
1615
# The following imports depend on the code running inside Krita, so the cannot be imported in tests.
1716
if importlib.util.find_spec("krita"):
1817
from .extension import AIToolsExtension as AIToolsExtension

ai_diffusion/client.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,16 @@ class TextOutput(NamedTuple):
4040
mime: str
4141

4242

43+
class ResizeCommand(NamedTuple):
44+
resize_canvas: bool = False
45+
46+
4347
class SharedWorkflow(NamedTuple):
4448
publisher: str
4549
workflow: dict
4650

4751

48-
ClientOutput = dict | SharedWorkflow | TextOutput
52+
ClientOutput = dict | SharedWorkflow | TextOutput | ResizeCommand
4953

5054

5155
class ClientMessage(NamedTuple):

ai_diffusion/comfy_client.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,14 @@
1111

1212
from .api import WorkflowInput
1313
from .client import Client, CheckpointInfo, ClientMessage, ClientEvent, DeviceInfo, ClientModels
14-
from .client import SharedWorkflow, TranslationPackage, ClientFeatures, ClientJobQueue, TextOutput
14+
from .client import (
15+
SharedWorkflow,
16+
TranslationPackage,
17+
ClientFeatures,
18+
ClientJobQueue,
19+
TextOutput,
20+
ResizeCommand,
21+
)
1522
from .client import Quantization, MissingResources, filter_supported_styles, loras_to_upload
1623
from .comfy_workflow import ComfyObjectInfo
1724
from .files import FileFormat
@@ -385,6 +392,9 @@ async def _listen_websocket(self, websocket: websockets.ClientConnection):
385392
text_output = _extract_text_output(job.id, msg)
386393
if text_output is not None:
387394
await self._messages.put(text_output)
395+
resize_cmd = _extract_resize_output(job.id, msg)
396+
if resize_cmd is not None:
397+
await self._messages.put(resize_cmd)
388398
pose_json = _extract_pose_json(msg)
389399
if pose_json is not None:
390400
result = pose_json
@@ -890,3 +900,25 @@ def _extract_text_output(job_id: str, msg: dict):
890900
except Exception as e:
891901
log.warning(f"Error processing message, error={str(e)}, msg={msg}")
892902
return None
903+
904+
905+
def _extract_resize_output(job_id: str, msg: dict):
906+
"""Extract a Krita canvas resize toggle encoded directly in the UI output."""
907+
try:
908+
output = msg["data"]["output"]
909+
if output is None:
910+
return None
911+
912+
resize = output.get("resize_canvas")
913+
if isinstance(resize, list):
914+
active = any(bool(item) for item in resize)
915+
else:
916+
active = bool(resize)
917+
918+
if not active:
919+
return None
920+
921+
return ClientMessage(ClientEvent.output, job_id, result=ResizeCommand(True))
922+
except Exception as e:
923+
log.warning(f"Error processing Krita resize output: {e}, msg={msg}")
924+
return None

ai_diffusion/document.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ def get_image(
5050
def resize(self, extent: Extent):
5151
raise NotImplementedError
5252

53+
def resize_canvas(self, width: int, height: int):
54+
"""Resize the underlying canvas if supported by the implementation."""
55+
pass
56+
5357
def annotate(self, key: str, value: QByteArray):
5458
pass
5559

@@ -230,6 +234,9 @@ def resize(self, extent: Extent):
230234
res = self._doc.resolution()
231235
self._doc.scaleImage(extent.width, extent.height, res, res, "Bilinear")
232236

237+
def resize_canvas(self, width: int, height: int):
238+
self._doc.resizeImage(0, 0, width, height)
239+
233240
def annotate(self, key: str, value: QByteArray):
234241
self._doc.setAnnotation(f"ai_diffusion/{key}", f"AI Diffusion Plugin: {key}", value)
235242

ai_diffusion/jobs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class JobParams:
5656
is_layered: bool = False
5757
frame: tuple[int, int, int] = (0, 0, 0)
5858
animation_id: str = ""
59+
resize_canvas: bool = False
5960

6061
@staticmethod
6162
def from_dict(data: dict[str, Any]):

ai_diffusion/model.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,14 @@
2323
from .settings import settings
2424
from .network import NetworkError
2525
from .image import Extent, Image, Mask, Bounds, DummyImage
26-
from .client import Client, ClientMessage, ClientEvent, ClientOutput, is_style_supported
26+
from .client import (
27+
Client,
28+
ClientMessage,
29+
ClientEvent,
30+
ClientOutput,
31+
is_style_supported,
32+
ResizeCommand,
33+
)
2734
from .client import filter_supported_styles, resolve_arch
2835
from .custom_workflow import CustomWorkspace, WorkflowCollection, CustomGenerationMode
2936
from .document import Document, KritaDocument
@@ -590,7 +597,10 @@ def handle_message(self, message: ClientMessage):
590597
self.progress_kind = ProgressKind.upload
591598
self.progress = message.progress
592599
elif message.event is ClientEvent.output:
593-
self.custom.show_output(message.result)
600+
if isinstance(message.result, ResizeCommand):
601+
self._apply_resize_command(message.result, job)
602+
else:
603+
self.custom.show_output(message.result)
594604
elif message.event is ClientEvent.finished:
595605
if message.error: # successful jobs may have encountered some warnings
596606
self.report_error(Error.from_string(message.error, ErrorKind.warning))
@@ -630,6 +640,9 @@ def _finish_job(self, job: Job, event: ClientEvent):
630640
self.jobs.notify_cancelled(job)
631641
self.progress = 0
632642

643+
def _apply_resize_command(self, cmd: ResizeCommand, job: Job):
644+
job.params.resize_canvas = cmd.resize_canvas
645+
633646
def update_preview(self):
634647
if selection := self.jobs.selection:
635648
self.show_preview(selection[0].job, selection[0].image)
@@ -673,6 +686,9 @@ def apply_result(
673686
region_behavior=ApplyRegionBehavior.layer_group,
674687
prefix="",
675688
):
689+
if params.resize_canvas and self.document.extent != image.extent:
690+
self.document.resize_canvas(*image.extent)
691+
676692
bounds = Bounds(*params.bounds.offset, *image.extent)
677693
if len(params.regions) == 0 or region_behavior is ApplyRegionBehavior.none:
678694
if behavior is ApplyBehavior.replace:

0 commit comments

Comments
 (0)