Skip to content

Commit fbd441c

Browse files
hexagon : add cumsum op support (#21246)
* hexagon : add cumsum op support * hexagon: enable dma for cumsum op * Fix line-ending --------- Co-authored-by: Max Krasnyansky <maxk@qti.qualcomm.com>
1 parent c30e012 commit fbd441c

File tree

6 files changed

+347
-0
lines changed

6 files changed

+347
-0
lines changed

ggml/src/ggml-hexagon/ggml-hexagon.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2231,6 +2231,22 @@ static bool ggml_hexagon_supported_ssm_conv(const struct ggml_hexagon_session *
22312231
return true;
22322232
}
22332233

2234+
static bool ggml_hexagon_supported_cumsum(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
2235+
const struct ggml_tensor * src0 = op->src[0];
2236+
const struct ggml_tensor * dst = op;
2237+
2238+
if (src0->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
2239+
return false;
2240+
}
2241+
2242+
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
2243+
return false;
2244+
}
2245+
2246+
GGML_UNUSED(sess);
2247+
return true;
2248+
}
2249+
22342250
enum dspqbuf_type {
22352251
DSPQBUF_TYPE_DSP_WRITE_CPU_READ = 0,
22362252
DSPQBUF_TYPE_CPU_WRITE_DSP_READ,
@@ -2399,6 +2415,16 @@ static inline size_t init_repeat_req(htp_general_req * req, dspqueue_buffer * bu
23992415
return n_bufs;
24002416
}
24012417

2418+
static inline size_t init_cumsum_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
2419+
req->op = HTP_OP_CUMSUM;
2420+
2421+
size_t n_bufs = 0;
2422+
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2423+
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
2424+
2425+
return n_bufs;
2426+
}
2427+
24022428
static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
24032429
req->op = HTP_OP_GET_ROWS;
24042430

@@ -2780,6 +2806,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
27802806
ggml_hexagon_dispatch_op<init_ssm_conv_req>(sess, node, flags);
27812807
break;
27822808

2809+
case GGML_OP_CUMSUM:
2810+
ggml_hexagon_dispatch_op<init_cumsum_req>(sess, node, flags);
2811+
break;
2812+
27832813
default:
27842814
GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
27852815
}
@@ -3254,6 +3284,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
32543284
supp = ggml_hexagon_supported_ssm_conv(sess, op);
32553285
break;
32563286

3287+
case GGML_OP_CUMSUM:
3288+
supp = ggml_hexagon_supported_cumsum(sess, op);
3289+
break;
3290+
32573291
default:
32583292
break;
32593293
}

ggml/src/ggml-hexagon/htp/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ add_library(${HTP_LIB} SHARED
3333
repeat-ops.c
3434
argsort-ops.c
3535
ssm-conv.c
36+
cumsum-ops.c
3637
)
3738

