@@ -53,17 +53,17 @@ def compact_values(d: dict):
5353# extract dimensions using hooks
5454
5555@beartype
56- def extract_output_shapes (
56+ def extract_forward_hook_outputs (
5757 modules : List [Module ],
5858 model : Module ,
5959 model_input ,
6060 model_kwargs : dict = dict ()
6161):
62- shapes = []
6362 hooks = []
63+ all_hook_args = []
6464
65- def hook_fn (_ , input , output ):
66- return shapes .append (output . shape )
65+ def hook_fn (* hook_args ):
66+ return all_hook_args .append (hook_args )
6767
6868 for module in modules :
6969 hook = module .register_forward_hook (hook_fn )
@@ -75,7 +75,7 @@ def hook_fn(_, input, output):
7575 for hook in hooks :
7676 hook .remove ()
7777
78- return shapes
78+ return all_hook_args
7979
8080# freezing text-to-image, and only learning temporal parameters
8181
@@ -482,12 +482,24 @@ def __init__(
482482 mock_time = torch .ones ((1 ,))
483483 unet_kwarg = {unet_time_kwarg : mock_time }
484484
485- # get all dimensions
485+ # extract all hook outputs
486486
487- conv_shapes = extract_output_shapes (conv_modules , self .model , mock_images , unet_kwarg )
488- attn_shapes = extract_output_shapes (attn_modules , self .model , mock_images , unet_kwarg )
489- downsample_shapes = extract_output_shapes (downsample_modules , self .model , mock_images , unet_kwarg )
490- upsample_shapes = extract_output_shapes (upsample_modules , self .model , mock_images , unet_kwarg )
487+ conv_hook_args = extract_forward_hook_outputs (conv_modules , self .model , mock_images , unet_kwarg )
488+ attn_hook_args = extract_forward_hook_outputs (attn_modules , self .model , mock_images , unet_kwarg )
489+ downsample_hook_args = extract_forward_hook_outputs (downsample_modules , self .model , mock_images , unet_kwarg )
490+ upsample_hook_args = extract_forward_hook_outputs (upsample_modules , self .model , mock_images , unet_kwarg )
491+
492+ # reorder all modules by execution order and also extract output shape
493+
494+ conv_modules , _ , conv_outputs = zip (* conv_hook_args )
495+ attn_modules , _ , attn_outputs = zip (* attn_hook_args )
496+ downsample_modules , _ , downsample_outputs = zip (* downsample_hook_args )
497+ upsample_modules , _ , upsample_outputs = zip (* upsample_hook_args )
498+
499+ conv_shapes = [t .shape for t in conv_outputs ]
500+ attn_shapes = [t .shape for t in attn_outputs ]
501+ downsample_shapes = [t .shape for t in downsample_outputs ]
502+ upsample_shapes = [t .shape for t in upsample_outputs ]
491503
492504 # temporal klasses - for setting temporal dimension on forward
493505
0 commit comments