Skip to content

Commit 90cbb95

Browse files
authored
add pre-commit workflow (#11973)
* add pre-commit workflow * run 'pre-commit run --all-files' * setup python version
1 parent 66b731b commit 90cbb95

22 files changed

+1676
-951
lines changed

PPOCRLabel.py

Lines changed: 1185 additions & 589 deletions
Large diffs are not rendered by default.

gen_ocr_train_val_test.py

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,16 @@ def isCreateOrDeleteFolder(path, flag):
1717
return flagAbsPath
1818

1919

20-
def splitTrainVal(root, abs_train_root_path, abs_val_root_path, abs_test_root_path, train_txt, val_txt, test_txt, flag):
21-
20+
def splitTrainVal(
21+
root,
22+
abs_train_root_path,
23+
abs_val_root_path,
24+
abs_test_root_path,
25+
train_txt,
26+
val_txt,
27+
test_txt,
28+
flag,
29+
):
2230
data_abs_path = os.path.abspath(root)
2331
label_file_name = args.detLabelFileName if flag == "det" else args.recLabelFileName
2432
label_file_path = os.path.join(data_abs_path, label_file_name)
@@ -29,13 +37,15 @@ def splitTrainVal(root, abs_train_root_path, abs_val_root_path, abs_test_root_pa
2937
label_record_len = len(label_file_content)
3038

3139
for index, label_record_info in enumerate(label_file_content):
32-
image_relative_path, image_label = label_record_info.split('\t')
40+
image_relative_path, image_label = label_record_info.split("\t")
3341
image_name = os.path.basename(image_relative_path)
3442

3543
if flag == "det":
3644
image_path = os.path.join(data_abs_path, image_name)
3745
elif flag == "rec":
38-
image_path = os.path.join(data_abs_path, args.recImageDirName, image_name)
46+
image_path = os.path.join(
47+
data_abs_path, args.recImageDirName, image_name
48+
)
3949

4050
train_val_test_ratio = args.trainValTestRatio.split(":")
4151
train_ratio = eval(train_val_test_ratio[0]) / 10
@@ -77,27 +87,46 @@ def genDetRecTrainVal(args):
7787
removeFile(os.path.join(args.recRootPath, "val.txt"))
7888
removeFile(os.path.join(args.recRootPath, "test.txt"))
7989

80-
detTrainTxt = open(os.path.join(args.detRootPath, "train.txt"), "a", encoding="UTF-8")
90+
detTrainTxt = open(
91+
os.path.join(args.detRootPath, "train.txt"), "a", encoding="UTF-8"
92+
)
8193
detValTxt = open(os.path.join(args.detRootPath, "val.txt"), "a", encoding="UTF-8")
8294
detTestTxt = open(os.path.join(args.detRootPath, "test.txt"), "a", encoding="UTF-8")
83-
recTrainTxt = open(os.path.join(args.recRootPath, "train.txt"), "a", encoding="UTF-8")
95+
recTrainTxt = open(
96+
os.path.join(args.recRootPath, "train.txt"), "a", encoding="UTF-8"
97+
)
8498
recValTxt = open(os.path.join(args.recRootPath, "val.txt"), "a", encoding="UTF-8")
8599
recTestTxt = open(os.path.join(args.recRootPath, "test.txt"), "a", encoding="UTF-8")
86100

87-
splitTrainVal(args.datasetRootPath, detAbsTrainRootPath, detAbsValRootPath, detAbsTestRootPath, detTrainTxt, detValTxt,
88-
detTestTxt, "det")
101+
splitTrainVal(
102+
args.datasetRootPath,
103+
detAbsTrainRootPath,
104+
detAbsValRootPath,
105+
detAbsTestRootPath,
106+
detTrainTxt,
107+
detValTxt,
108+
detTestTxt,
109+
"det",
110+
)
89111

90112
for root, dirs, files in os.walk(args.datasetRootPath):
91113
for dir in dirs:
92-
if dir == 'crop_img':
93-
splitTrainVal(root, recAbsTrainRootPath, recAbsValRootPath, recAbsTestRootPath, recTrainTxt, recValTxt,
94-
recTestTxt, "rec")
114+
if dir == "crop_img":
115+
splitTrainVal(
116+
root,
117+
recAbsTrainRootPath,
118+
recAbsValRootPath,
119+
recAbsTestRootPath,
120+
recTrainTxt,
121+
recValTxt,
122+
recTestTxt,
123+
"rec",
124+
)
95125
else:
96126
continue
97127
break
98128

99129

100-
101130
if __name__ == "__main__":
102131
# 功能描述:分别划分检测和识别的训练集、验证集、测试集
103132
# 说明:可以根据自己的路径和需求调整参数,图像数据往往多人合作分批标注,每一批图像数据放在一个文件夹内用PPOCRLabel进行标注,
@@ -107,40 +136,43 @@ def genDetRecTrainVal(args):
107136
"--trainValTestRatio",
108137
type=str,
109138
default="6:2:2",
110-
help="ratio of trainset:valset:testset")
139+
help="ratio of trainset:valset:testset",
140+
)
111141
parser.add_argument(
112142
"--datasetRootPath",
113143
type=str,
114144
default="../train_data/",
115-
help="path to the dataset marked by ppocrlabel, E.g, dataset folder named 1,2,3..."
145+
help="path to the dataset marked by ppocrlabel, E.g, dataset folder named 1,2,3...",
116146
)
117147
parser.add_argument(
118148
"--detRootPath",
119149
type=str,
120150
default="../train_data/det",
121-
help="the path where the divided detection dataset is placed")
151+
help="the path where the divided detection dataset is placed",
152+
)
122153
parser.add_argument(
123154
"--recRootPath",
124155
type=str,
125156
default="../train_data/rec",
126-
help="the path where the divided recognition dataset is placed"
157+
help="the path where the divided recognition dataset is placed",
127158
)
128159
parser.add_argument(
129160
"--detLabelFileName",
130161
type=str,
131162
default="Label.txt",
132-
help="the name of the detection annotation file")
163+
help="the name of the detection annotation file",
164+
)
133165
parser.add_argument(
134166
"--recLabelFileName",
135167
type=str,
136168
default="rec_gt.txt",
137-
help="the name of the recognition annotation file"
169+
help="the name of the recognition annotation file",
138170
)
139171
parser.add_argument(
140172
"--recImageDirName",
141173
type=str,
142174
default="crop_img",
143-
help="the name of the folder where the cropped recognition dataset is located"
175+
help="the name of the folder where the cropped recognition dataset is located",
144176
)
145177
args = parser.parse_args()
146-
genDetRecTrainVal(args)
178+
genDetRecTrainVal(args)

