Skip to content

Commit 532ff39

Browse files
authored
Update train.py
1 parent 1e7b9b7 commit 532ff39

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

train.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,12 @@
177177
lr = Freeze_lr
178178
start_epoch = Init_Epoch
179179
end_epoch = Freeze_Epoch
180+
181+
epoch_step = num_train // batch_size
182+
epoch_step_val = num_val // batch_size
183+
184+
if epoch_step == 0 or epoch_step_val == 0:
185+
raise ValueError("数据集过小,无法进行训练,请扩充数据集。")
180186

181187
optimizer = optim.Adam(model_train.parameters(), lr, weight_decay = 5e-4)
182188
if Cosine_lr:
@@ -190,12 +196,6 @@
190196
drop_last=True, collate_fn=yolo_dataset_collate)
191197
gen_val = DataLoader(val_dataset , shuffle = True, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
192198
drop_last=True, collate_fn=yolo_dataset_collate)
193-
194-
epoch_step = num_train // batch_size
195-
epoch_step_val = num_val // batch_size
196-
197-
if epoch_step == 0 or epoch_step_val == 0:
198-
raise ValueError("数据集过小,无法进行训练,请扩充数据集。")
199199

200200
#------------------------------------#
201201
# 冻结一定部分训练
@@ -214,6 +214,12 @@
214214
lr = Unfreeze_lr
215215
start_epoch = Freeze_Epoch
216216
end_epoch = UnFreeze_Epoch
217+
218+
epoch_step = num_train // batch_size
219+
epoch_step_val = num_val // batch_size
220+
221+
if epoch_step == 0 or epoch_step_val == 0:
222+
raise ValueError("数据集过小,无法进行训练,请扩充数据集。")
217223

218224
optimizer = optim.Adam(model_train.parameters(), lr, weight_decay = 5e-4)
219225
if Cosine_lr:
@@ -227,12 +233,6 @@
227233
drop_last=True, collate_fn=yolo_dataset_collate)
228234
gen_val = DataLoader(val_dataset , shuffle = True, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
229235
drop_last=True, collate_fn=yolo_dataset_collate)
230-
231-
epoch_step = num_train // batch_size
232-
epoch_step_val = num_val // batch_size
233-
234-
if epoch_step == 0 or epoch_step_val == 0:
235-
raise ValueError("数据集过小,无法进行训练,请扩充数据集。")
236236

237237
#------------------------------------#
238238
# 冻结一定部分训练

0 commit comments

Comments
 (0)