Skip to content

Commit 8b6ea10

Browse files
committed
Refactor
1 parent 7fc3d53 commit 8b6ea10

File tree

5 files changed

+170
-59
lines changed

5 files changed

+170
-59
lines changed

flask_inputfilter/_input_filter.pyx

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,15 @@ from libcpp.string cimport string
2929

3030
from libcpp.algorithm cimport find
3131

32+
cdef extern from "helper.h":
33+
vector[string] make_default_methods()
34+
35+
cdef cppclass StringConstants:
36+
@staticmethod
37+
const char* get(const string& key) nogil
38+
@staticmethod
39+
bint has(const string& key) nogil
40+
3241
cdef dict _INTERNED_STRINGS = {
3342
"_condition": sys.intern("_condition"),
3443
"_error": sys.intern("_error"),
@@ -47,9 +56,6 @@ cdef dict _INTERNED_STRINGS = {
4756
"validators": sys.intern("validators"),
4857
}
4958

50-
cdef extern from "helper.h":
51-
vector[string] make_default_methods()
52-
5359
T = TypeVar("T")
5460

5561

@@ -403,16 +409,13 @@ cdef class InputFilter:
403409
return {}
404410

405411
cdef:
406-
Py_ssize_t i, n = len(self.fields)
407-
dict result
412+
Py_ssize_t i
413+
dict result = {}
408414
list field_names = list(self.fields.keys())
409415
str field
410416
object field_value
411417

412-
# Pre-allocate dictionary size for better performance
413-
result = {}
414-
415-
for i in range(n):
418+
for i in range(len(self.fields)):
416419
field = field_names[i]
417420
field_value = self.data.get(field)
418421
if field_value is not None:

flask_inputfilter/include/helper.h

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,102 @@ inline std::vector<std::string> make_default_methods() {
1818
};
1919
}
2020

