-
Notifications
You must be signed in to change notification settings - Fork 16
Added generic fallback method to to_device
#2362
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
to_device
# Generic fallback for other types that might need device adaptation | ||
function to_device(device::ClimaComms.AbstractDevice, x) | ||
return Adapt.adapt(ClimaComms.array_type(device), x) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code for this is identical to what is written above, so it doesn't make sense to add this if everything goes through the same thing anyway.
Also, I don't think this fallback should be added, since there could be correctness issue if Adapt.adapt
didn't throw an error, but the object isn't meant to be put onto the GPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code for this is identical to what is written above, so it doesn't make sense to add this if everything goes through the same thing anyway.
Well, not really. What's written above limits to_device
to a union of certain ClimaCore data structures. We could either extend to_device
via a generic fallback (as in my PR), or remove the restrictions, if we want to_device
to be more versatile.
Also, I don't think this fallback should be added, since there could be correctness issue if Adapt.adapt didn't throw an error, but the object isn't meant to be put onto the GPU.
I suppose you are correct to be more cautious, but I submitted this PR due to an issue I faced. Specifically, the full orographic gravity wave pipeline requires loading in an external orography file and doing some preprocessing analysis on this dataset. These steps are done on the CPU.
Now, in one of the tests, see for example the link below, I explicitly move the arrays initialised on the CPU to the GPU via to_device
. These GPU arrays are then used in the GPU orographic gravity wave parameterization. One of the obstacles was the existing to_device
would not work to move instances of ThermodynamicsParameters
to the GPU, and this PR resolves that issue.
I understand that my integral test is not very idiomatic Clima, but that is because I wanted to integrate CPU preprocessing with GPU ClimaAtmos computations into one integral test. But if you have a better solution to this problem, please let me know! :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Aside: @ray-chew We should drop support for the Fields.bycolumn
usage in the linked test above ; I also think we may be able to replace interp_latlong2cg
with the SpaceVaryingInput
utility. )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Aside: @ray-chew We should drop support for the
Fields.bycolumn
usage in the linked test above ; I also think we may be able to replaceinterp_latlong2cg
with theSpaceVaryingInput
utility. )
Yes, the whole bycolumn
, parent
, and etc part of the test is the CPU part I mentioned in my reply to Kevin. Which is the reason for me moving between host and device with to_device
.
In ClimaAtmos #3867 point 3, I mentioned that we should move all these to GPU-friendly code. However, if we can already use the existing machinery in a preprocessing step, the cost of refactoring these right now is too high with too little benefits.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ray-chew
@ph-kev and I are a bit confused by this. The thermo_params
shouldn't need to be moved to the device because it is already isbits.
For example:
julia> function bar(x, s)
x + s.grav
end
bar (generic function with 1 method)
julia> params = ClimaLand.Parameters.LandParameters(Float32).thermo_params
Thermodynamics.Parameters.ThermodynamicsParameters{Float32}(273.16f0, 101325.0f0, 100000.0f0, 1859.0f0, 4181.0f0, 2100.0f0, 2.5008f6, 2.8344f6, 611.657f0, 273.16f0, 273.15f0, 1.0f0, 1000.0f0, 150.0f0, 298.15f0, 6864.8f0, 10513.6f0, 0.2857143f0, 8.31446f0, 0.02897f0, 0.01801528f0, 290.0f0, 220.0f0, 9.81f0, 233.0f0, 1.0f0)
julia> c = CUDA.cu([1.0,2.0])
2-element CuArray{Float32, 1, CUDA.DeviceMemory}:
1.0
2.0
julia> bar.(c, params)
2-element CuArray{Float32, 1, CUDA.DeviceMemory}:
10.81
11.81
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ph-kev @imreddyTeja : Confirming that I can replicate your results where c
is a ClimaCore.Fields.Field (VIJFH layout)
and params
has the same type as above. @ray-chew and I looked over the test setup where this issue popped up - turns out it was related to the Broadcast space mismatch
error since the test problem involves computing subgrid variables on the CPU on a lat-long grid given some source dataset, and then moving them to the GPU : inconsistent spaces meant that the to_device
was being used as a somewhat hacky solution. A better solution seems to be to use the Fields.Field(Fields.field_values(x),S)
and ensuring that the target space S
is always identical (thermo_params
don't need additional manipulation). (His test case now runs on GPU following this change). We can discuss this further, but I'm closing this issue for now.
Thanks @ray-chew @ph-kev @imreddyTeja.
This allows for
to_device
to work with generic data structures.Specifically, this extension allows for the following use case:
where
params
is a parameter struct.