Skip to content

Commit 4de2b60

Browse files
committed
specify hilbert by arch name
1 parent cc4ce91 commit 4de2b60

File tree

1 file changed

+28
-11
lines changed

1 file changed

+28
-11
lines changed

training.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,17 @@ def boolean_string(s):
110110
choices=['cosine', 'karras', 'edm'], help='Noise schedule')
111111

112112
parser.add_argument('--architecture', type=str,
113-
choices=["unet", "uvit", "diffusers_unet_simple", "simple_dit", "simple_mmdit", "hierarchical_mmdit"],
113+
choices=[
114+
"unet",
115+
"uvit",
116+
"diffusers_unet_simple",
117+
"simple_dit",
118+
"simple_mmdit",
119+
"hierarchical_mmdit",
120+
"simple_dit-hilbert",
121+
"simple_mmdit-hilbert",
122+
"hierarchical_mmdit-hilbert",
123+
],
114124
default="unet", help='Architecture to use')
115125
parser.add_argument('--emb_features', type=int, default=256, help='Embedding features')
116126
parser.add_argument('--feature_depths', type=int, nargs='+', default=[64, 128, 256, 512], help='Feature depths')
@@ -304,7 +314,14 @@ def main(args):
304314
INPUT_CHANNELS = 4
305315
DIFFUSION_INPUT_SIZE = DIFFUSION_INPUT_SIZE // 8
306316

307-
if 'diffusers' in args.architecture:
317+
use_hilbert = args.use_hilbert
318+
architecture_name = args.architecture
319+
if 'hilbert' in architecture_name:
320+
architecture_name = architecture_name.split('-')[0]
321+
print("Will use Hilbert Patch Reordering")
322+
use_hilbert = True
323+
324+
if 'diffusers' in architecture_name:
308325
model_config = {}
309326
else:
310327
model_config = {
@@ -338,7 +355,7 @@ def main(args):
338355
"add_residualblock_output": args.add_residualblock_output,
339356
"use_flash_attention": args.flash_attention,
340357
"use_self_and_cross": args.use_self_and_cross,
341-
"use_hilbert": args.use_hilbert,
358+
"use_hilbert": use_hilbert,
342359
},
343360
},
344361
"simple_dit": {
@@ -350,7 +367,7 @@ def main(args):
350367
"dropout_rate": 0.1,
351368
"use_flash_attention": args.flash_attention,
352369
"mlp_ratio": args.mlp_ratio,
353-
"use_hilbert": args.use_hilbert,
370+
"use_hilbert": use_hilbert,
354371
},
355372
},
356373
"simple_mmdit": {
@@ -362,7 +379,7 @@ def main(args):
362379
"dropout_rate": 0.1,
363380
"use_flash_attention": args.flash_attention,
364381
"mlp_ratio": args.mlp_ratio,
365-
"use_hilbert": args.use_hilbert,
382+
"use_hilbert": use_hilbert,
366383
},
367384
},
368385
"hierarchical_mmdit": {
@@ -375,7 +392,7 @@ def main(args):
375392
"dropout_rate": 0.1,
376393
"use_flash_attention": args.flash_attention,
377394
"mlp_ratio": args.mlp_ratio,
378-
"use_hilbert": args.use_hilbert,
395+
"use_hilbert": use_hilbert,
379396
},
380397
},
381398
"diffusers_unet_simple": {
@@ -395,10 +412,10 @@ def main(args):
395412
}
396413
}
397414

398-
model_architecture = MODEL_ARCHITECUTRES[args.architecture]['class']
399-
model_config.update(MODEL_ARCHITECUTRES[args.architecture]['kwargs'])
415+
model_architecture = MODEL_ARCHITECUTRES[architecture_name]['class']
416+
model_config.update(MODEL_ARCHITECUTRES[architecture_name]['kwargs'])
400417

401-
if args.architecture == 'uvit':
418+
if architecture_name == 'uvit':
402419
model_config['emb_features'] = 768
403420

404421
sorted_args_json = json.dumps(vars(args), sort_keys=True)
@@ -430,7 +447,7 @@ def main(args):
430447

431448
CONFIG = {
432449
"model": model_config,
433-
"architecture": args.architecture,
450+
"architecture": architecture_name,
434451
"dataset": {
435452
"name": dataset_name,
436453
"length": datalen,
@@ -486,7 +503,7 @@ def main(args):
486503
model = model_architecture(**model_config)
487504

488505
# If using the Diffusers UNet, we need to wrap it
489-
if 'diffusers' in args.architecture:
506+
if 'diffusers' in architecture_name:
490507
from flaxdiff.models.general import BCHWModelWrapper
491508
model = BCHWModelWrapper(model)
492509

0 commit comments

Comments
 (0)