-
Notifications
You must be signed in to change notification settings - Fork 3k
Enable CUDA build for ARM64 #2352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for CUDA-enabled Docker images on ARM64 (aarch64) architecture to support platforms like NVIDIA DGX Spark. It introduces two new workflow jobs for building TensorFlow and PyTorch notebooks with CUDA support on aarch64 and adds these jobs to the dependency chain in the tag-push workflow.
- Adds aarch64-tensorflow-cuda job using CUDA variant
- Adds aarch64-pytorch-cuda12 job using CUDA 12 variant
- Updates tag-push workflow dependencies to include the new aarch64 CUDA jobs
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated no new comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| image: ${{ inputs.image }} | ||
| variant: ${{ inputs.variant }} | ||
| platform: aarch64 | ||
| if: ${{ !contains(inputs.variant, 'cuda') }} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You either have to make it work for pytorch-cuda11, or do not download it here, otherwise it will fail
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added pytorch-cuda11 to keep thing easier to read. (and ppl using GH100 can use this image 😀)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dictcp it seems we can't build it, so you need to revert last commit and make an exception here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
@mathbunnyru I think you need to change the first item in the checklist from
to
|
Enable CUDA build for ARM64
Since there are ARM64 platform like DGX spark, we need CUDA build for ARM64.
(give DGX spark case, given it requires CUDA capacity of 12.1, so latest CUDA is fine)
Checklist (especially for first-time contributors)