-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathnodes.py
More file actions
149 lines (125 loc) · 5.49 KB
/
nodes.py
File metadata and controls
149 lines (125 loc) · 5.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# nodes.py for ComfyUI-Kontext-API
"""
Custom ComfyUI nodes for the Fal Kontext API.
"""
import os
from typing import Any, Dict
import torch
import numpy as np
from PIL import Image
from .api import call_kontext_api, FalKontextMaxMultiImageNode
# Debug flag - set to False to disable debug output
DEBUG = False
# ComfyUI imports - use proper error handling
def pil2tensor(image):
"""Convert PIL image to ComfyUI tensor format."""
# Convert PIL to numpy array
img_array = np.array(image).astype(np.float32) / 255.0
# Handle different image formats
if len(img_array.shape) == 2: # Grayscale
img_array = np.stack([img_array] * 3, axis=-1)
elif len(img_array.shape) == 3:
if img_array.shape[2] == 1: # Single channel
img_array = np.repeat(img_array, 3, axis=2)
elif img_array.shape[2] == 4: # RGBA
img_array = img_array[:, :, :3] # Drop alpha channel
# Convert to tensor in ComfyUI format (H, W, C)
tensor = torch.from_numpy(img_array)
if DEBUG:
print(f"[pil2tensor] Output shape: {tensor.shape}, dtype: {tensor.dtype}, min: {tensor.min():.3f}, max: {tensor.max():.3f}")
return tensor
class KontextAPINode:
"""
Main node for calling the Fal Kontext API.
Inputs: prompt, image, seed, aspect_ratio, disable prompt enhancement.
Outputs: single output image, info, passed_nsfw_filtering.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"prompt": ("STRING", {"multiline": True, "default": ""}),
"image": ("IMAGE", {}),
"seed": ("INT", {"default": -1, "min": -1, "max": 2147483647}),
"aspect_ratio": ([
"Match input image",
"21:9",
"16:9",
"4:3",
"3:2",
"1:1",
"2:3",
"3:4",
"9:16",
"9:21"
], {"default": "Match input image"}),
},
"optional": {
"disable_prompt_enhancement": ("BOOLEAN", {"default": False}),
},
}
RETURN_TYPES = ("IMAGE", "STRING", "BOOLEAN")
RETURN_NAMES = ("image", "info", "passed_nsfw_filtering")
FUNCTION = "execute"
CATEGORY = "image/generation"
def execute(self, prompt, image, seed, aspect_ratio, disable_prompt_enhancement=False):
if DEBUG:
print(f"[KontextAPINode] Input image shape: {image.shape}, dtype: {image.dtype}")
print(f"[KontextAPINode] Prompt: {prompt}")
print(f"[KontextAPINode] Seed: {seed}")
print(f"[KontextAPINode] Aspect Ratio: {aspect_ratio}")
# Validate inputs
if not prompt.strip():
return (image, "Error: Prompt cannot be empty", True) # Default to True when error
# Handle random seed
if seed == -1:
import random
seed = random.randint(0, 2147483647)
if DEBUG:
print(f"[KontextAPINode] Generated random seed: {seed}")
# Set values for API call
guidance_scale = 3.5 # Default
output_format = "jpeg" # Not used, but required by API
raw = disable_prompt_enhancement
image_prompt_strength = 0.1 # Default
num_inference_steps = 28 # Default
safety_tolerance = 6 # Max safety
num_images = 1
try:
output_images, info, passed_nsfw_filtering = call_kontext_api(
prompt, image, aspect_ratio, num_images, seed, guidance_scale,
output_format, raw, image_prompt_strength, num_inference_steps, safety_tolerance
)
# Convert PIL image to tensor for ComfyUI
if output_images and len(output_images) > 0:
# Take only the first image
output_image = output_images[0]
if DEBUG:
print(f"[KontextAPINode] Got {len(output_images)} images from API, using first one")
print(f"[KontextAPINode] Output PIL image size: {output_image.size}, mode: {output_image.mode}")
print(f"[KontextAPINode] Passed NSFW filtering: {passed_nsfw_filtering}")
# Convert to tensor
output_tensor = pil2tensor(output_image)
# Add batch dimension for ComfyUI
output_tensor = output_tensor.unsqueeze(0)
if DEBUG:
print(f"[KontextAPINode] Final output tensor shape: {output_tensor.shape}")
return (output_tensor, info, passed_nsfw_filtering)
else:
# Return original image if no output
print(f"[KontextAPINode] ERROR: No images returned from API")
return (image, f"Error: No images returned from API\n{info}", True)
except Exception as e:
import traceback
error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
print(f"[KontextAPINode] ERROR: {error_msg}")
return (image, error_msg, True)
# Add these for ComfyUI node discovery
NODE_CLASS_MAPPINGS = {
"KontextAPINode": KontextAPINode,
"FalKontextMaxMultiImageNode": FalKontextMaxMultiImageNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"KontextAPINode": "Fal Kontext API",
"FalKontextMaxMultiImageNode": "Fal Kontext Max (Multi-Image)",
}