Skip to content

Commit 1377ba2

Browse files
committed
Clean configs and arguments
1 parent 88badfa commit 1377ba2

File tree

7 files changed

+3
-16
lines changed

7 files changed

+3
-16
lines changed

configs/base.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ units_dim: 80 # 768
2828
midi_num_bins: 128
2929
model_cls: null
3030
midi_extractor_args: {}
31-
use_BCEWithLogitsLoss: false
3231

3332
# training
3433
use_midi_loss: true

configs/discrete.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ midi_extractor_args:
3838
attention_drop: 0.1
3939
attention_heads: 8
4040
attention_heads_dim: 64
41-
sig: false
4241

4342
# training
4443
task_cls: training.QuantizedMIDIExtractionTask

configs/midi_conformer.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ midi_shift_range: [-12, 12]
1919
rest_threshold: 0.1
2020

2121

22-
use_BCEWithLogitsLoss: true
2322
midi_extractor_args:
2423
lay: 8
2524
dim: 512
@@ -32,6 +31,5 @@ midi_extractor_args:
3231
attention_drop: 0.1
3332
attention_heads: 8
3433
attention_heads_dim: 64
35-
sig: false
3634

3735
pl_trainer_precision: 'bf16'
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ midi_extractor_args:
1818
attention_drop: 0.1
1919
attention_heads: 8
2020
attention_heads_dim: 64
21-
sig: false
2221

2322
# training
2423
task_cls: training.QuantizedMIDIExtractionTask

configs/unet.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ midi_shift_proportion: 0.0
1919
midi_shift_range: [-12, 12]
2020
rest_threshold: 0.1
2121

22-
use_BCEWithLogitsLoss: true
2322
midi_extractor_args:
2423
output_lay: 3
2524
dim: 512
@@ -34,6 +33,5 @@ midi_extractor_args:
3433
unet_down: [2, 2, 2,2,2]
3534
unet_dim: [512, 512, 768,768,1024]
3635
unet_latentdim: 1024
37-
sig: false
3836

3937
pl_trainer_precision: 'bf16'

modules/conform/Gconform.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,8 @@ def __init__(self, lay: int, dim: int, indim: int, outdim: int, use_lay_skip: bo
9494
conv_drop: float = 0.1,
9595
ffn_latent_drop: float = 0.1,
9696
ffn_out_drop: float = 0.1, attention_drop: float = 0.1, attention_heads: int = 4,
97-
attention_heads_dim: int = 64,sig:bool=True):
97+
attention_heads_dim: int = 64):
9898
super().__init__()
99-
self.sig=sig
10099

101100
self.inln = nn.Linear(indim, dim)
102101
self.inln1 = nn.Linear(indim, dim)
@@ -137,6 +136,5 @@ def forward(self, x, pitch, mask=None):
137136
midiout = self.outln(x)
138137
cutprp = torch.sigmoid(cutprp)
139138
cutprp = torch.squeeze(cutprp, -1)
140-
# if self.sig:
141-
# midiout = torch.sigmoid(midiout)
139+
142140
return midiout, cutprp

modules/conform/unet_with_conform.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -575,9 +575,8 @@ def __init__(self, output_lay: int, dim: int, indim: int, outdim: int, kernel_s
575575
conv_drop: float = 0.1,
576576
ffn_latent_drop: float = 0.1,
577577
ffn_out_drop: float = 0.1, attention_drop: float = 0.1, attention_heads: int = 4,
578-
attention_heads_dim: int = 64,sig:bool=True,unet_type='cf_unet_full',unet_down=[2, 2, 2], unet_dim=[512, 768, 1024], unet_latentdim=1024,):
578+
attention_heads_dim: int = 64, unet_type='cf_unet_full',unet_down=[2, 2, 2], unet_dim=[512, 768, 1024], unet_latentdim=1024,):
579579
super().__init__()
580-
self.sig=sig
581580

582581
self.unet=unet_adp(unet_type=unet_type, unet_down=unet_down, unet_dim=unet_dim, unet_latentdim=unet_latentdim,
583582
unet_indim=indim, unet_outdim=dim,
@@ -620,12 +619,9 @@ def forward(self, x, pitch=None, mask=None):
620619
midiout = self.outln(xo)
621620
cutprp = torch.sigmoid(cutprp)
622621
cutprp = torch.squeeze(cutprp, -1)
623-
# if self.sig:
624-
# midiout = torch.sigmoid(midiout)
625622
return midiout, cutprp
626623

627624

628-
629625
if __name__ == '__main__':
630626
fff = unet_base_cf( dim=512,indim=128,outdim=256,output_lay=1,unet_down=[2, 2, 4], unet_dim=[128, 128, 128], unet_latentdim=128)
631627
aaa = fff(torch.randn(2, 255, 128))

0 commit comments

Comments
 (0)