Skip to content

Commit 0c7ff98

Browse files
committed
sw: Align collective support with hardware
1 parent e752b8d commit 0c7ff98

File tree

5 files changed

+18
-38
lines changed

5 files changed

+18
-38
lines changed

sw/runtime/api/sync_decls.h

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ typedef struct {
2323
typedef snrt_comm_info_t *snrt_comm_t;
2424

2525
typedef enum {
26-
SNRT_REDUCTION_NONE = 0,
26+
SNRT_COLLECTIVE_UNICAST = 0,
27+
SNRT_COLLECTIVE_MULTICAST = 1,
2728
SNRT_REDUCTION_BARRIER = 2,
2829
SNRT_REDUCTION_FADD = 4,
2930
SNRT_REDUCTION_FMUL = 5,
@@ -35,21 +36,13 @@ typedef enum {
3536
SNRT_REDUCTION_MAX = 11,
3637
SNRT_REDUCTION_MINU = 14,
3738
SNRT_REDUCTION_MAXU = 15
38-
} snrt_reduction_opcode_t;
39-
40-
typedef enum {
41-
SNRT_COLLECTIVE_UNICAST = 0,
42-
SNRT_COLLECTIVE_MULTICAST = 1,
43-
SNRT_COLLECTIVE_PARALLEL_REDUCTION = 2,
44-
SNRT_COLLECTIVE_OFFLOAD_REDUCTION = 3
4539
} snrt_collective_opcode_t;
4640

4741
typedef union {
4842
struct __attribute__((__packed__)) {
49-
// snrt_reduction_opcode_t reduction_opcode : SNRT_REDUCTION_OPCODE_WIDTH;
50-
// snrt_collective_opcode_t collective_opcode
51-
// : SNRT_COLLECTIVE_OPCODE_WIDTH;
52-
uint64_t mask : SNRT_COLLECTIVE_MASK_WIDTH;
43+
snrt_collective_opcode_t collective_op
44+
: SNRT_COLLECTIVE_OPCODE_WIDTH;
45+
uint64_t mask : (64 - SNRT_COLLECTIVE_OPCODE_WIDTH);
5346
} f;
5447
uint64_t w;
5548
} snrt_collective_op_t;
@@ -79,7 +72,7 @@ inline void snrt_enable_multicast(uint64_t mask);
7972
inline void snrt_disable_multicast();
8073

8174
inline void snrt_enable_reduction(uint64_t mask,
82-
snrt_reduction_opcode_t reduction);
75+
snrt_collective_opcode_t reduction);
8376

8477
inline void snrt_disable_reduction();
8578

sw/runtime/impl/snitch_cluster_cfg.h.tpl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@
3434
#define SNRT_SUPPORTS_NARROW_REDUCTION
3535
% endif
3636

37-
#define SNRT_REDUCTION_OPCODE_WIDTH ${cfg['cluster']['reduction_opcode_width']}
38-
#define SNRT_COLLECTIVE_OPCODE_WIDTH ${cfg['cluster']['collective_width'] - cfg['cluster']['reduction_opcode_width']}
37+
#define SNRT_COLLECTIVE_OPCODE_WIDTH ${cfg['cluster']['collective_width']}
3938

4039
// Software configuration
4140
#define SNRT_LOG2_STACK_SIZE 10

sw/runtime/src/dma.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ extern void snrt_dma_set_awuser(uint64_t field);
77
extern void snrt_dma_enable_multicast(uint64_t mask);
88

99
extern void snrt_dma_enable_reduction(uint64_t mask,
10-
snrt_reduction_opcode_t opcode);
10+
snrt_collective_opcode_t opcode);
1111

1212
extern void snrt_dma_disable_multicast();
1313

@@ -32,7 +32,7 @@ extern snrt_dma_txid_t snrt_dma_load_1d_tile_mcast(void *dst, void *src,
3232

3333
extern snrt_dma_txid_t snrt_dma_reduction_load_1d_tile(
3434
void *dst, void *src, size_t tile_idx, size_t tile_size, uint32_t prec,
35-
uint64_t mask, snrt_reduction_opcode_t opcode);
35+
uint64_t mask, snrt_collective_opcode_t opcode);
3636

