File tree Expand file tree Collapse file tree 2 files changed +6
-3
lines changed Expand file tree Collapse file tree 2 files changed +6
-3
lines changed Original file line number Diff line number Diff line change @@ -842,6 +842,7 @@ def __init__(
842842 context_channels : Optional [Sequence [int ]] = None ,
843843 context_embedding_features : Optional [int ] = None ,
844844 use_post_out_block : bool = False ,
845+ use_sequential_patching : bool = False ,
845846 ):
846847 super ().__init__ ()
847848
@@ -873,8 +874,10 @@ def __init__(
873874 and len (num_blocks ) == num_layers
874875 )
875876
877+ patching = "b c (p l)" if use_sequential_patching else "b c (l p)"
878+
876879 self .to_in = nn .Sequential (
877- Rearrange ("b c (l p) -> b (c p) l" , p = patch_size ),
880+ Rearrange (f" { patching } -> b (c p) l" , p = patch_size ),
878881 CrossEmbed1d (
879882 in_channels = (in_channels + context_channels [0 ]) * patch_size ,
880883 out_channels = channels ,
@@ -981,7 +984,7 @@ def __init__(
981984 out_channels = out_channels * patch_size ,
982985 kernel_size = 1 ,
983986 ),
984- Rearrange ("b (c p) l -> b c (l p) " , p = patch_size ),
987+ Rearrange (f "b (c p) l -> { patching } " , p = patch_size ),
985988 )
986989
987990 if self .use_post_out_block :
Original file line number Diff line number Diff line change 33setup (
44 name = "audio-diffusion-pytorch" ,
55 packages = find_packages (exclude = []),
6- version = "0.0.38 " ,
6+ version = "0.0.39 " ,
77 license = "MIT" ,
88 description = "Audio Diffusion - PyTorch" ,
99 long_description_content_type = "text/markdown" ,
You can’t perform that action at this time.
0 commit comments