Skip to content

Commit 4af7157

Browse files
feat: option to use sequential patching
1 parent 537764c commit 4af7157

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

audio_diffusion_pytorch/modules.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -842,6 +842,7 @@ def __init__(
842842
context_channels: Optional[Sequence[int]] = None,
843843
context_embedding_features: Optional[int] = None,
844844
use_post_out_block: bool = False,
845+
use_sequential_patching: bool = False,
845846
):
846847
super().__init__()
847848

@@ -873,8 +874,10 @@ def __init__(
873874
and len(num_blocks) == num_layers
874875
)
875876

877+
patching = "b c (p l)" if use_sequential_patching else "b c (l p)"
878+
876879
self.to_in = nn.Sequential(
877-
Rearrange("b c (l p) -> b (c p) l", p=patch_size),
880+
Rearrange(f"{patching} -> b (c p) l", p=patch_size),
878881
CrossEmbed1d(
879882
in_channels=(in_channels + context_channels[0]) * patch_size,
880883
out_channels=channels,
@@ -981,7 +984,7 @@ def __init__(
981984
out_channels=out_channels * patch_size,
982985
kernel_size=1,
983986
),
984-
Rearrange("b (c p) l -> b c (l p)", p=patch_size),
987+
Rearrange(f"b (c p) l -> {patching}", p=patch_size),
985988
)
986989

987990
if self.use_post_out_block:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.38",
6+
version="0.0.39",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)