@@ -85,7 +85,7 @@ inline void snrt_dma_set_awuser(uint64_t field) {
8585 */
8686inline void snrt_dma_enable_multicast (uint64_t mask ) {
8787 snrt_collective_op_t op ;
88- // op.f.collective_opcode = SNRT_COLLECTIVE_MULTICAST;
88+ op .f .collective_op = SNRT_COLLECTIVE_MULTICAST ;
8989 op .f .mask = mask ;
9090 snrt_dma_set_awuser (op .w );
9191}
@@ -99,10 +99,9 @@ inline void snrt_dma_enable_multicast(uint64_t mask) {
9999 * @param opcode Type of reduction operation
100100 */
101101inline void snrt_dma_enable_reduction (uint64_t mask ,
102- snrt_reduction_opcode_t opcode ) {
102+ snrt_collective_opcode_t opcode ) {
103103 snrt_collective_op_t op ;
104- // op.f.reduction_opcode = opcode;
105- // op.f.collective_opcode = SNRT_COLLECTIVE_OFFLOAD_REDUCTION;
104+ op .f .collective_op = opcode ;
106105 op .f .mask = mask ;
107106 snrt_dma_set_awuser (op .w );
108107}
@@ -129,7 +128,7 @@ inline void snrt_dma_disable_reduction() { snrt_dma_set_awuser(0); }
129128 */
130129static inline uint32_t snrt_dma_start_1d_reduction (
131130 uint64_t dst , uint64_t src , size_t size , uint64_t mask ,
132- snrt_reduction_opcode_t opcode , const uint32_t channel = 0 ) {
131+ snrt_collective_opcode_t opcode , const uint32_t channel = 0 ) {
133132 snrt_dma_enable_reduction (mask , opcode );
134133 uint32_t txid = snrt_dma_start_1d (dst , src , size , channel );
135134 snrt_dma_disable_reduction ();
@@ -162,7 +161,7 @@ static inline uint32_t snrt_dma_start_1d_mcast(uint64_t dst, uint64_t src,
162161 */
163162static inline uint32_t snrt_dma_start_1d_reduction (
164163 volatile void * dst , volatile void * src , size_t size , uint64_t mask ,
165- snrt_reduction_opcode_t opcode , const uint32_t channel = 0 ) {
164+ snrt_collective_opcode_t opcode , const uint32_t channel = 0 ) {
166165 return snrt_dma_start_1d_reduction ((uint64_t )dst , (uint64_t )src , size , mask ,
167166 opcode , channel );
168167}
@@ -423,7 +422,7 @@ inline snrt_dma_txid_t snrt_dma_load_1d_tile_mcast(void *dst, void *src,
423422 */
424423inline snrt_dma_txid_t snrt_dma_reduction_load_1d_tile (
425424 void * dst , void * src , size_t tile_idx , size_t tile_size , uint32_t prec ,
426- uint64_t mask , snrt_reduction_opcode_t opcode ) {
425+ uint64_t mask , snrt_collective_opcode_t opcode ) {
427426 size_t tile_nbytes = tile_size * prec ;
428427 return snrt_dma_start_1d_reduction ((uintptr_t )dst ,
429428 (uintptr_t )src + tile_idx * tile_nbytes ,
0 commit comments