Skip to content

Commit 44fe868

Browse files
authored
Merge branch 'comfyanonymous:master' into master
2 parents 73175cf + 61e7767 commit 44fe868

File tree

3 files changed

+301
-23
lines changed

3 files changed

+301
-23
lines changed

comfy_extras/nodes_mask.py

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
import torch
2+
3+
from nodes import MAX_RESOLUTION
4+
5+
class LatentCompositeMasked:
6+
@classmethod
7+
def INPUT_TYPES(s):
8+
return {
9+
"required": {
10+
"destination": ("LATENT",),
11+
"source": ("LATENT",),
12+
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
13+
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
14+
},
15+
"optional": {
16+
"mask": ("MASK",),
17+
}
18+
}
19+
RETURN_TYPES = ("LATENT",)
20+
FUNCTION = "composite"
21+
22+
CATEGORY = "latent"
23+
24+
def composite(self, destination, source, x, y, mask = None):
25+
output = destination.copy()
26+
destination = destination["samples"].clone()
27+
source = source["samples"]
28+
29+
x = max(-source.shape[3] * 8, min(x, destination.shape[3] * 8))
30+
y = max(-source.shape[2] * 8, min(y, destination.shape[2] * 8))
31+
32+
left, top = (x // 8, y // 8)
33+
right, bottom = (left + source.shape[3], top + source.shape[2],)
34+
35+
36+
if mask is None:
37+
mask = torch.ones_like(source)
38+
else:
39+
mask = mask.clone()
40+
mask = torch.nn.functional.interpolate(mask[None, None], size=(source.shape[2], source.shape[3]), mode="bilinear")
41+
mask = mask.repeat((source.shape[0], source.shape[1], 1, 1))
42+
43+
# calculate the bounds of the source that will be overlapping the destination
44+
# this prevents the source trying to overwrite latent pixels that are out of bounds
45+
# of the destination
46+
visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),)
47+
48+
mask = mask[:, :, :visible_height, :visible_width]
49+
inverse_mask = torch.ones_like(mask) - mask
50+
51+
source_portion = mask * source[:, :, :visible_height, :visible_width]
52+
destination_portion = inverse_mask * destination[:, :, top:bottom, left:right]
53+
54+
destination[:, :, top:bottom, left:right] = source_portion + destination_portion
55+
56+
output["samples"] = destination
57+
58+
return (output,)
59+
60+
class MaskToImage:
61+
@classmethod
62+
def INPUT_TYPES(s):
63+
return {
64+
"required": {
65+
"mask": ("MASK",),
66+
}
67+
}
68+
69+
CATEGORY = "mask"
70+
71+
RETURN_TYPES = ("IMAGE",)
72+
FUNCTION = "mask_to_image"
73+
74+
def mask_to_image(self, mask):
75+
result = mask[None, :, :, None].expand(-1, -1, -1, 3)
76+
return (result,)
77+
78+
class ImageToMask:
79+
@classmethod
80+
def INPUT_TYPES(s):
81+
return {
82+
"required": {
83+
"image": ("IMAGE",),
84+
"channel": (["red", "green", "blue"],),
85+
}
86+
}
87+
88+
CATEGORY = "mask"
89+
90+
RETURN_TYPES = ("MASK",)
91+
FUNCTION = "image_to_mask"
92+
93+
def image_to_mask(self, image, channel):
94+
channels = ["red", "green", "blue"]
95+
mask = image[0, :, :, channels.index(channel)]
96+
return (mask,)
97+
98+
class SolidMask:
99+
@classmethod
100+
def INPUT_TYPES(cls):
101+
return {
102+
"required": {
103+
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
104+
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
105+
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
106+
}
107+
}
108+
109+
CATEGORY = "mask"
110+
111+
RETURN_TYPES = ("MASK",)
112+
113+
FUNCTION = "solid"
114+
115+
def solid(self, value, width, height):
116+
out = torch.full((height, width), value, dtype=torch.float32, device="cpu")
117+
return (out,)
118+
119+
class InvertMask:
120+
@classmethod
121+
def INPUT_TYPES(cls):
122+
return {
123+
"required": {
124+
"mask": ("MASK",),
125+
}
126+
}
127+
128+
CATEGORY = "mask"
129+
130+
RETURN_TYPES = ("MASK",)
131+
132+
FUNCTION = "invert"
133+
134+
def invert(self, mask):
135+
out = 1.0 - mask
136+
return (out,)
137+
138+
class CropMask:
139+
@classmethod
140+
def INPUT_TYPES(cls):
141+
return {
142+
"required": {
143+
"mask": ("MASK",),
144+
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
145+
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
146+
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
147+
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
148+
}
149+
}
150+
151+
CATEGORY = "mask"
152+
153+
RETURN_TYPES = ("MASK",)
154+
155+
FUNCTION = "crop"
156+
157+
def crop(self, mask, x, y, width, height):
158+
out = mask[y:y + height, x:x + width]
159+
return (out,)
160+
161+
class MaskComposite:
162+
@classmethod
163+
def INPUT_TYPES(cls):
164+
return {
165+
"required": {
166+
"destination": ("MASK",),
167+
"source": ("MASK",),
168+
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
169+
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
170+
"operation": (["multiply", "add", "subtract"],),
171+
}
172+
}
173+
174+
CATEGORY = "mask"
175+
176+
RETURN_TYPES = ("MASK",)
177+
178+
FUNCTION = "combine"
179+
180+
def combine(self, destination, source, x, y, operation):
181+
output = destination.clone()
182+
183+
left, top = (x, y,)
184+
right, bottom = (min(left + source.shape[1], destination.shape[1]), min(top + source.shape[0], destination.shape[0]))
185+
visible_width, visible_height = (right - left, bottom - top,)
186+
187+
source_portion = source[:visible_height, :visible_width]
188+
destination_portion = destination[top:bottom, left:right]
189+
190+
match operation:
191+
case "multiply":
192+
output[top:bottom, left:right] = destination_portion * source_portion
193+
case "add":
194+
output[top:bottom, left:right] = destination_portion + source_portion
195+
case "subtract":
196+
output[top:bottom, left:right] = destination_portion - source_portion
197+
198+
output = torch.clamp(output, 0.0, 1.0)
199+
200+
return (output,)
201+
202+
class FeatherMask:
203+
@classmethod
204+
def INPUT_TYPES(cls):
205+
return {
206+
"required": {
207+
"mask": ("MASK",),
208+
"left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
209+
"top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
210+
"right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
211+
"bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
212+
}
213+
}
214+
215+
CATEGORY = "mask"
216+
217+
RETURN_TYPES = ("MASK",)
218+
219+
FUNCTION = "feather"
220+
221+
def feather(self, mask, left, top, right, bottom):
222+
output = mask.clone()
223+
224+
left = min(left, output.shape[1])
225+
right = min(right, output.shape[1])
226+
top = min(top, output.shape[0])
227+
bottom = min(bottom, output.shape[0])
228+
229+
for x in range(left):
230+
feather_rate = (x + 1.0) / left
231+
output[:, x] *= feather_rate
232+
233+
for x in range(right):
234+
feather_rate = (x + 1) / right
235+
output[:, -x] *= feather_rate
236+
237+
for y in range(top):
238+
feather_rate = (y + 1) / top
239+
output[y, :] *= feather_rate
240+
241+
for y in range(bottom):
242+
feather_rate = (y + 1) / bottom
243+
output[-y, :] *= feather_rate
244+
245+
return (output,)
246+
247+
248+
249+
NODE_CLASS_MAPPINGS = {
250+
"LatentCompositeMasked": LatentCompositeMasked,
251+
"MaskToImage": MaskToImage,
252+
"ImageToMask": ImageToMask,
253+
"SolidMask": SolidMask,
254+
"InvertMask": InvertMask,
255+
"CropMask": CropMask,
256+
"MaskComposite": MaskComposite,
257+
"FeatherMask": FeatherMask,
258+
}
259+
260+
NODE_DISPLAY_NAME_MAPPINGS = {
261+
"ImageToMask": "Convert Image to Mask",
262+
"MaskToImage": "Convert Mask to Image",
263+
}

