Skip to content

Commit 9896991

Browse files
committed
Fix NvvmArch::all_target_features bugs.
It now does a single filter pass over the enum variants, which is simpler and fixes the sorting issue and the incorrect 'f' suffix results. I removed some comments in the `nvvm_arch_all_target_features` test, because they were low-value. There are now better comments within `all_target_features` that explain what's happening. I also remove the comment about PTX forward-compatibility. It was correct but confusing. This function answers the question "what features are available if I'm targeting a particular NvvmArch?" (backwards compatibility). That comment explained "what GPU CCs will this run on?" (forward compatibility). Also update the relevant section in the guide, where the 'f' details were incorrect. And make the terminology more consistent.
1 parent 6020b60 commit 9896991

File tree

2 files changed

+134
-99
lines changed

2 files changed

+134
-99
lines changed

crates/nvvm/src/lib.rs

Lines changed: 121 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -430,73 +430,62 @@ impl NvvmArch {
430430
}
431431
}
432432

433-
/// Get all target features up to and including this architecture.
433+
/// Gets all target features up to and including this architecture. This effectively answers
434+
/// the question "for a given compilation target, what architectural features can be used?"
434435
///
435-
/// # PTX Forward-Compatibility Rules (per NVIDIA documentation):
436+
/// # Examples
436437
///
437-
/// - **No suffix** (compute_XX): PTX is forward-compatible across all future architectures.
438-
/// Example: compute_70 runs on CC 7.0, 8.x, 9.x, 10.x, 12.x, and all future GPUs.
438+
/// ```
439+
/// # use nvvm::NvvmArch;
440+
/// let features = NvvmArch::Compute53.all_target_features();
441+
/// assert_eq!(
442+
/// features,
443+
/// vec!["compute_35", "compute_37", "compute_50", "compute_52", "compute_53"]
444+
/// );
445+
/// ```
439446
///
440-
/// - **Family-specific 'f' suffix** (compute_XXf): Forward-compatible within the same major
441-
/// version family. Supports devices with same major CC and equal or higher minor CC.
442-
/// Example: compute_100f runs on CC 10.0, 10.3, and future 10.x devices, but NOT on 11.x.
443-
///
444-
/// - **Architecture-specific 'a' suffix** (compute_XXa): The code only runs on GPUs of that
445-
/// specific CC and no others. No forward or backward compatibility whatsoever.
446-
/// These features are primarily related to Tensor Core programming.
447-
/// Example: compute_100a ONLY runs on CC 10.0, not on 10.3, 10.1, 9.0, or any other version.
447+
/// # External resources
448448
///
449449
/// For more details on family and architecture-specific features, see:
450450
/// <https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/>
451451
pub fn all_target_features(&self) -> Vec<String> {
452-
let mut features: Vec<String> = if self.is_architecture_variant() {
453-
// 'a' variants: include all available instructions for the architecture
454-
// This means: all base variants up to same version, all 'f' variants with same major and <= minor, plus itself
455-
let base_features: Vec<String> = NvvmArch::iter()
456-
.filter(|arch| {
457-
arch.is_base_variant() && arch.capability_value() <= self.capability_value()
458-
})
459-
.map(|arch| arch.target_feature())
460-
.collect();
461-
462-
let family_features: Vec<String> = NvvmArch::iter()
463-
.filter(|arch| {
464-
arch.is_family_variant()
465-
&& arch.major_version() == self.major_version()
466-
&& arch.minor_version() <= self.minor_version()
467-
})
468-
.map(|arch| arch.target_feature())
469-
.collect();
452+
// All lower-or-equal baseline features are included.
453+
let included_baseline = |arch: &NvvmArch| {
454+
arch.is_base_variant() && arch.capability_value() <= self.capability_value()
455+
};
470456

471-
base_features
472-
.into_iter()
473-
.chain(family_features)
474-
.chain(std::iter::once(self.target_feature()))
457+
// All lower-or-equal-with-same-major-version family features are included.
458+
let included_family = |arch: &NvvmArch| {
459+
arch.is_family_variant()
460+
&& arch.major_version() == self.major_version()
461+
&& arch.minor_version() <= self.minor_version()
462+
};
463+
464+
if self.is_architecture_variant() {
465+
// Architecture-specific ('a' suffix) features include:
466+
// - all lower-or-equal baseline features
467+
// - all lower-or-equal-with-same-major-version family features
468+
// - itself
469+
NvvmArch::iter()
470+
.filter(|arch| included_baseline(arch) || included_family(arch) || arch == self)
471+
.map(|arch| arch.target_feature())
475472
.collect()
476473
} else if self.is_family_variant() {
477-
// 'f' variants: same major version with equal or higher minor version
474+
// Family-specific ('f' suffix) features include:
475+
// - all lower-or-equal baseline features
476+
// - all lower-or-equal-with-same-major-version family features
478477
NvvmArch::iter()
479-
.filter(|arch| {
480-
// Include base variants with same major and >= minor version
481-
arch.is_base_variant()
482-
&& arch.major_version() == self.major_version()
483-
&& arch.minor_version() >= self.minor_version()
484-
})
478+
.filter(|arch| included_baseline(arch) || included_family(arch))
485479
.map(|arch| arch.target_feature())
486-
.chain(std::iter::once(self.target_feature())) // Add the 'f' variant itself
487480
.collect()
488481
} else {
489-
// Base variants: all base architectures from lower or equal versions
482+
// Baseline (no suffix) features include:
483+
// - all lower-or-equal baseline features
490484
NvvmArch::iter()
491-
.filter(|arch| {
492-
arch.is_base_variant() && arch.capability_value() <= self.capability_value()
493-
})
485+
.filter(|arch| included_baseline(arch))
494486
.map(|arch| arch.target_feature())
495487
.collect()
496-
};
497-
498-
features.sort();
499-
features
488+
}
500489
}
501490

