@@ -110,7 +110,17 @@ def boolean_string(s):
110110 choices = ['cosine' , 'karras' , 'edm' ], help = 'Noise schedule' )
111111
112112parser .add_argument ('--architecture' , type = str ,
113- choices = ["unet" , "uvit" , "diffusers_unet_simple" , "simple_dit" , "simple_mmdit" , "hierarchical_mmdit" ],
113+ choices = [
114+ "unet" ,
115+ "uvit" ,
116+ "diffusers_unet_simple" ,
117+ "simple_dit" ,
118+ "simple_mmdit" ,
119+ "hierarchical_mmdit" ,
120+ "simple_dit-hilbert" ,
121+ "simple_mmdit-hilbert" ,
122+ "hierarchical_mmdit-hilbert" ,
123+ ],
114124 default = "unet" , help = 'Architecture to use' )
115125parser .add_argument ('--emb_features' , type = int , default = 256 , help = 'Embedding features' )
116126parser .add_argument ('--feature_depths' , type = int , nargs = '+' , default = [64 , 128 , 256 , 512 ], help = 'Feature depths' )
@@ -304,7 +314,14 @@ def main(args):
304314 INPUT_CHANNELS = 4
305315 DIFFUSION_INPUT_SIZE = DIFFUSION_INPUT_SIZE // 8
306316
307- if 'diffusers' in args .architecture :
317+ use_hilbert = args .use_hilbert
318+ architecture_name = args .architecture
319+ if 'hilbert' in architecture_name :
320+ architecture_name = architecture_name .split ('-' )[0 ]
321+ print ("Will use Hilbert Patch Reordering" )
322+ use_hilbert = True
323+
324+ if 'diffusers' in architecture_name :
308325 model_config = {}
309326 else :
310327 model_config = {
@@ -338,7 +355,7 @@ def main(args):
338355 "add_residualblock_output" : args .add_residualblock_output ,
339356 "use_flash_attention" : args .flash_attention ,
340357 "use_self_and_cross" : args .use_self_and_cross ,
341- "use_hilbert" : args . use_hilbert ,
358+ "use_hilbert" : use_hilbert ,
342359 },
343360 },
344361 "simple_dit" : {
@@ -350,7 +367,7 @@ def main(args):
350367 "dropout_rate" : 0.1 ,
351368 "use_flash_attention" : args .flash_attention ,
352369 "mlp_ratio" : args .mlp_ratio ,
353- "use_hilbert" : args . use_hilbert ,
370+ "use_hilbert" : use_hilbert ,
354371 },
355372 },
356373 "simple_mmdit" : {
@@ -362,7 +379,7 @@ def main(args):
362379 "dropout_rate" : 0.1 ,
363380 "use_flash_attention" : args .flash_attention ,
364381 "mlp_ratio" : args .mlp_ratio ,
365- "use_hilbert" : args . use_hilbert ,
382+ "use_hilbert" : use_hilbert ,
366383 },
367384 },
368385 "hierarchical_mmdit" : {
@@ -375,7 +392,7 @@ def main(args):
375392 "dropout_rate" : 0.1 ,
376393 "use_flash_attention" : args .flash_attention ,
377394 "mlp_ratio" : args .mlp_ratio ,
378- "use_hilbert" : args . use_hilbert ,
395+ "use_hilbert" : use_hilbert ,
379396 },
380397 },
381398 "diffusers_unet_simple" : {
@@ -395,10 +412,10 @@ def main(args):
395412 }
396413 }
397414
398- model_architecture = MODEL_ARCHITECUTRES [args . architecture ]['class' ]
399- model_config .update (MODEL_ARCHITECUTRES [args . architecture ]['kwargs' ])
415+ model_architecture = MODEL_ARCHITECUTRES [architecture_name ]['class' ]
416+ model_config .update (MODEL_ARCHITECUTRES [architecture_name ]['kwargs' ])
400417
401- if args . architecture == 'uvit' :
418+ if architecture_name == 'uvit' :
402419 model_config ['emb_features' ] = 768
403420
404421 sorted_args_json = json .dumps (vars (args ), sort_keys = True )
@@ -430,7 +447,7 @@ def main(args):
430447
431448 CONFIG = {
432449 "model" : model_config ,
433- "architecture" : args . architecture ,
450+ "architecture" : architecture_name ,
434451 "dataset" : {
435452 "name" : dataset_name ,
436453 "length" : datalen ,
@@ -486,7 +503,7 @@ def main(args):
486503 model = model_architecture (** model_config )
487504
488505 # If using the Diffusers UNet, we need to wrap it
489- if 'diffusers' in args . architecture :
506+ if 'diffusers' in architecture_name :
490507 from flaxdiff .models .general import BCHWModelWrapper
491508 model = BCHWModelWrapper (model )
492509
0 commit comments