nodes.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,7 @@ def compute_vars(input):
872872
"filename": file,
873873
"subfolder": subfolder,
874874
"type": self.type
875-
});
875+
})
876876
counter += 1
877877

878878
return { "ui": { "images": results } }
@@ -933,7 +933,7 @@ def INPUT_TYPES(s):
933933
"channel": (["alpha", "red", "green", "blue"], ),}
934934
}
935935

936-
CATEGORY = "image"
936+
CATEGORY = "mask"
937937

938938
RETURN_TYPES = ("MASK",)
939939
FUNCTION = "load_image"
@@ -1193,3 +1193,4 @@ def init_custom_nodes():
11931193
load_custom_nodes()
11941194
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
11951195
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
1196+
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))

web/extensions/core/widgetInputs.js

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -159,27 +159,31 @@ app.registerExtension({
159159
const r = origOnInputDblClick ? origOnInputDblClick.apply(this, arguments) : undefined;
160160

161161
const input = this.inputs[slot];
162-
if (input.widget && !input[ignoreDblClick]) {
163-
const node = LiteGraph.createNode("PrimitiveNode");
164-
app.graph.add(node);
165-
166-
// Calculate a position that wont directly overlap another node
167-
const pos = [this.pos[0] - node.size[0] - 30, this.pos[1]];
168-
while (isNodeAtPos(pos)) {
169-
pos[1] += LiteGraph.NODE_TITLE_HEIGHT;
170-
}
162+
if (!input.widget || !input[ignoreDblClick])// Not a widget input or already handled input
163+
{
164+
if (!(input.type in ComfyWidgets)) return r;//also Not a ComfyWidgets input (do nothing)
165+
}
171166

172-
node.pos = pos;
173-
node.connect(0, this, slot);
174-
node.title = input.name;
167+
// Create a primitive node
168+
const node = LiteGraph.createNode("PrimitiveNode");
169+
app.graph.add(node);
175170

176-
// Prevent adding duplicates due to triple clicking
177-
input[ignoreDblClick] = true;
178-
setTimeout(() => {
179-
delete input[ignoreDblClick];
180-
}, 300);
171+
// Calculate a position that wont directly overlap another node
172+
const pos = [this.pos[0] - node.size[0] - 30, this.pos[1]];
173+
while (isNodeAtPos(pos)) {
174+
pos[1] += LiteGraph.NODE_TITLE_HEIGHT;
181175
}
182176

177+
node.pos = pos;
178+
node.connect(0, this, slot);
179+
node.title = input.name;
180+
181+
// Prevent adding duplicates due to triple clicking
182+
input[ignoreDblClick] = true;
183+
setTimeout(() => {
184+
delete input[ignoreDblClick];
185+
}, 300);
186+
183187
return r;
184188
};
185189
},
@@ -233,7 +237,9 @@ app.registerExtension({
233237
// Fires before the link is made allowing us to reject it if it isn't valid
234238

235239
// No widget, we cant connect
236-
if (!input.widget) return false;
240+
if (!input.widget) {
241+
if (!(input.type in ComfyWidgets)) return false;
242+
}
237243

238244
if (this.outputs[slot].links?.length) {
239245
return this.#isValidConnection(input);
@@ -252,9 +258,17 @@ app.registerExtension({
252258
const input = theirNode.inputs[link.target_slot];
253259
if (!input) return;
254260

255-
const widget = input.widget;
256-
const { type, linkType } = getWidgetType(widget.config);
257261

262+
var _widget;
263+
if (!input.widget) {
264+
if (!(input.type in ComfyWidgets)) return;
265+
_widget = { "name": input.name, "config": [input.type, {}] }//fake widget
266+
} else {
267+
_widget = input.widget;
268+
}
269+
270+
const widget = _widget;
271+
const { type, linkType } = getWidgetType(widget.config);
258272
// Update our output to restrict to the widget type
259273
this.outputs[0].type = linkType;
260274
this.outputs[0].name = type;
@@ -274,7 +288,7 @@ app.registerExtension({
274288
if (type in ComfyWidgets) {
275289
widget = (ComfyWidgets[type](this, "value", inputData, app) || {}).widget;
276290
} else {
277-
widget = this.addWidget(type, "value", null, () => {}, {});
291+
widget = this.addWidget(type, "value", null, () => { }, {});
278292
}
279293

280294
if (node?.widgets && widget) {

0 commit comments

Comments
 (0)