Fixes #780: Decoupled device and dtype selections#800
Fixes #780: Decoupled device and dtype selections#800Varun-sai-500 wants to merge 2 commits intokornia:mainfrom
Conversation
|
@sidd-27 Can you please review it, if anything is missing, I will work on it to make it better, thanks! |
sidd-27
left a comment
There was a problem hiding this comment.
Review: Incomplete Implementation
While this PR addresses a legitimate architectural issue, the implementation is incomplete and doesn't deliver what's promised.
✅ What Works
- Successfully decouples device and dtype selection into separate functions
- Clean separation of concerns
❌ Critical Issues
1. Missing Compute Capability Checks
The PR description claims to add "checks where it will check with compute compatibility" but this is not implemented. The code still only checks if device.is_cuda() - the same level of hardware awareness as before.
2. Undocumented BF16 → F16 Change
The code silently switches from DType::BF16 to DType::F16 for CUDA devices without explanation. This:
- Isn't mentioned in the PR description
- May impact model accuracy/performance
- Doesn't solve the original problem (not all GPUs support F16 efficiently either)
3. No Hardware Capability Detection
Issue #780 specifically requests runtime capability checks for GPU compatibility. The new code still assumes all CUDA devices support F16.
Required Changes
The select_dtype function needs actual hardware capability detection:
fn select_dtype(device: &Device) -> DType {
#[cfg(feature = "cuda")]
if device.is_cuda() {
// Check actual compute capability
if let Ok(capability) = device.compute_capability() {
if capability >= (8, 0) { // Ampere+ supports BF16
return DType::BF16;
} else if capability >= (5, 3) { // Maxwell+ supports F16
return DType::F16;
}
}
return DType::F32; // Fallback for older GPUs
}
DType::F32
}Additional Requirements
- Document the BF16 → F16 change and rationale
- Add tests for different GPU capabilities
- Update PR description to match actual implementation
This creates good architectural foundation but needs the core feature - hardware capability detection - to be truly useful.
📝 Description
Fixes/Relates to: #780
🛠️ Changes Made
Incorrect hardware assumption as not all CUDA GPUs support BF16, So I have added checks where it will check with compute compatibility, so it will check before randomly accomodating bf16 to cuda devices
Overall I have added clean code architecture to the device selection and dtype selection
🧪 How Was This Tested?
🕵️ AI Usage Disclosure
Check one of the following:
🚦 Checklist
💭 Additional Context
Add any other context or screenshots about the pull request here.