From dd4c4f1611c97b97b103106623416a53732dfa5b Mon Sep 17 00:00:00 2001 From: amm <1490435889@qq.com> Date: Thu, 12 Oct 2023 22:22:47 +0800 Subject: [PATCH] small convenience for setting fewer channels --- denoising_diffusion_pytorch/denoising_diffusion_pytorch.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py b/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py index e0b6aee5c..1a41c5c02 100644 --- a/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +++ b/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py @@ -302,7 +302,9 @@ def __init__( dims = [init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) - + + assert dim > 3 and dim % 2 == 0, 'in this version, the number of channels must be even and greater than 3' + resnet_block_groups = min(resnet_block_groups, dim) block_klass = partial(ResnetBlock, groups = resnet_block_groups) # time embeddings