3839
target_compile_definitions(${HTP_LIB} PRIVATE
Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
#pragma clang diagnostic ignored "-Wunused-variable"
2+
#pragma clang diagnostic ignored "-Wunused-function"
3+
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
4+
5+
#include <HAP_farf.h>
6+
#include <HAP_perf.h>
7+
8+
#define GGML_COMMON_DECL_C
9+
#include "ggml-common.h"
10+
#include "htp-ctx.h"
11+
#include "htp-ops.h"
12+
#include "hvx-types.h"
13+
#include "hvx-utils.h"
14+
#include "hex-dma.h"
15+
16+
#define htp_cumsum_tensors_preamble \
17+
struct htp_tensor * restrict src0 = &octx->src0; \
18+
struct htp_tensor * restrict dst = &octx->dst; \
19+
\
20+
const uint32_t ne00 = src0->ne[0]; \
21+
const uint32_t ne01 = src0->ne[1]; \
22+
const uint32_t ne02 = src0->ne[2]; \
23+
const uint32_t ne03 = src0->ne[3]; \
24+
\
25+
const uint32_t ne0 = dst->ne[0]; \
26+
const uint32_t ne1 = dst->ne[1]; \
27+
const uint32_t ne2 = dst->ne[2]; \
28+
const uint32_t ne3 = dst->ne[3]; \
29+
\
30+
const uint32_t nb00 = src0->nb[0]; \
31+
const uint32_t nb01 = src0->nb[1]; \
32+
const uint32_t nb02 = src0->nb[2]; \
33+
const uint32_t nb03 = src0->nb[3]; \
34+
\
35+
const uint32_t nb0 = dst->nb[0]; \
36+
const uint32_t nb1 = dst->nb[1]; \
37+
const uint32_t nb2 = dst->nb[2]; \
38+
const uint32_t nb3 = dst->nb[3];
39+
40+
struct htp_cumsum_context {
41+
struct htp_ops_context * octx;
42+
size_t src_row_size;
43+
size_t dst_row_size;
44+
size_t src_row_size_aligned;
45+
size_t dst_row_size_aligned;
46+
uint32_t rows_per_thread;
47+
uint32_t total_rows;
48+
};
49+
50+
#define htp_cumsum_preamble \
51+
struct htp_cumsum_context * cctx = (struct htp_cumsum_context *) data; \
52+
struct htp_ops_context * octx = cctx->octx; \
53+
htp_cumsum_tensors_preamble; \
54+
dma_queue * dma_queue = octx->ctx->dma[ith];
55+
56+
// ---------------------------------------------------------------------------
57+
// HVX prefix scan helpers
58+
// ---------------------------------------------------------------------------
59+
60+
#if __HVX_ARCH__ > 75
61+
static inline HVX_Vector hvx_cumsum_vadd(HVX_Vector a, HVX_Vector b) {
62+
return Q6_Vsf_vadd_VsfVsf(a, b);
63+
}
64+
#else
65+
static inline HVX_Vector hvx_cumsum_vadd(HVX_Vector a, HVX_Vector b) {
66+
return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b));
67+
}
68+
#endif // __HVX_ARCH__ > 75
69+
70+
static inline HVX_Vector hvx_prefix_scan_f32(HVX_Vector v, HVX_Vector carry_in) {
71+
const HVX_Vector zero = Q6_V_vsplat_R(0);
72+
73+
v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 4));
74+
v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 8));
75+
v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 16));
76+
v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 32));
77+
v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 64));
78+
v = hvx_cumsum_vadd(v, carry_in);
79+
80+
return v;
81+
}
82+
83+
static inline HVX_Vector hvx_splat_last_f32(HVX_Vector v) {
84+
return hvx_vec_repl4(Q6_V_vror_VR(v, 124));
85+
}
86+
87+
static inline void hvx_cumsum_row_f32(const float * restrict src, float * restrict dst, uint32_t n) {
88+
const uint32_t nvec = n / VLEN_FP32;
89+
const uint32_t nloe = n % VLEN_FP32;
90+
91+
HVX_Vector carry = Q6_V_vsplat_R(0);
92+
93+
for (uint32_t i = 0; i < nvec; i++) {
94+
HVX_Vector v = *((const HVX_UVector *) (src + i * VLEN_FP32));
95+
v = hvx_prefix_scan_f32(v, carry);
96+
hvx_vec_store_u(dst + i * VLEN_FP32, VLEN, v);
97+
carry = hvx_splat_last_f32(v);
98+
}
99+
100+
if (nloe) {
101+
float acc = hvx_vec_get_f32(carry);
102+
const float * src_tail = src + nvec * VLEN_FP32;
103+
float * dst_tail = dst + nvec * VLEN_FP32;
104+
for (uint32_t i = 0; i < nloe; i++) {
105+
acc += src_tail[i];
106+
dst_tail[i] = acc;
107+
}
108+
}
109+
}
110+
111+
// ---------------------------------------------------------------------------
112+
// Per thread worker: Double-buffered DMA
113+
// ---------------------------------------------------------------------------
114+
115+
static void cumsum_thread_f32_dma(unsigned int nth, unsigned int ith, void * data) {
116+
htp_cumsum_preamble;
117+
118+
uint64_t t1, t2;
119+
t1 = HAP_perf_get_qtimer_count();
120+
121+
const uint32_t ir0 = cctx->rows_per_thread * ith;
122+
const uint32_t ir1 = MIN(ir0 + cctx->rows_per_thread, cctx->total_rows);
123+
124+
if (ir0 >= ir1) {
125+
return;
126+
}
127+
128+
const size_t src_row_size = cctx->src_row_size;
129+
const size_t dst_row_size = cctx->dst_row_size;
130+
const size_t src_row_size_aligned = cctx->src_row_size_aligned;
131+
const size_t dst_row_size_aligned = cctx->dst_row_size_aligned;
132+
133+
const uint8_t * src_data = (const uint8_t *) src0->data;
134+
uint8_t * dst_data = (uint8_t *) dst->data;
135+
136+
uint8_t * src_spad = octx->src0_spad.data + (ith * src_row_size_aligned * 2);
137+
uint8_t * dst_spad = octx->dst_spad.data + (ith * dst_row_size_aligned * 2);
138+
139+
for (uint32_t ir = ir0, spad_idx = 0; ir < ir1 && spad_idx < 2; ir++, spad_idx++) {
140+
// Dummy dst writeback to establish queue ordering
141+
dma_queue_push_vtcm_to_ddr(dma_queue,
142+
dma_make_ptr(dst_data, dst_spad + (spad_idx * dst_row_size_aligned)),
143+
dst_row_size, dst_row_size_aligned, 0);
144+
145+
dma_queue_push_ddr_to_vtcm(dma_queue,
146+
dma_make_ptr(src_spad + (spad_idx * src_row_size_aligned),
147+
src_data + (ir * src_row_size)),
148+
src_row_size_aligned, src_row_size, 1);
149+
}
150+
151+
for (uint32_t ir = ir0; ir < ir1; ir++) {
152+
float * dst_spad_row = (float *) dma_queue_pop(dma_queue).src;
153+
float * src_spad_row = (float *) dma_queue_pop(dma_queue).dst;
154+
155+
hvx_cumsum_row_f32(src_spad_row, dst_spad_row, ne00);
156+
157+
dma_queue_push_vtcm_to_ddr(dma_queue,
158+
dma_make_ptr(dst_data + (ir * dst_row_size), (uint8_t *) dst_spad_row),
159+
dst_row_size, dst_row_size_aligned, 1);
160+
161+
const uint32_t next_row = ir + 2;
162+
if (next_row < ir1) {
163+
dma_queue_push_ddr_to_vtcm(dma_queue,
164+
dma_make_ptr((uint8_t *) src_spad_row, src_data + (next_row * src_row_size)),
165+
src_row_size_aligned, src_row_size, 1);
166+
}
167+
}
168+
169+
dma_queue_flush(dma_queue);
170+
t2 = HAP_perf_get_qtimer_count();
171+
172+
FARF(HIGH, "cumsum-f32-dma %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n",
173+
ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1,
174+
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
175+
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
176+
}
177+
178+
// ---------------------------------------------------------------------------
179+
// Per thread worker: Direct HVX (no DMA)
180+
// ---------------------------------------------------------------------------
181+
182+
static void cumsum_thread_f32(unsigned int nth, unsigned int ith, void * data) {
183+
htp_cumsum_preamble;
184+
185+
uint64_t t1, t2;
186+
t1 = HAP_perf_get_qtimer_count();
187+
188+
const uint8_t * src_data = (const uint8_t *) src0->data;
189+
uint8_t * dst_data = (uint8_t *) dst->data;
190+
191+
const uint32_t ir0 = cctx->rows_per_thread * ith;
192+
const uint32_t ir1 = MIN(ir0 + cctx->rows_per_thread, cctx->total_rows);
193+
194+
for (uint32_t ir = ir0; ir < ir1; ir++) {
195+
const float * restrict src_row = (const float *) (src_data + ir * cctx->src_row_size);
196+
float * restrict dst_row = (float *) (dst_data + ir * cctx->dst_row_size);
197+
hvx_cumsum_row_f32(src_row, dst_row, ne00);
198+
}
199+
200+
t2 = HAP_perf_get_qtimer_count();
201+
202+
FARF(HIGH, "cumsum-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n",
203+
ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1,
204+
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
205+
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
206+
}
207+
208+
int op_cumsum_f32(struct htp_ops_context * octx) {
209+
const struct htp_tensor * src0 = &octx->src0;
210+
const struct htp_tensor * dst = &octx->dst;
211+
212+
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
213+
return HTP_STATUS_OK;
214+
}
215+
216+
const uint32_t total_rows = src0->ne[1] * src0->ne[2] * src0->ne[3];
217+
const uint32_t n_threads = MIN(octx->n_threads, total_rows);
218+
219+
const size_t src_row_size = src0->nb[1];
220+
const size_t dst_row_size = dst->nb[1];
221+
const size_t src_row_size_aligned = hex_round_up(src_row_size, VLEN);
222+
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
223+
224+
// 2 ping-pong buffers per thread for src and dst
225+
const size_t spad_per_thread = 2 * (src_row_size_aligned + dst_row_size_aligned);
226+
227+
octx->src0_spad.size_per_thread = src_row_size_aligned * 2;
228+
octx->dst_spad.size_per_thread = dst_row_size_aligned * 2;
229+
octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
230+
octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread;
231+
octx->src0_spad.data = octx->ctx->vtcm_base;
232+
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
233+
234+
struct htp_cumsum_context cctx = {
235+
.octx = octx,
236+
.src_row_size = src_row_size,
237+
.dst_row_size = dst_row_size,
238+
.src_row_size_aligned = src_row_size_aligned,
239+
.dst_row_size_aligned = dst_row_size_aligned,
240+
.rows_per_thread = (total_rows + n_threads - 1) / n_threads,
241+
.total_rows = total_rows,
242+
};
243+
244+
if (octx->ctx->vtcm_size < spad_per_thread * n_threads) {
245+
worker_pool_run_func(octx->ctx->worker_pool, cumsum_thread_f32, &cctx, n_threads);
246+
} else {
247+
worker_pool_run_func(octx->ctx->worker_pool, cumsum_thread_f32_dma, &cctx, n_threads);
248+
}
249+
250+
return HTP_STATUS_OK;
251+
}
252+
253+
int op_cumsum(struct htp_ops_context * octx) {
254+
int err = HTP_STATUS_OK;
255+
struct htp_tensor * dst = &octx->dst;
256+
257+
switch (dst->type) {
258+
case HTP_TYPE_F32:
259+
err = op_cumsum_f32(octx);
260+
break;
261+
default:
262+
err = HTP_STATUS_NO_SUPPORT;
263+
break;
264+
}
265+
266+
return err;
267+
}

ggml/src/ggml-hexagon/htp/htp-msg.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ enum htp_op {
7575
HTP_OP_SUM_ROWS,
7676
HTP_OP_SSM_CONV,
7777
HTP_OP_REPEAT,
78+
HTP_OP_CUMSUM,
7879
INVALID
7980
};
8081

ggml/src/ggml-hexagon/htp/htp-ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,5 +60,6 @@ int op_cpy(struct htp_ops_context * octx);
6060
int op_repeat(struct htp_ops_context * octx);
6161
int op_argsort(struct htp_ops_context * octx);
6262
int op_ssm_conv(struct htp_ops_context * octx);
63+
int op_cumsum(struct htp_ops_context * octx);
6364

6465
#endif /* HTP_OPS_H */

0 commit comments

Comments
 (0)