Skip to content

Commit dea366b

Browse files
committed
fix: issues #294
1 parent aa7f20e commit dea366b

File tree

1 file changed

+31
-8
lines changed

1 file changed

+31
-8
lines changed

tools/export.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,19 +92,46 @@ def parse_args():
9292
os.environ["LOCAL_RANK"] = str(args.local_rank)
9393
return args
9494

95+
def find_and_sample_images(folder_path, limit=10000, sample_size=100):
96+
import random
97+
98+
image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'}
99+
image_files = []
100+
101+
if not os.path.isdir(folder_path):
102+
raise ValueError(f"The provided image path '{folder_path}' is not a valid directory.")
103+
104+
for root, _, files in os.walk(folder_path):
105+
for file in files:
106+
if os.path.splitext(file)[1].lower() in image_extensions:
107+
image_files.append(os.path.join(root, file))
108+
if len(image_files) >= limit:
109+
break
110+
if len(image_files) >= limit:
111+
break
112+
113+
found = len(image_files)
114+
if found < sample_size:
115+
if found == 0:
116+
raise ValueError(f"No images found in the directory '{folder_path}'.")
117+
print(f"Warning: Found only {found} images, which is less than the requested sample size of {sample_size}.")
118+
sample_size = found
119+
120+
random_sample = random.sample(image_files, sample_size)
121+
122+
return random_sample
95123

96124
def generate_input(images_path, img_shape):
97125
import cv2
98126

99127
res = []
100-
datasets = [
101-
osp.join(images_path, i) for i in os.listdir(images_path) if i.endswith(".jpg")
102-
]
128+
datasets = find_and_sample_images(images_path, limit=10000, sample_size=100)
103129
for ps in datasets[:100]:
104130
img = cv2.imread(ps)
105131
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255
106132
img = cv2.resize(img, (img_shape[0], img_shape[1]))
107133
res.append(img)
134+
res.append(img.astype(np.float32))
108135
return np.asarray(res)
109136

110137

@@ -360,11 +387,7 @@ def export_tflite(onnx_path: str, img_shape, img_path):
360387
converter = tf.lite.TFLiteConverter.from_saved_model(osp.dirname(onnx_path))
361388

362389
def representative_dataset():
363-
datasets = [
364-
osp.join(img_path, i)
365-
for i in os.listdir(img_path)
366-
if i.lower().endswith((".jpg", ".jpeg", ".png"))
367-
]
390+
datasets = find_and_sample_images(img_path, limit=10000, sample_size=300)
368391
for ps in tqdm(datasets[:300]):
369392
img = cv2.imread(ps)
370393
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255

0 commit comments

Comments
 (0)