Skip to content

Commit fd0b101

Browse files
feat: set default parameters to unet
1 parent 5fb0939 commit fd0b101

File tree

3 files changed

+8
-28
lines changed

3 files changed

+8
-28
lines changed

audio_diffusion_pytorch/model.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,6 @@ def forward(self, x: Tensor, **kwargs) -> Tensor:
350350
def get_default_model_kwargs():
351351
return dict(
352352
channels=128,
353-
patch_blocks=1,
354353
patch_factor=16,
355354
multipliers=[1, 2, 4, 4, 4, 4, 4],
356355
factors=[4, 4, 4, 2, 2, 2],
@@ -360,11 +359,6 @@ def get_default_model_kwargs():
360359
attention_features=64,
361360
attention_multiplier=2,
362361
attention_use_rel_pos=False,
363-
resnet_groups=8,
364-
kernel_multiplier_downsample=2,
365-
use_nearest_upsample=False,
366-
use_skip_scale=True,
367-
use_context_time=True,
368362
diffusion_type="v",
369363
diffusion_sigma_distribution=UniformDistribution(),
370364
)
@@ -416,13 +410,6 @@ def decode(self, *args, **kwargs):
416410
class AudioDiffusionMAE(DiffusionMAE1d):
417411
def __init__(self, *args, **kwargs):
418412
default_kwargs = dict(
419-
patch_blocks=1,
420-
patch_factor=1,
421-
resnet_groups=8,
422-
kernel_multiplier_downsample=2,
423-
use_nearest_upsample=False,
424-
use_skip_scale=True,
425-
use_context_time=True,
426413
diffusion_type="v",
427414
diffusion_sigma_distribution=UniformDistribution(),
428415
stft_num_fft=1023,
@@ -470,8 +457,6 @@ def __init__(self, in_channels: int, **kwargs):
470457
stft_num_fft=1023,
471458
stft_hop_length=256,
472459
channels=512,
473-
patch_blocks=1,
474-
patch_factor=1,
475460
multipliers=[3, 2, 1, 1, 1, 1, 1, 1],
476461
factors=[1, 2, 2, 2, 2, 2, 2],
477462
num_blocks=[1, 1, 1, 1, 1, 1, 1],
@@ -480,11 +465,6 @@ def __init__(self, in_channels: int, **kwargs):
480465
attention_features=64,
481466
attention_multiplier=2,
482467
attention_use_rel_pos=False,
483-
resnet_groups=8,
484-
kernel_multiplier_downsample=2,
485-
use_nearest_upsample=False,
486-
use_skip_scale=True,
487-
use_context_time=True,
488468
diffusion_type="v",
489469
diffusion_sigma_distribution=UniformDistribution(),
490470
)

audio_diffusion_pytorch/modules.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -923,17 +923,17 @@ def __init__(
923923
self,
924924
in_channels: int,
925925
channels: int,
926-
patch_blocks: int,
927-
patch_factor: int,
928926
multipliers: Sequence[int],
929927
factors: Sequence[int],
930928
num_blocks: Sequence[int],
931929
attentions: Sequence[int],
932-
resnet_groups: int,
933-
kernel_multiplier_downsample: int,
934-
use_nearest_upsample: bool,
935-
use_skip_scale: bool,
936-
use_context_time: bool,
930+
patch_blocks: int = 1,
931+
patch_factor: int = 1,
932+
resnet_groups: int = 8,
933+
use_context_time: bool = True,
934+
kernel_multiplier_downsample: int = 2,
935+
use_nearest_upsample: bool = False,
936+
use_skip_scale: bool = True,
937937
use_stft: bool = False,
938938
use_stft_context: bool = False,
939939
out_channels: Optional[int] = None,

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.88",
6+
version="0.0.89",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)