3838from utils .plots import plot_lr_scheduler , plot_images , plot_labels , plot_results , plot_evolution
3939from utils .torch_utils import ModelEMA , select_device , intersect_dicts , torch_distributed_zero_first , is_parallel
4040from utils .wandb_logging .wandb_utils import WandbLogger , check_wandb_resume
41- from ..utils .paths import get_data_yaml , get_result_dir
4241logger = logging .getLogger (__name__ )
4342
43+ import os
44+ from dotenv import load_dotenv
45+
46+ load_dotenv ()
47+
48+ IS_LOCAL = os .getenv ('IS_LOCAL' ) == "TRUE"
49+ LOCAL_PATH_DATA = os .getenv ('LOCAL_PATH_DATA' )
50+ PALMA_PATH_DATA = os .getenv ('PALMA_PATH_DATA' )
51+
52+ def get_base_dir ()-> str :
53+ return LOCAL_PATH_DATA if IS_LOCAL else PALMA_PATH_DATA
54+
55+ def get_homography () -> str :
56+ return os .path .join (get_base_dir (), "homography_calib.yaml" )
57+
58+ def get_calib_rgb () -> str :
59+ return os .path .join (get_base_dir (), "calib_dng.yaml" )
60+
61+ def get_calib_event () -> str :
62+ return os .path .join (get_base_dir (), "calib_raw.yaml" )
63+
64+ def get_config_path (config_name :str ) -> str :
65+ return os .path .join (get_base_dir (),"configs" ,config_name )
66+
67+ def get_annotations_path ()-> str :
68+ return os .path .join (get_base_dir (), "annotations.ndjson" )
69+
70+ def get_dataset_dir (dataset_name :str )-> str :
71+ return os .path .join (get_base_dir (),"datasets" ,dataset_name )
72+
73+ def get_video_dir (video_dir_name : str ) -> str :
74+ return os .path .join (get_base_dir (),"videos" ,video_dir_name )
75+
76+ def get_data_yaml (dataset_name : str ) -> str :
77+ return os .path .join (get_base_dir (),"datasets" ,dataset_name ,"data.yaml" )
78+
79+ def get_result_dir (dataset_name : str ) -> str :
80+ dir = os .path .join (get_base_dir (),"results" ,dataset_name )
81+ os .makedirs (dir ,exist_ok = True )
82+ return dir
4483
4584def train (hyp , opt , device , tb_writer = None ):
4685 logger .info (colorstr ('hyperparameters: ' ) + ', ' .join (f'{ k } ={ v } ' for k , v in hyp .items ()))
@@ -88,18 +127,18 @@ def train(hyp, opt, device, tb_writer=None):
88127
89128 # Model
90129 pretrained = weights .endswith ('.pt' )
91- if pretrained :
92- with torch_distributed_zero_first (rank ):
93- attempt_download (weights ) # download if not found locally
94- ckpt = torch .load (weights , map_location = device ) # load checkpoint
95- model = Model (opt .cfg or ckpt ['model' ].yaml , ch = (4 if opt .four_channels else 3 )* opt .multi_frame , nc = nc , anchors = hyp .get ('anchors' )).to (device ) # create
96- exclude = ['anchor' ] if (opt .cfg or hyp .get ('anchors' )) and not opt .resume else [] # exclude keys
97- state_dict = ckpt ['model' ].float ().state_dict () # to FP32
98- state_dict = intersect_dicts (state_dict , model .state_dict (), exclude = exclude ) # intersect
99- model .load_state_dict (state_dict , strict = False ) # load
100- logger .info ('Transferred %g/%g items from %s' % (len (state_dict ), len (model .state_dict ()), weights )) # report
101- else :
102- model = Model (opt .cfg , ch = (4 if opt .four_channels else 3 )* opt .multi_frame , nc = nc , anchors = hyp .get ('anchors' )).to (device ) # create
130+ # if pretrained:
131+ # with torch_distributed_zero_first(rank):
132+ # attempt_download(weights) # download if not found locally
133+ # ckpt = torch.load(weights, map_location=device) # load checkpoint
134+ # model = Model(opt.cfg or ckpt['model'].yaml, ch=(4 if opt.four_channels else 3)*opt.multi_frame, nc=nc, anchors=hyp.get('anchors')).to(device) # create
135+ # exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys
136+ # state_dict = ckpt['model'].float().state_dict() # to FP32
137+ # state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
138+ # model.load_state_dict(state_dict, strict=False) # load
139+ # logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
140+ # else:
141+ model = Model (opt .cfg , ch = (4 if opt .four_channels else 3 )* opt .multi_frame , nc = nc , anchors = hyp .get ('anchors' )).to (device ) # create
103142 with torch_distributed_zero_first (rank ):
104143 check_dataset (data_dict ) # check
105144 train_path = data_dict ['train' ]
@@ -208,6 +247,7 @@ def train(hyp, opt, device, tb_writer=None):
208247
209248 # Resume
210249 start_epoch , best_fitness = 0 , 0.0
250+ """
211251 if pretrained:
212252 # Optimizer
213253 if ckpt['optimizer'] is not None:
@@ -233,7 +273,7 @@ def train(hyp, opt, device, tb_writer=None):
233273 epochs += ckpt['epoch'] # finetune additional epochs
234274
235275 del ckpt, state_dict
236-
276+ """
237277 # if opt.multi_frame > 1:
238278 # multi_train_path = stack_images(train_path, opt.multi_frame)
239279 # train_path = multi_train_path
@@ -477,8 +517,22 @@ def train(hyp, opt, device, tb_writer=None):
477517 tb_writer .add_scalar (tag , x , epoch ) # tensorboard
478518 if wandb_logger .wandb :
479519 wandb_logger .log ({tag : x }) # W&B
520+
521+
522+
523+ def flatten_data (data ):
524+ flattened_data = []
525+ for item in data :
526+ if isinstance (item , np .ndarray ):
527+ flattened_data .extend (item .flatten ())
528+ else :
529+ flattened_data .append (item )
530+ return flattened_data
480531
481532 # Update best mAP
533+ print (results )
534+ results = flatten_data (results )
535+ print (results )
482536 fi = fitness (np .array (results ).reshape (1 , - 1 )) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
483537 if fi > best_fitness :
484538 best_fitness = fi
@@ -571,13 +625,13 @@ def train(hyp, opt, device, tb_writer=None):
571625
572626def main ():
573627 parser = argparse .ArgumentParser ()
574- parser .add_argument ('--weights' , type = str , default = 'yolo7 .pt' , help = 'initial weights path' )
575- parser .add_argument ('--cfg' , type = str , default = '' , help = 'model.yaml path' )
576- parser .add_argument ('--dataset ' , type = str , default = 'data/coco.yaml' , help = 'data.yaml path' )
577- parser .add_argument ('--hyp' , type = str , default = 'data/hyp.scratch.p5.yaml' , help = 'hyperparameters path' )
628+ parser .add_argument ('--weights' , type = str , default = '/scratch/tmp/jdanel/data/best .pt' , help = 'initial weights path' )
629+ parser .add_argument ('--cfg' , type = str , default = '/home/j/jdanel/codespace/ML4IM/code/yolov7_custom/cfg/training/yolov7.yaml ' , help = 'model.yaml path' )
630+ parser .add_argument ('--data ' , type = str , default = 'data/coco.yaml' , help = 'data.yaml path' )
631+ parser .add_argument ('--hyp' , type = str , default = '/home/j/jdanel/codespace/ML4IM/code/yolov7_custom/ data/hyp.scratch.p5.yaml' , help = 'hyperparameters path' )
578632 parser .add_argument ('--epochs' , type = int , default = 100 )
579- parser .add_argument ('--batch-size' , type = int , default = 16 , help = 'total batch size for all GPUs' )
580- parser .add_argument ('--img-size' , nargs = '+' , type = int , default = [640 , 640 ], help = '[train, test] image sizes' )
633+ parser .add_argument ('--batch-size' , type = int , default = 16 , help = 'total batch size for all GPUs' ) #TODO: Change default to 256
634+ parser .add_argument ('--img-size' , nargs = '+' , type = int , default = [1280 , 1280 ], help = '[train, test] image sizes' ) # TODO: Change default to 1280,1280
581635 parser .add_argument ('--rect' , action = 'store_true' , help = 'rectangular training' )
582636 parser .add_argument ('--resume' , nargs = '?' , const = True , default = False , help = 'resume most recent training' )
583637 parser .add_argument ('--nosave' , action = 'store_true' , help = 'only save final checkpoint' )
@@ -613,14 +667,10 @@ def main():
613667 parser .add_argument ('--multi-frame' , type = int , default = 1 , choices = range (1 ,101 ), help = 'how many frames to load at once' )
614668 parser .add_argument ('--center-point' , action = 'store_true' , help = 'use center point metric instead of iou' )
615669 opt = parser .parse_args ()
616-
670+ print ( opt . data )
617671 # Set DDP variables
618672 opt .world_size = int (os .environ ['WORLD_SIZE' ]) if 'WORLD_SIZE' in os .environ else 1
619673 opt .global_rank = int (os .environ ['RANK' ]) if 'RANK' in os .environ else - 1
620- dataset = opt .dataset
621- opt .dataset = get_data_yaml (dataset )
622- opt .project = get_result_dir (dataset )
623-
624674 set_logging (opt .global_rank )
625675 #if opt.global_rank in [-1, 0]:
626676 # check_git_status()
0 commit comments