@@ -850,103 +850,204 @@ void WG_PS_SD(p1f16_u32_p3i8)(_Float16 *first, uint32_t n, uint8_t *scratch) {
850850
851851// ============= default sub group private sort for signed integer =============
852852DEVICE_EXTERN_C_INLINE
853- int8_t SG_PS_A (i8 )(int8_t value, uint8_t *scratch) {
853+ int8_t SG_PS_A (i8_p1i8)(int8_t value, uint8_t *scratch) {
854+ return sub_group_merge_sort (value, scratch, std::less<int8_t >{});
855+ }
856+
857+ int8_t SG_PS_A (i8_p3i8)(int8_t value, uint8_t *scratch) {
854858 return sub_group_merge_sort (value, scratch, std::less<int8_t >{});
855859}
856860
857861DEVICE_EXTERN_C_INLINE
858- int16_t SG_PS_A (i16 )(int16_t value, uint8_t *scratch) {
862+ int16_t SG_PS_A (i16_p1i8 )(int16_t value, uint8_t *scratch) {
859863 return sub_group_merge_sort (value, scratch, std::less<int16_t >{});
860864}
861865
862866DEVICE_EXTERN_C_INLINE
863- int32_t SG_PS_A (i32 )(int32_t value, uint8_t *scratch) {
867+ int16_t SG_PS_A (i16_p3i8)(int16_t value, uint8_t *scratch) {
868+ return sub_group_merge_sort (value, scratch, std::less<int16_t >{});
869+ }
870+
871+ DEVICE_EXTERN_C_INLINE
872+ int32_t SG_PS_A (i32_p1i8)(int32_t value, uint8_t *scratch) {
873+ return sub_group_merge_sort (value, scratch, std::less<int32_t >{});
874+ }
875+
876+ DEVICE_EXTERN_C_INLINE
877+ int32_t SG_PS_A (i32_p3i8)(int32_t value, uint8_t *scratch) {
864878 return sub_group_merge_sort (value, scratch, std::less<int32_t >{});
865879}
866880
867881DEVICE_EXTERN_C_INLINE
868- int64_t SG_PS_A (i64 )(int64_t value, uint8_t *scratch) {
882+ int64_t SG_PS_A (i64_p1i8 )(int64_t value, uint8_t *scratch) {
869883 return sub_group_merge_sort (value, scratch, std::less<int64_t >{});
870884}
871885
872886DEVICE_EXTERN_C_INLINE
873- uint8_t SG_PS_A (u8 )(uint8_t value, uint8_t *scratch) {
887+ int64_t SG_PS_A (i64_p3i8)(int64_t value, uint8_t *scratch) {
888+ return sub_group_merge_sort (value, scratch, std::less<int64_t >{});
889+ }
890+
891+ DEVICE_EXTERN_C_INLINE
892+ uint8_t SG_PS_A (u8_p1i8)(uint8_t value, uint8_t *scratch) {
893+ return sub_group_merge_sort (value, scratch, std::less<uint8_t >{});
894+ }
895+
896+ DEVICE_EXTERN_C_INLINE
897+ uint8_t SG_PS_A (u8_p3i8)(uint8_t value, uint8_t *scratch) {
874898 return sub_group_merge_sort (value, scratch, std::less<uint8_t >{});
875899}
876900
877901DEVICE_EXTERN_C_INLINE
878- uint16_t SG_PS_A (u16 )(uint16_t value, uint8_t *scratch) {
902+ uint16_t SG_PS_A (u16_p1i8 )(uint16_t value, uint8_t *scratch) {
879903 return sub_group_merge_sort (value, scratch, std::less<uint16_t >{});
880904}
881905
882906DEVICE_EXTERN_C_INLINE
883- uint32_t SG_PS_A (u32 )(uint32_t value, uint8_t *scratch) {
907+ uint16_t SG_PS_A (u16_p3i8)(uint16_t value, uint8_t *scratch) {
908+ return sub_group_merge_sort (value, scratch, std::less<uint16_t >{});
909+ }
910+
911+ DEVICE_EXTERN_C_INLINE
912+ uint32_t SG_PS_A (u32_p1i8)(uint32_t value, uint8_t *scratch) {
913+ return sub_group_merge_sort (value, scratch, std::less<uint32_t >{});
914+ }
915+
916+ DEVICE_EXTERN_C_INLINE
917+ uint32_t SG_PS_A (u32_p3i8)(uint32_t value, uint8_t *scratch) {
884918 return sub_group_merge_sort (value, scratch, std::less<uint32_t >{});
885919}
886920
887921DEVICE_EXTERN_C_INLINE
888- uint64_t SG_PS_A (u64 )(uint64_t value, uint8_t *scratch) {
922+ uint64_t SG_PS_A (u64_p1i8)(uint64_t value, uint8_t *scratch) {
923+ return sub_group_merge_sort (value, scratch, std::less<uint64_t >{});
924+ }
925+
926+ DEVICE_EXTERN_C_INLINE
927+ uint64_t SG_PS_A (u64_p3i8)(uint64_t value, uint8_t *scratch) {
889928 return sub_group_merge_sort (value, scratch, std::less<uint64_t >{});
890929}
891930
892931DEVICE_EXTERN_C_INLINE
893- float SG_PS_A (f32 )(float value, uint8_t *scratch) {
932+ float SG_PS_A (f32_p1i8 )(float value, uint8_t *scratch) {
894933 return sub_group_merge_sort (value, scratch, std::less<float >{});
895934}
896935
897936DEVICE_EXTERN_C_INLINE
898- _Float16 SG_PS_A (f16 )(_Float16 value, uint8_t *scratch) {
937+ float SG_PS_A (f32_p3i8)(float value, uint8_t *scratch) {
938+ return sub_group_merge_sort (value, scratch, std::less<float >{});
939+ }
940+
941+ DEVICE_EXTERN_C_INLINE
942+ _Float16 SG_PS_A (f16_p1i8)(_Float16 value, uint8_t *scratch) {
943+ return sub_group_merge_sort (value, scratch,
944+ [](_Float16 a, _Float16 b) { return (a < b); });
945+ }
946+
947+ DEVICE_EXTERN_C_INLINE
948+ _Float16 SG_PS_A (f16_p3i8)(_Float16 value, uint8_t *scratch) {
899949 return sub_group_merge_sort (value, scratch,
900950 [](_Float16 a, _Float16 b) { return (a < b); });
901951}
902952
903953DEVICE_EXTERN_C_INLINE
904- int8_t SG_PS_D (i8 )(int8_t value, uint8_t *scratch) {
954+ int8_t SG_PS_D (i8_p1i8 )(int8_t value, uint8_t *scratch) {
905955 return sub_group_merge_sort (value, scratch, std::greater<int8_t >{});
906956}
907957
908958DEVICE_EXTERN_C_INLINE
909- int16_t SG_PS_D (i16 )(int16_t value, uint8_t *scratch) {
959+ int8_t SG_PS_D (i8_p3i8)(int8_t value, uint8_t *scratch) {
960+ return sub_group_merge_sort (value, scratch, std::greater<int8_t >{});
961+ }
962+
963+ DEVICE_EXTERN_C_INLINE
964+ int16_t SG_PS_D (i16_p1i8)(int16_t value, uint8_t *scratch) {
965+ return sub_group_merge_sort (value, scratch, std::greater<int16_t >{});
966+ }
967+
968+ DEVICE_EXTERN_C_INLINE
969+ int16_t SG_PS_D (i16_p3i8)(int16_t value, uint8_t *scratch) {
910970 return sub_group_merge_sort (value, scratch, std::greater<int16_t >{});
911971}
912972
913973DEVICE_EXTERN_C_INLINE
914- int32_t SG_PS_D (i32 )(int32_t value, uint8_t *scratch) {
974+ int32_t SG_PS_D (i32_p1i8)(int32_t value, uint8_t *scratch) {
975+ return sub_group_merge_sort (value, scratch, std::greater<int32_t >{});
976+ }
977+
978+ DEVICE_EXTERN_C_INLINE
979+ int32_t SG_PS_D (i32_p3i8)(int32_t value, uint8_t *scratch) {
915980 return sub_group_merge_sort (value, scratch, std::greater<int32_t >{});
916981}
917982
918983DEVICE_EXTERN_C_INLINE
919- int64_t SG_PS_D (i64 )(int64_t value, uint8_t *scratch) {
984+ int64_t SG_PS_D (i64_p1i8)(int64_t value, uint8_t *scratch) {
985+ return sub_group_merge_sort (value, scratch, std::greater<int64_t >{});
986+ }
987+
988+ DEVICE_EXTERN_C_INLINE
989+ int64_t SG_PS_D (i64_p3i8)(int64_t value, uint8_t *scratch) {
920990 return sub_group_merge_sort (value, scratch, std::greater<int64_t >{});
921991}
922992
923993DEVICE_EXTERN_C_INLINE
924- uint8_t SG_PS_D (u8 )(uint8_t value, uint8_t *scratch) {
994+ uint8_t SG_PS_D (u8_p1i8 )(uint8_t value, uint8_t *scratch) {
925995 return sub_group_merge_sort (value, scratch, std::greater<uint8_t >{});
926996}
927997
928998DEVICE_EXTERN_C_INLINE
929- uint16_t SG_PS_D (u16 )(uint16_t value, uint8_t *scratch) {
999+ uint8_t SG_PS_D (u8_p3i8)(uint8_t value, uint8_t *scratch) {
1000+ return sub_group_merge_sort (value, scratch, std::greater<uint8_t >{});
1001+ }
1002+
1003+ DEVICE_EXTERN_C_INLINE
1004+ uint16_t SG_PS_D (u16_p1i8)(uint16_t value, uint8_t *scratch) {
1005+ return sub_group_merge_sort (value, scratch, std::greater<uint16_t >{});
1006+ }
1007+
1008+ DEVICE_EXTERN_C_INLINE
1009+ uint16_t SG_PS_D (u16_p3i8)(uint16_t value, uint8_t *scratch) {
9301010 return sub_group_merge_sort (value, scratch, std::greater<uint16_t >{});
9311011}
9321012
9331013DEVICE_EXTERN_C_INLINE
934- uint32_t SG_PS_D (u32 )(uint32_t value, uint8_t *scratch) {
1014+ uint32_t SG_PS_D (u32_p1i8)(uint32_t value, uint8_t *scratch) {
1015+ return sub_group_merge_sort (value, scratch, std::greater<uint32_t >{});
1016+ }
1017+
1018+ DEVICE_EXTERN_C_INLINE
1019+ uint32_t SG_PS_D (u32_p3i8)(uint32_t value, uint8_t *scratch) {
9351020 return sub_group_merge_sort (value, scratch, std::greater<uint32_t >{});
9361021}
9371022
9381023DEVICE_EXTERN_C_INLINE
939- uint64_t SG_PS_D (u64 )(uint64_t value, uint8_t *scratch) {
1024+ uint64_t SG_PS_D (u64_p1i8)(uint64_t value, uint8_t *scratch) {
1025+ return sub_group_merge_sort (value, scratch, std::greater<uint64_t >{});
1026+ }
1027+
1028+ DEVICE_EXTERN_C_INLINE
1029+ uint64_t SG_PS_D (u64_p3i8)(uint64_t value, uint8_t *scratch) {
9401030 return sub_group_merge_sort (value, scratch, std::greater<uint64_t >{});
9411031}
9421032
9431033DEVICE_EXTERN_C_INLINE
944- float SG_PS_D (f32 )(float value, uint8_t *scratch) {
1034+ float SG_PS_D (f32_p1i8)(float value, uint8_t *scratch) {
1035+ return sub_group_merge_sort (value, scratch, std::greater<float >{});
1036+ }
1037+
1038+ DEVICE_EXTERN_C_INLINE
1039+ float SG_PS_D (f32_p3i8)(float value, uint8_t *scratch) {
9451040 return sub_group_merge_sort (value, scratch, std::greater<float >{});
9461041}
9471042
9481043DEVICE_EXTERN_C_INLINE
949- _Float16 SG_PS_D (f16 )(_Float16 value, uint8_t *scratch) {
1044+ _Float16 SG_PS_D (f16_p1i8)(_Float16 value, uint8_t *scratch) {
1045+ return sub_group_merge_sort (value, scratch,
1046+ [](_Float16 a, _Float16 b) { return (a > b); });
1047+ }
1048+
1049+ DEVICE_EXTERN_C_INLINE
1050+ _Float16 SG_PS_D (f16_p3i8)(_Float16 value, uint8_t *scratch) {
9501051 return sub_group_merge_sort (value, scratch,
9511052 [](_Float16 a, _Float16 b) { return (a > b); });
9521053}
0 commit comments