libs/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
__version_info__ = ('1', '0', '0')
2-
__version__ = '.'.join(__version_info__)
1+
__version_info__ = ("1", "0", "0")
2+
__version__ = ".".join(__version_info__)

libs/autoDialog.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,40 +29,53 @@ def __init__(self, ocr, mImgList, mainThread, model):
2929
self.mImgList = mImgList
3030
self.mainThread = mainThread
3131
self.model = model
32-
self.setStackSize(1024*1024)
32+
self.setStackSize(1024 * 1024)
3333

3434
def run(self):
3535
try:
3636
findex = 0
3737
for Imgpath in self.mImgList:
3838
if self.handle == 0:
3939
self.listValue.emit(Imgpath)
40-
if self.model == 'paddle':
41-
h, w, _ = cv2.imdecode(np.fromfile(Imgpath, dtype=np.uint8), 1).shape
40+
if self.model == "paddle":
41+
h, w, _ = cv2.imdecode(
42+
np.fromfile(Imgpath, dtype=np.uint8), 1
43+
).shape
4244
if h > 32 and w > 32:
43-
self.result_dic = self.ocr.ocr(Imgpath, cls=True, det=True)[0]
45+
self.result_dic = self.ocr.ocr(Imgpath, cls=True, det=True)[
46+
0
47+
]
4448
else:
45-
print('The size of', Imgpath, 'is too small to be recognised')
49+
print(
50+
"The size of", Imgpath, "is too small to be recognised"
51+
)
4652
self.result_dic = None
4753

