Skip to content

Commit f17d55b

Browse files
authored
Fix boolean cast in Fortran code (pyccel#1789)
Add missing cast when creating an array from an object with a different datatype. Fixes pyccel#1785 In order to call the cast on the argument passed to `np.array` (which may be a `InhomogeneousTupleType`) fix the type of the cast functions. Modify `InhomogeneousTupleType.datatype` to return a `FixedSizeType` if the datatypes of all elements are equivalent. Also fix some minor bugs after pyccel#1756
1 parent d44bf3a commit f17d55b

File tree

11 files changed

+64
-193
lines changed

11 files changed

+64
-193
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ All notable changes to this project will be documented in this file.
2424
- #1792 : Fix array unpacking.
2525
- #1795 : Fix bug when returning slices in C.
2626
- #1732 : Fix multidimensional list indexing in Python.
27+
- #1785 : Add missing cast when creating an array of booleans from non-boolean values.
2728

2829
### Changed
2930
- #1720 : functions with the `@inline` decorator are no longer exposed to Python in the shared library.

pyccel/ast/builtins.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -513,11 +513,6 @@ def __init__(self, *args):
513513

514514
# Get possible datatypes
515515
dtypes = [a.class_type.datatype for a in args]
516-
# Extract all dtypes inside any inhomogeneous tuples
517-
while any(isinstance(d, InhomogeneousTupleType) for d in dtypes):
518-
dtypes = [di for d in dtypes for di in ((d_elem.datatype for d_elem in d)
519-
if isinstance(d, InhomogeneousTupleType)
520-
else [d])]
521516
# Create a set of dtypes using the same key for compatible types
522517
dtypes = set((d.primitive_type, d.precision) if isinstance(d, FixedSizeNumericType) else d for d in dtypes)
523518

pyccel/ast/datatypes.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -714,10 +714,14 @@ class InhomogeneousTupleType(ContainerType, TupleType, metaclass = ArgumentSingl
714714
*args : tuple of DataTypes
715715
The datatypes stored in the inhomogeneous tuple.
716716
"""
717-
__slots__ = ('_element_types',)
717+
__slots__ = ('_element_types', '_datatype')
718718

719719
def __init__(self, *args):
720720
self._element_types = args
721+
722+
possible_types = set(t.datatype for t in self._element_types)
723+
dtype = possible_types.pop()
724+
self._datatype = dtype if all(d == dtype for d in possible_types) else self
721725
super().__init__()
722726

723727
def __str__(self):
@@ -754,13 +758,11 @@ def datatype(self):
754758
"""
755759
The datatype of the object.
756760
757-
The datatype of the object.
761+
The datatype of the object. For an inhomogeneous tuple the datatype is the type
762+
of the tuple unless the tuple is comprised of containers which are all based on
763+
compatible data types. In this case one of the underlying types is returned.
758764
"""
759-
possible_types = set(t.datatype for t in self._element_types)
760-
if len(possible_types) == 1:
761-
return possible_types.pop()
762-
else:
763-
return self
765+
return self._datatype
764766

765767
class DictType(ContainerType, metaclass = ArgumentSingleton):
766768
"""

pyccel/ast/numpyext.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,14 @@ class NumpyFloat(PythonFloat):
184184
The argument passed to the function.
185185
"""
186186
__slots__ = ('_rank','_shape','_order','_class_type')
187+
_static_type = NumpyFloat64Type()
187188
name = 'float'
189+
188190
def __init__(self, arg):
189191
self._shape = arg.shape
190192
self._rank = arg.rank
191193
self._order = arg.order
192-
self._class_type = arg.class_type.switch_basic_type(self.static_type())
194+
self._class_type = NumpyNDArrayType(self.static_type()) if self._rank else self.static_type()
193195
super().__init__(arg)
194196

195197
@property
@@ -250,7 +252,7 @@ def __init__(self, arg):
250252
self._shape = arg.shape
251253
self._rank = arg.rank
252254
self._order = arg.order
253-
self._class_type = arg.class_type.switch_basic_type(self.static_type())
255+
self._class_type = NumpyNDArrayType(self.static_type()) if self._rank else self.static_type()
254256
super().__init__(arg)
255257

256258
@property
@@ -277,12 +279,14 @@ class NumpyInt(PythonInt):
277279
The argument passed to the function.
278280
"""
279281
__slots__ = ('_shape','_rank','_order','_class_type')
282+
_static_type = numpy_precision_map[(PrimitiveIntegerType(), PythonInt._static_type.precision)]
280283
name = 'int'
284+
281285
def __init__(self, arg=None, base=10):
282286
self._shape = arg.shape
283287
self._rank = arg.rank
284288
self._order = arg.order
285-
self._class_type = arg.class_type.switch_basic_type(self.static_type())
289+
self._class_type = NumpyNDArrayType(self.static_type()) if self._rank else self.static_type()
286290
super().__init__(arg)
287291

288292
@property
@@ -374,7 +378,10 @@ class NumpyReal(PythonReal):
374378
name = 'real'
375379
def __new__(cls, arg):
376380
if isinstance(arg.dtype, PythonNativeBool):
377-
return NumpyInt(arg)
381+
if arg.rank:
382+
return NumpyInt(arg)
383+
else:
384+
return PythonInt(arg)
378385
else:
379386
return super().__new__(cls, arg)
380387

@@ -452,14 +459,16 @@ class NumpyComplex(PythonComplex):
452459
_real_cast = NumpyReal
453460
_imag_cast = NumpyImag
454461
__slots__ = ('_rank','_shape','_order','_class_type')
462+
_static_type = NumpyComplex128Type()
455463
name = 'complex'
464+
456465
def __init__(self, arg0, arg1 = None):
457466
if arg1 is not None:
458467
raise NotImplementedError("Use builtin complex function not deprecated np.complex")
459468
self._shape = arg0.shape
460469
self._rank = arg0.rank
461470
self._order = arg0.order
462-
self._class_type = arg0.class_type.switch_basic_type(self.static_type())
471+
self._class_type = NumpyNDArrayType(self.static_type()) if self._rank else self.static_type()
463472
super().__init__(arg0)
464473

465474
@property

pyccel/codegen/printing/ccode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1242,7 +1242,7 @@ def get_declare_type(self, expr):
12421242
else:
12431243
errors.report(PYCCEL_RESTRICTION_TODO+' (rank>0)', symbol=expr, severity='fatal')
12441244
elif not isinstance(class_type, CustomDataType):
1245-
dtype = self.find_in_dtype_registry(class_type)
1245+
dtype = self.find_in_dtype_registry(expr.dtype)
12461246
else:
12471247
dtype = self._print(expr.class_type)
12481248

pyccel/codegen/printing/fcode.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1190,12 +1190,17 @@ def _print_NumpyWhere(self, expr):
11901190
return stmt
11911191

11921192
def _print_NumpyArray(self, expr):
1193-
expr_args = (expr.arg,) if isinstance(expr.arg, Variable) else expr.arg
11941193
order = expr.order
1194+
1195+
try :
1196+
cast_func = DtypePrecisionToCastFunction[expr.dtype]
1197+
except KeyError:
1198+
errors.report(PYCCEL_RESTRICTION_TODO, severity='fatal')
1199+
arg = expr.arg if expr.arg.dtype == expr.dtype else cast_func(expr.arg)
11951200
# If Numpy array is stored with column-major ordering, transpose values
11961201
# use reshape with order for rank > 2
11971202
if expr.rank <= 2:
1198-
rhs_code = self._print(expr.arg)
1203+
rhs_code = self._print(arg)
11991204
if expr.arg.order and expr.arg.order != expr.order:
12001205
rhs_code = f'transpose({rhs_code})'
12011206
if expr.arg.rank < expr.rank:
@@ -1205,6 +1210,8 @@ def _print_NumpyArray(self, expr):
12051210
shape_code = ', '.join(self._print(i) for i in expr.shape[::-1])
12061211
rhs_code = f"reshape({rhs_code}, [{shape_code}])"
12071212
else:
1213+
expr_args = (expr.arg,) if isinstance(expr.arg, Variable) else expr.arg
1214+
expr_args = tuple(a if a.dtype == expr.dtype else cast_func(a) for a in expr_args)
12081215
new_args = []
12091216
inv_order = 'C' if order == 'F' else 'F'
12101217
for a in expr_args:

tests/epyccel/modules/python_annotations.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,21 @@ def array_int32_2d_F_add( x:'int32[:,:](order=F)', y:'int32[:,:](order=F)' ):
1212
def array_int_1d_scalar_add( x:'int[:]', a:'int' ):
1313
x[:] += a
1414

15-
def array_real_1d_scalar_add( x:'real[:]', a:'real' ):
15+
def array_float_1d_scalar_add( x:'float[:]', a:'float' ):
1616
x[:] += a
1717

18-
def array_real_2d_F_scalar_add( x:'real[:,:](order=F)', a:'real' ):
18+
def array_float_2d_F_scalar_add( x:'float[:,:](order=F)', a:'float' ):
1919
x[:,:] += a
2020

21-
def array_real_2d_F_add( x:'real[:,:](order=F)', y:'real[:,:](order=F)' ):
21+
def array_float_2d_F_add( x:'float[:,:](order=F)', y:'float[:,:](order=F)' ):
2222
x[:,:] += y
2323

2424
def array_int32_2d_F_complex_3d_expr( x:'int32[:,:](order=F)', y:'int32[:,:](order=F)' ):
2525
from numpy import full, int32
2626
z = full((2,3),5,order='F', dtype=int32)
2727
x[:] = (x // y) * x + z
2828

29-
def array_real_1d_complex_3d_expr( x:'real[:]', y:'real[:]' ):
29+
def array_float_1d_complex_3d_expr( x:'float[:]', y:'float[:]' ):
3030
from numpy import full
3131
z = full(3,5)
3232
x[:] = (x // y) * x + z

tests/epyccel/recognised_functions/test_numpy_funcs.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2160,25 +2160,26 @@ def create_array_tuple_ref(a : 'int[:,:]'):
21602160
tmp_arr = np.ones((3,4), dtype=int)
21612161
assert np.allclose(array_tuple_ref(tmp_arr), create_array_tuple_ref(tmp_arr))
21622162

2163-
@pytest.mark.parametrize( 'language', (
2164-
pytest.param("fortran", marks = pytest.mark.fortran),
2165-
pytest.param("c", marks = [
2166-
pytest.mark.skip(reason="Changing dtype is broken in C. See #1641"),
2167-
pytest.mark.c]
2168-
),
2169-
pytest.param("python", marks = pytest.mark.python)
2170-
)
2171-
)
21722163
def test_array_new_dtype(language):
21732164
def create_float_array_tuple_ref(a : 'int[:,:]'):
21742165
from numpy import array
21752166
b = (a[0,:], a[1,:])
21762167
c = array(b, dtype=float)
21772168
return c
2169+
def create_bool_array_tuple_ref(a : 'int[:,:]'):
2170+
from numpy import array
2171+
b = (a[0,:], a[1,:])
2172+
c = array(b, dtype=bool)
2173+
return c
2174+
21782175
array_float_tuple_ref = epyccel(create_float_array_tuple_ref, language = language)
21792176
tmp_arr = np.ones((3,4), dtype=int)
21802177
assert np.allclose(array_float_tuple_ref(tmp_arr), create_float_array_tuple_ref(tmp_arr))
21812178

2179+
array_bool_tuple_ref = epyccel(create_float_array_tuple_ref, language = language)
2180+
tmp_arr = np.ones((3,4), dtype=int)
2181+
assert np.allclose(array_bool_tuple_ref(tmp_arr), create_bool_array_tuple_ref(tmp_arr))
2182+
21822183
@pytest.mark.parametrize( 'language', (
21832184
pytest.param("fortran", marks = pytest.mark.fortran),
21842185
pytest.param("c", marks = [

tests/epyccel/test_array_as_func_args.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@ def array_int_1d_scalar_add(x : 'T[:]', a : 'T', x_len : int):
3131

3232
assert np.array_equal( x1, x2 )
3333

34-
def test_array_real_1d_scalar_add(language):
35-
@template('T', ['float32', 'double'])
36-
def array_real_1d_scalar_add(x : 'T[:]', a : 'T', x_len : int):
34+
def test_array_float_1d_scalar_add(language):
35+
@template('T', ['float32', 'float'])
36+
def array_float_1d_scalar_add(x : 'T[:]', a : 'T', x_len : int):
3737
for i in range(x_len):
3838
x[i] += a
39-
f1 = array_real_1d_scalar_add
39+
f1 = array_float_1d_scalar_add
4040
f2 = epyccel(f1, language=language)
4141

4242
for t in float_types:
@@ -92,13 +92,13 @@ def array_int_2d_scalar_add( x : 'T[:,:]', a : 'T', d1 : int, d2 : int):
9292

9393
assert np.array_equal( x1, x2 )
9494

95-
def test_array_real_2d_scalar_add(language):
96-
@template('T', ['float32', 'double'])
97-
def array_real_2d_scalar_add(x : 'T[:,:]', a : 'T', d1 : int, d2 : int):
95+
def test_array_float_2d_scalar_add(language):
96+
@template('T', ['float32', 'float'])
97+
def array_float_2d_scalar_add(x : 'T[:,:]', a : 'T', d1 : int, d2 : int):
9898
for i in range(d1):
9999
for j in range(d2):
100100
x[i, j] += a
101-
f1 = array_real_2d_scalar_add
101+
f1 = array_float_2d_scalar_add
102102
f2 = epyccel(f1, language=language)
103103

104104
for t in float_types:

0 commit comments

Comments
 (0)