Skip to content

Commit 10dd833

Browse files
committed
Fix generated type hints for function params in _binaryninjacore.py
1 parent 9cc7fe0 commit 10dd833

File tree

1 file changed

+56
-25
lines changed

1 file changed

+56
-25
lines changed

python/generator.cpp

Lines changed: 56 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ map<string, string> g_pythonKeywordReplacements = {
6565
};
6666

6767

68-
void OutputType(FILE* out, Type* type, bool isReturnType = false, bool isCallback = false)
68+
void OutputType(FILE* out, Type* type, bool isReturnType = false, bool isCallback = false, bool isTypeHint = false)
6969
{
7070
switch (type->GetClass())
7171
{
@@ -129,30 +129,47 @@ void OutputType(FILE* out, Type* type, bool isReturnType = false, bool isCallbac
129129
else if ((type->GetChildType()->GetClass() == IntegerTypeClass) && (type->GetChildType()->GetWidth() == 1)
130130
&& (type->GetChildType()->IsSigned()))
131131
{
132-
if (isReturnType)
132+
if (isTypeHint)
133+
fprintf(out, "ctypes._Pointer[ctypes.c_byte]");
134+
else if (isReturnType)
133135
fprintf(out, "ctypes.POINTER(ctypes.c_byte)");
134136
else
135137
fprintf(out, "ctypes.c_char_p");
136138
break;
137139
}
138140
else if (type->GetChildType()->GetClass() == FunctionTypeClass)
139141
{
140-
fprintf(out, "ctypes.CFUNCTYPE(");
141-
OutputType(out, type->GetChildType()->GetChildType().GetValue(), true, true);
142+
if (isTypeHint)
143+
fprintf(out, "ctypes.CFUNCTYPE[");
144+
else
145+
fprintf(out, "ctypes.CFUNCTYPE(");
146+
OutputType(out, type->GetChildType()->GetChildType().GetValue(), true, true, isTypeHint);
142147
for (auto& i : type->GetChildType()->GetParameters())
143148
{
144149
fprintf(out, ", ");
145-
OutputType(out, i.type.GetValue());
150+
OutputType(out, i.type.GetValue(), false, false, isTypeHint);
146151
}
147-
fprintf(out, ")");
152+
153+
if (isTypeHint)
154+
fprintf(out, "]");
155+
else
156+
fprintf(out, ")");
148157
break;
149158
}
150-
fprintf(out, "ctypes.POINTER(");
151-
OutputType(out, type->GetChildType().GetValue());
152-
fprintf(out, ")");
159+
if (isTypeHint)
160+
fprintf(out, "ctypes._Pointer[");
161+
else
162+
fprintf(out, "ctypes.POINTER(");
163+
164+
OutputType(out, type->GetChildType().GetValue(), false, false, isTypeHint);
165+
166+
if (isTypeHint)
167+
fprintf(out, "]");
168+
else
169+
fprintf(out, ")");
153170
break;
154171
case ArrayTypeClass:
155-
OutputType(out, type->GetChildType().GetValue());
172+
OutputType(out, type->GetChildType().GetValue(), false, false, isTypeHint);
156173
fprintf(out, " * %" PRId64, type->GetElementCount());
157174
break;
158175
default:
@@ -162,7 +179,7 @@ void OutputType(FILE* out, Type* type, bool isReturnType = false, bool isCallbac
162179
}
163180

164181

165-
void OutputSwizzledType(FILE* out, Type* type)
182+
void OutputSwizzledType(FILE* out, Type* type, bool isTypeHint = false)
166183
{
167184
switch (type->GetClass())
168185
{
@@ -202,22 +219,35 @@ void OutputSwizzledType(FILE* out, Type* type)
202219
}
203220
else if (type->GetChildType()->GetClass() == FunctionTypeClass)
204221
{
205-
fprintf(out, "ctypes.CFUNCTYPE(");
206-
OutputType(out, type->GetChildType()->GetChildType().GetValue(), true, true);
222+
if (isTypeHint)
223+
fprintf(out, "ctypes.CFUNCTYPE[");
224+
else
225+
fprintf(out, "ctypes.CFUNCTYPE(");
226+
OutputType(out, type->GetChildType()->GetChildType().GetValue(), true, true, isTypeHint);
207227
for (auto& i : type->GetChildType()->GetParameters())
208228
{
209229
fprintf(out, ", ");
210-
OutputType(out, i.type.GetValue());
230+
OutputType(out, i.type.GetValue(), false, false, isTypeHint);
211231
}
212-
fprintf(out, ")");
232+
if (isTypeHint)
233+
fprintf(out, "]");
234+
else
235+
fprintf(out, ")");
236+
213237
break;
214238
}
215-
fprintf(out, "ctypes.POINTER(");
216-
OutputType(out, type->GetChildType().GetValue());
217-
fprintf(out, ")");
239+
if (isTypeHint)
240+
fprintf(out, "ctypes._Pointer[");
241+
else
242+
fprintf(out, "ctypes.POINTER(");
243+
OutputType(out, type->GetChildType().GetValue(), false, false, isTypeHint);
244+
if (isTypeHint)
245+
fprintf(out, "]");
246+
else
247+
fprintf(out, ")");
218248
break;
219249
case ArrayTypeClass:
220-
OutputType(out, type->GetChildType().GetValue());
250+
OutputType(out, type->GetChildType().GetValue(), false, false, isTypeHint);
221251
fprintf(out, " * %" PRId64, type->GetElementCount());
222252
break;
223253
default:
@@ -522,20 +552,21 @@ int main(int argc, char* argv[])
522552
if (argN > 0)
523553
fprintf(out, ", ");
524554
fprintf(out, "\n\t\t");
525-
fprintf(out, "%s: ", argName.c_str());
555+
fprintf(out, "%s: '", argName.c_str());
526556
if (swizzleArgs)
527-
OutputSwizzledType(out, arg.type.GetValue());
557+
OutputSwizzledType(out, arg.type.GetValue(), true);
528558
else
529-
OutputType(out, arg.type.GetValue());
559+
OutputType(out, arg.type.GetValue(), false, false, true);
560+
fprintf(out, "'");
530561
argN++;
531562
}
532563
}
533564
fprintf(out, "\n\t\t) -> ");
534565
if (stringResult || pointerResult)
535-
fprintf(out, "Optional[");
536-
OutputSwizzledType(out, i.second->GetChildType().GetValue());
566+
fprintf(out, "Optional['");
567+
OutputSwizzledType(out, i.second->GetChildType().GetValue(), true);
537568
if (stringResult || pointerResult)
538-
fprintf(out, "]");
569+
fprintf(out, "']");
539570
fprintf(out, ":\n");
540571

541572
string stringArgFuncCall = funcName + "(";

0 commit comments

Comments
 (0)