Skip to content

Commit 10eb02a

Browse files
committed
feat: training of heirarchial mmdit
1 parent 8d1578e commit 10eb02a

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

training.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -369,9 +369,9 @@ def main(args):
369369
"class": HierarchicalMMDiT,
370370
"kwargs": {
371371
"base_patch_size": args.patch_size // 2, # Use half the patch size for base
372-
"emb_features": (512, 768, 1024), # Default dims per stage
373-
"num_layers": (4, 6, 12), # Default layers per stage
374-
"num_heads": (8, 12, 16), # Default heads per stage
372+
"emb_features": (args.emb_features - 256, args.emb_features, args.emb_features + 256), # Default dims per stage
373+
"num_layers": (args.num_layers // 3, args.num_layers // 2, args.num_layers), # Default layers per stage
374+
"num_heads": (args.num_heads - 2, args.num_heads, args.num_heads + 2), # Default heads per stage
375375
"dropout_rate": 0.1,
376376
"use_flash_attention": args.flash_attention,
377377
"mlp_ratio": args.mlp_ratio,

0 commit comments

Comments
 (0)