|
8 | 8 |
|
9 | 9 | from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput |
10 | 10 | from invokeai.app.util.misc import SEED_MAX, get_random_seed |
| 11 | +from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint |
11 | 12 | from invokeai.backend.image_util.lama import LaMA |
12 | 13 | from invokeai.backend.image_util.patchmatch import PatchMatch |
13 | 14 |
|
14 | 15 | from ..models.image import ImageCategory, ResourceOrigin |
15 | 16 | from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation |
| 17 | +from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES |
16 | 18 |
|
17 | 19 |
|
18 | 20 | def infill_methods() -> list[str]: |
19 | | - methods = [ |
20 | | - "tile", |
21 | | - "solid", |
22 | | - "lama", |
23 | | - ] |
| 21 | + methods = ["tile", "solid", "lama", "cv2"] |
24 | 22 | if PatchMatch.patchmatch_available(): |
25 | 23 | methods.insert(0, "patchmatch") |
26 | 24 | return methods |
@@ -49,6 +47,10 @@ def infill_patchmatch(im: Image.Image) -> Image.Image: |
49 | 47 | return im_patched |
50 | 48 |
|
51 | 49 |
|
| 50 | +def infill_cv2(im: Image.Image) -> Image.Image: |
| 51 | + return cv2_inpaint(im) |
| 52 | + |
| 53 | + |
52 | 54 | def get_tile_images(image: np.ndarray, width=8, height=8): |
53 | 55 | _nrows, _ncols, depth = image.shape |
54 | 56 | _strides = image.strides |
@@ -194,15 +196,35 @@ class InfillPatchMatchInvocation(BaseInvocation): |
194 | 196 | """Infills transparent areas of an image using the PatchMatch algorithm""" |
195 | 197 |
|
196 | 198 | image: ImageField = InputField(description="The image to infill") |
| 199 | + downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill") |
| 200 | + resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") |
197 | 201 |
|
198 | 202 | def invoke(self, context: InvocationContext) -> ImageOutput: |
199 | | - image = context.services.images.get_pil_image(self.image.image_name) |
| 203 | + image = context.services.images.get_pil_image(self.image.image_name).convert("RGBA") |
| 204 | + |
| 205 | + resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] |
| 206 | + |
| 207 | + infill_image = image.copy() |
| 208 | + width = int(image.width / self.downscale) |
| 209 | + height = int(image.height / self.downscale) |
| 210 | + infill_image = infill_image.resize( |
| 211 | + (width, height), |
| 212 | + resample=resample_mode, |
| 213 | + ) |
200 | 214 |
|
201 | 215 | if PatchMatch.patchmatch_available(): |
202 | | - infilled = infill_patchmatch(image.copy()) |
| 216 | + infilled = infill_patchmatch(infill_image) |
203 | 217 | else: |
204 | 218 | raise ValueError("PatchMatch is not available on this system") |
205 | 219 |
|
| 220 | + infilled = infilled.resize( |
| 221 | + (image.width, image.height), |
| 222 | + resample=resample_mode, |
| 223 | + ) |
| 224 | + |
| 225 | + infilled.paste(image, (0, 0), mask=image.split()[-1]) |
| 226 | + # image.paste(infilled, (0, 0), mask=image.split()[-1]) |
| 227 | + |
206 | 228 | image_dto = context.services.images.create( |
207 | 229 | image=infilled, |
208 | 230 | image_origin=ResourceOrigin.INTERNAL, |
@@ -245,3 +267,30 @@ def invoke(self, context: InvocationContext) -> ImageOutput: |
245 | 267 | width=image_dto.width, |
246 | 268 | height=image_dto.height, |
247 | 269 | ) |
| 270 | + |
| 271 | + |
| 272 | +@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint") |
| 273 | +class CV2InfillInvocation(BaseInvocation): |
| 274 | + """Infills transparent areas of an image using OpenCV Inpainting""" |
| 275 | + |
| 276 | + image: ImageField = InputField(description="The image to infill") |
| 277 | + |
| 278 | + def invoke(self, context: InvocationContext) -> ImageOutput: |
| 279 | + image = context.services.images.get_pil_image(self.image.image_name) |
| 280 | + |
| 281 | + infilled = infill_cv2(image.copy()) |
| 282 | + |
| 283 | + image_dto = context.services.images.create( |
| 284 | + image=infilled, |
| 285 | + image_origin=ResourceOrigin.INTERNAL, |
| 286 | + image_category=ImageCategory.GENERAL, |
| 287 | + node_id=self.id, |
| 288 | + session_id=context.graph_execution_state_id, |
| 289 | + is_intermediate=self.is_intermediate, |
| 290 | + ) |
| 291 | + |
| 292 | + return ImageOutput( |
| 293 | + image=ImageField(image_name=image_dto.image_name), |
| 294 | + width=image_dto.width, |
| 295 | + height=image_dto.height, |
| 296 | + ) |
0 commit comments