Skip to content

Commit 125fd4d

Browse files
committed
module names may not be given to Lumiere in the right order of execution. reorder based on the mock forward call on init
1 parent 1b26c9f commit 125fd4d

File tree

2 files changed

+23
-11
lines changed

2 files changed

+23
-11
lines changed

lumiere_pytorch/lumiere.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,17 +53,17 @@ def compact_values(d: dict):
5353
# extract dimensions using hooks
5454

5555
@beartype
56-
def extract_output_shapes(
56+
def extract_forward_hook_outputs(
5757
modules: List[Module],
5858
model: Module,
5959
model_input,
6060
model_kwargs: dict = dict()
6161
):
62-
shapes = []
6362
hooks = []
63+
all_hook_args = []
6464

65-
def hook_fn(_, input, output):
66-
return shapes.append(output.shape)
65+
def hook_fn(*hook_args):
66+
return all_hook_args.append(hook_args)
6767

6868
for module in modules:
6969
hook = module.register_forward_hook(hook_fn)
@@ -75,7 +75,7 @@ def hook_fn(_, input, output):
7575
for hook in hooks:
7676
hook.remove()
7777

78-
return shapes
78+
return all_hook_args
7979

8080
# freezing text-to-image, and only learning temporal parameters
8181

@@ -482,12 +482,24 @@ def __init__(
482482
mock_time = torch.ones((1,))
483483
unet_kwarg = {unet_time_kwarg: mock_time}
484484

485-
# get all dimensions
485+
# extract all hook outputs
486486

487-
conv_shapes = extract_output_shapes(conv_modules, self.model, mock_images, unet_kwarg)
488-
attn_shapes = extract_output_shapes(attn_modules, self.model, mock_images, unet_kwarg)
489-
downsample_shapes = extract_output_shapes(downsample_modules, self.model, mock_images, unet_kwarg)
490-
upsample_shapes = extract_output_shapes(upsample_modules, self.model, mock_images, unet_kwarg)
487+
conv_hook_args = extract_forward_hook_outputs(conv_modules, self.model, mock_images, unet_kwarg)
488+
attn_hook_args = extract_forward_hook_outputs(attn_modules, self.model, mock_images, unet_kwarg)
489+
downsample_hook_args = extract_forward_hook_outputs(downsample_modules, self.model, mock_images, unet_kwarg)
490+
upsample_hook_args = extract_forward_hook_outputs(upsample_modules, self.model, mock_images, unet_kwarg)
491+
492+
# reorder all modules by execution order and also extract output shape
493+
494+
conv_modules, _, conv_outputs = zip(*conv_hook_args)
495+
attn_modules, _, attn_outputs = zip(*attn_hook_args)
496+
downsample_modules, _, downsample_outputs = zip(*downsample_hook_args)
497+
upsample_modules, _, upsample_outputs = zip(*upsample_hook_args)
498+
499+
conv_shapes = [t.shape for t in conv_outputs]
500+
attn_shapes = [t.shape for t in attn_outputs]
501+
downsample_shapes = [t.shape for t in downsample_outputs]
502+
upsample_shapes = [t.shape for t in upsample_outputs]
491503

492504
# temporal klasses - for setting temporal dimension on forward
493505

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'lumiere-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.21',
6+
version = '0.0.22',
77
license='MIT',
88
description = 'Lumiere',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)