Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions skcuda/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,13 @@ def _fft(x_gpu, y_gpu, plan, direction, scale=None):
raise ValueError('can only compute in-place transform of complex data')

if direction == cufft.CUFFT_FORWARD and \
plan.in_dtype in np.sctypes['complex'] and \
plan.out_dtype in np.sctypes['float']:
plan.in_dtype in [np.dtype(t).type for t in np.typecodes['Complex']] and \
plan.out_dtype in [np.dtype(t).type for t in np.typecodes['Float']]:
raise ValueError('cannot compute forward complex -> real transform')

if direction == cufft.CUFFT_INVERSE and \
plan.in_dtype in np.sctypes['float'] and \
plan.out_dtype in np.sctypes['complex']:
plan.in_dtype in [np.dtype(t).type for t in np.typecodes['Float']] and \
plan.out_dtype in [np.dtype(t).type for t in np.typecodes['Complex']]:
raise ValueError('cannot compute inverse real -> complex transform')

if plan.fft_type in [cufft.CUFFT_C2C, cufft.CUFFT_Z2Z]:
Expand Down
2 changes: 1 addition & 1 deletion skcuda/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ def diff(x_gpu):


# List of available numerical types provided by numpy:
num_types = [np.sctypeDict[t] for t in \
num_types = [np.dtype(t).type for t in \
np.typecodes['AllInteger']+np.typecodes['AllFloat']]

# Numbers of bytes occupied by each numerical type:
Expand Down