@@ -6985,4 +6985,251 @@ void WG_PS_SD(p1i64_p1f32_u32_p3i8)(int64_t *keys, float *vals, uint32_t n,
69856985 n, scratch,
69866986 std::greater_equal<int64_t >{});
69876987}
6988+
6989+ // 1-element version of subgroup private sort.
6990+ DEVICE_EXTERN_C_INLINE
6991+ void SG_PS_A (p1u8_p1u8_u32_p1i8)(uint8_t *key, uint8_t *val, uint8_t *scratch) {
6992+ private_merge_sort_key_value_close (key, val, 1 , scratch,
6993+ std::less_equal<uint8_t >{});
6994+ }
6995+
6996+ DEVICE_EXTERN_C_INLINE
6997+ void SG_PS_D (p1u8_p1u8_u32_p1i8)(uint8_t *key, uint8_t *val, uint8_t *scratch) {
6998+ private_merge_sort_key_value_close (key, val, 1 , scratch,
6999+ std::greater_equal<uint8_t >{});
7000+ }
7001+
7002+ DEVICE_EXTERN_C_INLINE
7003+ void SG_PS_A (p1u8_p1i8_u32_p1i8)(uint8_t *key, int8_t *val, uint8_t *scratch) {
7004+ private_merge_sort_key_value_close (key, reinterpret_cast <int8_t *>(val), 1 ,
7005+ scratch, std::less_equal<uint8_t >{});
7006+ }
7007+
7008+ DEVICE_EXTERN_C_INLINE
7009+ void SG_PS_D (p1u8_p1i8_u32_p1i8)(uint8_t *key, int8_t *val, uint8_t *scratch) {
7010+ private_merge_sort_key_value_close (key, reinterpret_cast <int8_t *>(val), 1 ,
7011+ scratch, std::greater_equal<uint8_t >{});
7012+ }
7013+
7014+ DEVICE_EXTERN_C_INLINE
7015+ void SG_PS_A (p1u8_p1u16_u32_p1i8)(uint8_t *key, uint16_t *val,
7016+ uint8_t *scratch) {
7017+ private_merge_sort_key_value_close (key, val, 1 , scratch,
7018+ std::less_equal<uint8_t >{});
7019+ }
7020+
7021+ DEVICE_EXTERN_C_INLINE
7022+ void SG_PS_D (p1u8_p1u16_u32_p1i8)(uint8_t *key, uint16_t *val,
7023+ uint8_t *scratch) {
7024+ private_merge_sort_key_value_close (key, val, 1 , scratch,
7025+ std::greater_equal<uint8_t >{});
7026+ }
7027+
7028+ DEVICE_EXTERN_C_INLINE
7029+ void SG_PS_A (p1u8_p1i16_u32_p1i8)(uint8_t *key, int16_t *val,
7030+ uint8_t *scratch) {
7031+ private_merge_sort_key_value_close (key, reinterpret_cast <uint16_t *>(val), 1 ,
7032+ scratch, std::less_equal<uint8_t >{});
7033+ }
7034+
7035+ DEVICE_EXTERN_C_INLINE
7036+ void SG_PS_D (p1u8_p1i16_u32_p1i8)(uint8_t *key, int16_t *val,
7037+ uint8_t *scratch) {
7038+ private_merge_sort_key_value_close (key, reinterpret_cast <uint16_t *>(val), 1 ,
7039+ scratch, std::greater_equal<uint8_t >{});
7040+ }
7041+
7042+ DEVICE_EXTERN_C_INLINE
7043+ void SG_PS_A (p1u8_p1u32_u32_p1i8)(uint8_t *key, uint32_t *val,
7044+ uint8_t *scratch) {
7045+ private_merge_sort_key_value_close (key, val, 1 , scratch,
7046+ std::less_equal<uint8_t >{});
7047+ }
7048+
7049+ DEVICE_EXTERN_C_INLINE
7050+ void SG_PS_D (p1u8_p1u32_u32_p1i8)(uint8_t *key, uint32_t *val,
7051+ uint8_t *scratch) {
7052+ private_merge_sort_key_value_close (key, val, 1 , scratch,
7053+ std::greater_equal<uint8_t >{});
7054+ }
7055+
7056+ DEVICE_EXTERN_C_INLINE
7057+ void SG_PS_A (p1u8_p1i32_u32_p1i8)(uint8_t *key, int32_t *val,
7058+ uint8_t *scratch) {
7059+ private_merge_sort_key_value_close (key, reinterpret_cast <uint32_t *>(val), 1 ,
7060+ scratch, std::less_equal<uint8_t >{});
7061+ }
7062+
7063+ DEVICE_EXTERN_C_INLINE
7064+ void SG_PS_D (p1u8_p1i32_u32_p1i8)(uint8_t *key, int32_t *val,
7065+ uint8_t *scratch) {
7066+ private_merge_sort_key_value_close (key, reinterpret_cast <uint32_t *>(val), 1 ,
7067+ scratch, std::greater_equal<uint8_t >{});
7068+ }
7069+
7070+ DEVICE_EXTERN_C_INLINE
7071+ void SG_PS_A (p1u8_p1u64_u32_p1i8)(uint8_t *key, uint64_t *val,
7072+ uint8_t *scratch) {
7073+ private_merge_sort_key_value_close (key, val, 1 , scratch,
7074+ std::less_equal<uint8_t >{});
7075+ }
7076+
7077+ DEVICE_EXTERN_C_INLINE
7078+ void SG_PS_D (p1u8_p1u64_u32_p1i8)(uint8_t *key, uint64_t *val,
7079+ uint8_t *scratch) {
7080+ private_merge_sort_key_value_close (key, val, 1 , scratch,
7081+ std::greater_equal<uint8_t >{});
7082+ }
7083+
7084+ DEVICE_EXTERN_C_INLINE
7085+ void SG_PS_A (p1u8_p1i64_u32_p1i8)(uint8_t *key, int64_t *val,
7086+ uint8_t *scratch) {
7087+ private_merge_sort_key_value_close (key, reinterpret_cast <uint64_t *>(val), 1 ,
7088+ scratch, std::less_equal<uint8_t >{});
7089+ }
7090+
7091+ DEVICE_EXTERN_C_INLINE
7092+ void SG_PS_D (p1u8_p1i64_u32_p1i8)(uint8_t *key, int64_t *val,
7093+ uint8_t *scratch) {
7094+ private_merge_sort_key_value_close (key, reinterpret_cast <uint64_t *>(val), 1 ,
7095+ scratch, std::greater_equal<uint8_t >{});
7096+ }
7097+
7098+ DEVICE_EXTERN_C_INLINE
7099+ void SG_PS_A (p1u8_p1f32_u32_p1i8)(uint8_t *key, float *val, uint8_t *scratch) {
7100+ private_merge_sort_key_value_close (key, reinterpret_cast <uint32_t *>(val), 1 ,
7101+ scratch, std::less_equal<uint8_t >{});
7102+ }
7103+
7104+ DEVICE_EXTERN_C_INLINE
7105+ void SG_PS_D (p1u8_p1f32_u32_p1i8)(uint8_t *key, float *val, uint8_t *scratch) {
7106+ private_merge_sort_key_value_close (key, reinterpret_cast <uint32_t *>(val), 1 ,
7107+ scratch, std::greater_equal<uint8_t >{});
7108+ }
7109+
7110+ DEVICE_EXTERN_C_INLINE
7111+ void SG_PS_A (p1u16_p1u8_u32_p1i8)(uint16_t *key, uint8_t *val,
7112+ uint8_t *scratch) {
7113+ private_merge_sort_key_value_close (key, val, 1 , scratch,
7114+ std::less_equal<uint16_t >{});
7115+ }
7116+
7117+ DEVICE_EXTERN_C_INLINE
7118+ void SG_PS_D (p1u16_p1u8_u32_p1i8)(uint16_t *key, uint8_t *val,
7119+ uint8_t *scratch) {
7120+ private_merge_sort_key_value_close (key, val, 1 , scratch,
7121+ std::greater_equal<uint16_t >{});
7122+ }
7123+
7124+ DEVICE_EXTERN_C_INLINE
7125+ void SG_PS_A (p1u16_p1i8_u32_p1i8)(uint16_t *key, int8_t *val,
7126+ uint8_t *scratch) {
7127+ private_merge_sort_key_value_close (key, reinterpret_cast <int8_t *>(val), 1 ,
7128+ scratch, std::less_equal<uint16_t >{});
7129+ }
7130+
7131+ DEVICE_EXTERN_C_INLINE
7132+ void SG_PS_D (p1u16_p1i8_u32_p1i8)(uint16_t *key, int8_t *val,
7133+ uint8_t *scratch) {
7134+ private_merge_sort_key_value_close (key, reinterpret_cast <int8_t *>(val), 1 ,
7135+ scratch, std::greater_equal<uint16_t >{});
7136+ }
7137+
7138+ DEVICE_EXTERN_C_INLINE
7139+ void SG_PS_A (p1u16_p1u16_u32_p1i8)(uint16_t *key, uint16_t *val,
7140+ uint8_t *scratch) {
7141+ private_merge_sort_key_value_close (key, val, 1 , scratch,
7142+ std::less_equal<uint16_t >{});
7143+ }
7144+
7145+ DEVICE_EXTERN_C_INLINE
7146+ void SG_PS_D (p1u16_p1u16_u32_p1i8)(uint16_t *key, uint16_t *val,
7147+ uint8_t *scratch) {
7148+ private_merge_sort_key_value_close (key, val, 1 , scratch,
7149+ std::greater_equal<uint16_t >{});
7150+ }
7151+
7152+ DEVICE_EXTERN_C_INLINE
7153+ void SG_PS_A (p1u16_p1i16_u32_p1i8)(uint16_t *key, int16_t *val,
7154+ uint8_t *scratch) {
7155+ private_merge_sort_key_value_close (key, reinterpret_cast <uint16_t *>(val), 1 ,
7156+ scratch, std::less_equal<uint16_t >{});
7157+ }
7158+
7159+ DEVICE_EXTERN_C_INLINE
7160+ void SG_PS_D (p1u16_p1i16_u32_p1i8)(uint16_t *key, int16_t *val,
7161+ uint8_t *scratch) {
7162+ private_merge_sort_key_value_close (key, reinterpret_cast <uint16_t *>(val), 1 ,
7163+ scratch, std::greater_equal<uint16_t >{});
7164+ }
7165+
7166+ DEVICE_EXTERN_C_INLINE
7167+ void SG_PS_A (p1u16_p1u32_u32_p1i8)(uint16_t *key, uint32_t *val,
7168+ uint8_t *scratch) {
7169+ private_merge_sort_key_value_close (key, val, 1 , scratch,
7170+ std::less_equal<uint16_t >{});
7171+ }
7172+
7173+ DEVICE_EXTERN_C_INLINE
7174+ void SG_PS_D (p1u16_p1u32_u32_p1i8)(uint16_t *key, uint32_t *val,
7175+ uint8_t *scratch) {
7176+ private_merge_sort_key_value_close (key, val, 1 , scratch,
7177+ std::greater_equal<uint16_t >{});
7178+ }
7179+
7180+ DEVICE_EXTERN_C_INLINE
7181+ void SG_PS_A (p1u16_p1i32_u32_p1i8)(uint16_t *key, int32_t *val,
7182+ uint8_t *scratch) {
7183+ private_merge_sort_key_value_close (key, reinterpret_cast <uint32_t *>(val), 1 ,
7184+ scratch, std::less_equal<uint16_t >{});
7185+ }
7186+
7187+ DEVICE_EXTERN_C_INLINE
7188+ void SG_PS_D (p1u16_p1i32_u32_p1i8)(uint16_t *key, int32_t *val,
7189+ uint8_t *scratch) {
7190+ private_merge_sort_key_value_close (key, reinterpret_cast <uint32_t *>(val), 1 ,
7191+ scratch, std::greater_equal<uint16_t >{});
7192+ }
7193+
7194+ DEVICE_EXTERN_C_INLINE
7195+ void SG_PS_A (p1u16_p1u64_u32_p1i8)(uint16_t *key, uint64_t *val,
7196+ uint8_t *scratch) {
7197+ private_merge_sort_key_value_close (key, val, 1 , scratch,
7198+ std::less_equal<uint16_t >{});
7199+ }
7200+
7201+ DEVICE_EXTERN_C_INLINE
7202+ void SG_PS_D (p1u16_p1u64_u32_p1i8)(uint16_t *key, uint64_t *val,
7203+ uint8_t *scratch) {
7204+ private_merge_sort_key_value_close (key, val, 1 , scratch,
7205+ std::greater_equal<uint16_t >{});
7206+ }
7207+
7208+ DEVICE_EXTERN_C_INLINE
7209+ void SG_PS_A (p1u16_p1i64_u32_p1i8)(uint16_t *key, int64_t *val,
7210+ uint8_t *scratch) {
7211+ private_merge_sort_key_value_close (key, reinterpret_cast <uint64_t *>(val), 1 ,
7212+ scratch, std::less_equal<uint16_t >{});
7213+ }
7214+
7215+ DEVICE_EXTERN_C_INLINE
7216+ void SG_PS_D (p1u16_p1i64_u32_p1i8)(uint16_t *key, int64_t *val,
7217+ uint8_t *scratch) {
7218+ private_merge_sort_key_value_close (key, reinterpret_cast <uint64_t *>(val), 1 ,
7219+ scratch, std::greater_equal<uint16_t >{});
7220+ }
7221+
7222+ DEVICE_EXTERN_C_INLINE
7223+ void SG_PS_A (p1u16_p1f32_u32_p1i8)(uint16_t *key, float *val,
7224+ uint8_t *scratch) {
7225+ private_merge_sort_key_value_close (key, reinterpret_cast <uint32_t *>(val), 1 ,
7226+ scratch, std::less_equal<uint16_t >{});
7227+ }
7228+
7229+ DEVICE_EXTERN_C_INLINE
7230+ void SG_PS_D (p1u16_p1f32_u32_p1i8)(uint16_t *key, float *val,
7231+ uint8_t *scratch) {
7232+ private_merge_sort_key_value_close (key, reinterpret_cast <uint32_t *>(val), 1 ,
7233+ scratch, std::greater_equal<uint16_t >{});
7234+ }
69887235#endif // __SPIR__ || __SPIRV__
0 commit comments