@@ -92,19 +92,46 @@ def parse_args():
9292 os .environ ["LOCAL_RANK" ] = str (args .local_rank )
9393 return args
9494
95+ def find_and_sample_images (folder_path , limit = 10000 , sample_size = 100 ):
96+ import random
97+
98+ image_extensions = {'.jpg' , '.jpeg' , '.png' , '.gif' , '.bmp' , '.tiff' , '.webp' }
99+ image_files = []
100+
101+ if not os .path .isdir (folder_path ):
102+ raise ValueError (f"The provided image path '{ folder_path } ' is not a valid directory." )
103+
104+ for root , _ , files in os .walk (folder_path ):
105+ for file in files :
106+ if os .path .splitext (file )[1 ].lower () in image_extensions :
107+ image_files .append (os .path .join (root , file ))
108+ if len (image_files ) >= limit :
109+ break
110+ if len (image_files ) >= limit :
111+ break
112+
113+ found = len (image_files )
114+ if found < sample_size :
115+ if found == 0 :
116+ raise ValueError (f"No images found in the directory '{ folder_path } '." )
117+ print (f"Warning: Found only { found } images, which is less than the requested sample size of { sample_size } ." )
118+ sample_size = found
119+
120+ random_sample = random .sample (image_files , sample_size )
121+
122+ return random_sample
95123
96124def generate_input (images_path , img_shape ):
97125 import cv2
98126
99127 res = []
100- datasets = [
101- osp .join (images_path , i ) for i in os .listdir (images_path ) if i .endswith (".jpg" )
102- ]
128+ datasets = find_and_sample_images (images_path , limit = 10000 , sample_size = 100 )
103129 for ps in datasets [:100 ]:
104130 img = cv2 .imread (ps )
105131 img = cv2 .cvtColor (img , cv2 .COLOR_BGR2RGB ) / 255
106132 img = cv2 .resize (img , (img_shape [0 ], img_shape [1 ]))
107133 res .append (img )
134+ res .append (img .astype (np .float32 ))
108135 return np .asarray (res )
109136
110137
@@ -360,11 +387,7 @@ def export_tflite(onnx_path: str, img_shape, img_path):
360387 converter = tf .lite .TFLiteConverter .from_saved_model (osp .dirname (onnx_path ))
361388
362389 def representative_dataset ():
363- datasets = [
364- osp .join (img_path , i )
365- for i in os .listdir (img_path )
366- if i .lower ().endswith ((".jpg" , ".jpeg" , ".png" ))
367- ]
390+ datasets = find_and_sample_images (img_path , limit = 10000 , sample_size = 300 )
368391 for ps in tqdm (datasets [:300 ]):
369392 img = cv2 .imread (ps )
370393 img = cv2 .cvtColor (img , cv2 .COLOR_BGR2RGB ) / 255
0 commit comments