Skip to content

Commit e1a81e2

Browse files
authored
fix: register keyGet* function for evaluator (#225)
1 parent c93da99 commit e1a81e2

File tree

5 files changed

+310
-148
lines changed

5 files changed

+310
-148
lines changed

casbin/enforcer.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ bool Enforcer::m_enforce(const std::string& matcher, std::vector<std::string>& e
140140
// set to no-match at first
141141
matcher_results[policy_index] = 0;
142142
if (evalator->CheckType() == Type::Bool) {
143-
bool result = evalator->GetBoolen();
143+
bool result = evalator->GetBoolean();
144144
if (result) {
145145
matcher_results[policy_index] = 1;
146146
}
@@ -200,7 +200,7 @@ bool Enforcer::m_enforce(const std::string& matcher, std::vector<std::string>& e
200200
if (!isvalid) {
201201
return false;
202202
}
203-
bool result = evalator->GetBoolen();
203+
bool result = evalator->GetBoolean();
204204

205205
if (result) {
206206
policy_effects[0] = Effect::Allow;

casbin/model/evaluator.cpp

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@
2121

2222
namespace casbin {
2323
bool ExprtkEvaluator::Eval(const std::string& expression_string) {
24+
expression.register_symbol_table(symbol_table);
25+
if (enable_get) {
26+
expression.register_symbol_table(glbl_variable_symbol_table);
27+
}
28+
2429
if (this->expression_string_ != expression_string) {
2530
this->expression_string_ = expression_string;
2631
// replace (&& -> and), (|| -> or)
@@ -39,6 +44,15 @@ void ExprtkEvaluator::InitialObject(const std::string& identifier) {
3944
// symbol_table.add_stringvar("");
4045
}
4146

47+
void ExprtkEvaluator::EnableGet(const std::string& identifier) {
48+
enable_get = true;
49+
if (identifier.empty()) {
50+
glbl_variable_symbol_table.add_stringvar("key_get_result", key_get_result);
51+
} else {
52+
glbl_variable_symbol_table.add_stringvar(identifier, key_get_result);
53+
}
54+
}
55+
4256
void ExprtkEvaluator::PushObjectString(const std::string& target, const std::string& proprity, const std::string& var) {
4357
auto identifier = target + "." + proprity;
4458

@@ -57,6 +71,9 @@ void ExprtkEvaluator::LoadFunctions() {
5771
AddFunction("keyMatch4", ExprtkFunctionFactory::GetExprtkFunction(ExprtkFunctionType::KeyMatch4, 2));
5872
AddFunction("regexMatch", ExprtkFunctionFactory::GetExprtkFunction(ExprtkFunctionType::RegexMatch, 2));
5973
AddFunction("ipMatch", ExprtkFunctionFactory::GetExprtkFunction(ExprtkFunctionType::IpMatch, 2));
74+
AddFunction("keyGet", ExprtkFunctionFactory::GetExprtkFunction(ExprtkFunctionType::KeyGet, 2));
75+
AddFunction("keyGet2", ExprtkFunctionFactory::GetExprtkFunction(ExprtkFunctionType::KeyGet2, 3));
76+
AddFunction("keyGet3", ExprtkFunctionFactory::GetExprtkFunction(ExprtkFunctionType::KeyGet3, 3));
6077
}
6178

6279
void ExprtkEvaluator::LoadGFunction(std::shared_ptr<RoleManager> rm, const std::string& name, int narg) {
@@ -78,20 +95,26 @@ Type ExprtkEvaluator::CheckType() {
7895
}
7996
}
8097

81-
bool ExprtkEvaluator::GetBoolen() {
98+
bool ExprtkEvaluator::GetBoolean() {
8299
return bool(this->expression);
83100
}
84101

85102
float ExprtkEvaluator::GetFloat() {
86103
return expression.value();
87104
}
88105

106+
std::string ExprtkEvaluator::GetString() {
107+
const numerical_type result = expression.value();
108+
return key_get_result;
109+
}
110+
89111
void ExprtkEvaluator::Clean(AssertionMap& section, bool after_enforce) {
90-
if (after_enforce == false) {
112+
if (!after_enforce) {
91113
return;
92114
}
93115

94116
this->symbol_table.clear();
117+
this->glbl_variable_symbol_table.clear();
95118
this->expression_string_ = "";
96119
this->Functions.clear();
97120
this->identifiers_.clear();

include/casbin/model/evaluator.h

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,45 +49,57 @@ class IEvaluator {
4949

5050
virtual Type CheckType() = 0;
5151

52-
virtual bool GetBoolen() = 0;
52+
virtual bool GetBoolean() = 0;
5353

5454
virtual float GetFloat() = 0;
5555

56+
virtual std::string GetString() = 0;
57+
5658
virtual void Clean(AssertionMap& section, bool after_enforce = true) = 0;
5759
};
5860

5961
class ExprtkEvaluator : public IEvaluator {
6062
private:
6163
std::string expression_string_;
64+
std::string key_get_result;
6265
symbol_table_t symbol_table;
66+
symbol_table_t glbl_variable_symbol_table;
67+
bool enable_get{false};
6368
expression_t expression;
6469
parser_t parser;
6570
std::vector<std::shared_ptr<exprtk_func_t>> Functions;
6671
std::unordered_map<std::string, std::unique_ptr<std::string>> identifiers_;
6772

6873
public:
69-
ExprtkEvaluator() { this->expression.register_symbol_table(this->symbol_table); };
70-
bool Eval(const std::string& expression);
74+
ExprtkEvaluator() {
75+
this->symbol_table.add_constants();
76+
this->expression.register_symbol_table(this->symbol_table);
77+
};
78+
bool Eval(const std::string& expression) override;
79+
80+
void InitialObject(const std::string& target) override;
81+
82+
void EnableGet(const std::string& identifier);
7183

72-
void InitialObject(const std::string& target);
84+
void PushObjectString(const std::string& target, const std::string& proprity, const std::string& var) override;
7385

74-
void PushObjectString(const std::string& target, const std::string& proprity, const std::string& var);
86+
void PushObjectJson(const std::string& target, const std::string& proprity, const nlohmann::json& var) override;
7587

76-
void PushObjectJson(const std::string& target, const std::string& proprity, const nlohmann::json& var);
88+
void LoadFunctions() override;
7789

78-
void LoadFunctions();
90+
void LoadGFunction(std::shared_ptr<RoleManager> rm, const std::string& name, int narg) override;
7991

80-
void LoadGFunction(std::shared_ptr<RoleManager> rm, const std::string& name, int narg);
92+
void ProcessFunctions(const std::string& expression) override;
8193

82-
void ProcessFunctions(const std::string& expression);
94+
Type CheckType() override;
8395

84-
Type CheckType();
96+
bool GetBoolean() override;
8597

86-
bool GetBoolen();
98+
float GetFloat() override;
8799

88-
float GetFloat();
100+
std::string GetString();
89101

90-
void Clean(AssertionMap& section, bool after_enforce = true);
102+
void Clean(AssertionMap& section, bool after_enforce = true) override;
91103

92104
void PrintSymbol();
93105

include/casbin/model/exprtk_config.h

Lines changed: 105 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ struct ExprtkGFunction : public exprtk::igeneric_function<numerical_type> {
9494
}
9595
};
9696

97-
struct ExprtkOtherFunction : public exprtk::igeneric_function<numerical_type> {
97+
struct ExprtkMatchFunction : public exprtk::igeneric_function<numerical_type> {
9898
typedef typename exprtk::igeneric_function<numerical_type>::generic_type generic_type;
9999

100100
typedef typename generic_type::scalar_view scalar_t;
@@ -107,9 +107,9 @@ struct ExprtkOtherFunction : public exprtk::igeneric_function<numerical_type> {
107107
casbin::MatchingFunc func_;
108108

109109
public:
110-
ExprtkOtherFunction(const std::string& idenfier, casbin::MatchingFunc func) : exprtk::igeneric_function<numerical_type>(idenfier), func_(func) {}
110+
ExprtkMatchFunction(const std::string& idenfier, casbin::MatchingFunc func) : exprtk::igeneric_function<numerical_type>(idenfier), func_(func) {}
111111

112-
ExprtkOtherFunction() : exprtk::igeneric_function<numerical_type>("ss") {}
112+
ExprtkMatchFunction() : exprtk::igeneric_function<numerical_type>("ss") {}
113113

114114
inline numerical_type operator()(parameter_list_t parameters) {
115115
bool res = false;
@@ -143,6 +143,90 @@ struct ExprtkOtherFunction : public exprtk::igeneric_function<numerical_type> {
143143
}
144144
};
145145

146+
// KeyGet
147+
struct ExprtkGetFunction : public exprtk::igeneric_function<numerical_type> {
148+
typedef exprtk::igeneric_function<numerical_type> igenfunct_t;
149+
typedef typename igenfunct_t::generic_type generic_t;
150+
typedef typename igenfunct_t::parameter_list_t parameter_list_t;
151+
typedef typename generic_t::string_view string_t;
152+
153+
private:
154+
using MatchingFunc = std::function<std::string(const std::string&, const std::string&)>;
155+
MatchingFunc func_;
156+
157+
public:
158+
ExprtkGetFunction(const std::string& idenfier, MatchingFunc func) : igenfunct_t(idenfier, igenfunct_t::e_rtrn_string), func_(func) {}
159+
160+
ExprtkGetFunction() : igenfunct_t("SS", igenfunct_t::e_rtrn_string) {}
161+
162+
inline numerical_type operator()(std::string& result, parameter_list_t parameters) {
163+
result.clear();
164+
165+
// check value cnt
166+
if (parameters.size() != 2) {
167+
return numerical_type(0);
168+
}
169+
170+
// check value type
171+
for (std::size_t i = 0; i < parameters.size(); ++i) {
172+
generic_type& gt = parameters[i];
173+
if (generic_type::e_string != gt.type) {
174+
return numerical_type(0);
175+
}
176+
}
177+
std::string key1 = exprtk::to_str(string_t(parameters[0]));
178+
std::string key2 = exprtk::to_str(string_t(parameters[1]));
179+
180+
if (this->func_ != nullptr) {
181+
result = this->func_(key1, key2);
182+
}
183+
184+
return numerical_type(0);
185+
}
186+
};
187+
188+
struct ExprtkGetWithPathFunction : public exprtk::igeneric_function<numerical_type> {
189+
typedef exprtk::igeneric_function<numerical_type> igenfunct_t;
190+
typedef typename igenfunct_t::generic_type generic_t;
191+
typedef typename igenfunct_t::parameter_list_t parameter_list_t;
192+
typedef typename generic_t::string_view string_t;
193+
194+
private:
195+
using MatchingFunc = std::function<std::string(const std::string&, const std::string&, const std::string&)>;
196+
MatchingFunc func_;
197+
198+
public:
199+
ExprtkGetWithPathFunction(const std::string& idenfier, MatchingFunc func) : igenfunct_t(idenfier, igenfunct_t::e_rtrn_string), func_(func) {}
200+
201+
ExprtkGetWithPathFunction() : igenfunct_t("SSS", igenfunct_t::e_rtrn_string) {}
202+
203+
inline numerical_type operator()(std::string& result, parameter_list_t parameters) {
204+
result.clear();
205+
206+
// check value cnt
207+
if (parameters.size() != 3) {
208+
return numerical_type(0);
209+
}
210+
211+
// check value type
212+
for (std::size_t i = 0; i < parameters.size(); ++i) {
213+
generic_type& gt = parameters[i];
214+
if (generic_type::e_string != gt.type) {
215+
return numerical_type(0);
216+
}
217+
}
218+
std::string key1 = exprtk::to_str(string_t(parameters[0]));
219+
std::string key2 = exprtk::to_str(string_t(parameters[1]));
220+
std::string path_var = exprtk::to_str(string_t(parameters[2]));
221+
222+
if (this->func_ != nullptr) {
223+
result = this->func_(key1, key2, path_var);
224+
}
225+
226+
return numerical_type(0);
227+
}
228+
};
229+
146230
enum class ExprtkFunctionType {
147231
Unknown,
148232
Gfunction,
@@ -152,6 +236,9 @@ enum class ExprtkFunctionType {
152236
KeyMatch4,
153237
RegexMatch,
154238
IpMatch,
239+
KeyGet,
240+
KeyGet2,
241+
KeyGet3,
155242
};
156243

157244
class ExprtkFunctionFactory {
@@ -164,22 +251,31 @@ class ExprtkFunctionFactory {
164251
func = std::make_shared<ExprtkGFunction>(idenfier, rm);
165252
break;
166253
case ExprtkFunctionType::KeyMatch:
167-
func.reset(new ExprtkOtherFunction(idenfier, KeyMatch));
254+
func.reset(new ExprtkMatchFunction(idenfier, KeyMatch));
168255
break;
169256
case ExprtkFunctionType::KeyMatch2:
170-
func.reset(new ExprtkOtherFunction(idenfier, KeyMatch2));
257+
func.reset(new ExprtkMatchFunction(idenfier, KeyMatch2));
171258
break;
172259
case ExprtkFunctionType::KeyMatch3:
173-
func.reset(new ExprtkOtherFunction(idenfier, KeyMatch3));
260+
func.reset(new ExprtkMatchFunction(idenfier, KeyMatch3));
174261
break;
175262
case ExprtkFunctionType::KeyMatch4:
176-
func.reset(new ExprtkOtherFunction(idenfier, KeyMatch4));
263+
func.reset(new ExprtkMatchFunction(idenfier, KeyMatch4));
177264
break;
178265
case ExprtkFunctionType::IpMatch:
179-
func.reset(new ExprtkOtherFunction(idenfier, IPMatch));
266+
func.reset(new ExprtkMatchFunction(idenfier, IPMatch));
180267
break;
181268
case ExprtkFunctionType::RegexMatch:
182-
func.reset(new ExprtkOtherFunction(idenfier, RegexMatch));
269+
func.reset(new ExprtkMatchFunction(idenfier, RegexMatch));
270+
break;
271+
case ExprtkFunctionType::KeyGet:
272+
func.reset(new ExprtkGetFunction(idenfier, KeyGet));
273+
break;
274+
case ExprtkFunctionType::KeyGet2:
275+
func.reset(new ExprtkGetWithPathFunction(idenfier, KeyGet2));
276+
break;
277+
case ExprtkFunctionType::KeyGet3:
278+
func.reset(new ExprtkGetWithPathFunction(idenfier, KeyGet3));
183279
break;
184280
default:
185281
func = nullptr;

0 commit comments

Comments
 (0)