Skip to content

Commit d091f79

Browse files
committed
temp array-api-compat changes to be upstreamed
1 parent cdd1c8d commit d091f79

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

array_api_compat/common/_helpers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -521,11 +521,15 @@ def your_function(x, y):
521521
import torch
522522
namespaces.add(torch)
523523
elif is_dask_array(x):
524-
if _use_compat:
524+
# dask main namespace is not array APi compatible
525+
# so return namespace from array-api-compat unless
526+
# explicitly requested otherwise
527+
if _use_compat or _use_compat is None:
525528
_check_api_version(api_version)
526529
from ..dask import array as dask_namespace
527530
namespaces.add(dask_namespace)
528-
else:
531+
elif _use_compat is False:
532+
print("why am i false")
529533
import dask.array as da
530534
namespaces.add(da)
531535
elif is_jax_array(x):

0 commit comments

Comments
 (0)