3737
extern snrt_dma_txid_t snrt_dma_1d_to_2d(volatile void *dst, volatile void *src,
3838
size_t size, size_t row_size,

sw/runtime/src/dma.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ inline void snrt_dma_set_awuser(uint64_t field) {
8585
*/
8686
inline void snrt_dma_enable_multicast(uint64_t mask) {
8787
snrt_collective_op_t op;
88-
// op.f.collective_opcode = SNRT_COLLECTIVE_MULTICAST;
88+
op.f.collective_op = SNRT_COLLECTIVE_MULTICAST;
8989
op.f.mask = mask;
9090
snrt_dma_set_awuser(op.w);
9191
}
@@ -99,10 +99,9 @@ inline void snrt_dma_enable_multicast(uint64_t mask) {
9999
* @param opcode Type of reduction operation
100100
*/
101101
inline void snrt_dma_enable_reduction(uint64_t mask,
102-
snrt_reduction_opcode_t opcode) {
102+
snrt_collective_opcode_t opcode) {
103103
snrt_collective_op_t op;
104-
// op.f.reduction_opcode = opcode;
105-
// op.f.collective_opcode = SNRT_COLLECTIVE_OFFLOAD_REDUCTION;
104+
op.f.collective_op = opcode;
106105
op.f.mask = mask;
107106
snrt_dma_set_awuser(op.w);
108107
}
@@ -129,7 +128,7 @@ inline void snrt_dma_disable_reduction() { snrt_dma_set_awuser(0); }
129128
*/
130129
static inline uint32_t snrt_dma_start_1d_reduction(
131130
uint64_t dst, uint64_t src, size_t size, uint64_t mask,
132-
snrt_reduction_opcode_t opcode, const uint32_t channel = 0) {
131+
snrt_collective_opcode_t opcode, const uint32_t channel = 0) {
133132
snrt_dma_enable_reduction(mask, opcode);
134133
uint32_t txid = snrt_dma_start_1d(dst, src, size, channel);
135134
snrt_dma_disable_reduction();
@@ -162,7 +161,7 @@ static inline uint32_t snrt_dma_start_1d_mcast(uint64_t dst, uint64_t src,
162161
*/
163162
static inline uint32_t snrt_dma_start_1d_reduction(
164163
volatile void *dst, volatile void *src, size_t size, uint64_t mask,
165-
snrt_reduction_opcode_t opcode, const uint32_t channel = 0) {
164+
snrt_collective_opcode_t opcode, const uint32_t channel = 0) {
166165
return snrt_dma_start_1d_reduction((uint64_t)dst, (uint64_t)src, size, mask,
167166
opcode, channel);
168167
}
@@ -423,7 +422,7 @@ inline snrt_dma_txid_t snrt_dma_load_1d_tile_mcast(void *dst, void *src,
423422
*/
424423
inline snrt_dma_txid_t snrt_dma_reduction_load_1d_tile(
425424
void *dst, void *src, size_t tile_idx, size_t tile_size, uint32_t prec,
426-
uint64_t mask, snrt_reduction_opcode_t opcode) {
425+
uint64_t mask, snrt_collective_opcode_t opcode) {
427426
size_t tile_nbytes = tile_size * prec;
428427
return snrt_dma_start_1d_reduction((uintptr_t)dst,
429428
(uintptr_t)src + tile_idx * tile_nbytes,

sw/runtime/src/sync.h

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ inline void snrt_set_awuser(uint64_t field) {
432432
*/
433433
inline void snrt_enable_multicast(uint64_t mask) {
434434
snrt_collective_op_t op;
435-
// op.f.collective_opcode = SNRT_COLLECTIVE_MULTICAST;
435+
op.f.collective_op = SNRT_COLLECTIVE_MULTICAST;
436436
op.f.mask = mask;
437437
snrt_set_awuser(op.w);
438438
}
@@ -454,21 +454,10 @@ inline void snrt_disable_multicast() { snrt_set_awuser(0); }
454454
* @param opcode Type of reduction operation
455455
*/
456456
inline void snrt_enable_reduction(uint64_t mask,
457-
snrt_reduction_opcode_t opcode) {
458-
snrt_collective_opcode_t coll_opcode;
459-
460-
switch (opcode) {
461-
case SNRT_REDUCTION_BARRIER:
462-
coll_opcode = SNRT_COLLECTIVE_PARALLEL_REDUCTION;
463-
break;
464-
default:
465-
coll_opcode = SNRT_COLLECTIVE_OFFLOAD_REDUCTION;
466-
break;
467-
}
457+
snrt_collective_opcode_t opcode) {
468458

469459
snrt_collective_op_t op;
470-
// op.f.reduction_opcode = opcode;
471-
// op.f.collective_opcode = coll_opcode;
460+
op.f.collective_op = opcode;
472461
op.f.mask = mask;
473462
snrt_set_awuser(op.w);
474463
}

0 commit comments

Comments
 (0)