21+
class HttpMethodSet {
22+
private:
23+
std::unordered_set<std::string> methods;
24+
25+
public:
26+
HttpMethodSet() = default;
27+
28+
HttpMethodSet(const HttpMethodSet& other) : methods(other.methods) {}
29+
30+
HttpMethodSet(HttpMethodSet&& other) noexcept : methods(std::move(other.methods)) {}
31+
32+
HttpMethodSet& operator=(const HttpMethodSet& other) {
33+
if (this != &other) {
34+
methods = other.methods;
35+
}
36+
return *this;
37+
}
38+
39+
HttpMethodSet& operator=(HttpMethodSet&& other) noexcept {
40+
if (this != &other) {
41+
methods = std::move(other.methods);
42+
}
43+
return *this;
44+
}
45+
46+
explicit HttpMethodSet(const std::vector<std::string>& method_list) {
47+
methods.reserve(method_list.size());
48+
for (const auto& method : method_list) {
49+
methods.insert(method);
50+
}
51+
}
52+
53+
bool contains(const std::string& method) const {
54+
return methods.find(method) != methods.end();
55+
}
56+
57+
void add(const std::string& method) {
58+
methods.insert(method);
59+
}
60+
61+
void clear() {
62+
methods.clear();
63+
}
64+
65+
size_t size() const {
66+
return methods.size();
67+
}
68+
69+
void from_vector(const std::vector<std::string>& method_list) {
70+
clear();
71+
methods.reserve(method_list.size());
72+
for (const auto& method : method_list) {
73+
methods.insert(method);
74+
}
75+
}
76+
};
77+
78+
class StringConstants {
79+
private:
80+
std::unordered_map<std::string, const char*> constants;
81+
static StringConstants& get_instance() {
82+
static StringConstants instance;
83+
return instance;
84+
}
85+
86+
StringConstants() {
87+
constants["_condition"] = "_condition";
88+
constants["_error"] = "_error";
89+
constants["copy"] = "copy";
90+
constants["default"] = "default";
91+
constants["DELETE"] = "DELETE";
92+
constants["external_api"] = "external_api";
93+
constants["fallback"] = "fallback";
94+
constants["filters"] = "filters";
95+
constants["GET"] = "GET";
96+
constants["PATCH"] = "PATCH";
97+
constants["POST"] = "POST";
98+
constants["PUT"] = "PUT";
99+
constants["required"] = "required";
100+
constants["steps"] = "steps";
101+
constants["validators"] = "validators";
102+
}
103+
104+
public:
105+
static const char* get(const std::string& key) {
106+
auto& instance = get_instance();
107+
auto it = instance.constants.find(key);
108+
return (it != instance.constants.end()) ? it->second : nullptr;
109+
}
110+
111+
static bool has(const std::string& key) {
112+
auto& instance = get_instance();
113+
return instance.constants.find(key) != instance.constants.end();
114+
}
115+
};
116+
21117
namespace string_ops {
22118
inline bool fast_startswith(const std::string& str, const std::string& prefix) {
23119
return str.size() >= prefix.size() &&

flask_inputfilter/mixins/data_mixin/_data_mixin.pyx

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ from flask_inputfilter.exceptions import ValidationError
1010
from flask_inputfilter.mixins.cimports cimport ValidationMixin
1111
from flask_inputfilter.models.cimports cimport BaseFilter, BaseValidator, FieldModel, BaseCondition, InputFilter
1212

13+
from libcpp.vector cimport vector
14+
from libcpp.string cimport string
15+
from libcpp.algorithm cimport find
16+
1317
DEF LARGE_DATASET_THRESHOLD = 100
1418

1519

@@ -36,19 +40,9 @@ cdef class DataMixin:
3640
if not data and fields:
3741
return True
3842

39-
cdef set field_set
40-
41-
# Use set operations for faster lookup when there are many fields
42-
if len(fields) > LARGE_DATASET_THRESHOLD:
43-
field_set = set(fields.keys())
44-
for field_name in data.keys():
45-
if field_name not in field_set:
46-
return True
47-
else:
48-
# Use direct dict lookup for smaller field counts
49-
for field_name in data.keys():
50-
if field_name not in fields:
51-
return True
43+
for field_name in data.keys():
44+
if field_name not in fields:
45+
return True
5246

5347
return False
5448

@@ -73,13 +67,19 @@ cdef class DataMixin:
7367
"""
7468
cdef:
7569
dict[str, Any] filtered_data = {}
76-
Py_ssize_t i, n = len(data) if data else 0
77-
list keys = list(data.keys()) if n > 0 else []
78-
list values = list(data.values()) if n > 0 else []
70+
Py_ssize_t i
71+
list keys
72+
list values
7973
str field_name
8074
object field_value
8175

82-
for i in range(n):
76+
if not data:
77+
return filtered_data
78+
79+
keys = list(data.keys())
80+
values = list(data.values())
81+
82+
for i in range(len(data)):
8383
field_name = keys[i]
8484
field_value = values[i]
8585

@@ -153,6 +153,8 @@ cdef class DataMixin:
153153
dict source_inputs = source_filter.get_inputs()
154154
list keys = list(source_inputs.keys()) if source_inputs else []
155155
list new_fields = list(source_inputs.values()) if source_inputs else []
156+
str method
157+
bytes encoded_method
156158

157159
n = len(keys)
158160
for i in range(n):
@@ -161,12 +163,12 @@ cdef class DataMixin:
161163
target_filter.conditions.extend(source_filter.conditions)
162164

163165
DataMixin._merge_component_list(
164-
target_filter.global_filters,
166+
target_filter.global_filters,
165167
source_filter.global_filters
166168
)
167169

168170
DataMixin._merge_component_list(
169-
target_filter.global_validators,
171+
target_filter.global_validators,
170172
source_filter.global_validators
171173
)
172174

flask_inputfilter/mixins/validation_mixin/_validation_mixin.pxd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ cdef class ValidationMixin:
1818
cdef void check_conditions(list[BaseCondition] conditions, dict[str, Any] validated_data) except *
1919

2020
@staticmethod
21-
cdef object check_for_required(str field_name, FieldModel field_info, object value)
21+
cdef inline object check_for_required(str field_name, FieldModel field_info, object value)
2222

2323
@staticmethod
2424
cdef tuple validate_fields(
@@ -29,4 +29,4 @@ cdef class ValidationMixin:
2929
)
3030

3131
@staticmethod
32-
cdef object get_field_value(str field_name, FieldModel field_info, dict[str, Any] data, dict[str, Any] validated_data)
32+
cdef inline object get_field_value(str field_name, FieldModel field_info, dict[str, Any] data, dict[str, Any] validated_data)

flask_inputfilter/mixins/validation_mixin/_validation_mixin.pyx

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,18 @@ cdef class ValidationMixin:
4242
return None
4343

4444
cdef:
45-
Py_ssize_t i, n
45+
Py_ssize_t i
4646
BaseFilter current_filter
4747

48-
n = len(filters1) if filters1 else 0
49-
for i in range(n):
50-
current_filter = filters1[i]
51-
value = current_filter.apply(value)
48+
if filters1:
49+
for i in range(len(filters1)):
50+
current_filter = filters1[i]
51+
value = current_filter.apply(value)
5252

53-
n = len(filters2) if filters2 else 0
54-
for i in range(n):
55-
current_filter = filters2[i]
56-
value = current_filter.apply(value)
53+
if filters2:
54+
for i in range(len(filters2)):
55+
current_filter = filters2[i]
56+
value = current_filter.apply(value)
5757

5858
return value
5959

@@ -95,12 +95,15 @@ cdef class ValidationMixin:
9595
if value is None:
9696
return None
9797

98+
if not steps:
99+
return value
100+
98101
cdef:
99-
Py_ssize_t i, n = len(steps) if steps else 0
102+
Py_ssize_t i
100103
object current_step
101104

102105
try:
103-
for i in range(n):
106+
for i in range(len(steps)):
104107
current_step = steps[i]
105108
if isinstance(current_step, BaseFilter):
106109
value = current_step.apply(value)
@@ -132,11 +135,14 @@ cdef class ValidationMixin:
132135
- **validated_data** (*dict[str, Any]*):
133136
The validated data to check against the conditions.
134137
"""
138+
if not conditions:
139+
return
140+
135141
cdef:
136-
Py_ssize_t i, n = len(conditions) if conditions else 0
142+
Py_ssize_t i
137143
object current_condition
138144

139-
for i in range(n):
145+
for i in range(len(conditions)):
140146
current_condition = conditions[i]
141147
if not current_condition.check(validated_data):
142148
raise ValidationError(
@@ -215,19 +221,19 @@ cdef class ValidationMixin:
215221
return None
216222

217223
cdef:
218-
Py_ssize_t i, n
224+
Py_ssize_t i
219225
BaseValidator current_validator
220226

221227
try:
222-
n = len(validators1) if validators1 else 0
223-
for i in range(n):
224-
current_validator = validators1[i]
225-
current_validator.validate(value)
226-
227-
n = len(validators2) if validators2 else 0
228-
for i in range(n):
229-
current_validator = validators2[i]
230-
current_validator.validate(value)
228+
if validators1:
229+
for i in range(len(validators1)):
230+
current_validator = validators1[i]
231+
current_validator.validate(value)
232+
233+
if validators2:
234+
for i in range(len(validators2)):
235+
current_validator = validators2[i]
236+
current_validator.validate(value)
231237
except ValidationError:
232238
if fallback is None:
233239
raise
@@ -269,16 +275,20 @@ cdef class ValidationMixin:
269275
cdef:
270276
dict[str, Any] validated_data = {}
271277
dict[str, str] errors = {}
272-
Py_ssize_t i, n = len(fields) if fields else 0
273-
274-
cdef:
275-
list field_names = list(fields.keys()) if n > 0 else []
276-
list field_infos = list(fields.values()) if n > 0 else []
278+
Py_ssize_t i
279+
list field_names
280+
list field_infos
277281
str field_name
278282
FieldModel field_info
279283
object value
280284

281-
for i in range(n):
285+
if not fields:
286+
return validated_data, errors
287+
288+
field_names = list(fields.keys())
289+
field_infos = list(fields.values())
290+
291+
for i in range(len(fields)):
282292
field_name = field_names[i]
283293
field_info = field_infos[i]
284294

0 commit comments

Comments
 (0)