Skip to content

Commit a4e8572

Browse files
committed
correct function name for subgroup private sorting
Signed-off-by: jinge90 <[email protected]>
1 parent b26e1a5 commit a4e8572

File tree

1 file changed

+121
-20
lines changed

1 file changed

+121
-20
lines changed

libdevice/fallback-gsort.cpp

Lines changed: 121 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -850,103 +850,204 @@ void WG_PS_SD(p1f16_u32_p3i8)(_Float16 *first, uint32_t n, uint8_t *scratch) {
850850

851851
//============= default sub group private sort for signed integer =============
852852
DEVICE_EXTERN_C_INLINE
853-
int8_t SG_PS_A(i8)(int8_t value, uint8_t *scratch) {
853+
int8_t SG_PS_A(i8_p1i8)(int8_t value, uint8_t *scratch) {
854+
return sub_group_merge_sort(value, scratch, std::less<int8_t>{});
855+
}
856+
857+
int8_t SG_PS_A(i8_p3i8)(int8_t value, uint8_t *scratch) {
854858
return sub_group_merge_sort(value, scratch, std::less<int8_t>{});
855859
}
856860

857861
DEVICE_EXTERN_C_INLINE
858-
int16_t SG_PS_A(i16)(int16_t value, uint8_t *scratch) {
862+
int16_t SG_PS_A(i16_p1i8)(int16_t value, uint8_t *scratch) {
859863
return sub_group_merge_sort(value, scratch, std::less<int16_t>{});
860864
}
861865

862866
DEVICE_EXTERN_C_INLINE
863-
int32_t SG_PS_A(i32)(int32_t value, uint8_t *scratch) {
867+
int16_t SG_PS_A(i16_p3i8)(int16_t value, uint8_t *scratch) {
868+
return sub_group_merge_sort(value, scratch, std::less<int16_t>{});
869+
}
870+
871+
DEVICE_EXTERN_C_INLINE
872+
int32_t SG_PS_A(i32_p1i8)(int32_t value, uint8_t *scratch) {
873+
return sub_group_merge_sort(value, scratch, std::less<int32_t>{});
874+
}
875+
876+
DEVICE_EXTERN_C_INLINE
877+
int32_t SG_PS_A(i32_p3i8)(int32_t value, uint8_t *scratch) {
864878
return sub_group_merge_sort(value, scratch, std::less<int32_t>{});
865879
}
866880

867881
DEVICE_EXTERN_C_INLINE
868-
int64_t SG_PS_A(i64)(int64_t value, uint8_t *scratch) {
882+
int64_t SG_PS_A(i64_p1i8)(int64_t value, uint8_t *scratch) {
869883
return sub_group_merge_sort(value, scratch, std::less<int64_t>{});
870884
}
871885

872886
DEVICE_EXTERN_C_INLINE
873-
uint8_t SG_PS_A(u8)(uint8_t value, uint8_t *scratch) {
887+
int64_t SG_PS_A(i64_p3i8)(int64_t value, uint8_t *scratch) {
888+
return sub_group_merge_sort(value, scratch, std::less<int64_t>{});
889+
}
890+
891+
DEVICE_EXTERN_C_INLINE
892+
uint8_t SG_PS_A(u8_p1i8)(uint8_t value, uint8_t *scratch) {
893+
return sub_group_merge_sort(value, scratch, std::less<uint8_t>{});
894+
}
895+
896+
DEVICE_EXTERN_C_INLINE
897+
uint8_t SG_PS_A(u8_p3i8)(uint8_t value, uint8_t *scratch) {
874898
return sub_group_merge_sort(value, scratch, std::less<uint8_t>{});
875899
}
876900

877901
DEVICE_EXTERN_C_INLINE
878-
uint16_t SG_PS_A(u16)(uint16_t value, uint8_t *scratch) {
902+
uint16_t SG_PS_A(u16_p1i8)(uint16_t value, uint8_t *scratch) {
879903
return sub_group_merge_sort(value, scratch, std::less<uint16_t>{});
880904
}
881905

882906
DEVICE_EXTERN_C_INLINE
883-
uint32_t SG_PS_A(u32)(uint32_t value, uint8_t *scratch) {
907+
uint16_t SG_PS_A(u16_p3i8)(uint16_t value, uint8_t *scratch) {
908+
return sub_group_merge_sort(value, scratch, std::less<uint16_t>{});
909+
}
910+
911+
DEVICE_EXTERN_C_INLINE
912+
uint32_t SG_PS_A(u32_p1i8)(uint32_t value, uint8_t *scratch) {
913+
return sub_group_merge_sort(value, scratch, std::less<uint32_t>{});
914+
}
915+
916+
DEVICE_EXTERN_C_INLINE
917+
uint32_t SG_PS_A(u32_p3i8)(uint32_t value, uint8_t *scratch) {
884918
return sub_group_merge_sort(value, scratch, std::less<uint32_t>{});
885919
}
886920

887921
DEVICE_EXTERN_C_INLINE
888-
uint64_t SG_PS_A(u64)(uint64_t value, uint8_t *scratch) {
922+
uint64_t SG_PS_A(u64_p1i8)(uint64_t value, uint8_t *scratch) {
923+
return sub_group_merge_sort(value, scratch, std::less<uint64_t>{});
924+
}
925+
926+
DEVICE_EXTERN_C_INLINE
927+
uint64_t SG_PS_A(u64_p3i8)(uint64_t value, uint8_t *scratch) {
889928
return sub_group_merge_sort(value, scratch, std::less<uint64_t>{});
890929
}
891930

892931
DEVICE_EXTERN_C_INLINE
893-
float SG_PS_A(f32)(float value, uint8_t *scratch) {
932+
float SG_PS_A(f32_p1i8)(float value, uint8_t *scratch) {
894933
return sub_group_merge_sort(value, scratch, std::less<float>{});
895934
}
896935

897936
DEVICE_EXTERN_C_INLINE
898-
_Float16 SG_PS_A(f16)(_Float16 value, uint8_t *scratch) {
937+
float SG_PS_A(f32_p3i8)(float value, uint8_t *scratch) {
938+
return sub_group_merge_sort(value, scratch, std::less<float>{});
939+
}
940+
941+
DEVICE_EXTERN_C_INLINE
942+
_Float16 SG_PS_A(f16_p1i8)(_Float16 value, uint8_t *scratch) {
943+
return sub_group_merge_sort(value, scratch,
944+
[](_Float16 a, _Float16 b) { return (a < b); });
945+
}
946+
947+
DEVICE_EXTERN_C_INLINE
948+
_Float16 SG_PS_A(f16_p3i8)(_Float16 value, uint8_t *scratch) {
899949
return sub_group_merge_sort(value, scratch,
900950
[](_Float16 a, _Float16 b) { return (a < b); });
901951
}
902952

903953
DEVICE_EXTERN_C_INLINE
904-
int8_t SG_PS_D(i8)(int8_t value, uint8_t *scratch) {
954+
int8_t SG_PS_D(i8_p1i8)(int8_t value, uint8_t *scratch) {
905955
return sub_group_merge_sort(value, scratch, std::greater<int8_t>{});
906956
}
907957

908958
DEVICE_EXTERN_C_INLINE
909-
int16_t SG_PS_D(i16)(int16_t value, uint8_t *scratch) {
959+
int8_t SG_PS_D(i8_p3i8)(int8_t value, uint8_t *scratch) {
960+
return sub_group_merge_sort(value, scratch, std::greater<int8_t>{});
961+
}
962+
963+
DEVICE_EXTERN_C_INLINE
964+
int16_t SG_PS_D(i16_p1i8)(int16_t value, uint8_t *scratch) {
965+
return sub_group_merge_sort(value, scratch, std::greater<int16_t>{});
966+
}
967+
968+
DEVICE_EXTERN_C_INLINE
969+
int16_t SG_PS_D(i16_p3i8)(int16_t value, uint8_t *scratch) {
910970
return sub_group_merge_sort(value, scratch, std::greater<int16_t>{});
911971
}
912972

913973
DEVICE_EXTERN_C_INLINE
914-
int32_t SG_PS_D(i32)(int32_t value, uint8_t *scratch) {
974+
int32_t SG_PS_D(i32_p1i8)(int32_t value, uint8_t *scratch) {
975+
return sub_group_merge_sort(value, scratch, std::greater<int32_t>{});
976+
}
977+
978+
DEVICE_EXTERN_C_INLINE
979+
int32_t SG_PS_D(i32_p3i8)(int32_t value, uint8_t *scratch) {
915980
return sub_group_merge_sort(value, scratch, std::greater<int32_t>{});
916981
}
917982

918983
DEVICE_EXTERN_C_INLINE
919-
int64_t SG_PS_D(i64)(int64_t value, uint8_t *scratch) {
984+
int64_t SG_PS_D(i64_p1i8)(int64_t value, uint8_t *scratch) {
985+
return sub_group_merge_sort(value, scratch, std::greater<int64_t>{});
986+
}
987+
988+
DEVICE_EXTERN_C_INLINE
989+
int64_t SG_PS_D(i64_p3i8)(int64_t value, uint8_t *scratch) {
920990
return sub_group_merge_sort(value, scratch, std::greater<int64_t>{});
921991
}
922992

923993
DEVICE_EXTERN_C_INLINE
924-
uint8_t SG_PS_D(u8)(uint8_t value, uint8_t *scratch) {
994+
uint8_t SG_PS_D(u8_p1i8)(uint8_t value, uint8_t *scratch) {
925995
return sub_group_merge_sort(value, scratch, std::greater<uint8_t>{});
926996
}
927997

928998
DEVICE_EXTERN_C_INLINE
929-
uint16_t SG_PS_D(u16)(uint16_t value, uint8_t *scratch) {
999+
uint8_t SG_PS_D(u8_p3i8)(uint8_t value, uint8_t *scratch) {
1000+
return sub_group_merge_sort(value, scratch, std::greater<uint8_t>{});
1001+
}
1002+
1003+
DEVICE_EXTERN_C_INLINE
1004+
uint16_t SG_PS_D(u16_p1i8)(uint16_t value, uint8_t *scratch) {
1005+
return sub_group_merge_sort(value, scratch, std::greater<uint16_t>{});
1006+
}
1007+
1008+
DEVICE_EXTERN_C_INLINE
1009+
uint16_t SG_PS_D(u16_p3i8)(uint16_t value, uint8_t *scratch) {
9301010
return sub_group_merge_sort(value, scratch, std::greater<uint16_t>{});
9311011
}
9321012

9331013
DEVICE_EXTERN_C_INLINE
934-
uint32_t SG_PS_D(u32)(uint32_t value, uint8_t *scratch) {
1014+
uint32_t SG_PS_D(u32_p1i8)(uint32_t value, uint8_t *scratch) {
1015+
return sub_group_merge_sort(value, scratch, std::greater<uint32_t>{});
1016+
}
1017+
1018+
DEVICE_EXTERN_C_INLINE
1019+
uint32_t SG_PS_D(u32_p3i8)(uint32_t value, uint8_t *scratch) {
9351020
return sub_group_merge_sort(value, scratch, std::greater<uint32_t>{});
9361021
}
9371022

9381023
DEVICE_EXTERN_C_INLINE
939-
uint64_t SG_PS_D(u64)(uint64_t value, uint8_t *scratch) {
1024+
uint64_t SG_PS_D(u64_p1i8)(uint64_t value, uint8_t *scratch) {
1025+
return sub_group_merge_sort(value, scratch, std::greater<uint64_t>{});
1026+
}
1027+
1028+
DEVICE_EXTERN_C_INLINE
1029+
uint64_t SG_PS_D(u64_p3i8)(uint64_t value, uint8_t *scratch) {
9401030
return sub_group_merge_sort(value, scratch, std::greater<uint64_t>{});
9411031
}
9421032

9431033
DEVICE_EXTERN_C_INLINE
944-
float SG_PS_D(f32)(float value, uint8_t *scratch) {
1034+
float SG_PS_D(f32_p1i8)(float value, uint8_t *scratch) {
1035+
return sub_group_merge_sort(value, scratch, std::greater<float>{});
1036+
}
1037+
1038+
DEVICE_EXTERN_C_INLINE
1039+
float SG_PS_D(f32_p3i8)(float value, uint8_t *scratch) {
9451040
return sub_group_merge_sort(value, scratch, std::greater<float>{});
9461041
}
9471042

9481043
DEVICE_EXTERN_C_INLINE
949-
_Float16 SG_PS_D(f16)(_Float16 value, uint8_t *scratch) {
1044+
_Float16 SG_PS_D(f16_p1i8)(_Float16 value, uint8_t *scratch) {
1045+
return sub_group_merge_sort(value, scratch,
1046+
[](_Float16 a, _Float16 b) { return (a > b); });
1047+
}
1048+
1049+
DEVICE_EXTERN_C_INLINE
1050+
_Float16 SG_PS_D(f16_p3i8)(_Float16 value, uint8_t *scratch) {
9501051
return sub_group_merge_sort(value, scratch,
9511052
[](_Float16 a, _Float16 b) { return (a > b); });
9521053
}

0 commit comments

Comments
 (0)