Skip to content

Commit 649d79a

Browse files
authored
Merge branch 'master' into hn-activation
2 parents 877d94f + 757264c commit 649d79a

File tree

11 files changed

+876
-32
lines changed

11 files changed

+876
-32
lines changed

localizations/fr-FR.json

Lines changed: 415 additions & 0 deletions
Large diffs are not rendered by default.

modules/api/api.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from fastapi import Body, APIRouter, HTTPException
88
from fastapi.responses import JSONResponse
99
from pydantic import BaseModel, Field, Json
10+
from typing import List
1011
import json
1112
import io
1213
import base64
@@ -15,12 +16,12 @@
1516
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
1617

1718
class TextToImageResponse(BaseModel):
18-
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
19+
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
1920
parameters: Json
2021
info: Json
2122

2223
class ImageToImageResponse(BaseModel):
23-
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
24+
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
2425
parameters: Json
2526
info: Json
2627

@@ -65,7 +66,7 @@ def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
6566
i.save(buffer, format="png")
6667
b64images.append(base64.b64encode(buffer.getvalue()))
6768

68-
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info))
69+
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.js())
6970

7071

7172

@@ -111,7 +112,11 @@ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
111112
i.save(buffer, format="png")
112113
b64images.append(base64.b64encode(buffer.getvalue()))
113114

114-
return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=json.dumps(processed.info))
115+
if (not img2imgreq.include_init_images):
116+
img2imgreq.init_images = None
117+
img2imgreq.mask = None
118+
119+
return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js())
115120

116121
def extrasapi(self):
117122
raise NotImplementedError

modules/api/models.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class ModelDef(BaseModel):
3131
field_alias: str
3232
field_type: Any
3333
field_value: Any
34+
field_exclude: bool = False
3435

3536

3637
class PydanticModelGenerator:
@@ -78,15 +79,16 @@ def merge_class_params(class_):
7879
field=underscore(fields["key"]),
7980
field_alias=fields["key"],
8081
field_type=fields["type"],
81-
field_value=fields["default"]))
82+
field_value=fields["default"],
83+
field_exclude=fields["exclude"] if "exclude" in fields else False))
8284

8385
def generate_model(self):
8486
"""
8587
Creates a pydantic BaseModel
8688
from the json and overrides provided at initialization
8789
"""
8890
fields = {
89-
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def
91+
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
9092
}
9193
DynamicModel = create_model(self._model_name, **fields)
9294
DynamicModel.__config__.allow_population_by_field_name = True
@@ -102,5 +104,5 @@ def generate_model(self):
102104
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
103105
"StableDiffusionProcessingImg2Img",
104106
StableDiffusionProcessingImg2Img,
105-
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}]
107+
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
106108
).generate_model()

modules/hypernetworks/hypernetwork.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import sys
77
import traceback
8+
import inspect
89

910
import modules.textual_inversion.dataset
1011
import torch
@@ -15,20 +16,25 @@
1516
from modules.textual_inversion import textual_inversion
1617
from modules.textual_inversion.learn_schedule import LearnRateScheduler
1718
from torch import einsum
19+
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
1820

1921
from collections import defaultdict, deque
2022
from statistics import stdev, mean
2123

24+
2225
class HypernetworkModule(torch.nn.Module):
2326
multiplier = 1.0
2427
activation_dict = {
2528
"relu": torch.nn.ReLU,
2629
"leakyrelu": torch.nn.LeakyReLU,
2730
"elu": torch.nn.ELU,
2831
"swish": torch.nn.Hardswish,
32+
"tanh": torch.nn.Tanh,
33+
"sigmoid": torch.nn.Sigmoid,
2934
}
35+
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
3036

31-
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False, activate_output=False):
37+
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False, activate_output=False):
3238
super().__init__()
3339

3440
assert layer_structure is not None, "layer_structure must not be None"
@@ -65,9 +71,24 @@ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=N
6571
else:
6672
for layer in self.linear:
6773
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
68-
layer.weight.data.normal_(mean=0.0, std=0.01)
69-
layer.bias.data.zero_()
70-
74+
w, b = layer.weight.data, layer.bias.data
75+
if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
76+
normal_(w, mean=0.0, std=0.01)
77+
normal_(b, mean=0.0, std=0.005)
78+
elif weight_init == 'XavierUniform':
79+
xavier_uniform_(w)
80+
zeros_(b)
81+
elif weight_init == 'XavierNormal':
82+
xavier_normal_(w)
83+
zeros_(b)
84+
elif weight_init == 'KaimingUniform':
85+
kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
86+
zeros_(b)
87+
elif weight_init == 'KaimingNormal':
88+
kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
89+
zeros_(b)
90+
else:
91+
raise KeyError(f"Key {weight_init} is not defined as initialization!")
7192
self.to(devices.device)
7293

