Skip to content

Commit f4a3f1f

Browse files
committed
Merge branch 'gempoll' into gempoll-docker
2 parents b0cb2d8 + a98b440 commit f4a3f1f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1898
-414
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@ venv/
1414
/web/extensions/*
1515
!/web/extensions/logging.js.example
1616
!/web/extensions/core/
17-
/tests-ui/data/object_info.json
17+
/tests-ui/data/object_info.json
18+
/user/

.vscode/settings.json

Lines changed: 0 additions & 9 deletions
This file was deleted.

app/app_settings.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import os
2+
import json
3+
from aiohttp import web
4+
5+
6+
class AppSettings():
7+
def __init__(self, user_manager):
8+
self.user_manager = user_manager
9+
10+
def get_settings(self, request):
11+
file = self.user_manager.get_request_user_filepath(
12+
request, "comfy.settings.json")
13+
if os.path.isfile(file):
14+
with open(file) as f:
15+
return json.load(f)
16+
else:
17+
return {}
18+
19+
def save_settings(self, request, settings):
20+
file = self.user_manager.get_request_user_filepath(
21+
request, "comfy.settings.json")
22+
with open(file, "w") as f:
23+
f.write(json.dumps(settings, indent=4))
24+
25+
def add_routes(self, routes):
26+
@routes.get("/settings")
27+
async def get_settings(request):
28+
return web.json_response(self.get_settings(request))
29+
30+
@routes.get("/settings/{id}")
31+
async def get_setting(request):
32+
value = None
33+
settings = self.get_settings(request)
34+
setting_id = request.match_info.get("id", None)
35+
if setting_id and setting_id in settings:
36+
value = settings[setting_id]
37+
return web.json_response(value)
38+
39+
@routes.post("/settings")
40+
async def post_settings(request):
41+
settings = self.get_settings(request)
42+
new_settings = await request.json()
43+
self.save_settings(request, {**settings, **new_settings})
44+
return web.Response(status=200)
45+
46+
@routes.post("/settings/{id}")
47+
async def post_setting(request):
48+
setting_id = request.match_info.get("id", None)
49+
if not setting_id:
50+
return web.Response(status=400)
51+
settings = self.get_settings(request)
52+
settings[setting_id] = await request.json()
53+
self.save_settings(request, settings)
54+
return web.Response(status=200)

app/user_manager.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import json
2+
import os
3+
import re
4+
import uuid
5+
from aiohttp import web
6+
from comfy.cli_args import args
7+
from folder_paths import user_directory
8+
from .app_settings import AppSettings
9+
10+
default_user = "default"
11+
users_file = os.path.join(user_directory, "users.json")
12+
13+
14+
class UserManager():
15+
def __init__(self):
16+
global user_directory
17+
18+
self.settings = AppSettings(self)
19+
if not os.path.exists(user_directory):
20+
os.mkdir(user_directory)
21+
if not args.multi_user:
22+
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
23+
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
24+
25+
if args.multi_user:
26+
if os.path.isfile(users_file):
27+
with open(users_file) as f:
28+
self.users = json.load(f)
29+
else:
30+
self.users = {}
31+
else:
32+
self.users = {"default": "default"}
33+
34+
def get_request_user_id(self, request):
35+
user = "default"
36+
if args.multi_user and "comfy-user" in request.headers:
37+
user = request.headers["comfy-user"]
38+
39+
if user not in self.users:
40+
raise KeyError("Unknown user: " + user)
41+
42+
return user
43+
44+
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
45+
global user_directory
46+
47+
if type == "userdata":
48+
root_dir = user_directory
49+
else:
50+
raise KeyError("Unknown filepath type:" + type)
51+
52+
user = self.get_request_user_id(request)
53+
path = user_root = os.path.abspath(os.path.join(root_dir, user))
54+
55+
# prevent leaving /{type}
56+
if os.path.commonpath((root_dir, user_root)) != root_dir:
57+
return None
58+
59+
parent = user_root
60+
61+
if file is not None:
62+
# prevent leaving /{type}/{user}
63+
path = os.path.abspath(os.path.join(user_root, file))
64+
if os.path.commonpath((user_root, path)) != user_root:
65+
return None
66+
67+
if create_dir and not os.path.exists(parent):
68+
os.mkdir(parent)
69+
70+
return path
71+
72+
def add_user(self, name):
73+
name = name.strip()
74+
if not name:
75+
raise ValueError("username not provided")
76+
user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name)
77+
user_id = user_id + "_" + str(uuid.uuid4())
78+
79+
self.users[user_id] = name
80+
81+
global users_file
82+
with open(users_file, "w") as f:
83+
json.dump(self.users, f)
84+
85+
return user_id
86+
87+
def add_routes(self, routes):
88+
self.settings.add_routes(routes)
89+
90+
@routes.get("/users")
91+
async def get_users(request):
92+
if args.multi_user:
93+
return web.json_response({"storage": "server", "users": self.users})
94+
else:
95+
user_dir = self.get_request_user_filepath(request, None, create_dir=False)
96+
return web.json_response({
97+
"storage": "server" if args.server_storage else "browser",
98+
"migrated": os.path.exists(user_dir)
99+
})
100+
101+
@routes.post("/users")
102+
async def post_users(request):
103+
body = await request.json()
104+
username = body["username"]
105+
if username in self.users.values():
106+
return web.json_response({"error": "Duplicate username."}, status=400)
107+
108+
user_id = self.add_user(username)
109+
return web.json_response(user_id)
110+
111+
@routes.get("/userdata/{file}")
112+
async def getuserdata(request):
113+
file = request.match_info.get("file", None)
114+
if not file:
115+
return web.Response(status=400)
116+
117+
path = self.get_request_user_filepath(request, file)
118+
if not path:
119+
return web.Response(status=403)
120+
121+
if not os.path.exists(path):
122+
return web.Response(status=404)
123+
124+
return web.FileResponse(path)
125+
126+
@routes.post("/userdata/{file}")
127+
async def post_userdata(request):
128+
file = request.match_info.get("file", None)
129+
if not file:
130+
return web.Response(status=400)
131+
132+
path = self.get_request_user_filepath(request, file)
133+
if not path:
134+
return web.Response(status=403)
135+
136+
body = await request.read()
137+
with open(path, "wb") as f:
138+
f.write(body)
139+
140+
return web.Response(status=200)

comfy/cli_args.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ def __call__(self, parser, namespace, values, option_string=None):
6666
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
6767
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
6868

69+
parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.")
70+
6971
fpte_group = parser.add_mutually_exclusive_group()
7072
fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
7173
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
@@ -110,6 +112,9 @@ class LatentPreviewMethod(enum.Enum):
110112

111113
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
112114

115+
parser.add_argument("--server-storage", action="store_true", help="Saves settings and other user configuration on the server instead of in browser storage.")
116+
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage. If enabled, server-storage will be unconditionally enabled.")
117+
113118
if comfy.options.args_parsing:
114119
args = parser.parse_args()
115120
else:
@@ -120,3 +125,6 @@ class LatentPreviewMethod(enum.Enum):
120125

121126
if args.disable_auto_launch:
122127
args.auto_launch = False
128+
129+
if args.multi_user:
130+
args.server_storage = True

comfy/clip_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate
5757
self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
5858

5959
def forward(self, x, mask=None, intermediate_output=None):
60-
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None)
60+
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
6161

6262
if intermediate_output is not None:
6363
if intermediate_output < 0:
@@ -151,7 +151,7 @@ def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dty
151151

152152
def forward(self, pixel_values):
153153
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
154-
return torch.cat([self.class_embedding.expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight
154+
return torch.cat([self.class_embedding.to(embeds.device).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight.to(embeds.device)
155155

156156

157157
class CLIPVision(torch.nn.Module):

comfy/latent_formats.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,7 @@ def __init__(self):
3333
[-0.3112, -0.2359, -0.2076]
3434
]
3535
self.taesd_decoder_name = "taesdxl_decoder"
36+
37+
class SD_X4(LatentFormat):
38+
def __init__(self):
39+
self.scale_factor = 0.08333

comfy/ldm/modules/attention.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
177177
kv_chunk_size_min=kv_chunk_size_min,
178178
use_checkpoint=False,
179179
upcast_attention=upcast_attention,
180+
mask=mask,
180181
)
181182

182183
hidden_states = hidden_states.to(dtype)
@@ -239,6 +240,12 @@ def attention_split(q, k, v, heads, mask=None):
239240
else:
240241
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
241242

243+
if mask is not None:
244+
if len(mask.shape) == 2:
245+
s1 += mask[i:end]
246+
else:
247+
s1 += mask[:, i:end]
248+
242249
s2 = s1.softmax(dim=-1).to(v.dtype)
243250
del s1
244251
first_op_done = True
@@ -294,11 +301,14 @@ def attention_xformers(q, k, v, heads, mask=None):
294301
(q, k, v),
295302
)
296303

297-
# actually compute the attention, what we cannot get enough of
298-
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
304+
if mask is not None:
305+
pad = 8 - q.shape[1] % 8
306+
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
307+
mask_out[:, :, :mask.shape[-1]] = mask
308+
mask = mask_out[:, :, :mask.shape[-1]]
309+
310+
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
299311

300-
if exists(mask):
301-
raise NotImplementedError
302312
out = (
303313
out.unsqueeze(0)
304314
.reshape(b, heads, -1, dim_head)
@@ -323,7 +333,6 @@ def attention_pytorch(q, k, v, heads, mask=None):
323333

324334

325335
optimized_attention = attention_basic
326-
optimized_attention_masked = attention_basic
327336

328337
if model_management.xformers_enabled():
329338
print("Using xformers cross attention")
@@ -339,15 +348,18 @@ def attention_pytorch(q, k, v, heads, mask=None):
339348
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
340349
optimized_attention = attention_sub_quad
341350

342-
if model_management.pytorch_attention_enabled():
343-
optimized_attention_masked = attention_pytorch
351+
optimized_attention_masked = optimized_attention
344352

345-
def optimized_attention_for_device(device, mask=False):
346-
if device == torch.device("cpu"): #TODO
353+
def optimized_attention_for_device(device, mask=False, small_input=False):
354+
if small_input:
347355
if model_management.pytorch_attention_enabled():
348-
return attention_pytorch
356+
return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases
349357
else:
350358
return attention_basic
359+
360+
if device == torch.device("cpu"):
361+
return attention_sub_quad
362+
351363
if mask:
352364
return optimized_attention_masked
353365

comfy/ldm/modules/diffusionmodules/openaimodel.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -437,9 +437,6 @@ def __init__(
437437
operations=ops,
438438
):
439439
super().__init__()
440-
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
441-
if use_spatial_transformer:
442-
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
443440

444441
if context_dim is not None:
445442
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
@@ -456,7 +453,6 @@ def __init__(
456453
if num_head_channels == -1:
457454
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
458455

459-
self.image_size = image_size
460456
self.in_channels = in_channels
461457
self.model_channels = model_channels
462458
self.out_channels = out_channels
@@ -502,7 +498,7 @@ def __init__(
502498

503499
if self.num_classes is not None:
504500
if isinstance(self.num_classes, int):
505-
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
501+
self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device)
506502
elif self.num_classes == "continuous":
507503
print("setting up linear c_adm embedding layer")
508504
self.label_emb = nn.Linear(1, time_embed_dim)

comfy/ldm/modules/diffusionmodules/upscaling.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,12 @@ def register_schedule(self, beta_schedule="linear", timesteps=1000,
4141
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
4242
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
4343

44-
def q_sample(self, x_start, t, noise=None):
45-
noise = default(noise, lambda: torch.randn_like(x_start))
44+
def q_sample(self, x_start, t, noise=None, seed=None):
45+
if noise is None:
46+
if seed is None:
47+
noise = torch.randn_like(x_start)
48+
else:
49+
noise = torch.randn(x_start.size(), dtype=x_start.dtype, layout=x_start.layout, generator=torch.manual_seed(seed)).to(x_start.device)
4650
return (extract_into_tensor(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
4751
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise)
4852

@@ -69,12 +73,12 @@ def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
6973
super().__init__(noise_schedule_config=noise_schedule_config)
7074
self.max_noise_level = max_noise_level
7175

72-
def forward(self, x, noise_level=None):
76+
def forward(self, x, noise_level=None, seed=None):
7377
if noise_level is None:
7478
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
7579
else:
7680
assert isinstance(noise_level, torch.Tensor)
77-
z = self.q_sample(x, noise_level)
81+
z = self.q_sample(x, noise_level, seed=seed)
7882
return z, noise_level
7983

8084

0 commit comments

Comments
 (0)