Skip to content

Commit 397e5f4

Browse files
authored
feat: Add train and predict with yolo (#20)
#17
1 parent 04c2714 commit 397e5f4

File tree

11 files changed

+338
-0
lines changed

11 files changed

+338
-0
lines changed

utils/coco2yolo.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import os
2+
import json
3+
4+
def main():
5+
# 저장할 경로 입력
6+
save_path = '/data/ephemeral/home/dataset/yolo/labels'
7+
if not os.path.isdir('/data/ephemeral/home/dataset/yolo'):
8+
os.mkdir('/data/ephemeral/home/dataset/yolo')
9+
if not os.path.isdir(save_path):
10+
os.mkdir(save_path)
11+
12+
# 읽어올 annotation 경로 입력
13+
json_path = '/data/ephemeral/home/dataset/train.json'
14+
with open(json_path, 'r') as f:
15+
json_data = json.load(f)
16+
17+
print('start')
18+
19+
image_info_dict = dict()
20+
for image_info in json_data['images']:
21+
image_id = image_info['id']
22+
image_info_dict[image_id] = image_info
23+
24+
yolo_info_dict = dict()
25+
for anno_info in json_data['annotations']:
26+
image_id = anno_info['image_id']
27+
image_info = image_info_dict[image_id]
28+
image_ww, image_hh = image_info['width'], image_info['height']
29+
file_name = os.path.splitext(os.path.basename(image_info['file_name']))[0]
30+
31+
cate_id = anno_info['category_id']
32+
xmin, ymin, ww, hh = anno_info['bbox']
33+
cx, cy = xmin + (ww / 2), ymin + (hh / 2)
34+
cx, cy, ww, hh = cx / image_ww, cy / image_hh, ww / image_ww, hh / image_hh
35+
36+
if file_name not in yolo_info_dict:
37+
yolo_info_dict[file_name] = f'{cate_id} {cx} {cy} {ww} {hh}\n'
38+
else:
39+
yolo_info_dict[file_name] += f'{cate_id} {cx} {cy} {ww} {hh}\n'
40+
41+
for k, v in yolo_info_dict.items():
42+
label_path = os.path.join(save_path, k + '.txt')
43+
with open(label_path, 'w') as f:
44+
f.write(v)
45+
46+
print('end')
47+
48+
if __name__ == '__main__':
49+
main()

yolo/custom.yaml

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Ultralytics YOLO 🚀, AGPL-3.0 license
2+
# Default training settings and hyperparameters for medium-augmentation COCO training
3+
4+
task: detect # (str) YOLO task, i.e. detect, segment, classify, pose
5+
mode: train # (str) YOLO mode, i.e. train, val, predict, export, track, benchmark
6+
7+
# Train settings -------------------------------------------------------------------------------------------------------
8+
model: # (str, optional) path to model file, i.e. yolov8n.pt, yolov8n.yaml
9+
data: # (str, optional) path to data file, i.e. coco128.yaml
10+
epochs: 100 # (int) number of epochs to train for
11+
time: # (float, optional) number of hours to train for, overrides epochs if supplied
12+
patience: 50 # (int) epochs to wait for no observable improvement for early stopping of training
13+
batch: 16 # (int) number of images per batch (-1 for AutoBatch)
14+
imgsz: 640 # (int | list) input images size as int for train and val modes, or list[w,h] for predict and export modes
15+
save: True # (bool) save train checkpoints and predict results
16+
save_period: -1 # (int) Save checkpoint every x epochs (disabled if < 1)
17+
cache: False # (bool) True/ram, disk or False. Use cache for data loading
18+
device: # (int | str | list, optional) device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu
19+
workers: 8 # (int) number of worker threads for data loading (per RANK if DDP)
20+
project: # (str, optional) project name
21+
name: # (str, optional) experiment name, results saved to 'project/name' directory
22+
exist_ok: False # (bool) whether to overwrite existing experiment
23+
pretrained: True # (bool | str) whether to use a pretrained model (bool) or a model to load weights from (str)
24+
optimizer: auto # (str) optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto]
25+
verbose: True # (bool) whether to print verbose output
26+
seed: 0 # (int) random seed for reproducibility
27+
deterministic: True # (bool) whether to enable deterministic mode
28+
single_cls: False # (bool) train multi-class data as single-class
29+
rect: False # (bool) rectangular training if mode='train' or rectangular validation if mode='val'
30+
cos_lr: False # (bool) use cosine learning rate scheduler
31+
close_mosaic: 10 # (int) disable mosaic augmentation for final epochs (0 to disable)
32+
resume: False # (bool) resume training from last checkpoint
33+
amp: True # (bool) Automatic Mixed Precision (AMP) training, choices=[True, False], True runs AMP check
34+
fraction: 1.0 # (float) dataset fraction to train on (default is 1.0, all images in train set)
35+
profile: False # (bool) profile ONNX and TensorRT speeds during training for loggers
36+
freeze: None # (int | list, optional) freeze first n layers, or freeze list of layer indices during training
37+
multi_scale: False # (bool) Whether to use multi-scale during training
38+
# Segmentation
39+
overlap_mask: True # (bool) masks should overlap during training (segment train only)
40+
mask_ratio: 4 # (int) mask downsample ratio (segment train only)
41+
# Classification
42+
dropout: 0.0 # (float) use dropout regularization (classify train only)
43+
44+
# Val/Test settings ----------------------------------------------------------------------------------------------------
45+
val: True # (bool) validate/test during training
46+
split: val # (str) dataset split to use for validation, i.e. 'val', 'test' or 'train'
47+
save_json: False # (bool) save results to JSON file
48+
save_hybrid: False # (bool) save hybrid version of labels (labels + additional predictions)
49+
conf: 0.05 # (float, optional) object confidence threshold for detection (default 0.25 predict, 0.001 val)
50+
iou: 0.5 # (float) intersection over union (IoU) threshold for NMS
51+
max_det: 300 # (int) maximum number of detections per image
52+
half: False # (bool) use half precision (FP16)
53+
dnn: False # (bool) use OpenCV DNN for ONNX inference
54+
plots: True # (bool) save plots and images during train/val
55+
56+
# Predict settings -----------------------------------------------------------------------------------------------------
57+
source: # (str, optional) source directory for images or videos
58+
vid_stride: 1 # (int) video frame-rate stride
59+
stream_buffer: False # (bool) buffer all streaming frames (True) or return the most recent frame (False)
60+
visualize: False # (bool) visualize model features
61+
augment: False # (bool) apply image augmentation to prediction sources
62+
agnostic_nms: False # (bool) class-agnostic NMS
63+
classes: # (int | list[int], optional) filter results by class, i.e. classes=0, or classes=[0,2,3]
64+
retina_masks: False # (bool) use high-resolution segmentation masks
65+
embed: # (list[int], optional) return feature vectors/embeddings from given layers
66+
67+
# Visualize settings ---------------------------------------------------------------------------------------------------
68+
show: False # (bool) show predicted images and videos if environment allows
69+
save_frames: False # (bool) save predicted individual video frames
70+
save_txt: False # (bool) save results as .txt file
71+
save_conf: False # (bool) save results with confidence scores
72+
save_crop: False # (bool) save cropped images with results
73+
show_labels: True # (bool) show prediction labels, i.e. 'person'
74+
show_conf: True # (bool) show prediction confidence, i.e. '0.99'
75+
show_boxes: True # (bool) show prediction boxes
76+
line_width: # (int, optional) line width of the bounding boxes. Scaled to image size if None.
77+
78+
# Export settings ------------------------------------------------------------------------------------------------------
79+
format: torchscript # (str) format to export to, choices at https://docs.ultralytics.com/modes/export/#export-formats
80+
keras: False # (bool) use Kera=s
81+
optimize: False # (bool) TorchScript: optimize for mobile
82+
int8: False # (bool) CoreML/TF INT8 quantization
83+
dynamic: False # (bool) ONNX/TF/TensorRT: dynamic axes
84+
simplify: False # (bool) ONNX: simplify model
85+
opset: # (int, optional) ONNX: opset version
86+
workspace: 4 # (int) TensorRT: workspace size (GB)
87+
nms: False # (bool) CoreML: add NMS
88+
89+
# Hyperparameters ------------------------------------------------------------------------------------------------------
90+
lr0: 0.01 # (float) initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
91+
lrf: 0.01 # (float) final learning rate (lr0 * lrf)
92+
momentum: 0.937 # (float) SGD momentum/Adam beta1
93+
weight_decay: 0.0005 # (float) optimizer weight decay 5e-4
94+
warmup_epochs: 3.0 # (float) warmup epochs (fractions ok)
95+
warmup_momentum: 0.8 # (float) warmup initial momentum
96+
warmup_bias_lr: 0.1 # (float) warmup initial bias lr
97+
box: 7.5 # (float) box loss gain
98+
cls: 0.5 # (float) cls loss gain (scale with pixels)
99+
dfl: 1.5 # (float) dfl loss gain
100+
pose: 12.0 # (float) pose loss gain
101+
kobj: 1.0 # (float) keypoint obj loss gain
102+
label_smoothing: 0.0 # (float) label smoothing (fraction)
103+
nbs: 64 # (int) nominal batch size
104+
hsv_h: 0.015 # (float) image HSV-Hue augmentation (fraction)
105+
hsv_s: 0.7 # (float) image HSV-Saturation augmentation (fraction)
106+
hsv_v: 0.4 # (float) image HSV-Value augmentation (fraction)
107+
degrees: 0.0 # (float) image rotation (+/- deg)
108+
translate: 0.1 # (float) image translation (+/- fraction)
109+
scale: 0.5 # (float) image scale (+/- gain)
110+
shear: 0.0 # (float) image shear (+/- deg)
111+
perspective: 0.0 # (float) image perspective (+/- fraction), range 0-0.001
112+
flipud: 0.0 # (float) image flip up-down (probability)
113+
fliplr: 0.5 # (float) image flip left-right (probability)
114+
mosaic: 1.0 # (float) image mosaic (probability)
115+
mixup: 0.0 # (float) image mixup (probability)
116+
copy_paste: 0.0 # (float) segment copy-paste (probability)
117+
auto_augment: randaugment # (str) auto augmentation policy for classification (randaugment, autoaugment, augmix)
118+
erasing: 0.4 # (float) probability of random erasing during classification training (0-1)
119+
crop_fraction: 1.0 # (float) image crop fraction for classification evaluation/inference (0-1)
120+
121+
# Custom config.yaml ---------------------------------------------------------------------------------------------------
122+
cfg: # (str, optional) for overriding defaults.yaml
123+
124+
# Tracker settings ------------------------------------------------------------------------------------------------------
125+
tracker: botsort.yaml # (str) tracker type, choices=[botsort.yaml, bytetrack.yaml]
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
train: /data/ephemeral/home/dataset/yolo/split/train_42_fold_1.txt
2+
val: /data/ephemeral/home/dataset/yolo/split/val_42_fold_1.txt
3+
4+
nc: 10
5+
names:
6+
0: General trash
7+
1: Paper
8+
2: Paper pack
9+
3: Metal
10+
4: Glass
11+
5: Plastic
12+
6: Styrofoam
13+
7: Plastic bag
14+
8: Battery
15+
9: Clothing
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
train: /data/ephemeral/home/dataset/yolo/split/train_42_fold_2.txt
2+
val: /data/ephemeral/home/dataset/yolo/split/val_42_fold_2.txt
3+
4+
nc: 10
5+
names:
6+
0: General trash
7+
1: Paper
8+
2: Paper pack
9+
3: Metal
10+
4: Glass
11+
5: Plastic
12+
6: Styrofoam
13+
7: Plastic bag
14+
8: Battery
15+
9: Clothing
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
train: /data/ephemeral/home/dataset/yolo/split/train_42_fold_3.txt
2+
val: /data/ephemeral/home/dataset/yolo/split/val_42_fold_3.txt
3+
4+
nc: 10
5+
names:
6+
0: General trash
7+
1: Paper
8+
2: Paper pack
9+
3: Metal
10+
4: Glass
11+
5: Plastic
12+
6: Styrofoam
13+
7: Plastic bag
14+
8: Battery
15+
9: Clothing
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
train: /data/ephemeral/home/dataset/yolo/split/train_42_fold_4.txt
2+
val: /data/ephemeral/home/dataset/yolo/split/val_42_fold_4.txt
3+
4+
nc: 10
5+
names:
6+
0: General trash
7+
1: Paper
8+
2: Paper pack
9+
3: Metal
10+
4: Glass
11+
5: Plastic
12+
6: Styrofoam
13+
7: Plastic bag
14+
8: Battery
15+
9: Clothing
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
train: /data/ephemeral/home/dataset/yolo/split/train_42_fold_5.txt
2+
val: /data/ephemeral/home/dataset/yolo/split/val_42_fold_5.txt
3+
4+
nc: 10
5+
names:
6+
0: General trash
7+
1: Paper
8+
2: Paper pack
9+
3: Metal
10+
4: Glass
11+
5: Plastic
12+
6: Styrofoam
13+
7: Plastic bag
14+
8: Battery
15+
9: Clothing

yolo/predict.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from ultralytics import YOLO
2+
import argparse
3+
import pandas as pd
4+
import os
5+
6+
def main(opt):
7+
model = YOLO(opt.weight)
8+
results = model(opt.source)
9+
10+
df = pd.DataFrame(columns=['PredictionString', 'image_id'])
11+
prediction_arr = []
12+
13+
for result in results:
14+
for box in result.boxes:
15+
c = int(box.cls)
16+
confidence = float(box.conf)
17+
*xyxy, = box.xyxy
18+
19+
prediction_arr.append(str(c))
20+
prediction_arr.append(str(confidence))
21+
for coord in list(map(float, xyxy[0])):
22+
prediction_arr.append(str(coord))
23+
24+
paths = result.path.split('/')
25+
image_id = '/'.join(paths[-2:])
26+
prediction_str = ' '.join(prediction_arr)
27+
28+
df = df.append(pd.DataFrame({'PredictionString':[prediction_str], 'image_id':[image_id]}), ignore_index=True)
29+
prediction_arr = []
30+
31+
save_dir = opt.save_dir
32+
name = opt.name
33+
if not os.path.isdir(save_dir):
34+
os.mkdir(save_dir)
35+
36+
df.to_csv(os.path.join(save_dir, f'{name}.csv'), index=False)
37+
38+
39+
def parse_opt():
40+
parser = argparse.ArgumentParser()
41+
42+
parser.add_argument('--weight', type=str, required=True)
43+
parser.add_argument('--source', type=str, default='/data/ephemeral/home/dataset/test')
44+
parser.add_argument('--save-dir', type=str, default='/data/ephemeral/home/level2-objectdetection-cv-03/yolo/results')
45+
parser.add_argument('--name', type=str, required=True)
46+
47+
return parser.parse_args()
48+
49+
if __name__ == '__main__':
50+
opt = parse_opt()
51+
main(opt)

yolo/predict.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
python ./predict.py --weight '/data/ephemeral/home/level2-objectdetection-cv-03/yolo/runs/detect/yolov8s_42_1_e50_SGD/weights/best.pt' --name yolov8s_42_1_e50_SGD;
2+
python ./predict.py --weight '/data/ephemeral/home/level2-objectdetection-cv-03/yolo/runs/detect/yolov8s_42_2_e50_SGD/weights/best.pt' --name yolov8s_42_2_e50_SGD;
3+
python ./predict.py --weight '/data/ephemeral/home/level2-objectdetection-cv-03/yolo/runs/detect/yolov8s_42_3_e50_SGD/weights/best.pt' --name yolov8s_42_3_e50_SGD;
4+
python ./predict.py --weight '/data/ephemeral/home/level2-objectdetection-cv-03/yolo/runs/detect/yolov8s_42_4_e50_SGD/weights/best.pt' --name yolov8s_42_4_e50_SGD;
5+
python ./predict.py --weight '/data/ephemeral/home/level2-objectdetection-cv-03/yolo/runs/detect/yolov8s_42_5_e50_SGD/weights/best.pt' --name yolov8s_42_5_e50_SGD;

yolo/train.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from ultralytics import YOLO
2+
import argparse
3+
4+
def main(opt):
5+
model = YOLO()
6+
model.train(cfg='./custom.yaml', **vars(opt))
7+
8+
def parse_opt():
9+
parser = argparse.ArgumentParser()
10+
11+
parser.add_argument('--model', type=str, required=True)
12+
parser.add_argument('--data', type=str, required=True)
13+
parser.add_argument('--epochs', type=int, default=50)
14+
parser.add_argument('--patience', type=int, default=50)
15+
parser.add_argument('--batch', '--batch-size', type=int, default=16)
16+
parser.add_argument('--imgsz', '--img-size', type=int, default=640)
17+
parser.add_argument('--workers', type=int, default=8)
18+
parser.add_argument('--project', type=str, default='')
19+
parser.add_argument('--name', type=str, default='')
20+
parser.add_argument('--optimizer', type=str, default='SGD')
21+
parser.add_argument('--conf', type=float, default=0.05)
22+
parser.add_argument('--iou', type=float, default=0.5)
23+
24+
return parser.parse_args()
25+
26+
if __name__ == '__main__':
27+
opt = parse_opt()
28+
main(opt)

0 commit comments

Comments
 (0)