7394
def fix_old_state_dict(self, state_dict):
@@ -105,7 +126,7 @@ class Hypernetwork:
105126
filename = None
106127
name = None
107128

108-
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False, activate_output=False):
129+
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False)
109130
self.filename = None
110131
self.name = name
111132
self.layers = {}
@@ -114,14 +135,15 @@ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activatio
114135
self.sd_checkpoint_name = None
115136
self.layer_structure = layer_structure
116137
self.activation_func = activation_func
138+
self.weight_init = weight_init
117139
self.add_layer_norm = add_layer_norm
118140
self.use_dropout = use_dropout
119141
self.activate_output = activate_output
120142

121143
for size in enable_sizes or []:
122144
self.layers[size] = (
123-
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output),
124-
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output),
145+
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
146+
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
125147
)
126148

127149
def weights(self):
@@ -145,6 +167,7 @@ def save(self, filename):
145167
state_dict['layer_structure'] = self.layer_structure
146168
state_dict['activation_func'] = self.activation_func
147169
state_dict['is_layer_norm'] = self.add_layer_norm
170+
state_dict['weight_initialization'] = self.weight_init
148171
state_dict['use_dropout'] = self.use_dropout
149172
state_dict['sd_checkpoint'] = self.sd_checkpoint
150173
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
@@ -160,16 +183,22 @@ def load(self, filename):
160183
state_dict = torch.load(filename, map_location='cpu')
161184

162185
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
186+
print(self.layer_structure)
163187
self.activation_func = state_dict.get('activation_func', None)
188+
print(f"Activation function is {self.activation_func}")
189+
self.weight_init = state_dict.get('weight_initialization', 'Normal')
190+
print(f"Weight initialization is {self.weight_init}")
164191
self.add_layer_norm = state_dict.get('is_layer_norm', False)
165-
self.use_dropout = state_dict.get('use_dropout', False)
192+
print(f"Layer norm is set to {self.add_layer_norm}")
193+
self.use_dropout = state_dict.get('use_dropout', False
194+
print(f"Dropout usage is set to {self.use_dropout}" )
166195
self.activate_output = state_dict.get('activate_output', True)
167196

168197
for size, sd in state_dict.items():
169198
if type(size) == int:
170199
self.layers[size] = (
171-
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output),
172-
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output),
200+
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
201+
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
173202
)
174203

175204
self.name = state_dict.get('name', self.name)

modules/hypernetworks/ui.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
from modules import devices, sd_hijack, shared
99
from modules.hypernetworks import hypernetwork
1010

11+
keys = list(hypernetwork.HypernetworkModule.activation_dict.keys())
1112

12-
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
13+
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
1314
# Remove illegal characters from name.
1415
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
1516

@@ -25,6 +26,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
2526
enable_sizes=[int(x) for x in enable_sizes],
2627
layer_structure=layer_structure,
2728
activation_func=activation_func,
29+
weight_init=weight_init,
2830
add_layer_norm=add_layer_norm,
2931
use_dropout=use_dropout,
3032
)

modules/images.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def resize(im, w, h):
277277
invalid_filename_prefix = ' '
278278
invalid_filename_postfix = ' .'
279279
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
280-
re_pattern = re.compile(r"([^\[\]]+|\[([^]]+)]|[\[\]]*)")
280+
re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
281281
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
282282
max_filename_part_length = 128
283283

@@ -343,7 +343,7 @@ def prompt_words(self):
343343
def datetime(self, *args):
344344
time_datetime = datetime.datetime.now()
345345

346-
time_format = args[0] if len(args) > 0 else self.default_time_format
346+
time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
347347
try:
348348
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
349349
except pytz.exceptions.UnknownTimeZoneError as _:
@@ -362,9 +362,9 @@ def apply(self, x):
362362

363363
for m in re_pattern.finditer(x):
364364
text, pattern = m.groups()
365+
res += text
365366

366367
if pattern is None:
367-
res += text
368368
continue
369369

370370
pattern_args = []
@@ -385,12 +385,9 @@ def apply(self, x):
385385
print(f"Error adding [{pattern}] to filename", file=sys.stderr)
386386
print(traceback.format_exc(), file=sys.stderr)
387387

388-
if replacement is None:
389-
res += f'[{pattern}]'
390-
else:
388+
if replacement is not None:
391389
res += str(replacement)
392-
393-
continue
390+
continue
394391

395392
res += f'[{pattern}]'
396393

0 commit comments

Comments
 (0)