Skip to content

Commit 344f7d7

Browse files
authored
Merge pull request #22 from freds-dev/feat/four-channels
Feat/four channels
2 parents 8b88710 + 00060eb commit 344f7d7

File tree

7 files changed

+89
-39
lines changed

7 files changed

+89
-39
lines changed

code/start_scene_cross_validation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def start_scene_cross_validation(dataset_name,video_event_name,video_rgb_name,co
1717
# Run the Bash script with arguments
1818
if id is None:
1919
id = -1
20-
id = subprocess.check_output(['bash', "create_scene_split.sh", dataset_name, scene,video_event_name,video_rgb_name,config_name,str(id)]).strip().split("\n")[-1]
20+
id = subprocess.check_output(['bash', "create_scene_split.sh", dataset_name, scene,video_event_name,video_rgb_name,config_name,str(id)]).decode('utf-8').strip().split("\n")[-1]
2121

2222

2323
if __name__ == "__main__":
@@ -30,4 +30,4 @@ def start_scene_cross_validation(dataset_name,video_event_name,video_rgb_name,co
3030

3131

3232
args = parser.parse_args()
33-
start_scene_cross_validation(args.dataset,args.video_event_name,args.video_rgb_name,args.config_nameargs.exception_scenes)
33+
start_scene_cross_validation(args.dataset,args.video_event_name,args.video_rgb_name,args.config_name,args.exception_scenes)

code/utils/build_sbatch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,4 @@ def main():
8181
write_file(args.script_location,gpu_script_content)
8282

8383
if __name__ == "__main__":
84-
main()
84+
main()

code/utils/build_split_script.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ def main():
4545
parser.add_argument("--video_event_name", type=str, required=True,help="Directory where the event videos are located")
4646
parser.add_argument("--video_rgb_name", type=str, required=True,help="Directory where the rgb videos are located")
4747
parser.add_argument("--config_name",type=str,required=True,help="Name of the configuration file")
48-
parser.add_argument("--cpus", type=int, default=18, help="Number of CPUs")
49-
parser.add_argument("--memory", type=int, default=48, help="Memory in GB")
50-
parser.add_argument("--hours", type=int, default=12, help="Wallclock time in hours")
51-
parser.add_argument("--partition", default="normal", help="Partition for the job")
48+
parser.add_argument("--cpus", type=int, default=36, help="Number of CPUs")
49+
parser.add_argument("--memory", type=int, default=25, help="Memory in GB")
50+
parser.add_argument("--hours", type=int, default=8, help="Wallclock time in hours")
51+
parser.add_argument("--partition", default="normal,long", help="Partition for the job")
5252

5353
args = parser.parse_args()
5454

code/yolov7_custom/train.py

Lines changed: 75 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,48 @@
3838
from utils.plots import plot_lr_scheduler, plot_images, plot_labels, plot_results, plot_evolution
3939
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel
4040
from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
41-
from ..utils.paths import get_data_yaml, get_result_dir
4241
logger = logging.getLogger(__name__)
4342

43+
import os
44+
from dotenv import load_dotenv
45+
46+
load_dotenv()
47+
48+
IS_LOCAL = os.getenv('IS_LOCAL') == "TRUE"
49+
LOCAL_PATH_DATA = os.getenv('LOCAL_PATH_DATA')
50+
PALMA_PATH_DATA = os.getenv('PALMA_PATH_DATA')
51+
52+
def get_base_dir()-> str:
53+
return LOCAL_PATH_DATA if IS_LOCAL else PALMA_PATH_DATA
54+
55+
def get_homography() -> str:
56+
return os.path.join(get_base_dir(), "homography_calib.yaml")
57+
58+
def get_calib_rgb() -> str:
59+
return os.path.join(get_base_dir(), "calib_dng.yaml")
60+
61+
def get_calib_event() -> str:
62+
return os.path.join(get_base_dir(), "calib_raw.yaml")
63+
64+
def get_config_path(config_name:str) -> str:
65+
return os.path.join(get_base_dir(),"configs",config_name)
66+
67+
def get_annotations_path()->str:
68+
return os.path.join(get_base_dir(), "annotations.ndjson")
69+
70+
def get_dataset_dir(dataset_name:str)-> str:
71+
return os.path.join(get_base_dir(),"datasets",dataset_name)
72+
73+
def get_video_dir(video_dir_name: str) -> str:
74+
return os.path.join(get_base_dir(),"videos",video_dir_name)
75+
76+
def get_data_yaml(dataset_name: str) -> str:
77+
return os.path.join(get_base_dir(),"datasets",dataset_name,"data.yaml")
78+
79+
def get_result_dir(dataset_name: str) -> str:
80+
dir = os.path.join(get_base_dir(),"results",dataset_name)
81+
os.makedirs(dir,exist_ok=True)
82+
return dir
4483

4584
def train(hyp, opt, device, tb_writer=None):
4685
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
@@ -88,18 +127,18 @@ def train(hyp, opt, device, tb_writer=None):
88127

89128
# Model
90129
pretrained = weights.endswith('.pt')
91-
if pretrained:
92-
with torch_distributed_zero_first(rank):
93-
attempt_download(weights) # download if not found locally
94-
ckpt = torch.load(weights, map_location=device) # load checkpoint
95-
model = Model(opt.cfg or ckpt['model'].yaml, ch=(4 if opt.four_channels else 3)*opt.multi_frame, nc=nc, anchors=hyp.get('anchors')).to(device) # create
96-
exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys
97-
state_dict = ckpt['model'].float().state_dict() # to FP32
98-
state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
99-
model.load_state_dict(state_dict, strict=False) # load
100-
logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
101-
else:
102-
model = Model(opt.cfg, ch=(4 if opt.four_channels else 3)*opt.multi_frame, nc=nc, anchors=hyp.get('anchors')).to(device) # create
130+
#if pretrained:
131+
# with torch_distributed_zero_first(rank):
132+
# attempt_download(weights) # download if not found locally
133+
# ckpt = torch.load(weights, map_location=device) # load checkpoint
134+
# model = Model(opt.cfg or ckpt['model'].yaml, ch=(4 if opt.four_channels else 3)*opt.multi_frame, nc=nc, anchors=hyp.get('anchors')).to(device) # create
135+
# exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys
136+
# state_dict = ckpt['model'].float().state_dict() # to FP32
137+
# state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
138+
# model.load_state_dict(state_dict, strict=False) # load
139+
# logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
140+
#else:
141+
model = Model(opt.cfg, ch=(4 if opt.four_channels else 3)*opt.multi_frame, nc=nc, anchors=hyp.get('anchors')).to(device) # create
103142
with torch_distributed_zero_first(rank):
104143
check_dataset(data_dict) # check
105144
train_path = data_dict['train']
@@ -208,6 +247,7 @@ def train(hyp, opt, device, tb_writer=None):
208247

209248
# Resume
210249
start_epoch, best_fitness = 0, 0.0
250+
"""
211251
if pretrained:
212252
# Optimizer
213253
if ckpt['optimizer'] is not None:
@@ -233,7 +273,7 @@ def train(hyp, opt, device, tb_writer=None):
233273
epochs += ckpt['epoch'] # finetune additional epochs
234274
235275
del ckpt, state_dict
236-
276+
"""
237277
# if opt.multi_frame > 1:
238278
# multi_train_path = stack_images(train_path, opt.multi_frame)
239279
# train_path = multi_train_path
@@ -477,8 +517,22 @@ def train(hyp, opt, device, tb_writer=None):
477517
tb_writer.add_scalar(tag, x, epoch) # tensorboard
478518
if wandb_logger.wandb:
479519
wandb_logger.log({tag: x}) # W&B
520+
521+
522+
523+
def flatten_data(data):
524+
flattened_data = []
525+
for item in data:
526+
if isinstance(item, np.ndarray):
527+
flattened_data.extend(item.flatten())
528+
else:
529+
flattened_data.append(item)
530+
return flattened_data
480531

481532
# Update best mAP
533+
print(results)
534+
results = flatten_data(results)
535+
print(results)
482536
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
483537
if fi > best_fitness:
484538
best_fitness = fi
@@ -571,13 +625,13 @@ def train(hyp, opt, device, tb_writer=None):
571625

572626
def main():
573627
parser = argparse.ArgumentParser()
574-
parser.add_argument('--weights', type=str, default='yolo7.pt', help='initial weights path')
575-
parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
576-
parser.add_argument('--dataset', type=str, default='data/coco.yaml', help='data.yaml path')
577-
parser.add_argument('--hyp', type=str, default='data/hyp.scratch.p5.yaml', help='hyperparameters path')
628+
parser.add_argument('--weights', type=str, default='/scratch/tmp/jdanel/data/best.pt', help='initial weights path')
629+
parser.add_argument('--cfg', type=str, default='/home/j/jdanel/codespace/ML4IM/code/yolov7_custom/cfg/training/yolov7.yaml', help='model.yaml path')
630+
parser.add_argument('--data', type=str, default='data/coco.yaml', help='data.yaml path')
631+
parser.add_argument('--hyp', type=str, default='/home/j/jdanel/codespace/ML4IM/code/yolov7_custom/data/hyp.scratch.p5.yaml', help='hyperparameters path')
578632
parser.add_argument('--epochs', type=int, default=100)
579-
parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
580-
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes')
633+
parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs') #TODO: Change default to 256
634+
parser.add_argument('--img-size', nargs='+', type=int, default=[1280, 1280], help='[train, test] image sizes') # TODO: Change default to 1280,1280
581635
parser.add_argument('--rect', action='store_true', help='rectangular training')
582636
parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
583637
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
@@ -613,14 +667,10 @@ def main():
613667
parser.add_argument('--multi-frame', type=int, default=1, choices=range(1,101), help='how many frames to load at once')
614668
parser.add_argument('--center-point', action='store_true', help='use center point metric instead of iou')
615669
opt = parser.parse_args()
616-
670+
print(opt.data)
617671
# Set DDP variables
618672
opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
619673
opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1
620-
dataset = opt.dataset
621-
opt.dataset = get_data_yaml(dataset)
622-
opt.project = get_result_dir(dataset)
623-
624674
set_logging(opt.global_rank)
625675
#if opt.global_rank in [-1, 0]:
626676
# check_git_status()

code/yolov7_custom/utils/datasets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
451451
x[:, 0] = 0
452452

453453
n = len(shapes) # number of images
454-
bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
454+
bi = np.floor(np.arange(n) / batch_size).astype(int) # batch index
455455
nb = bi[-1] + 1 # number of batches
456456
self.batch = bi # batch index of image
457457
self.n = n
@@ -479,7 +479,7 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
479479
elif mini > 1:
480480
shapes[i] = [1, 1 / mini]
481481

482-
self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
482+
self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(int) * stride
483483

484484
# Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
485485
self.imgs = [None] * n

code/yolov7_custom/utils/general.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def labels_to_class_weights(labels, nc=80):
219219
return torch.Tensor()
220220

221221
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
222-
classes = labels[:, 0].astype(np.int) # labels = [class xywh]
222+
classes = labels[:, 0].astype(int) # labels = [class xywh]
223223
weights = np.bincount(classes, minlength=nc) # occurrences per class
224224

225225
# Prepend gridpoint count (for uCE training)

code/yolov7_custom/utils/loss.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,7 @@ def build_targets(self, p, targets, imgs):
642642
#indices, anch = self.find_4_positive(p, targets)
643643
#indices, anch = self.find_5_positive(p, targets)
644644
#indices, anch = self.find_9_positive(p, targets)
645-
645+
device = torch.device(targets.device)
646646
matching_bs = [[] for pp in p]
647647
matching_as = [[] for pp in p]
648648
matching_gjs = [[] for pp in p]
@@ -682,7 +682,7 @@ def build_targets(self, p, targets, imgs):
682682
all_gj.append(gj)
683683
all_gi.append(gi)
684684
all_anch.append(anch[i][idx])
685-
from_which_layer.append(torch.ones(size=(len(b),)) * i)
685+
from_which_layer.append((torch.ones(size=(len(b),)) * i).to(device))
686686

687687
fg_pred = pi[b, a, gj, gi]
688688
p_obj.append(fg_pred[:, 4:5])
@@ -739,7 +739,7 @@ def build_targets(self, p, targets, imgs):
739739
+ 3.0 * pair_wise_iou_loss
740740
)
741741

742-
matching_matrix = torch.zeros_like(cost)
742+
matching_matrix = torch.zeros_like(cost, device=device)
743743

744744
for gt_idx in range(num_gt):
745745
_, pos_idx = torch.topk(
@@ -753,7 +753,7 @@ def build_targets(self, p, targets, imgs):
753753
_, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
754754
matching_matrix[:, anchor_matching_gt > 1] *= 0.0
755755
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0
756-
fg_mask_inboxes = matching_matrix.sum(0) > 0.0
756+
fg_mask_inboxes = (matching_matrix.sum(0) > 0.0).to(device)
757757
matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
758758

759759
from_which_layer = from_which_layer[fg_mask_inboxes]

0 commit comments

Comments
 (0)