-
Notifications
You must be signed in to change notification settings - Fork 63
[BLOG]: Supporting dask arrays in scipy via the Python Array API standard #904
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
500994e
bb35e68
5e5d493
f4d1dbe
f0dd049
035831f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,225 @@ | ||||||||||||||
| --- | ||||||||||||||
| title: 'Supporting dask arrays in scipy via the Python Array API standard' | ||||||||||||||
| authors: [thomas-li] | ||||||||||||||
| published: May 26, 2025 | ||||||||||||||
| description: 'Extending Array API Standard support in scipy to distributed arrays via dask.array' | ||||||||||||||
| category: [Array API, PyData ecosystem] | ||||||||||||||
| featuredImage: | ||||||||||||||
| src: /posts/dask-array-api-scipy/scipy_logo_img_featured.png | ||||||||||||||
| alt: 'Scipy logo' | ||||||||||||||
| hero: | ||||||||||||||
| imageSrc: /posts/dask-array-api-scipy/scipy_logo_img_hero.png | ||||||||||||||
| imageAlt: 'Scipy logo' | ||||||||||||||
| --- | ||||||||||||||
|
|
||||||||||||||
| In this post, I describe my journey getting SciPy to work with Dask arrays natively via the array API and the current | ||||||||||||||
| limitations and future outlook. | ||||||||||||||
|
Comment on lines
+15
to
+16
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
|
|
||||||||||||||
| ## Introduction: A quick refresher of the Python Array API standard | ||||||||||||||
|
|
||||||||||||||
| For those unfamiliar, the [Python Array API standard](https://data-apis.org/array-api/latest/API_specification/), | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I think it'll be nicer to start with a simpler statement |
||||||||||||||
| is a specification aimed at unifying the various APIs of different array libraries (e.g. Numpy, PyTorch, JAX, Dask, etc.). | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I see there are several different capitalizations for various libraries through the blog. Could you please do a find and replace to have one style for all? I think these are the capitalizations: NumPy, SciPy, Dask, PyTorch, pandas, JAX, and CuPy, unless you're referring to the API, in which case it's all lowercase and presented as inline code like |
||||||||||||||
| While many libraries (e.g. JAX, cupy), implement a numpy-like interface, it is challenging (and sometimes undesirable) | ||||||||||||||
| for some libaries to support all the features/quirks of numpy, especially for those that run on hardware such as GPUs/TPUs | ||||||||||||||
| (e.g. MLX, cupy). Because of this, many array libraries subtly differ from numpy in various way, making it impossible for | ||||||||||||||
| users to treat arbritrary array objects as numpy arrays via duck typing. | ||||||||||||||
|
|
||||||||||||||
| Today, [array api support](https://scipy.github.io/devdocs/dev/api-dev/array_api.html) in scipy has progressed a long | ||||||||||||||
| way since mid 2023 when array API support was first experimentally introduced within the libary. While the array API | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
| still remains experimental (one has to set `SCIPY_ARRAY_API=1` to use it), several modules (e.g. `scipy.special`, | ||||||||||||||
| `scipy.stats`), have either been fully or mostly ported to support the array API standard and successfully work with | ||||||||||||||
| input arrays from `PyTorch`, `cupy` in addition to `numpy`, allowing users to speed up their scipy code by passing in | ||||||||||||||
| e.g. GPU arrays to scipy functions that support the array API (see this [blog post](https://labs.quansight.org/blog/scipy-array-api) | ||||||||||||||
| for more details). | ||||||||||||||
|
Comment on lines
+31
to
+33
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Having clear descriptions for where the links lead to is more accessible for readers. |
||||||||||||||
|
|
||||||||||||||
| Now that a good chunk of scipy supports Array API inputs, one interesting area to study is whether array API compatible | ||||||||||||||
| functions in scipy can support lazy input such as dask or jax arrays. This is one area that has been purposefully left out | ||||||||||||||
| of the current array API spec as an open question, and investigating the feasibility of using lazy arrays via array API | ||||||||||||||
| will help us better evolve the spec to help users get the most performance benefits out of these libraries. In addition, | ||||||||||||||
| supporting dask (the focus of this blog post) in scipy, will allow users to take advantage of dask's out-of-core | ||||||||||||||
| capabilities for handling larger than RAM datasets, and distributed computing capabilities for scaling code to multiple | ||||||||||||||
| compute nodes. | ||||||||||||||
|
|
||||||||||||||
| ## Supporting `dask.array` within scipy | ||||||||||||||
|
|
||||||||||||||
| Like many other array libraries (e.g. PyTorch), `dask.array` is not fully array API compatible out of the box. | ||||||||||||||
| To work around these differences, scipy uses the [array-api-compat](https://github.com/data-apis/array-api-compat) | ||||||||||||||
| libary, which acts as a portability layer that wraps functions in various array namespaces to make them array API | ||||||||||||||
| compatible. | ||||||||||||||
|
|
||||||||||||||
| The wrapping done for dask can be found [here](https://github.com/data-apis/array-api-compat/pull/76). | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Per my previous comment, adding more informative link description |
||||||||||||||
| While this resolved many incompabilities of existing `dask.array` functions with the array API standard, we were still | ||||||||||||||
| unable to implement some array API functionality since some functions (mostly in the linear algebra extension) | ||||||||||||||
| were unimplemented by dask. | ||||||||||||||
| (a list of failures can be found [here](https://github.com/data-apis/array-api-compat/blob/main/dask-xfails.txt)) | ||||||||||||||
|
|
||||||||||||||
| With this, we were now ready to begin testing of dask arrays in scipy, which turned out to be a mixed bag initially. | ||||||||||||||
| While some functions worked perfectly with zero code change, others crashed completely due to incompabilities with | ||||||||||||||
| dask's lazy design or resulted in wrong answers. | ||||||||||||||
|
|
||||||||||||||
| After going through the failures, we were able to categorize them into 3 main areas. | ||||||||||||||
|
|
||||||||||||||
| ### Dask failure modes | ||||||||||||||
|
|
||||||||||||||
| 1. Data-dependent output shapes | ||||||||||||||
|
|
||||||||||||||
| For the uninitiated [data-dependent output shapes](https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html) | ||||||||||||||
| occur when the shape of the resulting array from an array API operation is unknown, because it depends on the | ||||||||||||||
| values of an input array. | ||||||||||||||
|
|
||||||||||||||
| For example, when calling the [`unique_values`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.unique_values.html) | ||||||||||||||
| function on an input array, the length of the output array will be dependent on the number of unique values appearing | ||||||||||||||
| in the input array. Another common case of data-dependent output shapes, and the one we see the most in scipy, comes | ||||||||||||||
| from the indexing with a boolean mask in numpy. When writing idiomatic numpy code, it is common to use a boolean mask | ||||||||||||||
| for selecting a subset of items via `__getitem__` (i.e. `x[mask]`), or to set items according to the mask (i.e. `x[mask] = y`). | ||||||||||||||
| Because the number of items to get/sets depends on the number of True values in the mask, this is a data-dependent operation. | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
|
|
||||||||||||||
| While some libraries like numpy or PyTorch can handle data-dependent output shapes, jitted/lazy libraries such as dask, | ||||||||||||||
| jax, or MLX do not support these operations, and it is necessary to rewrite functions to avoid these operations if | ||||||||||||||
| possible. | ||||||||||||||
|
|
||||||||||||||
| In scipy, the lack of data-dependent output shapes makes up the majority of failures for dask, so fixing this issue | ||||||||||||||
| is essential to have good support for dask arrays in scipy. | ||||||||||||||
|
|
||||||||||||||
| ** Avoiding data-dependent output shapes ** | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Currently, the |
||||||||||||||
|
|
||||||||||||||
| In order to avoid using data-dependent output shapes when setting array elements, | ||||||||||||||
| one can use the [where](https://data-apis.org/array-api/latest/API_specification/generated/array_api.where.html) | ||||||||||||||
| function from the array API spec. | ||||||||||||||
|
|
||||||||||||||
| e.g. | ||||||||||||||
|
|
||||||||||||||
| ```diff | ||||||||||||||
| - x[mask] = y | ||||||||||||||
| + xp = array_namespace(x) | ||||||||||||||
| + x = xp.where(mask, y, x) | ||||||||||||||
| ``` | ||||||||||||||
|
|
||||||||||||||
| Avoiding data-dependent output shapes in other cases (e.g. `__getitem__` with a boolean mask) is more non-trivial, | ||||||||||||||
| and remains an open question to be resolved. | ||||||||||||||
|
|
||||||||||||||
| 2. Use of non array API features | ||||||||||||||
|
|
||||||||||||||
| A good chunk of tests also failed due to the use of operations that were not part | ||||||||||||||
| of the Array API. | ||||||||||||||
|
|
||||||||||||||
| One example of this was in modules using C routines, that would rely | ||||||||||||||
| on implicit conversions to numpy via the `__array__` dunder when calling | ||||||||||||||
| a numpy function on an input array. This didn't always work for dask, as sometimes dask would | ||||||||||||||
| call its own implementation of a numpy function via [NEP18 dispatching](https://numpy.org/neps/nep-0018-array-function-protocol.html), | ||||||||||||||
| leading to a crash later on when a dask array was passed to a C extension module that expected numpy arrays. | ||||||||||||||
| Fortunately, the fix to this issue was simple enough - by manually casting to numpy via `np.asarray`, we were able to | ||||||||||||||
| prevent NEP18 dispatching, and make the previously failing tests pass. | ||||||||||||||
|
|
||||||||||||||
| Another failure case occured due to lack of support for numpy specific keywords in functions. | ||||||||||||||
| Some scipy modules (e.g. `scipy.ndimage`) allow a user to | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems incomplete? |
||||||||||||||
|
|
||||||||||||||
| 3. Miscellaneous failures (e.g. spurious warnings, incompatible/wrong tests) | ||||||||||||||
|
|
||||||||||||||
| Finally, a small portion of tests failed due to miscellaneous issues such as `RuntimeWarning`s emitted by | ||||||||||||||
| dask for NaNs/invalid values in data, and bad interactions with other libraries tested in the test suite | ||||||||||||||
| (e.g. matplotlib). | ||||||||||||||
|
|
||||||||||||||
| ### Current support | ||||||||||||||
|
|
||||||||||||||
| | Module | Status | Notes | | ||||||||||||||
| | ------------------- | ------ | -------------------------------------------------------------------------------------------- | | ||||||||||||||
| | scipy.cluster | 🚧 | | | ||||||||||||||
| | scipy.constants | ✅ | | | ||||||||||||||
| | scipy.datasets | ✅ | | | ||||||||||||||
| | scipy.differentiate | ❌ | | | ||||||||||||||
| | scipy.fft | ✅ | | | ||||||||||||||
| | scipy.io | ✅ | | ||||||||||||||
| | scipy.integrate | ❌ | | ||||||||||||||
| | scipy.linalg | ✅ | All current Array API compatible functions pass. Not all functions are dispatched right now. | | ||||||||||||||
| | scipy.optimize | ❌ | | ||||||||||||||
| | scipy.ndimage | 🚧 | | ||||||||||||||
| | scipy.signal\* | 🚧 | | ||||||||||||||
| | scipy.special | ✅ | | | ||||||||||||||
| | scipy.stats\* | 🚧 | | | ||||||||||||||
|
|
||||||||||||||
| `*` - Some public API functions/methods in this module have not yet been ported to the Array API standard. | ||||||||||||||
| (Status refers to the status of dask.array with ) | ||||||||||||||
| See [here](https://scipy.github.io/devdocs/dev/api-dev/array_api.html#currently-supported-functionality) | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
| for a list of supported functions/methods. | ||||||||||||||
|
|
||||||||||||||
| As of today, the `scipy.fft/special/stats` modules have the best support for dask arrays today, and are able to | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
| output lazy dask arrays without forcing computation, as they can be fully expressed using array API operations. | ||||||||||||||
|
|
||||||||||||||
| While dask arrays are mostly supported in modules such as `scipy.ndimage/signal`, the bulk of the computation in those | ||||||||||||||
| modules is done by C routines (which requires a conversion to numpy that forces computation of the lazy dask array early). | ||||||||||||||
|
|
||||||||||||||
| In the next section, we will take a look more closely at how array API compatibility enables better performance with | ||||||||||||||
| dask arrays within the `scipy.stats` module. | ||||||||||||||
|
|
||||||||||||||
| ## Example | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Really like this! |
||||||||||||||
|
|
||||||||||||||
| We'll now explore the lazy capabilities of dask in `scipy.stats` by doing some statistical tests | ||||||||||||||
| on the NYC Taxi Dataset. In this basic analysis, let's do a t-test to | ||||||||||||||
| check if fares for trips with multiple passengers differes from fares for trips with just 1 passenger. | ||||||||||||||
|
|
||||||||||||||
| Our null and alternate hypotheses are thus: | ||||||||||||||
|
|
||||||||||||||
| $H_0$: The average fare for trips with multiple passengers is the same as the average fare for trips with a single passenger. | ||||||||||||||
|
|
||||||||||||||
| $H_a$: The average fare for trips with multiple passengers different from the average fare for trips with a single passenger. | ||||||||||||||
|
|
||||||||||||||
| We'll perform this test at a significance level of $\alpha = 0.05$ | ||||||||||||||
|
|
||||||||||||||
| First, let's load in our data into a dask dataframe. We also set the | ||||||||||||||
| environment variable `SCIPY_ARRAY_API=1` to opt in to scipy's array API capabilities. | ||||||||||||||
|
|
||||||||||||||
| ```python | ||||||||||||||
| %env SCIPY_ARRAY_API=1 | ||||||||||||||
|
|
||||||||||||||
| import dask.dataframe as dd | ||||||||||||||
|
|
||||||||||||||
| ddf = dd.read_parquet( | ||||||||||||||
| # Original data found here | ||||||||||||||
| #"https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2025-01.parquet", | ||||||||||||||
| "yellow_tripdata_2025-01.parquet", | ||||||||||||||
| ) | ||||||||||||||
| ``` | ||||||||||||||
|
|
||||||||||||||
| ```python | ||||||||||||||
| import scipy.stats as stats | ||||||||||||||
|
|
||||||||||||||
| onepass_fares = ddf[ddf["passenger_count"] == 1.0]["fare_amount"].to_dask_array().compute_chunk_sizes() | ||||||||||||||
| multpass_fares = ddf[ddf["passenger_count"] > 1.0]["fare_amount"].to_dask_array().compute_chunk_sizes() | ||||||||||||||
| res = stats.ttest_ind(a=onepass_fares, b=multpass_fares, equal_var=False) | ||||||||||||||
| print(f"T-statistic: {res.statistic.compute()}") | ||||||||||||||
| print(f"P-value: {res.pvalue.compute()}") | ||||||||||||||
|
|
||||||||||||||
| T-statistic: -5.5699390653688985 | ||||||||||||||
| P-value: 2.5485619336211492e-08 | ||||||||||||||
| ``` | ||||||||||||||
|
|
||||||||||||||
| From this p-value, we can reject our null hypothesis that the average fare for trips with one passenger is the same as the average fare for trips with multiple passengers. | ||||||||||||||
|
|
||||||||||||||
| While we weren't entirely able to avoid computation in the middle (dask still struggles with unknown shapes which we get through our boolean masking on the dataframe), we were able to entirely keep the computation in dask. This is a big improvement over the pre-Array API behavior where the input dask arrays would be cast to numpy arrays (forcing computation and storage of intermediate results in one worker which can lead to performance degredation and out-of-memory errors) | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
|
|
||||||||||||||
| ## Future Work | ||||||||||||||
|
|
||||||||||||||
| While we've made great progress in enabling support for `dask.array` within scipy, a lot of work remains | ||||||||||||||
| to be done to fully enable scipy support for dask arrays. In particular, we'd like to circle back and fix modules, | ||||||||||||||
| such as `scipy.integrate` and `scipy.differentiate` that were skipped in the initial port of array API compatible | ||||||||||||||
| modules to support dask.array. | ||||||||||||||
|
|
||||||||||||||
| Looking forward, we'd also like to enable `dask.array` support via the Array API in other Array API | ||||||||||||||
| compatible libraries, most notably scikit-learn. A previous | ||||||||||||||
| [attempt](https://github.com/scikit-learn/scikit-learn/pull/28588) to add array API support within scikit-learn stalled | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
| despite initial successes e.g. with getting functions like `sklearn.model_selection.train_test_split` working with dask | ||||||||||||||
| due to poor/missing support for features that scikit-learn used heavily such as data-dependent output shapes and sorting. | ||||||||||||||
| Given the lessons learned from getting dask to work with scipy, it should be possible to revisit that PR and support using | ||||||||||||||
| dask with more of the scikit-learn API via the Array API specification. | ||||||||||||||
|
|
||||||||||||||
| ## Acknowledgments | ||||||||||||||
|
|
||||||||||||||
| I'd like to thank [Ralf Gommers](https://github.com/rgommers) for introducing me to the Array API | ||||||||||||||
| standard and guiding me on contributing to the adoption of the Array API standard. I'd also like to thank | ||||||||||||||
| [Lucas Colley](https://github.com/lucascolley), [Guido Imperiale](https://github.com/crusaderky), and | ||||||||||||||
| [Aaron Meurer](https://github.com/asmeurer) for providing feedback and reviewing my PRs to | ||||||||||||||
| scipy and array-api-compat. | ||||||||||||||
|
|
||||||||||||||
| This work was supported by a grant from NASA to Pandas, scikit-learn, SciPy and | ||||||||||||||
| NumPy under the NASA ROSES 2020 program. | ||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just noting that we'll need to update this before merging. :)