@@ -433,73 +433,62 @@ impl NvvmArch {
433433 }
434434 }
435435
436- /// Get all target features up to and including this architecture.
436+ /// Gets all target features up to and including this architecture. This effectively answers
437+ /// the question "for a given compilation target, what architectural features can be used?"
437438 ///
438- /// # PTX Forward-Compatibility Rules (per NVIDIA documentation):
439+ /// # Examples
439440 ///
440- /// - **No suffix** (compute_XX): PTX is forward-compatible across all future architectures.
441- /// Example: compute_70 runs on CC 7.0, 8.x, 9.x, 10.x, 12.x, and all future GPUs.
441+ /// ```
442+ /// # use nvvm::NvvmArch;
443+ /// let features = NvvmArch::Compute53.all_target_features();
444+ /// assert_eq!(
445+ /// features,
446+ /// vec!["compute_35", "compute_37", "compute_50", "compute_52", "compute_53"]
447+ /// );
448+ /// ```
442449 ///
443- /// - **Family-specific 'f' suffix** (compute_XXf): Forward-compatible within the same major
444- /// version family. Supports devices with same major CC and equal or higher minor CC.
445- /// Example: compute_100f runs on CC 10.0, 10.3, and future 10.x devices, but NOT on 11.x.
446- ///
447- /// - **Architecture-specific 'a' suffix** (compute_XXa): The code only runs on GPUs of that
448- /// specific CC and no others. No forward or backward compatibility whatsoever.
449- /// These features are primarily related to Tensor Core programming.
450- /// Example: compute_100a ONLY runs on CC 10.0, not on 10.3, 10.1, 9.0, or any other version.
450+ /// # External resources
451451 ///
452452 /// For more details on family and architecture-specific features, see:
453453 /// <https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/>
454454 pub fn all_target_features ( & self ) -> Vec < String > {
455- let mut features: Vec < String > = if self . is_architecture_variant ( ) {
456- // 'a' variants: include all available instructions for the architecture
457- // This means: all base variants up to same version, all 'f' variants with same major and <= minor, plus itself
458- let base_features: Vec < String > = NvvmArch :: iter ( )
459- . filter ( |arch| {
460- arch. is_base_variant ( ) && arch. capability_value ( ) <= self . capability_value ( )
461- } )
462- . map ( |arch| arch. target_feature ( ) )
463- . collect ( ) ;
464-
465- let family_features: Vec < String > = NvvmArch :: iter ( )
466- . filter ( |arch| {
467- arch. is_family_variant ( )
468- && arch. major_version ( ) == self . major_version ( )
469- && arch. minor_version ( ) <= self . minor_version ( )
470- } )
471- . map ( |arch| arch. target_feature ( ) )
472- . collect ( ) ;
455+ // All lower-or-equal baseline features are included.
456+ let included_baseline = |arch : & NvvmArch | {
457+ arch. is_base_variant ( ) && arch. capability_value ( ) <= self . capability_value ( )
458+ } ;
473459
474- base_features
475- . into_iter ( )
476- . chain ( family_features)
477- . chain ( std:: iter:: once ( self . target_feature ( ) ) )
460+ // All lower-or-equal-with-same-major-version family features are included.
461+ let included_family = |arch : & NvvmArch | {
462+ arch. is_family_variant ( )
463+ && arch. major_version ( ) == self . major_version ( )
464+ && arch. minor_version ( ) <= self . minor_version ( )
465+ } ;
466+
467+ if self . is_architecture_variant ( ) {
468+ // Architecture-specific ('a' suffix) features include:
469+ // - all lower-or-equal baseline features
470+ // - all lower-or-equal-with-same-major-version family features
471+ // - itself
472+ NvvmArch :: iter ( )
473+ . filter ( |arch| included_baseline ( arch) || included_family ( arch) || arch == self )
474+ . map ( |arch| arch. target_feature ( ) )
478475 . collect ( )
479476 } else if self . is_family_variant ( ) {
480- // 'f' variants: same major version with equal or higher minor version
477+ // Family-specific ('f' suffix) features include:
478+ // - all lower-or-equal baseline features
479+ // - all lower-or-equal-with-same-major-version family features
481480 NvvmArch :: iter ( )
482- . filter ( |arch| {
483- // Include base variants with same major and >= minor version
484- arch. is_base_variant ( )
485- && arch. major_version ( ) == self . major_version ( )
486- && arch. minor_version ( ) >= self . minor_version ( )
487- } )
481+ . filter ( |arch| included_baseline ( arch) || included_family ( arch) )
488482 . map ( |arch| arch. target_feature ( ) )
489- . chain ( std:: iter:: once ( self . target_feature ( ) ) ) // Add the 'f' variant itself
490483 . collect ( )
491484 } else {
492- // Base variants: all base architectures from lower or equal versions
485+ // Baseline (no suffix) features include:
486+ // - all lower-or-equal baseline features
493487 NvvmArch :: iter ( )
494- . filter ( |arch| {
495- arch. is_base_variant ( ) && arch. capability_value ( ) <= self . capability_value ( )
496- } )
488+ . filter ( included_baseline)
497489 . map ( |arch| arch. target_feature ( ) )
498490 . collect ( )
499- } ;
500-
501- features. sort ( ) ;
502- features
491+ }
503492 }
504493
505494 /// Create an iterator over all architectures from Compute35 up to and including this one
@@ -780,19 +769,16 @@ mod tests {
780769 fn nvvm_arch_all_target_features ( ) {
781770 use crate :: NvvmArch ;
782771
783- // Compute35 only includes itself
784772 assert_eq ! (
785773 NvvmArch :: Compute35 . all_target_features( ) ,
786774 vec![ "compute_35" ]
787775 ) ;
788776
789- // Compute50 includes all lower base capabilities
790777 assert_eq ! (
791778 NvvmArch :: Compute50 . all_target_features( ) ,
792779 vec![ "compute_35" , "compute_37" , "compute_50" ] ,
793780 ) ;
794781
795- // Compute61 includes all lower base capabilities
796782 assert_eq ! (
797783 NvvmArch :: Compute61 . all_target_features( ) ,
798784 vec![
@@ -806,7 +792,6 @@ mod tests {
806792 ]
807793 ) ;
808794
809- // Compute70 includes all lower base capabilities
810795 assert_eq ! (
811796 NvvmArch :: Compute70 . all_target_features( ) ,
812797 vec![
@@ -822,7 +807,6 @@ mod tests {
822807 ]
823808 ) ;
824809
825- // Compute90 includes lower base capabilities
826810 let compute90_features = NvvmArch :: Compute90 . all_target_features ( ) ;
827811 assert_eq ! (
828812 compute90_features,
@@ -846,9 +830,6 @@ mod tests {
846830 ]
847831 ) ;
848832
849- // Test 'a' variant - includes all available instructions for the architecture.
850- // This means: all base variants up to same version, no 'f' variants (90 has none), and the
851- // 'a' variant.
852833 assert_eq ! (
853834 NvvmArch :: Compute90a . all_target_features( ) ,
854835 vec![
@@ -872,14 +853,9 @@ mod tests {
872853 ]
873854 ) ;
874855
875- // Test compute100a - should include base variants up to 100, and 100f, and itself,
876- // but NOT 101f or 103f (higher minor).
877856 assert_eq ! (
878857 NvvmArch :: Compute100a . all_target_features( ) ,
879858 vec![
880- "compute_100" ,
881- "compute_100a" ,
882- "compute_100f" ,
883859 "compute_35" ,
884860 "compute_37" ,
885861 "compute_50" ,
@@ -896,26 +872,39 @@ mod tests {
896872 "compute_87" ,
897873 "compute_89" ,
898874 "compute_90" ,
875+ "compute_100" ,
876+ "compute_100f" ,
877+ "compute_100a" ,
899878 ]
900879 ) ;
901880
902- // Test 'f' variant with 100f
903881 assert_eq ! (
904882 NvvmArch :: Compute100f . all_target_features( ) ,
905- // FIXME: this is wrong
906- vec![ "compute_100" , "compute_100f" , "compute_101" , "compute_103" ]
883+ vec![
884+ "compute_35" ,
885+ "compute_37" ,
886+ "compute_50" ,
887+ "compute_52" ,
888+ "compute_53" ,
889+ "compute_60" ,
890+ "compute_61" ,
891+ "compute_62" ,
892+ "compute_70" ,
893+ "compute_72" ,
894+ "compute_75" ,
895+ "compute_80" ,
896+ "compute_86" ,
897+ "compute_87" ,
898+ "compute_89" ,
899+ "compute_90" ,
900+ "compute_100" ,
901+ "compute_100f" ,
902+ ]
907903 ) ;
908904
909- // Test compute101a - should include base variants up to 101, and 100f and 101f, and
910- // itself, but not 103f (higher minor)
911905 assert_eq ! (
912906 NvvmArch :: Compute101a . all_target_features( ) ,
913907 vec![
914- "compute_100" ,
915- "compute_100f" ,
916- "compute_101" ,
917- "compute_101a" ,
918- "compute_101f" ,
919908 "compute_35" ,
920909 "compute_37" ,
921910 "compute_50" ,
@@ -932,22 +921,43 @@ mod tests {
932921 "compute_87" ,
933922 "compute_89" ,
934923 "compute_90" ,
924+ "compute_100" ,
925+ "compute_100f" ,
926+ "compute_101" ,
927+ "compute_101f" ,
928+ "compute_101a" ,
935929 ]
936930 ) ;
937931
938- // Test 'f' variant with 101f
939932 assert_eq ! (
940933 NvvmArch :: Compute101f . all_target_features( ) ,
941- vec![ "compute_101" , "compute_101f" , "compute_103" ] ,
934+ vec![
935+ "compute_35" ,
936+ "compute_37" ,
937+ "compute_50" ,
938+ "compute_52" ,
939+ "compute_53" ,
940+ "compute_60" ,
941+ "compute_61" ,
942+ "compute_62" ,
943+ "compute_70" ,
944+ "compute_72" ,
945+ "compute_75" ,
946+ "compute_80" ,
947+ "compute_86" ,
948+ "compute_87" ,
949+ "compute_89" ,
950+ "compute_90" ,
951+ "compute_100" ,
952+ "compute_100f" ,
953+ "compute_101" ,
954+ "compute_101f" ,
955+ ]
942956 ) ;
943957
944958 assert_eq ! (
945959 NvvmArch :: Compute120 . all_target_features( ) ,
946960 vec![
947- "compute_100" ,
948- "compute_101" ,
949- "compute_103" ,
950- "compute_120" ,
951961 "compute_35" ,
952962 "compute_37" ,
953963 "compute_50" ,
@@ -964,24 +974,43 @@ mod tests {
964974 "compute_87" ,
965975 "compute_89" ,
966976 "compute_90" ,
977+ "compute_100" ,
978+ "compute_101" ,
979+ "compute_103" ,
980+ "compute_120" ,
967981 ]
968982 ) ;
969983
970984 assert_eq ! (
971985 NvvmArch :: Compute120f . all_target_features( ) ,
972- // FIXME: this is wrong
973- vec![ "compute_120" , "compute_120f" , "compute_121" ]
974- ) ;
975-
976- assert_eq ! (
977- NvvmArch :: Compute120a . all_target_features( ) ,
978986 vec![
987+ "compute_35" ,
988+ "compute_37" ,
989+ "compute_50" ,
990+ "compute_52" ,
991+ "compute_53" ,
992+ "compute_60" ,
993+ "compute_61" ,
994+ "compute_62" ,
995+ "compute_70" ,
996+ "compute_72" ,
997+ "compute_75" ,
998+ "compute_80" ,
999+ "compute_86" ,
1000+ "compute_87" ,
1001+ "compute_89" ,
1002+ "compute_90" ,
9791003 "compute_100" ,
9801004 "compute_101" ,
9811005 "compute_103" ,
9821006 "compute_120" ,
983- "compute_120a" ,
9841007 "compute_120f" ,
1008+ ]
1009+ ) ;
1010+
1011+ assert_eq ! (
1012+ NvvmArch :: Compute120a . all_target_features( ) ,
1013+ vec![
9851014 "compute_35" ,
9861015 "compute_37" ,
9871016 "compute_50" ,
@@ -998,6 +1027,12 @@ mod tests {
9981027 "compute_87" ,
9991028 "compute_89" ,
10001029 "compute_90" ,
1030+ "compute_100" ,
1031+ "compute_101" ,
1032+ "compute_103" ,
1033+ "compute_120" ,
1034+ "compute_120f" ,
1035+ "compute_120a" ,
10011036 ]
10021037 ) ;
10031038 }
0 commit comments