502491
/// Create an iterator over all architectures from Compute35 up to and including this one
@@ -777,19 +766,16 @@ mod tests {
777766
fn nvvm_arch_all_target_features() {
778767
use crate::NvvmArch;
779768

780-
// Compute35 only includes itself
781769
assert_eq!(
782770
NvvmArch::Compute35.all_target_features(),
783771
vec!["compute_35"]
784772
);
785773

786-
// Compute50 includes all lower base capabilities
787774
assert_eq!(
788775
NvvmArch::Compute50.all_target_features(),
789776
vec!["compute_35", "compute_37", "compute_50"],
790777
);
791778

792-
// Compute61 includes all lower base capabilities
793779
assert_eq!(
794780
NvvmArch::Compute61.all_target_features(),
795781
vec![
@@ -803,7 +789,6 @@ mod tests {
803789
]
804790
);
805791

806-
// Compute70 includes all lower base capabilities
807792
assert_eq!(
808793
NvvmArch::Compute70.all_target_features(),
809794
vec![
@@ -819,7 +804,6 @@ mod tests {
819804
]
820805
);
821806

822-
// Compute90 includes lower base capabilities
823807
let compute90_features = NvvmArch::Compute90.all_target_features();
824808
assert_eq!(
825809
compute90_features,
@@ -843,9 +827,6 @@ mod tests {
843827
]
844828
);
845829

846-
// Test 'a' variant - includes all available instructions for the architecture.
847-
// This means: all base variants up to same version, no 'f' variants (90 has none), and the
848-
// 'a' variant.
849830
assert_eq!(
850831
NvvmArch::Compute90a.all_target_features(),
851832
vec![
@@ -869,14 +850,9 @@ mod tests {
869850
]
870851
);
871852

872-
// Test compute100a - should include base variants up to 100, and 100f, and itself,
873-
// but NOT 101f or 103f (higher minor).
874853
assert_eq!(
875854
NvvmArch::Compute100a.all_target_features(),
876855
vec![
877-
"compute_100",
878-
"compute_100a",
879-
"compute_100f",
880856
"compute_35",
881857
"compute_37",
882858
"compute_50",
@@ -893,26 +869,39 @@ mod tests {
893869
"compute_87",
894870
"compute_89",
895871
"compute_90",
872+
"compute_100",
873+
"compute_100f",
874+
"compute_100a",
896875
]
897876
);
898877

899-
// Test 'f' variant with 100f
900878
assert_eq!(
901879
NvvmArch::Compute100f.all_target_features(),
902-
// FIXME: this is wrong
903-
vec!["compute_100", "compute_100f", "compute_101", "compute_103"]
880+
vec![
881+
"compute_35",
882+
"compute_37",
883+
"compute_50",
884+
"compute_52",
885+
"compute_53",
886+
"compute_60",
887+
"compute_61",
888+
"compute_62",
889+
"compute_70",
890+
"compute_72",
891+
"compute_75",
892+
"compute_80",
893+
"compute_86",
894+
"compute_87",
895+
"compute_89",
896+
"compute_90",
897+
"compute_100",
898+
"compute_100f",
899+
]
904900
);
905901

906-
// Test compute101a - should include base variants up to 101, and 100f and 101f, and
907-
// itself, but not 103f (higher minor)
908902
assert_eq!(
909903
NvvmArch::Compute101a.all_target_features(),
910904
vec![
911-
"compute_100",
912-
"compute_100f",
913-
"compute_101",
914-
"compute_101a",
915-
"compute_101f",
916905
"compute_35",
917906
"compute_37",
918907
"compute_50",
@@ -929,22 +918,43 @@ mod tests {
929918
"compute_87",
930919
"compute_89",
931920
"compute_90",
921+
"compute_100",
922+
"compute_100f",
923+
"compute_101",
924+
"compute_101f",
925+
"compute_101a",
932926
]
933927
);
934928

935-
// Test 'f' variant with 101f
936929
assert_eq!(
937930
NvvmArch::Compute101f.all_target_features(),
938-
vec!["compute_101", "compute_101f", "compute_103"],
931+
vec![
932+
"compute_35",
933+
"compute_37",
934+
"compute_50",
935+
"compute_52",
936+
"compute_53",
937+
"compute_60",
938+
"compute_61",
939+
"compute_62",
940+
"compute_70",
941+
"compute_72",
942+
"compute_75",
943+
"compute_80",
944+
"compute_86",
945+
"compute_87",
946+
"compute_89",
947+
"compute_90",
948+
"compute_100",
949+
"compute_100f",
950+
"compute_101",
951+
"compute_101f",
952+
]
939953
);
940954

941955
assert_eq!(
942956
NvvmArch::Compute120.all_target_features(),
943957
vec![
944-
"compute_100",
945-
"compute_101",
946-
"compute_103",
947-
"compute_120",
948958
"compute_35",
949959
"compute_37",
950960
"compute_50",
@@ -961,24 +971,43 @@ mod tests {
961971
"compute_87",
962972
"compute_89",
963973
"compute_90",
974+
"compute_100",
975+
"compute_101",
976+
"compute_103",
977+
"compute_120",
964978
]
965979
);
966980

967981
assert_eq!(
968982
NvvmArch::Compute120f.all_target_features(),
969-
// FIXME: this is wrong
970-
vec!["compute_120", "compute_120f", "compute_121"]
971-
);
972-
973-
assert_eq!(
974-
NvvmArch::Compute120a.all_target_features(),
975983
vec![
984+
"compute_35",
985+
"compute_37",
986+
"compute_50",
987+
"compute_52",
988+
"compute_53",
989+
"compute_60",
990+
"compute_61",
991+
"compute_62",
992+
"compute_70",
993+
"compute_72",
994+
"compute_75",
995+
"compute_80",
996+
"compute_86",
997+
"compute_87",
998+
"compute_89",
999+
"compute_90",
9761000
"compute_100",
9771001
"compute_101",
9781002
"compute_103",
9791003
"compute_120",
980-
"compute_120a",
9811004
"compute_120f",
1005+
]
1006+
);
1007+
1008+
assert_eq!(
1009+
NvvmArch::Compute120a.all_target_features(),
1010+
vec![
9821011
"compute_35",
9831012
"compute_37",
9841013
"compute_50",
@@ -995,6 +1024,12 @@ mod tests {
9951024
"compute_87",
9961025
"compute_89",
9971026
"compute_90",
1027+
"compute_100",
1028+
"compute_101",
1029+
"compute_103",
1030+
"compute_120",
1031+
"compute_120f",
1032+
"compute_120a",
9981033
]
9991034
);
10001035
}

guide/src/guide/compute_capabilities.md

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ CudaBuilder::new("kernels")
7474
.unwrap();
7575

7676
// In your kernel code:
77-
#[cfg(target_feature = "compute_60")] // ✓ Pass (older compute capability)
78-
#[cfg(target_feature = "compute_70")] // ✓ Pass (current compute capability)
79-
#[cfg(target_feature = "compute_80")] // ✗ Fail (newer compute capability)
77+
#[cfg(target_feature = "compute_60")] // ✓ Pass (lower base variant)
78+
#[cfg(target_feature = "compute_70")] // ✓ Pass (this base variant))
79+
#[cfg(target_feature = "compute_80")] // ✗ Fail (higher base variant)
8080
```
8181

8282
### Family Suffix ('f')
@@ -99,13 +99,13 @@ CudaBuilder::new("kernels")
9999
.unwrap();
100100

101101
// In your kernel code:
102-
#[cfg(target_feature = "compute_100")] // ✗ Fail (10.0 < 10.1)
103-
#[cfg(target_feature = "compute_101")] // ✓ Pass (equal major, equal minor)
104-
#[cfg(target_feature = "compute_103")] // ✓ Pass (equal major, greater minor)
102+
#[cfg(target_feature = "compute_90")] // ✓ Pass (lower base variant)
103+
#[cfg(target_feature = "compute_100")] // ✓ Pass (lower base variant)
104+
#[cfg(target_feature = "compute_100f")] // ✓ Pass (lower 'f' variant)
105+
#[cfg(target_feature = "compute_101")] // ✓ Pass (this base variant)
105106
#[cfg(target_feature = "compute_101f")] // ✓ Pass (the 'f' variant itself)
106-
#[cfg(target_feature = "compute_100f")] // ✗ Fail (other 'f' variant)
107-
#[cfg(target_feature = "compute_90")] // ✗ Fail (different major)
108-
#[cfg(target_feature = "compute_110")] // ✗ Fail (different major)
107+
#[cfg(target_feature = "compute_103")] // ✗ Fail (higher base variant)
108+
#[cfg(target_feature = "compute_110")] // ✗ Fail (higher base variant)
109109
```
110110

111111
### Architecture Suffix ('a')
@@ -130,12 +130,12 @@ CudaBuilder::new("kernels")
130130
.unwrap();
131131

132132
// In your kernel code:
133-
#[cfg(target_feature = "compute_100a")] // ✓ Pass (the 'a' variant itself)
134-
#[cfg(target_feature = "compute_100")] // ✓ Pass (base variant)
135133
#[cfg(target_feature = "compute_90")] // ✓ Pass (lower base variant)
134+
#[cfg(target_feature = "compute_100")] // ✓ Pass (base variant)
136135
#[cfg(target_feature = "compute_100f")] // ✓ Pass (family variant with same major/minor)
137-
#[cfg(target_feature = "compute_101f")] // ✗ Fail (family variant with higher minor)
138-
#[cfg(target_feature = "compute_110")] // ✗ Fail (higher major version)
136+
#[cfg(target_feature = "compute_100a")] // ✓ Pass (the 'a' variant itself)
137+
#[cfg(target_feature = "compute_101f")] // ✗ Fail (higher family variant)
138+
#[cfg(target_feature = "compute_110")] // ✗ Fail (higher base variant)
139139
```
140140

141141
Note: While the 'a' variant enables all these features during compilation (allowing you to use all available instructions), the generated PTX code will still only run on the exact GPU architecture specified.

0 commit comments

Comments
 (0)