4854
# 结果保存
4955
if self.result_dic is None or len(self.result_dic) == 0:
50-
print('Can not recognise file', Imgpath)
56+
print("Can not recognise file", Imgpath)
5157
pass
5258
else:
53-
strs = ''
59+
strs = ""
5460
for res in self.result_dic:
5561
chars = res[1][0]
5662
cond = res[1][1]
5763
posi = res[0]
58-
strs += "Transcription: " + chars + " Probability: " + str(cond) + \
59-
" Location: " + json.dumps(posi) +'\n'
64+
strs += (
65+
"Transcription: "
66+
+ chars
67+
+ " Probability: "
68+
+ str(cond)
69+
+ " Location: "
70+
+ json.dumps(posi)
71+
+ "\n"
72+
)
6073
# Sending large amounts of data repeatedly through pyqtSignal may affect the program efficiency
6174
self.listValue.emit(strs)
6275
self.mainThread.result_dic = self.result_dic
6376
self.mainThread.filePath = Imgpath
6477
# 保存
65-
self.mainThread.saveFile(mode='Auto')
78+
self.mainThread.saveFile(mode="Auto")
6679
findex += 1
6780
self.progressBarValue.emit(findex)
6881
else:
@@ -75,8 +88,9 @@ def run(self):
7588

7689

7790
class AutoDialog(QDialog):
78-
79-
def __init__(self, text="Enter object label", parent=None, ocr=None, mImgList=None, lenbar=0):
91+
def __init__(
92+
self, text="Enter object label", parent=None, ocr=None, mImgList=None, lenbar=0
93+
):
8094
super(AutoDialog, self).__init__(parent)
8195
self.setFixedWidth(1000)
8296
self.parent = parent
@@ -89,13 +103,13 @@ def __init__(self, text="Enter object label", parent=None, ocr=None, mImgList=No
89103

90104
layout = QVBoxLayout()
91105
layout.addWidget(self.pb)
92-
self.model = 'paddle'
106+
self.model = "paddle"
93107
self.listWidget = QListWidget(self)
94108
layout.addWidget(self.listWidget)
95109

96110
self.buttonBox = bb = BB(BB.Ok | BB.Cancel, Qt.Horizontal, self)
97-
bb.button(BB.Ok).setIcon(newIcon('done'))
98-
bb.button(BB.Cancel).setIcon(newIcon('undo'))
111+
bb.button(BB.Ok).setIcon(newIcon("done"))
112+
bb.button(BB.Cancel).setIcon(newIcon("undo"))
99113
bb.accepted.connect(self.validate)
100114
bb.rejected.connect(self.reject)
101115
layout.addWidget(bb)
@@ -107,7 +121,7 @@ def __init__(self, text="Enter object label", parent=None, ocr=None, mImgList=No
107121

108122
# self.setWindowFlags(Qt.WindowCloseButtonHint)
109123

110-
self.thread_1 = Worker(self.ocr, self.mImgList, self.parent, 'paddle')
124+
self.thread_1 = Worker(self.ocr, self.mImgList, self.parent, "paddle")
111125
self.thread_1.progressBarValue.connect(self.handleProgressBarSingal)
112126
self.thread_1.listValue.connect(self.handleListWidgetSingal)
113127
self.thread_1.endsignal.connect(self.handleEndsignalSignal)
@@ -117,8 +131,14 @@ def handleProgressBarSingal(self, i):
117131
self.pb.setValue(i)
118132

119133
# calculate time left of auto labeling
120-
avg_time = (time.time() - self.time_start) / i # Use average time to prevent time fluctuations
121-
time_left = str(datetime.timedelta(seconds=avg_time * (self.lender - i))).split(".")[0] # Remove microseconds
134+
avg_time = (
135+
time.time() - self.time_start
136+
) / i # Use average time to prevent time fluctuations
137+
time_left = str(datetime.timedelta(seconds=avg_time * (self.lender - i))).split(
138+
"."
139+
)[
140+
0
141+
] # Remove microseconds
122142
self.setWindowTitle("PPOCRLabel -- " + f"Time Left: {time_left}") # show
123143

124144
def handleListWidgetSingal(self, i):

0 commit comments

Comments
 (0)