247
247
'stdio' ,
248
248
"inttypes" ,
249
249
'stdbool' ,
250
- 'assert' ]}
250
+ 'assert' ,
251
+ 'stc/cstr' ]}
251
252
252
253
import_header_guard_prefix = {'stc/hset' : '_TOOLS_SET' ,
253
254
'stc/vec' : '_TOOLS_LIST' ,
@@ -309,7 +310,6 @@ class CCodePrinter(CodePrinter):
309
310
(PrimitiveIntegerType (),8 ) : LiteralString ("%" ) + CMacro ('PRId64' ),
310
311
(PrimitiveIntegerType (),2 ) : LiteralString ("%" ) + CMacro ('PRId16' ),
311
312
(PrimitiveIntegerType (),1 ) : LiteralString ("%" ) + CMacro ('PRId8' ),
312
- StringType () : '%s' ,
313
313
}
314
314
315
315
def __init__ (self , filename , prefix_module = None ):
@@ -1039,7 +1039,7 @@ def _print_Import(self, expr):
1039
1039
f'#include <{ stc_extension_mapping [source ]} .h>\n ' ,
1040
1040
f'#endif // { header_guard } \n \n ' ))
1041
1041
return code
1042
- elif source .startswith ('stc/' ):
1042
+ elif source != 'stc/cstr' and ( source .startswith ('stc/' ) or source in import_header_guard_prefix ):
1043
1043
code = ''
1044
1044
for t in expr .target :
1045
1045
class_type = t .object .class_type
@@ -1130,6 +1130,13 @@ def get_print_format_and_arg(self, var):
1130
1130
except KeyError :
1131
1131
errors .report (f"Printing { var .dtype } type is not supported currently" , severity = 'fatal' )
1132
1132
arg = self ._print (var )
1133
+ elif isinstance (var .dtype , StringType ):
1134
+ if isinstance (var , Variable ):
1135
+ var_obj = self ._print (ObjectAddress (var ))
1136
+ arg = f'cstr_str({ var_obj } )'
1137
+ else :
1138
+ arg = self ._print (var )
1139
+ arg_format = '%s'
1133
1140
else :
1134
1141
try :
1135
1142
arg_format = self .type_to_format [var .dtype ]
@@ -1302,6 +1309,9 @@ def get_c_type(self, dtype):
1302
1309
i_type = f'{ container_type } _{ key_type } _{ val_type } '
1303
1310
self .add_import (Import (f'stc/{ container_type } ' , AsName (VariableTypeAnnotation (dtype ), i_type )))
1304
1311
return i_type
1312
+ elif isinstance (dtype , StringType ):
1313
+ self .add_import (c_imports ['stc/cstr' ])
1314
+ return 'cstr'
1305
1315
else :
1306
1316
key = dtype
1307
1317
@@ -1377,11 +1387,11 @@ def get_declare_type(self, expr):
1377
1387
rank = expr .rank
1378
1388
1379
1389
if rank > 0 :
1380
- if isinstance (expr . class_type , (HomogeneousSetType , HomogeneousListType , DictType )):
1381
- dtype = self .get_c_type (expr . class_type )
1382
- elif isinstance (expr . class_type , CStackArray ):
1383
- return self .get_c_type (expr . class_type .element_type )
1384
- elif isinstance (expr . class_type , (HomogeneousTupleType , NumpyNDArrayType )):
1390
+ if isinstance (class_type , (HomogeneousSetType , HomogeneousListType , DictType , StringType )):
1391
+ dtype = self .get_c_type (class_type )
1392
+ elif isinstance (class_type , CStackArray ):
1393
+ return self .get_c_type (class_type .element_type )
1394
+ elif isinstance (class_type , (HomogeneousTupleType , NumpyNDArrayType )):
1385
1395
if expr .rank > 15 :
1386
1396
errors .report (UNSUPPORTED_ARRAY_RANK , symbol = expr , severity = 'fatal' )
1387
1397
self .add_import (c_imports ['ndarrays' ])
@@ -1699,15 +1709,18 @@ def _print_PyccelArrayShapeElement(self, expr):
1699
1709
c_type = self .get_c_type (arg .class_type )
1700
1710
arg_code = self ._print (ObjectAddress (arg ))
1701
1711
return f'{ c_type } _size({ arg_code } )'
1712
+ elif isinstance (arg .class_type , StringType ):
1713
+ arg_code = self ._print (ObjectAddress (arg ))
1714
+ return f'cstr_size({ arg_code } )'
1702
1715
else :
1703
1716
raise NotImplementedError (f"Don't know how to represent shape of object of type { arg .class_type } " )
1704
1717
1705
1718
def _print_Allocate (self , expr ):
1706
1719
free_code = ''
1707
1720
variable = expr .variable
1708
- if isinstance (variable .class_type , (HomogeneousListType , HomogeneousSetType , DictType )):
1721
+ if isinstance (variable .class_type , (HomogeneousListType , HomogeneousSetType , DictType , StringType )):
1709
1722
if expr .status in ('allocated' , 'unknown' ):
1710
- free_code = f'{ self ._print (Deallocate (variable ))} \n '
1723
+ free_code = f'{ self ._print (Deallocate (variable ))} '
1711
1724
if expr .shape [0 ] is None :
1712
1725
return free_code
1713
1726
size = self ._print (expr .shape [0 ])
@@ -1718,7 +1731,7 @@ def _print_Allocate(self, expr):
1718
1731
elif expr .alloc_type == 'resize' :
1719
1732
return f'{ container_type } _resize({ variable_address } , { size } , { 0 } );\n '
1720
1733
return free_code
1721
- if isinstance (variable .class_type , (NumpyNDArrayType , HomogeneousTupleType )):
1734
+ elif isinstance (variable .class_type , (NumpyNDArrayType , HomogeneousTupleType )):
1722
1735
#free the array if its already allocated and checking if its not null if the status is unknown
1723
1736
if (expr .status == 'unknown' ):
1724
1737
shape_var = DottedVariable (VoidType (), 'shape' , lhs = variable )
@@ -1755,7 +1768,7 @@ def _print_Allocate(self, expr):
1755
1768
raise NotImplementedError (f"Allocate not implemented for { variable .class_type } " )
1756
1769
1757
1770
def _print_Deallocate (self , expr ):
1758
- if isinstance (expr .variable .class_type , (HomogeneousListType , HomogeneousSetType , DictType )):
1771
+ if isinstance (expr .variable .class_type , (HomogeneousListType , HomogeneousSetType , DictType , StringType )):
1759
1772
if expr .variable .is_alias :
1760
1773
return ''
1761
1774
variable_address = self ._print (ObjectAddress (expr .variable ))
@@ -2306,11 +2319,13 @@ def _print_Assign(self, expr):
2306
2319
return self .copy_NumpyArray_Data (expr )
2307
2320
if isinstance (rhs , (NumpyFull )):
2308
2321
return self .arrayFill (expr )
2309
- lhs = self ._print (expr . lhs )
2322
+ lhs_code = self ._print (lhs )
2310
2323
if isinstance (rhs , (PythonList , PythonSet , PythonDict )):
2311
2324
return self .init_stc_container (rhs , expr )
2312
- rhs = self ._print (expr .rhs )
2313
- return f'{ lhs } = { rhs } ;\n '
2325
+ rhs_code = self ._print (rhs )
2326
+ if isinstance (rhs , LiteralString ):
2327
+ rhs_code = f'cstr_lit({ rhs_code } )'
2328
+ return f'{ lhs_code } = { rhs_code } ;\n '
2314
2329
2315
2330
def _print_AliasAssign (self , expr ):
2316
2331
lhs_var = expr .lhs
0 commit comments