Skip to content

Conversation

@askorikov
Copy link

This PR adds support for CUDA backend of ASTRA, including the support for CuPy inputs (partially addressing #701). As you can see, the code becomes quite a bit more messy, but that is to cover several scenarios at the same time:

  1. CPU input + CPU execution
  2. CPU input + GPU execution (either via setting engine or projector_type to cuda)
  3. GPU input + GPU execution (including zero-copy GPU data exchange, which is only supported for 3D data in ASTRA at the moment, so we need to do a hack by representing 2D geometry as 3D with one slice).

With this, I have the following observations/questions:

  1. The adjoint for the CUDA backend of ASTRA is much less matched, so the tests produce >10% relative error. We do expect the adjoint to be mismatched, but this error is still way too high. I will be looking at that on the ASTRA side in the coming time. In the meantime, the tests with CUDA backend are skipped.
  2. Do we need to assert that the input is on CPU for cpu engine and on GPU for cuda engine?
  3. In principle, this operator almost works with JAX on both CPU and GPU, except for ensuring devices appropriately. What is now the ambition for adding JAX support in PyLops operators?
  4. The default projector_type now (strip) favors accuracy at the expense of performance. For experimental data the accuracy difference is usually not noticeable, so the faster linear projector is more commonly used. It's also the only projector available in CUDA backend. Let me know what you think fits the expected audience better.
  5. ASTRA only supports float32 dtype internally at the moment. How do we go about it?

@askorikov askorikov force-pushed the update-astra-integration branch from 00fe264 to e3996ca Compare October 16, 2025 12:40
@mrava87
Copy link
Collaborator

mrava87 commented Oct 18, 2025

Thanks @askorikov!

Let me reply to the questions below and then I will do a more general review of the code.

This PR adds support for CUDA backend of ASTRA, including the support for CuPy inputs (partially addressing #701). As you can see, the code becomes quite a bit more messy, but that is to cover several scenarios at the same time:

  1. CPU input + CPU execution
  2. CPU input + GPU execution (either via setting engine or projector_type to cuda)
  3. GPU input + GPU execution (including zero-copy GPU data exchange, which is only supported for 3D data in ASTRA at the moment, so we need to do a hack by representing 2D geometry as 3D with one slice).

It is kind of expected that the code becomes a bit more complicated for operators having to handle multiple backend when it's not a pure numpy/cupy switch, so in principle I have no problem.

So far PyLops' philosophy has been that we don't want operators to be magic black boxes so we want users to always provide numpy arrays when the operator works on CPU and cupy (or Jax) arrays for GPUs. In special cases where you cannot solve the entire problem on the GPU, we have an auxiliary operator called ToCupy that can be chained like any other PyLops operator and can be used to lift data in and out of the GPU as needed. An example for your case would be, split the CT operator into a stack of N CT operators, each with a portion of angles, and put them into a VStack; this way, even if the entire data does not fit the GPU one, one can still apply the operator on the GPU and then move the partial data out prior to moving to the next bit. Now, if your GPU operator does fancier things (like streaming), I am happy to keep option 2 but if you just move everything to the GPU and then apply the operator, probably having options 1 and 3 is enough, and more in line with the rest of PyLops operators?

With this, I have the following observations/questions:

  1. The adjoint for the CUDA backend of ASTRA is much less matched, so the tests produce >10% relative error. We do expect the adjoint to be mismatched, but this error is still way too high. I will be looking at that on the ASTRA side in the coming time. In the meantime, the tests with CUDA backend are skipped.

Well, in general we strive for forward-adjoint pairs that pass the dot test, for fp64 quite tighly (say atol=1e-6) and for fp32 of course much less.... if you know that your operator isn't perfectly matched we can still add it and add a test with the threshold you expect, I would prefer that to skip (or at least if you skip it I would like the tests to run both the forward and adjoint somehow to make sure we test they at least run 😄

  1. Do we need to assert that the input is on CPU for cpu engine and on GPU for cuda engine?

See above; in general, we do not assert but we have this as a convention... or in other words, if one expects to pass a CuPy array the other input parameters of the operator should also be passed as CuPy arrays - though I don't think this applies to CT2D as you dispatch everything to the Astra operator....

  1. In principle, this operator almost works with JAX on both CPU and GPU, except for ensuring devices appropriately. What is now the ambition for adding JAX support in PyLops operators?

We do not aim a 100% coverage with JAX - see this tables https://pylops.readthedocs.io/en/latest/gpu.html#supported-operators. So if you support it great, if not we just need to make sure the row for CT2D is updated accordingly.

  1. The default projector_type now (strip) favors accuracy at the expense of performance. For experimental data the accuracy difference is usually not noticeable, so the faster linear projector is more commonly used. It's also the only projector available in CUDA backend. Let me know what you think fits the expected audience better.

Mmh I guess this was me making this choice... I trust your judgement of what you think is best. Maybe we can just follow the ASTRA default?

  1. ASTRA only supports float32 dtype internally at the moment. How do we go about it?

I think this is fine. I would suggest we do something like that

xdtype = x.dtype
...
y = y.astype(x)
return y

so when we chain operators if one wants to use fp64 we don't break the chain by all of a sudden passing out a fp32. We can add a note to the doc saying that even if you pass fp64 the internal operations are done in fp32.

Copy link
Collaborator

@mrava87 mrava87 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice addition!

Overall the code changes look great, just left a few minor suggestions 😄

@askorikov askorikov force-pushed the update-astra-integration branch from baa4d61 to 0a5666d Compare October 22, 2025 14:59
@askorikov askorikov marked this pull request as draft October 22, 2025 15:05
* Separate basic functionality and adjointness tests, disable adjointness tests for CUDA backend of ASTRA
* Add tests for data that is not natively compatible with ASTRA
* Remove projector_type tests (too messy for the purpose, and CUDA backend doesn't support them at all)
* Use fixture pattern
* Probably related to float32 dtype used by ASTRA
2.2+ is needed for NumPy 2 support, 2.3+ for GPU data exchange support
@askorikov
Copy link
Author

A couple more questions:

  1. I've added casting of the output to the dtype of the input (potentially with a warning), but what is the semantics of dtype argument in the __init__? Does it need to force the dtype as well?
  2. I forgot that for GPU input we also require the input to be contiguous. Now I added this check, but when testing I discovered that the .dot method of LinearOperator already (implicitly) makes the input contiguous here, potentially at the expense of copying the array:
    x = x.ravel()
    Is it an intended behavior? In this case, I can remove the redundant check.
  3. I guess it would be nice to set engine="cuda" when using CuPy. Is checking pylops.utils.deps.cupy_enabled a good way to determine this default?
  4. Is JAX with CPU backend relevant, or do you expect mostly GPU?

@mrava87
Copy link
Collaborator

mrava87 commented Oct 24, 2025

A couple more questions:

  1. I've added casting of the output to the dtype of the input (potentially with a warning), but what is the semantics of dtype argument in the __init__? Does it need to force the dtype as well?

So dtype is there for a bit of an historical reason.. when we started PyLops we were subclassing from scipy.sparse.linalg.LinearOperator and there dtype was mandatory. In principle we use it to keep track of the overall type of an operator when we chain/combine multiple operators, eg:

d = np.ones(10)
D1 = Diagonal(d.astype(np.float32), dtype="float32")
D2 = Diagonal(d*2, dtype="float64")
D = D1 @ D2
print(D)
> <10x10 _ProductLinearOperator with dtype=float64>

but in practice we are not so strict, in the sense that the below does not really respect the dtype of the operator....

d = np.ones(10, dtype="float32")
D1 = Diagonal(d, dtype="float32")
D2 = Diagonal(d*2, dtype="float32")
D = D1 @ D2

x = np.ones(10, dtype="float64")
y = D @ x

print(D, y.dtype)
><10x10 _ProductLinearOperator with dtype=float32>
>dtype('float64')

I have been always tempted to make this more strict but i) it would require a major version bump, ii) it would require some conventions to be set (ie if the input and operator do not match in dtype, which one wins) or being very strict raising errors all the time there is no match.. so not sure at this point 😉

  1. I forgot that for GPU input we also require the input to be contiguous. Now I added this check, but when testing I discovered that the .dot method of LinearOperator already (implicitly) makes the input contiguous here, potentially at the expense of copying the array:
    x = x.ravel()

    Is it an intended behavior? In this case, I can remove the redundant check.
    Yes and no. Dot is invoked if you do D @ x but one could also call directly D.matvec(x) (we tend to do that inside solvers for example as it avoids doing a bunch of checks that make dot a bit slower that just invoking the matvec/rmatvec directly.. so if you really need contiguous arrays, I would suggest to keep your internal check 😄
  2. I guess it would be nice to set engine="cuda" when using CuPy. Is checking pylops.utils.deps.cupy_enabled a good way to determine this default?
    You mean if a user does not sets engine="cuda" but then passes a CuPy array? We tend to avoid outsmarting users, we expect users to be smart... so if they do something silly I rather an error is raised than something is changed under the hood for them. So pylops.utils.deps.cupy_enabled (and similar) is something we only use to check if a library is present and import it (or import some method), not for the check you want to do (if I understand what you want to do...)
  3. Is JAX with CPU backend relevant, or do you expect mostly GPU?
    Yeah why not.. we see JAX as a replacement for the pair numpy/cupy since JAX claims to be a library where you write one code and run it on different hardware 😉

Hope this makes sense?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants