@@ -42,15 +42,15 @@ class EMAConfig(BaseModel):
42
42
def decay_check (cls , v ):
43
43
if v <= 0 or v >= 1 :
44
44
raise ValueError (
45
- f"'decay' should be in (0, 1) when is type of float, but got { v } "
45
+ f"'ema. decay' should be in (0, 1) when is type of float, but got { v } "
46
46
)
47
47
return v
48
48
49
49
@field_validator ("avg_freq" )
50
50
def avg_freq_check (cls , v ):
51
51
if v <= 0 :
52
52
raise ValueError (
53
- "'avg_freq' should be a positive integer when is type of int, "
53
+ "'ema. avg_freq' should be a positive integer when is type of int, "
54
54
f"but got { v } "
55
55
)
56
56
return v
@@ -63,15 +63,17 @@ class SWAConfig(BaseModel):
63
63
@field_validator ("avg_range" )
64
64
def avg_range_check (cls , v , info : ValidationInfo ):
65
65
if isinstance (v , tuple ) and v [0 ] > v [1 ]:
66
- raise ValueError (f"'avg_range' should be a valid range, but got { v } ." )
66
+ raise ValueError (
67
+ f"'swa.avg_range' should be a valid range, but got { v } ."
68
+ )
67
69
if isinstance (v , tuple ) and v [0 ] < 0 :
68
70
raise ValueError (
69
- "The start epoch of 'avg_range' should be a non-negtive integer"
71
+ "The start epoch of 'swa. avg_range' should be a non-negtive integer"
70
72
f" , but got { v [0 ]} ."
71
73
)
72
74
if isinstance (v , tuple ) and v [1 ] > info .data ["epochs" ]:
73
75
raise ValueError (
74
- "The end epoch of 'avg_range' should not be lager than "
76
+ "The end epoch of 'swa. avg_range' should not be lager than "
75
77
f"'epochs'({ info .data ['epochs' ]} ), but got { v [1 ]} ."
76
78
)
77
79
return v
@@ -80,7 +82,7 @@ def avg_range_check(cls, v, info: ValidationInfo):
80
82
def avg_freq_check (cls , v ):
81
83
if v <= 0 :
82
84
raise ValueError (
83
- "'avg_freq' should be a positive integer when is type of int, "
85
+ "'swa. avg_freq' should be a positive integer when is type of int, "
84
86
f"but got { v } "
85
87
)
86
88
return v
@@ -107,7 +109,7 @@ class TrainConfig(BaseModel):
107
109
def epochs_check (cls , v ):
108
110
if v <= 0 :
109
111
raise ValueError (
110
- "'epochs' should be a positive integer when is type of int, "
112
+ "'TRAIN. epochs' should be a positive integer when is type of int, "
111
113
f"but got { v } "
112
114
)
113
115
return v
@@ -116,7 +118,7 @@ def epochs_check(cls, v):
116
118
def iters_per_epoch_check (cls , v ):
117
119
if v <= 0 :
118
120
raise ValueError (
119
- "'iters_per_epoch' should be a positive integer when is type of int"
121
+ "'TRAIN. iters_per_epoch' should be a positive integer when is type of int"
120
122
f", but got { v } "
121
123
)
122
124
return v
@@ -125,7 +127,7 @@ def iters_per_epoch_check(cls, v):
125
127
def update_freq_check (cls , v ):
126
128
if v <= 0 :
127
129
raise ValueError (
128
- "'update_freq' should be a positive integer when is type of int"
130
+ "'TRAIN. update_freq' should be a positive integer when is type of int"
129
131
f", but got { v } "
130
132
)
131
133
return v
@@ -134,7 +136,7 @@ def update_freq_check(cls, v):
134
136
def save_freq_check (cls , v ):
135
137
if v < 0 :
136
138
raise ValueError (
137
- "'save_freq' should be a non-negtive integer when is type of int"
139
+ "'TRAIN. save_freq' should be a non-negtive integer when is type of int"
138
140
f", but got { v } "
139
141
)
140
142
return v
@@ -144,8 +146,8 @@ def start_eval_epoch_check(cls, v, info: ValidationInfo):
144
146
if info .data ["eval_during_train" ]:
145
147
if v <= 0 :
146
148
raise ValueError (
147
- f"'start_eval_epoch' should be a positive integer when "
148
- f"'eval_during_train' is True, but got { v } "
149
+ f"'TRAIN. start_eval_epoch' should be a positive integer when "
150
+ f"'TRAIN. eval_during_train' is True, but got { v } "
149
151
)
150
152
return v
151
153
@@ -154,8 +156,8 @@ def eval_freq_check(cls, v, info: ValidationInfo):
154
156
if info .data ["eval_during_train" ]:
155
157
if v <= 0 :
156
158
raise ValueError (
157
- f"'eval_freq' should be a positive integer when "
158
- f"'eval_during_train' is True, but got { v } "
159
+ f"'TRAIN. eval_freq' should be a positive integer when "
160
+ f"'TRAIN. eval_during_train' is True, but got { v } "
159
161
)
160
162
return v
161
163
@@ -176,6 +178,15 @@ class EvalConfig(BaseModel):
176
178
pretrained_model_path : Optional [str ] = None
177
179
eval_with_no_grad : bool = False
178
180
compute_metric_by_batch : bool = False
181
+ batch_size : Optional [int ] = 256
182
+
183
+ @field_validator ("batch_size" )
184
+ def batch_size_check (cls , v ):
185
+ if isinstance (v , int ) and v <= 0 :
186
+ raise ValueError (
187
+ f"'EVAL.batch_size' should be greater than 0 or None, but got { v } "
188
+ )
189
+ return v
179
190
180
191
class InferConfig (BaseModel ):
181
192
"""
@@ -203,12 +214,12 @@ class InferConfig(BaseModel):
203
214
def engine_check (cls , v , info : ValidationInfo ):
204
215
if v == "tensorrt" and info .data ["device" ] != "gpu" :
205
216
raise ValueError (
206
- "'device' should be 'gpu' when 'engine' is 'tensorrt', "
217
+ "'INFER. device' should be 'gpu' when 'INFER. engine' is 'tensorrt', "
207
218
f"but got '{ info .data ['device' ]} '"
208
219
)
209
220
if v == "mkldnn" and info .data ["device" ] != "cpu" :
210
221
raise ValueError (
211
- "'device' should be 'cpu' when 'engine' is 'mkldnn', "
222
+ "'INFER. device' should be 'cpu' when 'INFER. engine' is 'mkldnn', "
212
223
f"but got '{ info .data ['device' ]} '"
213
224
)
214
225
@@ -218,46 +229,50 @@ def engine_check(cls, v, info: ValidationInfo):
218
229
def min_subgraph_size_check (cls , v ):
219
230
if v <= 0 :
220
231
raise ValueError (
221
- "'min_subgraph_size' should be greater than 0, " f"but got { v } "
232
+ "'INFER.min_subgraph_size' should be greater than 0, "
233
+ f"but got { v } "
222
234
)
223
235
return v
224
236
225
237
@field_validator ("gpu_mem" )
226
238
def gpu_mem_check (cls , v ):
227
239
if v <= 0 :
228
- raise ValueError ("'gpu_mem' should be greater than 0, " f"but got { v } " )
240
+ raise ValueError (
241
+ "'INFER.gpu_mem' should be greater than 0, " f"but got { v } "
242
+ )
229
243
return v
230
244
231
245
@field_validator ("gpu_id" )
232
246
def gpu_id_check (cls , v ):
233
247
if v < 0 :
234
248
raise ValueError (
235
- "'gpu_id' should be greater than or equal to 0, " f"but got { v } "
249
+ "'INFER.gpu_id' should be greater than or equal to 0, "
250
+ f"but got { v } "
236
251
)
237
252
return v
238
253
239
254
@field_validator ("max_batch_size" )
240
255
def max_batch_size_check (cls , v ):
241
256
if v <= 0 :
242
257
raise ValueError (
243
- "'max_batch_size' should be greater than 0, " f"but got { v } "
258
+ "'INFER. max_batch_size' should be greater than 0, " f"but got { v } "
244
259
)
245
260
return v
246
261
247
262
@field_validator ("num_cpu_threads" )
248
263
def num_cpu_threads_check (cls , v ):
249
264
if v < 0 :
250
265
raise ValueError (
251
- "'num_cpu_threads' should be greater than or equal to 0, "
266
+ "'INFER. num_cpu_threads' should be greater than or equal to 0, "
252
267
f"but got { v } "
253
268
)
254
269
return v
255
270
256
271
@field_validator ("batch_size" )
257
272
def batch_size_check (cls , v ):
258
- if v <= 0 :
273
+ if isinstance ( v , int ) and v <= 0 :
259
274
raise ValueError (
260
- "' batch_size' should be greater than 0, " f" but got { v } "
275
+ f"'INFER. batch_size' should be greater than 0 or None, but got { v } "
261
276
)
262
277
return v
263
278
@@ -326,7 +341,8 @@ def use_wandb_check(cls, v, info: ValidationInfo):
326
341
- TRAIN/swa: swa_default <-- 'swa_default' used here
327
342
- EVAL: eval_default <-- 'eval_default' used here
328
343
- INFER: infer_default <-- 'infer_default' used here
329
- - _self_
344
+ - _self_ <-- config defined in current yaml
345
+
330
346
mode: train
331
347
seed: 42
332
348
...
@@ -384,6 +400,7 @@ def use_wandb_check(cls, v, info: ValidationInfo):
384
400
"EVAL.pretrained_model_path" ,
385
401
"EVAL.eval_with_no_grad" ,
386
402
"EVAL.compute_metric_by_batch" ,
403
+ "EVAL.batch_size" ,
387
404
"INFER.pretrained_model_path" ,
388
405
"INFER.export_path" ,
389
406
"INFER.pdmodel_path" ,
0 commit comments