@@ -575,9 +575,8 @@ def __init__(self, output_lay: int, dim: int, indim: int, outdim: int, kernel_s
575575 conv_drop : float = 0.1 ,
576576 ffn_latent_drop : float = 0.1 ,
577577 ffn_out_drop : float = 0.1 , attention_drop : float = 0.1 , attention_heads : int = 4 ,
578- attention_heads_dim : int = 64 ,sig : bool = True , unet_type = 'cf_unet_full' ,unet_down = [2 , 2 , 2 ], unet_dim = [512 , 768 , 1024 ], unet_latentdim = 1024 ,):
578+ attention_heads_dim : int = 64 , unet_type = 'cf_unet_full' ,unet_down = [2 , 2 , 2 ], unet_dim = [512 , 768 , 1024 ], unet_latentdim = 1024 ,):
579579 super ().__init__ ()
580- self .sig = sig
581580
582581 self .unet = unet_adp (unet_type = unet_type , unet_down = unet_down , unet_dim = unet_dim , unet_latentdim = unet_latentdim ,
583582 unet_indim = indim , unet_outdim = dim ,
@@ -620,12 +619,9 @@ def forward(self, x, pitch=None, mask=None):
620619 midiout = self .outln (xo )
621620 cutprp = torch .sigmoid (cutprp )
622621 cutprp = torch .squeeze (cutprp , - 1 )
623- # if self.sig:
624- # midiout = torch.sigmoid(midiout)
625622 return midiout , cutprp
626623
627624
628-
629625if __name__ == '__main__' :
630626 fff = unet_base_cf ( dim = 512 ,indim = 128 ,outdim = 256 ,output_lay = 1 ,unet_down = [2 , 2 , 4 ], unet_dim = [128 , 128 , 128 ], unet_latentdim = 128 )
631627 aaa = fff (torch .randn (2 , 255 , 128 ))
0 commit comments