Skip to content

Commit c2c6f0d

Browse files
committed
fix various issues with unitary enums and return type conversions in opaque structs
1 parent 137d75b commit c2c6f0d

12 files changed

+218
-82
lines changed

src/binding_types.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
class TypeInfo:
22
def __init__(self, is_native_primitive, rust_obj, swift_type, c_ty, is_const, passed_as_ptr, is_ptr, var_name,
3-
arr_len, arr_access, subty=None):
3+
arr_len, arr_access, subty=None, swift_raw_type=None):
44
self.is_native_primitive = is_native_primitive
55
self.rust_obj = rust_obj
66
# self.java_ty = java_ty
77
# self.java_hu_ty = java_hu_ty
88
# self.java_fn_ty_arg = java_fn_ty_arg
99
self.swift_type = swift_type
10+
self.swift_raw_type = swift_raw_type if swift_raw_type is not None else swift_type
1011
self.c_ty = c_ty
1112
self.is_const = is_const
1213
self.passed_as_ptr = passed_as_ptr

src/generators/opaque_struct_generator.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def __init__(self) -> None:
1313
template = template_handle.read()
1414
self.template = template
1515

16-
def generate_opaque_struct(self, struct_name, struct_details):
16+
def generate_opaque_struct(self, struct_name, struct_details, all_type_details = {}):
1717
# method_names = ['openChannel', 'closeChannel']
1818
# native_method_names = ['ChannelHandler_openChannel', 'ChannelHandler_closeChannel']
1919

@@ -28,15 +28,26 @@ def generate_opaque_struct(self, struct_name, struct_details):
2828
constructor_native_name = constructor_details['name']['native']
2929
swift_arguments = []
3030
native_arguments = []
31+
constructor_argument_prep = ''
3132
for current_argument_details in constructor_details['argument_types']:
3233
argument_name = current_argument_details.var_name
3334
passed_argument_name = argument_name
35+
constructor_argument_conversion_method = None
36+
37+
if current_argument_details.rust_obj is not None and current_argument_details.rust_obj.startswith('LDK') and current_argument_details.swift_type.startswith('['):
38+
constructor_argument_conversion_method = f'let converted_{argument_name} = Bindings.new_{current_argument_details.rust_obj}(array: {argument_name})'
39+
constructor_argument_prep += constructor_argument_conversion_method
40+
passed_argument_name = f'converted_{argument_name}'
41+
elif current_argument_details.rust_obj == 'LDK'+current_argument_details.swift_type:
42+
passed_argument_name += '.cOpaqueStruct!'
3443

3544
swift_arguments.append(f'{argument_name}: {current_argument_details.swift_type}')
3645
native_arguments.append(f'{passed_argument_name}')
3746

