@@ -65,13 +65,13 @@ __spirv_MemoryBarrier(uint32_t Memory, uint32_t Semantics) {
6565 template <> \
6666 __SYCL_CONVERGENT__ DEVICE_EXTERNAL Type \
6767 __spirv_SubgroupBlockReadINTEL<Type>(const OCL_GLOBAL PType *Ptr) noexcept { \
68- return * Ptr; \
68+ return Ptr[ __spirv_SubgroupLocalInvocationId ()]; \
6969 } \
7070 template <> \
7171 __SYCL_CONVERGENT__ DEVICE_EXTERNAL void \
7272 __spirv_SubgroupBlockWriteINTEL<Type>(PType OCL_GLOBAL * ptr, \
7373 Type v) noexcept { \
74- *( Type *)ptr = v; \
74+ (( Type*)ptr)[ __spirv_SubgroupLocalInvocationId ()] = v; \
7575 }
7676
7777#define DefSubgroupBlockINTEL_vt (Type, VT_name ) \
@@ -92,16 +92,19 @@ template <class T> struct vtypes {
9292DefSubgroupBlockINTEL (uint32_t ) DefSubgroupBlockINTEL(uint64_t )
9393DefSubgroupBlockINTEL(uint8_t ) DefSubgroupBlockINTEL(uint16_t )
9494
95- #define DefineGOp1 (spir_sfx, mux_name )\
96- DEVICE_EXTERN_C bool mux_name (bool );\
95+ #define DefineGOp1 (spir_sfx, name )\
96+ DEVICE_EXTERN_C bool __mux_sub_group_##name##_i1(bool );\
97+ DEVICE_EXTERN_C bool __mux_work_group_##name##_i1(uint32_t id, bool val);\
9798DEVICE_EXTERNAL bool __spirv_Group ## spir_sfx(unsigned g, bool val) {\
9899 if (__spv::Scope::Flag::Subgroup == g)\
99- return mux_name (val);\
100+ return __mux_sub_group_##name##_i1 (val);\
101+ else if (__spv::Scope::Flag::Workgroup == g)\
102+ return __mux_work_group_##name##_i1 (0 , val);\
100103 return false ;\
101104}
102105
103- DefineGOp1 (Any, __mux_sub_group_any_i1 )
104- DefineGOp1(All, __mux_sub_group_all_i1 )
106+ DefineGOp1 (Any, any )
107+ DefineGOp1(All, all )
105108
106109
107110#define DefineGOp (Type, MuxType, spir_sfx, mux_sfx ) \
@@ -184,18 +187,6 @@ DefineBitwiseGroupOp(uint64_t, int64_t, i64)
184187
185188DefineLogicalGroupOp(bool , bool , i1)
186189
187- #define DefineBroadCastImpl (Type, Sfx, MuxType, IDType ) \
188- DEVICE_EXTERN_C MuxType __mux_work_group_broadcast_##Sfx( \
189- int32_t id, MuxType val, int64_t lidx, int64_t lidy, int64_t lidz); \
190- DEVICE_EXTERN_C MuxType __mux_sub_group_broadcast_##Sfx(MuxType val, \
191- int32_t sg_lid); \
192- DEVICE_EXTERNAL Type __spirv_GroupBroadcast (uint32_t g, Type v, \
193- IDType l) { \
194- if (__spv::Scope::Flag::Subgroup == g) \
195- return __mux_sub_group_broadcast_##Sfx (v, l); \
196- return Type (); /* todo: add support for other flags as they are tested*/ \
197- }
198-
199190#define DefineBroadcastMuxType (Type, Sfx, MuxType, IDType ) \
200191 DEVICE_EXTERN_C MuxType __mux_work_group_broadcast_##Sfx( \
201192 int32_t id, MuxType val, uint64_t lidx, uint64_t lidy, uint64_t lidz); \
@@ -216,7 +207,7 @@ DefineLogicalGroupOp(bool, bool, i1)
216207 if (__spv::Scope::Flag::Subgroup == g) \
217208 return __mux_sub_group_broadcast_##Sfx (v, l[0 ]); \
218209 else \
219- return __mux_work_group_broadcast_##Sfx (0 , v, l[0 ], l[0 ], 0 ); \
210+ return __mux_work_group_broadcast_##Sfx (0 , v, l[0 ], l[1 ], 0 ); \
220211 } \
221212 \
222213 DEVICE_EXTERNAL Type __spirv_GroupBroadcast (uint32_t g, Type v, \
0 commit comments