Skip to content

Commit b079feb

Browse files
committed
add subgroup key-value sort for 1-element for u8/u16 key
Signed-off-by: jinge90 <[email protected]>
1 parent c1e3aaa commit b079feb

File tree

1 file changed

+247
-0
lines changed

1 file changed

+247
-0
lines changed

libdevice/fallback-gsort.cpp

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6985,4 +6985,251 @@ void WG_PS_SD(p1i64_p1f32_u32_p3i8)(int64_t *keys, float *vals, uint32_t n,
69856985
n, scratch,
69866986
std::greater_equal<int64_t>{});
69876987
}
6988+
6989+
// 1-element version of subgroup private sort.
6990+
DEVICE_EXTERN_C_INLINE
6991+
void SG_PS_A(p1u8_p1u8_u32_p1i8)(uint8_t *key, uint8_t *val, uint8_t *scratch) {
6992+
private_merge_sort_key_value_close(key, val, 1, scratch,
6993+
std::less_equal<uint8_t>{});
6994+
}
6995+
6996+
DEVICE_EXTERN_C_INLINE
6997+
void SG_PS_D(p1u8_p1u8_u32_p1i8)(uint8_t *key, uint8_t *val, uint8_t *scratch) {
6998+
private_merge_sort_key_value_close(key, val, 1, scratch,
6999+
std::greater_equal<uint8_t>{});
7000+
}
7001+
7002+
DEVICE_EXTERN_C_INLINE
7003+
void SG_PS_A(p1u8_p1i8_u32_p1i8)(uint8_t *key, int8_t *val, uint8_t *scratch) {
7004+
private_merge_sort_key_value_close(key, reinterpret_cast<int8_t *>(val), 1,
7005+
scratch, std::less_equal<uint8_t>{});
7006+
}
7007+
7008+
DEVICE_EXTERN_C_INLINE
7009+
void SG_PS_D(p1u8_p1i8_u32_p1i8)(uint8_t *key, int8_t *val, uint8_t *scratch) {
7010+
private_merge_sort_key_value_close(key, reinterpret_cast<int8_t *>(val), 1,
7011+
scratch, std::greater_equal<uint8_t>{});
7012+
}
7013+
7014+
DEVICE_EXTERN_C_INLINE
7015+
void SG_PS_A(p1u8_p1u16_u32_p1i8)(uint8_t *key, uint16_t *val,
7016+
uint8_t *scratch) {
7017+
private_merge_sort_key_value_close(key, val, 1, scratch,
7018+
std::less_equal<uint8_t>{});
7019+
}
7020+
7021+
DEVICE_EXTERN_C_INLINE
7022+
void SG_PS_D(p1u8_p1u16_u32_p1i8)(uint8_t *key, uint16_t *val,
7023+
uint8_t *scratch) {
7024+
private_merge_sort_key_value_close(key, val, 1, scratch,
7025+
std::greater_equal<uint8_t>{});
7026+
}
7027+
7028+
DEVICE_EXTERN_C_INLINE
7029+
void SG_PS_A(p1u8_p1i16_u32_p1i8)(uint8_t *key, int16_t *val,
7030+
uint8_t *scratch) {
7031+
private_merge_sort_key_value_close(key, reinterpret_cast<uint16_t *>(val), 1,
7032+
scratch, std::less_equal<uint8_t>{});
7033+
}
7034+
7035+
DEVICE_EXTERN_C_INLINE
7036+
void SG_PS_D(p1u8_p1i16_u32_p1i8)(uint8_t *key, int16_t *val,
7037+
uint8_t *scratch) {
7038+
private_merge_sort_key_value_close(key, reinterpret_cast<uint16_t *>(val), 1,
7039+
scratch, std::greater_equal<uint8_t>{});
7040+
}
7041+
7042+
DEVICE_EXTERN_C_INLINE
7043+
void SG_PS_A(p1u8_p1u32_u32_p1i8)(uint8_t *key, uint32_t *val,
7044+
uint8_t *scratch) {
7045+
private_merge_sort_key_value_close(key, val, 1, scratch,
7046+
std::less_equal<uint8_t>{});
7047+
}
7048+
7049+
DEVICE_EXTERN_C_INLINE
7050+
void SG_PS_D(p1u8_p1u32_u32_p1i8)(uint8_t *key, uint32_t *val,
7051+
uint8_t *scratch) {
7052+
private_merge_sort_key_value_close(key, val, 1, scratch,
7053+
std::greater_equal<uint8_t>{});
7054+
}
7055+
7056+
DEVICE_EXTERN_C_INLINE
7057+
void SG_PS_A(p1u8_p1i32_u32_p1i8)(uint8_t *key, int32_t *val,
7058+
uint8_t *scratch) {
7059+
private_merge_sort_key_value_close(key, reinterpret_cast<uint32_t *>(val), 1,
7060+
scratch, std::less_equal<uint8_t>{});
7061+
}
7062+
7063+
DEVICE_EXTERN_C_INLINE
7064+
void SG_PS_D(p1u8_p1i32_u32_p1i8)(uint8_t *key, int32_t *val,
7065+
uint8_t *scratch) {
7066+
private_merge_sort_key_value_close(key, reinterpret_cast<uint32_t *>(val), 1,
7067+
scratch, std::greater_equal<uint8_t>{});
7068+
}
7069+
7070+
DEVICE_EXTERN_C_INLINE
7071+
void SG_PS_A(p1u8_p1u64_u32_p1i8)(uint8_t *key, uint64_t *val,
7072+
uint8_t *scratch) {
7073+
private_merge_sort_key_value_close(key, val, 1, scratch,
7074+
std::less_equal<uint8_t>{});
7075+
}
7076+
7077+
DEVICE_EXTERN_C_INLINE
7078+
void SG_PS_D(p1u8_p1u64_u32_p1i8)(uint8_t *key, uint64_t *val,
7079+
uint8_t *scratch) {
7080+
private_merge_sort_key_value_close(key, val, 1, scratch,
7081+
std::greater_equal<uint8_t>{});
7082+
}
7083+
7084+
DEVICE_EXTERN_C_INLINE
7085+
void SG_PS_A(p1u8_p1i64_u32_p1i8)(uint8_t *key, int64_t *val,
7086+
uint8_t *scratch) {
7087+
private_merge_sort_key_value_close(key, reinterpret_cast<uint64_t *>(val), 1,
7088+
scratch, std::less_equal<uint8_t>{});
7089+
}
7090+
7091+
DEVICE_EXTERN_C_INLINE
7092+
void SG_PS_D(p1u8_p1i64_u32_p1i8)(uint8_t *key, int64_t *val,
7093+
uint8_t *scratch) {
7094+
private_merge_sort_key_value_close(key, reinterpret_cast<uint64_t *>(val), 1,
7095+
scratch, std::greater_equal<uint8_t>{});
7096+
}
7097+
7098+
DEVICE_EXTERN_C_INLINE
7099+
void SG_PS_A(p1u8_p1f32_u32_p1i8)(uint8_t *key, float *val, uint8_t *scratch) {
7100+
private_merge_sort_key_value_close(key, reinterpret_cast<uint32_t *>(val), 1,
7101+
scratch, std::less_equal<uint8_t>{});
7102+
}
7103+
7104+
DEVICE_EXTERN_C_INLINE
7105+
void SG_PS_D(p1u8_p1f32_u32_p1i8)(uint8_t *key, float *val, uint8_t *scratch) {
7106+
private_merge_sort_key_value_close(key, reinterpret_cast<uint32_t *>(val), 1,
7107+
scratch, std::greater_equal<uint8_t>{});
7108+
}
7109+
7110+
DEVICE_EXTERN_C_INLINE
7111+
void SG_PS_A(p1u16_p1u8_u32_p1i8)(uint16_t *key, uint8_t *val,
7112+
uint8_t *scratch) {
7113+
private_merge_sort_key_value_close(key, val, 1, scratch,
7114+
std::less_equal<uint16_t>{});
7115+
}
7116+
7117+
DEVICE_EXTERN_C_INLINE
7118+
void SG_PS_D(p1u16_p1u8_u32_p1i8)(uint16_t *key, uint8_t *val,
7119+
uint8_t *scratch) {
7120+
private_merge_sort_key_value_close(key, val, 1, scratch,
7121+
std::greater_equal<uint16_t>{});
7122+
}
7123+
7124+
DEVICE_EXTERN_C_INLINE
7125+
void SG_PS_A(p1u16_p1i8_u32_p1i8)(uint16_t *key, int8_t *val,
7126+
uint8_t *scratch) {
7127+
private_merge_sort_key_value_close(key, reinterpret_cast<int8_t *>(val), 1,
7128+
scratch, std::less_equal<uint16_t>{});
7129+
}
7130+
7131+
DEVICE_EXTERN_C_INLINE
7132+
void SG_PS_D(p1u16_p1i8_u32_p1i8)(uint16_t *key, int8_t *val,
7133+
uint8_t *scratch) {
7134+
private_merge_sort_key_value_close(key, reinterpret_cast<int8_t *>(val), 1,
7135+
scratch, std::greater_equal<uint16_t>{});
7136+
}
7137+
7138+
DEVICE_EXTERN_C_INLINE
7139+
void SG_PS_A(p1u16_p1u16_u32_p1i8)(uint16_t *key, uint16_t *val,
7140+
uint8_t *scratch) {
7141+
private_merge_sort_key_value_close(key, val, 1, scratch,
7142+
std::less_equal<uint16_t>{});
7143+
}
7144+
7145+
DEVICE_EXTERN_C_INLINE
7146+
void SG_PS_D(p1u16_p1u16_u32_p1i8)(uint16_t *key, uint16_t *val,
7147+
uint8_t *scratch) {
7148+
private_merge_sort_key_value_close(key, val, 1, scratch,
7149+
std::greater_equal<uint16_t>{});
7150+
}
7151+
7152+
DEVICE_EXTERN_C_INLINE
7153+
void SG_PS_A(p1u16_p1i16_u32_p1i8)(uint16_t *key, int16_t *val,
7154+
uint8_t *scratch) {
7155+
private_merge_sort_key_value_close(key, reinterpret_cast<uint16_t *>(val), 1,
7156+
scratch, std::less_equal<uint16_t>{});
7157+
}
7158+
7159+
DEVICE_EXTERN_C_INLINE
7160+
void SG_PS_D(p1u16_p1i16_u32_p1i8)(uint16_t *key, int16_t *val,
7161+
uint8_t *scratch) {
7162+
private_merge_sort_key_value_close(key, reinterpret_cast<uint16_t *>(val), 1,
7163+
scratch, std::greater_equal<uint16_t>{});
7164+
}
7165+
7166+
DEVICE_EXTERN_C_INLINE
7167+
void SG_PS_A(p1u16_p1u32_u32_p1i8)(uint16_t *key, uint32_t *val,
7168+
uint8_t *scratch) {
7169+
private_merge_sort_key_value_close(key, val, 1, scratch,
7170+
std::less_equal<uint16_t>{});
7171+
}
7172+
7173+
DEVICE_EXTERN_C_INLINE
7174+
void SG_PS_D(p1u16_p1u32_u32_p1i8)(uint16_t *key, uint32_t *val,
7175+
uint8_t *scratch) {
7176+
private_merge_sort_key_value_close(key, val, 1, scratch,
7177+
std::greater_equal<uint16_t>{});
7178+
}
7179+
7180+
DEVICE_EXTERN_C_INLINE
7181+
void SG_PS_A(p1u16_p1i32_u32_p1i8)(uint16_t *key, int32_t *val,
7182+
uint8_t *scratch) {
7183+
private_merge_sort_key_value_close(key, reinterpret_cast<uint32_t *>(val), 1,
7184+
scratch, std::less_equal<uint16_t>{});
7185+
}
7186+
7187+
DEVICE_EXTERN_C_INLINE
7188+
void SG_PS_D(p1u16_p1i32_u32_p1i8)(uint16_t *key, int32_t *val,
7189+
uint8_t *scratch) {
7190+
private_merge_sort_key_value_close(key, reinterpret_cast<uint32_t *>(val), 1,
7191+
scratch, std::greater_equal<uint16_t>{});
7192+
}
7193+
7194+
DEVICE_EXTERN_C_INLINE
7195+
void SG_PS_A(p1u16_p1u64_u32_p1i8)(uint16_t *key, uint64_t *val,
7196+
uint8_t *scratch) {
7197+
private_merge_sort_key_value_close(key, val, 1, scratch,
7198+
std::less_equal<uint16_t>{});
7199+
}
7200+
7201+
DEVICE_EXTERN_C_INLINE
7202+
void SG_PS_D(p1u16_p1u64_u32_p1i8)(uint16_t *key, uint64_t *val,
7203+
uint8_t *scratch) {
7204+
private_merge_sort_key_value_close(key, val, 1, scratch,
7205+
std::greater_equal<uint16_t>{});
7206+
}
7207+
7208+
DEVICE_EXTERN_C_INLINE
7209+
void SG_PS_A(p1u16_p1i64_u32_p1i8)(uint16_t *key, int64_t *val,
7210+
uint8_t *scratch) {
7211+
private_merge_sort_key_value_close(key, reinterpret_cast<uint64_t *>(val), 1,
7212+
scratch, std::less_equal<uint16_t>{});
7213+
}
7214+
7215+
DEVICE_EXTERN_C_INLINE
7216+
void SG_PS_D(p1u16_p1i64_u32_p1i8)(uint16_t *key, int64_t *val,
7217+
uint8_t *scratch) {
7218+
private_merge_sort_key_value_close(key, reinterpret_cast<uint64_t *>(val), 1,
7219+
scratch, std::greater_equal<uint16_t>{});
7220+
}
7221+
7222+
DEVICE_EXTERN_C_INLINE
7223+
void SG_PS_A(p1u16_p1f32_u32_p1i8)(uint16_t *key, float *val,
7224+
uint8_t *scratch) {
7225+
private_merge_sort_key_value_close(key, reinterpret_cast<uint32_t *>(val), 1,
7226+
scratch, std::less_equal<uint16_t>{});
7227+
}
7228+
7229+
DEVICE_EXTERN_C_INLINE
7230+
void SG_PS_D(p1u16_p1f32_u32_p1i8)(uint16_t *key, float *val,
7231+
uint8_t *scratch) {
7232+
private_merge_sort_key_value_close(key, reinterpret_cast<uint32_t *>(val), 1,
7233+
scratch, std::greater_equal<uint16_t>{});
7234+
}
69887235
#endif // __SPIR__ || __SPIRV__

0 commit comments

Comments
 (0)