3847
mutating_output_file_contents = mutating_output_file_contents.replace('swift_constructor_arguments',
3948
', '.join(swift_arguments))
49+
mutating_output_file_contents = mutating_output_file_contents.replace('/* NATIVE_CONSTRUCTOR_PREP */',
50+
constructor_argument_prep)
4051
mutating_output_file_contents = mutating_output_file_contents.replace('native_constructor_arguments',
4152
', '.join(native_arguments))
4253
mutating_output_file_contents = mutating_output_file_contents.replace(
@@ -59,9 +70,23 @@ def generate_opaque_struct(self, struct_name, struct_details):
5970
current_native_method_name = current_method_details['name']['native']
6071
current_method_name = current_method_details['name']['swift']
6172
current_return_type = current_method_details['return_type'].swift_type
62-
# current_method_name = current_native_method_name[len(method_prefix):]
73+
# current_rust_return_type = current_method_details['return_type'].rust_obj
74+
75+
# if current_rust_return_type in all_type_details and all_type_details[current_rust_return_type].type.name == 'UNITARY_ENUM':
76+
# current_return_type = current_rust_return_type
77+
current_method_name = current_native_method_name[len(method_prefix):]
6378

6479
current_replacement = method_template
80+
81+
if current_method_details['return_type'].rust_obj is not None and current_method_details['return_type'].rust_obj.startswith('LDK') and current_method_details['return_type'].swift_type.startswith('['):
82+
return_type_wrapper_prefix = f'Bindings.{current_method_details["return_type"].rust_obj}_to_array(byteType: '
83+
return_type_wrapper_suffix = ')'
84+
current_replacement = current_replacement.replace('return OpaqueStructType_methodName(native_arguments)', f'return {return_type_wrapper_prefix}OpaqueStructType_methodName(native_arguments){return_type_wrapper_suffix}')
85+
elif current_method_details['return_type'].rust_obj == 'LDK' + current_method_details['return_type'].swift_type:
86+
return_type_wrapper_prefix = f'{current_method_details["return_type"].swift_type}(pointer: '
87+
return_type_wrapper_suffix = ')'
88+
current_replacement = current_replacement.replace('return OpaqueStructType_methodName(native_arguments)', f'return {return_type_wrapper_prefix}OpaqueStructType_methodName(native_arguments){return_type_wrapper_suffix}')
89+
6590
current_replacement = current_replacement.replace('func methodName(', f'func {current_method_name}(')
6691

6792
is_clone_method = current_method_details['is_clone']
@@ -106,7 +131,12 @@ def generate_opaque_struct(self, struct_name, struct_details):
106131

107132
if not pass_instance:
108133
swift_arguments.append(f'{argument_name}: {current_argument_details.swift_type}')
109-
native_arguments.append(f'{passed_argument_name}')
134+
135+
# native_arguments.append(f'{passed_argument_name}')
136+
if current_argument_details.rust_obj == 'LDK' + current_argument_details.swift_type and not current_argument_details.is_ptr:
137+
native_arguments.append(f'{passed_argument_name}.cOpaqueStruct!')
138+
else:
139+
native_arguments.append(f'{passed_argument_name}')
110140

111141
current_replacement = current_replacement.replace('swift_arguments', ', '.join(swift_arguments))
112142
if is_clone_method:

src/generators/trait_generator.py

Lines changed: 61 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -19,70 +19,77 @@ def generate_trait(self, struct_name, struct_details):
1919

2020
swift_struct_name = struct_name[3:]
2121

22-
method_template_regex = re.compile(
23-
"(\/\* STRUCT_METHODS_START \*\/\n)(.*)(\n[\t ]*\/\* STRUCT_METHODS_END \*\/)",
22+
native_callback_template_regex = re.compile(
23+
"(\/\* NATIVE_CALLBACKS_START \*\/\n)(.*)(\n[\t ]*\/\* NATIVE_CALLBACKS_END \*\/)",
2424
flags=re.MULTILINE | re.DOTALL)
25-
method_template = method_template_regex.search(self.template).group(2)
25+
native_callback_template = native_callback_template_regex.search(self.template).group(2)
26+
27+
swift_callback_template_regex = re.compile(
28+
"(\/\* SWIFT_CALLBACKS_START \*\/\n)(.*)(\n[\t ]*\/\* SWIFT_CALLBACKS_END \*\/)",
29+
flags=re.MULTILINE | re.DOTALL)
30+
swift_callback_template = swift_callback_template_regex.search(self.template).group(2)
2631

2732
method_prefix = swift_struct_name + '_'
28-
struct_methods = ''
33+
native_callbacks = ''
34+
swift_callbacks = ''
35+
36+
instantiation_arguments = []
2937

3038
# fill templates
31-
for current_method_details in struct_details.methods:
32-
current_native_method_name = current_method_details['name']['native']
33-
current_method_name = current_method_details['name']['swift']
34-
# current_method_name = current_native_method_name[len(method_prefix):]
35-
36-
current_replacement = method_template
37-
current_replacement = current_replacement.replace('func methodName(', f'func {current_method_name}(')
38-
current_replacement = current_replacement.replace('OpaqueStructType_methodName(',
39-
f'{current_native_method_name}(')
40-
41-
# replace arguments
42-
swift_arguments = []
43-
native_arguments = ['self.cOpaqueStruct']
44-
native_call_prep = ''
45-
for current_argument_details in current_method_details['argument_types']:
46-
argument_name = current_argument_details.var_name
47-
passed_argument_name = argument_name
48-
if argument_name == 'this_ptr':
49-
# we already pass this much more elegantly
50-
continue
51-
52-
if current_argument_details.passed_as_ptr:
53-
passed_argument_name = argument_name+'Pointer'
54-
# let managerPointer = withUnsafePointer(to: self.cChannelManager!) { (pointer: UnsafePointer<LDKChannelManager>) in
55-
# pointer
56-
# }
57-
# the \n\t will add a bunch of extra lines, but this file will be easier to read
58-
current_prep = f'''
59-
\n\t let {passed_argument_name} = withUnsafePointer(to: {argument_name}.cOpaqueStruct!) {{ (pointer: UnsafePointer<{current_argument_details.rust_obj}>) in
60-
\n\t\t pointer
61-
\n\t }}
62-
'''
63-
native_call_prep += current_prep
64-
65-
swift_arguments.append(f'{current_argument_details.java_hu_ty} {argument_name}')
66-
native_arguments.append(f'{passed_argument_name}')
67-
68-
current_replacement = current_replacement.replace('swift_arguments', ', '.join(swift_arguments))
69-
current_replacement = current_replacement.replace('native_arguments', ', '.join(native_arguments))
70-
current_replacement = current_replacement.replace('/* NATIVE_CALL_PREP */', native_call_prep)
71-
72-
struct_methods += '\n' + current_replacement + '\n'
73-
74-
opaque_struct_file = self.template.replace('class OpaqueStructName {', f'class {swift_struct_name} {{')
75-
opaque_struct_file = opaque_struct_file.replace('var cOpaqueStruct: OpaqueStructType?',
76-
f'var cOpaqueStruct: {struct_name}?')
77-
opaque_struct_file = opaque_struct_file.replace('self.cOpaqueStruct = OpaqueStructType()',
78-
f'self.cOpaqueStruct = {struct_name}_new()')
79-
opaque_struct_file = method_template_regex.sub(f'\g<1>{struct_methods}\g<3>', opaque_struct_file)
39+
for current_lambda in struct_details.lambdas:
40+
current_lambda_name = current_lambda['name']
41+
42+
current_native_callback_replacement = native_callback_template
43+
current_native_callback_replacement = current_native_callback_replacement.replace('func methodNameCallback(', f'func {current_lambda_name}Callback(')
44+
current_native_callback_replacement = current_native_callback_replacement.replace('instance: TraitName', f'instance: {swift_struct_name}')
45+
current_native_callback_replacement = current_native_callback_replacement.replace('instance.callbackName(', f'instance.{current_lambda_name}(')
46+
47+
current_swift_callback_replacement = swift_callback_template
48+
current_swift_callback_replacement = current_swift_callback_replacement.replace('func methodName(', f'func {current_lambda_name}(')
49+
50+
instantiation_arguments.append(f'{current_lambda_name}: {current_lambda_name}Callback')
51+
52+
# let's specify the correct return type
53+
swift_raw_return_type = current_lambda['return_type'].swift_raw_type
54+
current_native_callback_replacement = current_native_callback_replacement.replace(') -> Void {', f') -> {swift_raw_return_type} {{')
55+
56+
# let's get the current native arguments, i. e. the arguments we get from C into the native callback
57+
native_arguments = []
58+
swift_callback_arguments = []
59+
swift_argument_string = ''
60+
for current_argument in current_lambda['argument_types']:
61+
native_arguments.append(f'{current_argument.var_name}: {current_argument.swift_raw_type}?')
62+
swift_callback_arguments.append(f'{current_argument.var_name}: {current_argument.var_name}')
63+
if len(native_arguments) > 0:
64+
# add leading comma
65+
swift_argument_string = ', '.join(native_arguments)
66+
native_arguments.insert(0, '')
67+
68+
native_argument_string = ', '.join(native_arguments)
69+
current_native_callback_replacement = current_native_callback_replacement.replace(', native_arguments', native_argument_string)
70+
current_native_callback_replacement = current_native_callback_replacement.replace('swift_callback_arguments', ', '.join(swift_callback_arguments))
71+
current_swift_callback_replacement = current_swift_callback_replacement.replace('swift_arguments', swift_argument_string)
72+
73+
if not current_lambda['is_constant']:
74+
current_native_callback_replacement = current_native_callback_replacement.replace('(pointer: UnsafeRawPointer?', '(pointer: UnsafeMutableRawPointer?')
75+
76+
native_callbacks += '\n' + current_native_callback_replacement + '\n'
77+
swift_callbacks += '\n' + current_swift_callback_replacement + '\n'
78+
79+
trait_file = self.template.replace('class TraitName {', f'class {swift_struct_name} {{')
80+
trait_file = trait_file.replace('var cTrait: TraitType?',
81+
f'var cTrait: {struct_name}?')
82+
trait_file = trait_file.replace('self.cTrait = TraitType(',
83+
f'self.cTrait = {struct_name}(')
84+
trait_file = trait_file.replace('native_callback_instantiation_arguments', ', '.join(instantiation_arguments))
85+
trait_file = native_callback_template_regex.sub(f'\g<1>{native_callbacks}\g<3>', trait_file)
86+
trait_file = swift_callback_template_regex.sub(f'\g<1>{swift_callbacks}\g<3>', trait_file)
8087

8188
# store the output
8289
output_path = f'{Config.OUTPUT_DIRECTORY_PATH}/traits/{swift_struct_name}.swift'
8390
output_directory = os.path.dirname(output_path)
8491
if not os.path.exists(output_directory):
8592
os.makedirs(output_directory)
8693
with open(output_path, "w") as f:
87-
f.write(opaque_struct_file)
94+
f.write(trait_file)
8895
pass

src/generators/tuple_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def __init__(self) -> None:
1515
def generate_tuple(self, tuple_name, tuple_details):
1616
# method_names = ['openChannel', 'closeChannel']
1717
# native_method_names = ['ChannelHandler_openChannel', 'ChannelHandler_closeChannel']
18-
18+
print(tuple_name)
1919
swift_tuple_name = tuple_name[3:]
2020

2121
mutating_output_file_contents = self.template

src/generators/util_generators/vector_generator.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,41 @@ def __init__(self) -> None:
1212
self.loadTemplate()
1313

1414
def generate_vector(self, vector_name, vector_type_details):
15-
if not vector_type_details.is_primitive:
16-
# TODO: add non-primitive tuple support
17-
return
15+
dimensions = 1
16+
conversion_call = None
17+
is_primitive = vector_type_details.is_primitive
18+
19+
if is_primitive:
20+
swift_primitive = vector_type_details.primitive_swift_counterpart
21+
else:
22+
deepest_iteratee = vector_type_details
23+
while deepest_iteratee.iteratee is not None:
24+
deepest_iteratee = deepest_iteratee.iteratee
25+
dimensions += 1
26+
shallowmost_iteratee = vector_type_details.iteratee
27+
if deepest_iteratee.is_primitive:
28+
swift_primitive = deepest_iteratee.primitive_swift_counterpart
29+
else:
30+
dimensions -= 1
31+
swift_primitive = deepest_iteratee.name
32+
if dimensions > 1:
33+
conversion_call = f'let convertedEntry = {shallowmost_iteratee.name}_to_array(vector: currentEntry)'
34+
1835
mutating_current_vector_methods = self.template
36+
for dim_delta in range(1, dimensions):
37+
mutating_current_vector_methods = mutating_current_vector_methods.replace('[SwiftPrimitive]', '[[SwiftPrimitive]]')
1938
mutating_current_vector_methods = mutating_current_vector_methods.replace('LDKCVec_rust_primitive', vector_name)
39+
40+
if conversion_call is not None:
41+
mutating_current_vector_methods = mutating_current_vector_methods.replace('/* CONVERSION_PREP */', conversion_call)
42+
else:
43+
mutating_current_vector_methods = mutating_current_vector_methods.replace('(convertedEntry)', '(currentEntry)')
44+
45+
if not is_primitive:
46+
mutating_current_vector_methods = mutating_current_vector_methods.replace('/* SWIFT_TO_RUST_START */', '/* SWIFT_TO_RUST_START ')
47+
mutating_current_vector_methods = mutating_current_vector_methods.replace('/* SWIFT_TO_RUST_END */', 'SWIFT_TO_RUST_END */')
48+
2049
mutating_current_vector_methods = mutating_current_vector_methods.replace('SwiftPrimitive',
21-
vector_type_details.primitive_swift_counterpart)
50+
swift_primitive)
51+
2252
self.filled_template += "\n"+mutating_current_vector_methods+"\n"

src/lightning_header_parser.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@ def __init__(self) -> None:
2424
self.type = CTypes.OPAQUE_STRUCT
2525
self.fields = []
2626
self.methods = []
27+
self.lambdas = []
2728
self.constructor_method = None
2829
self.free_method = None
2930
self.is_primitive = False
3031
self.primitive_swift_counterpart = None
32+
self.iteratee = None
3133

3234

3335
class LightningHeaderParser():
@@ -256,18 +258,22 @@ def populate_type_details(self):
256258
elif vec_ty is not None:
257259
# TODO: vector type (each one needs to be mapped)
258260
self.vec_types.add(struct_name)
259-
vector_type_details = None
261+
# vector_type_details = None
262+
vector_type_details = TypeDetails() # iterator type
263+
vector_type_details.type = CTypes.VECTOR
264+
vector_type_details.name = struct_name
265+
260266
if vec_ty in self.type_details:
261-
vector_type_details = self.type_details[vec_ty]
267+
vectored_type_details = self.type_details[vec_ty]
268+
# vector_type_details.name = struct_name
269+
vector_type_details.is_primitive = False
270+
vector_type_details.iteratee = vectored_type_details
262271
else:
263272
# it's a primitive
264-
vector_type_details = TypeDetails()
265-
vector_type_details.name = vec_ty
266-
vector_type_details.type = CTypes.VECTOR
267273
vector_type_details.is_primitive = True
268274
vector_type_details.primitive_swift_counterpart = self.language_constants.c_type_map[vec_ty]
269275
self.type_details[struct_name] = vector_type_details
270-
pass
276+
# pass
271277
elif is_union_enum:
272278
assert (struct_name.endswith("_Tag"))
273279
struct_name = struct_name[:-4]
@@ -280,11 +286,13 @@ def populate_type_details(self):
280286
pass
281287
elif is_unitary_enum:
282288
self.type_details[struct_name].type = CTypes.UNITARY_ENUM
283-
# TODO: unitary enum
289+
self.unitary_enums.add(struct_name)
290+
# todo: unitary enums are to be used as is
284291
pass
285292
elif len(trait_fn_lines) > 0:
286293
self.trait_structs.add(struct_name)
287-
# TODO: trait
294+
lambdas = self.parse_lambda_details(trait_fn_lines)
295+
current_type_detail.lambdas = lambdas
288296
elif struct_name == "LDKTxOut":
289297
# TODO: why is this even a special case? It's Swift, we dgaf
290298
pass
@@ -296,7 +304,7 @@ def populate_type_details(self):
296304
print('irregular byte array struct type:', struct_name)
297305
continue
298306

299-
print('byte array struct:', struct_name)
307+
# print('byte array struct:', struct_name)
300308
self.byte_arrays.add(struct_name)
301309
self.type_details[struct_name].type = CTypes.BYTE_ARRAY
302310

@@ -351,6 +359,30 @@ def populate_type_details(self):
351359
# self.global_methods.add(method_details)
352360
pass
353361

362+
def parse_lambda_details(self, trait_fn_lines):
363+
lambdas = []
364+
for fn_line in trait_fn_lines:
365+
ret_ty_info = swift_type_mapper.map_types_to_swift(fn_line.group(2).strip() + " ret", None, False,
366+
self.tuple_types, self.unitary_enums,
367+
self.language_constants)
368+
is_const = fn_line.group(4) is not None
369+
370+
arg_tys = []
371+
for idx, arg in enumerate(fn_line.group(5).split(',')):
372+
if arg == "":
373+
continue
374+
arg_conv_info = swift_type_mapper.map_types_to_swift(arg, None, False, self.tuple_types,
375+
self.unitary_enums,
376+
self.language_constants)
377+
arg_tys.append(arg_conv_info)
378+
lambdas.append({
379+
'name': fn_line.group(3),
380+
'is_constant': is_const,
381+
'return_type': ret_ty_info,
382+
'argument_types': arg_tys
383+
})
384+
return lambdas
385+
354386
def parse_function_details(self, line, re_match, ret_arr_len, c_call_string):
355387
method_return_type = re_match.group(1)
356388
method_name = re_match.group(2)

0 commit comments

Comments
 (0)