Skip to content

Commit 54837db

Browse files
authored
refactored splitTrainVal and added multiOS path support (#11069)
1 parent 26925a9 commit 54837db

File tree

1 file changed

+38
-43
lines changed

1 file changed

+38
-43
lines changed

gen_ocr_train_val_test.py

Lines changed: 38 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,48 +17,43 @@ def isCreateOrDeleteFolder(path, flag):
1717
return flagAbsPath
1818

1919

20-
def splitTrainVal(root, absTrainRootPath, absValRootPath, absTestRootPath, trainTxt, valTxt, testTxt, flag):
21-
# 按照指定的比例划分训练集、验证集、测试集
22-
dataAbsPath = os.path.abspath(root)
23-
24-
if flag == "det":
25-
labelFilePath = os.path.join(dataAbsPath, args.detLabelFileName)
26-
elif flag == "rec":
27-
labelFilePath = os.path.join(dataAbsPath, args.recLabelFileName)
28-
29-
labelFileRead = open(labelFilePath, "r", encoding="UTF-8")
30-
labelFileContent = labelFileRead.readlines()
31-
random.shuffle(labelFileContent)
32-
labelRecordLen = len(labelFileContent)
33-
34-
for index, labelRecordInfo in enumerate(labelFileContent):
35-
imageRelativePath = labelRecordInfo.split('\t')[0]
36-
imageLabel = labelRecordInfo.split('\t')[1]
37-
imageName = os.path.basename(imageRelativePath)
38-
39-
if flag == "det":
40-
imagePath = os.path.join(dataAbsPath, imageName)
41-
elif flag == "rec":
42-
imagePath = os.path.join(dataAbsPath, "{}\\{}".format(args.recImageDirName, imageName))
43-
44-
# 按预设的比例划分训练集、验证集、测试集
45-
trainValTestRatio = args.trainValTestRatio.split(":")
46-
trainRatio = eval(trainValTestRatio[0]) / 10
47-
valRatio = trainRatio + eval(trainValTestRatio[1]) / 10
48-
curRatio = index / labelRecordLen
49-
50-
if curRatio < trainRatio:
51-
imageCopyPath = os.path.join(absTrainRootPath, imageName)
52-
shutil.copy(imagePath, imageCopyPath)
53-
trainTxt.write("{}\t{}".format(imageCopyPath, imageLabel))
54-
elif curRatio >= trainRatio and curRatio < valRatio:
55-
imageCopyPath = os.path.join(absValRootPath, imageName)
56-
shutil.copy(imagePath, imageCopyPath)
57-
valTxt.write("{}\t{}".format(imageCopyPath, imageLabel))
58-
else:
59-
imageCopyPath = os.path.join(absTestRootPath, imageName)
60-
shutil.copy(imagePath, imageCopyPath)
61-
testTxt.write("{}\t{}".format(imageCopyPath, imageLabel))
20+
def splitTrainVal(root, abs_train_root_path, abs_val_root_path, abs_test_root_path, train_txt, val_txt, test_txt, flag):
21+
22+
data_abs_path = os.path.abspath(root)
23+
label_file_name = args.detLabelFileName if flag == "det" else args.recLabelFileName
24+
label_file_path = os.path.join(data_abs_path, label_file_name)
25+
26+
with open(label_file_path, "r", encoding="UTF-8") as label_file:
27+
label_file_content = label_file.readlines()
28+
random.shuffle(label_file_content)
29+
label_record_len = len(label_file_content)
30+
31+
for index, label_record_info in enumerate(label_file_content):
32+
image_relative_path, image_label = label_record_info.split('\t')
33+
image_name = os.path.basename(image_relative_path)
34+
35+
if flag == "det":
36+
image_path = os.path.join(data_abs_path, image_name)
37+
elif flag == "rec":
38+
image_path = os.path.join(data_abs_path, args.recImageDirName, image_name)
39+
40+
train_val_test_ratio = args.trainValTestRatio.split(":")
41+
train_ratio = eval(train_val_test_ratio[0]) / 10
42+
val_ratio = train_ratio + eval(train_val_test_ratio[1]) / 10
43+
cur_ratio = index / label_record_len
44+
45+
if cur_ratio < train_ratio:
46+
image_copy_path = os.path.join(abs_train_root_path, image_name)
47+
shutil.copy(image_path, image_copy_path)
48+
train_txt.write("{}\t{}\n".format(image_copy_path, image_label))
49+
elif cur_ratio >= train_ratio and cur_ratio < val_ratio:
50+
image_copy_path = os.path.join(abs_val_root_path, image_name)
51+
shutil.copy(image_path, image_copy_path)
52+
val_txt.write("{}\t{}\n".format(image_copy_path, image_label))
53+
else:
54+
image_copy_path = os.path.join(abs_test_root_path, image_name)
55+
shutil.copy(image_path, image_copy_path)
56+
test_txt.write("{}\t{}\n".format(image_copy_path, image_label))
6257

6358

6459
# 删掉存在的文件
@@ -148,4 +143,4 @@ def genDetRecTrainVal(args):
148143
help="the name of the folder where the cropped recognition dataset is located"
149144
)
150145
args = parser.parse_args()
151-
genDetRecTrainVal(args)
146+
genDetRecTrainVal(args)

0 commit comments

Comments
 (0)