@@ -99,6 +99,29 @@ def core_options(func):
9999 return func
100100
101101
102+ def parse_classes (classes_input ):
103+ """
104+ Parse classes input which can be a tuple of ints (from multiple=True),
105+ a string (comma/space separated), or None.
106+ Returns a list of integers or None.
107+ """
108+ if classes_input is None :
109+ return None
110+
111+ if isinstance (classes_input , (list , tuple )):
112+ # If it's already a list/tuple of ints (from multiple=True)
113+ if not classes_input :
114+ return None
115+ return list (classes_input )
116+
117+ if isinstance (classes_input , str ):
118+ # Handle string input: "0,1" or "0 1"
119+ classes_input = classes_input .replace (',' , ' ' )
120+ return [int (x ) for x in classes_input .split ()]
121+
122+ return [int (classes_input )]
123+
124+
102125def singular_model_options (func ):
103126 options = [
104127 click .option ('--yolo-model' , type = Path ,
@@ -107,8 +130,8 @@ def singular_model_options(func):
107130 click .option ('--reid-model' , type = Path ,
108131 default = WEIGHTS / 'osnet_x0_25_msmt17.pt' ,
109132 help = 'path to ReID model weights' ),
110- click .option ('--classes' , type = int , multiple = True ,
111- help = 'filter by class indices' )
133+ click .option ('--classes' , type = str , default = None ,
134+ help = 'filter by class indices, e.g. 0 or "0,1" ' )
112135 ]
113136 for opt in reversed (options ):
114137 func = opt (func )
@@ -123,8 +146,8 @@ def plural_model_options(func):
123146 click .option ('--reid-model' , type = Path , multiple = True ,
124147 default = [WEIGHTS / 'osnet_x0_25_msmt17.pt' ],
125148 help = 'one or more ReID model weights' ),
126- click .option ('--classes' , type = int , multiple = True ,
127- default = [ 0 ], help = 'filter by class indices' )
149+ click .option ('--classes' , type = str , default = None ,
150+ help = 'filter by class indices, e.g. 0 or "0,1" ' )
128151 ]
129152 for opt in reversed (options ):
130153 func = opt (func )
@@ -291,7 +314,7 @@ def track(ctx, detector, reid, tracker, yolo_model, reid_model, classes, **kwarg
291314 params = {** kwargs ,
292315 'yolo_model' : yolo_model ,
293316 'reid_model' : reid_model ,
294- 'classes' : list (classes ) if classes else None ,
317+ 'classes' : parse_classes (classes ),
295318 'source' : src ,
296319 'benchmark' : bench ,
297320 'split' : split }
@@ -324,7 +347,7 @@ def generate(ctx, detector, reid, yolo_model, reid_model, classes, **kwargs):
324347 params = {** kwargs ,
325348 'yolo_model' : list (yolo_model ),
326349 'reid_model' : list (reid_model ),
327- 'classes' : list (classes ),
350+ 'classes' : parse_classes (classes ),
328351 'source' : src ,
329352 'benchmark' : bench ,
330353 'split' : split }
@@ -360,7 +383,7 @@ def eval(ctx, detector, reid, tracker, yolo_model, reid_model, classes, **kwargs
360383 params = {** kwargs ,
361384 'yolo_model' : list (yolo_model ),
362385 'reid_model' : list (reid_model ),
363- 'classes' : [ 0 ] ,
386+ 'classes' : parse_classes ( classes ) ,
364387 'source' : src ,
365388 'benchmark' : bench ,
366389 'split' : split }
@@ -397,7 +420,7 @@ def tune(ctx, detector, reid, tracker, yolo_model, reid_model, classes, **kwargs
397420 params = {** kwargs ,
398421 'yolo_model' : list (yolo_model ),
399422 'reid_model' : list (reid_model ),
400- 'classes' : list (classes ),
423+ 'classes' : parse_classes (classes ),
401424 'source' : src ,
402425 'benchmark' : bench ,
403426 'split' : split }
0 commit comments