Skip to content

Commit e1df0e3

Browse files
authored
Update yolo.py
1 parent ac2809e commit e1df0e3

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

nets/yolo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def yolo_head(filters_list, in_filters):
8484
# yolo_body
8585
#---------------------------------------------------#
8686
class YoloBody(nn.Module):
87-
def __init__(self, anchors_mask, num_classes):
87+
def __init__(self, anchors_mask, num_classes, pretrained = False):
8888
super(YoloBody, self).__init__()
8989
#---------------------------------------------------#
9090
# 生成CSPdarknet53的主干模型
@@ -93,7 +93,7 @@ def __init__(self, anchors_mask, num_classes):
9393
# 26,26,512
9494
# 13,13,1024
9595
#---------------------------------------------------#
96-
self.backbone = darknet53(None)
96+
self.backbone = darknet53(pretrained)
9797

9898
self.conv1 = make_three_conv([512,1024],1024)
9999
self.SPP = SpatialPyramidPooling()

0 commit comments

Comments
 (0)