-
Notifications
You must be signed in to change notification settings - Fork 2
Description
AI feedback from #96 on the rpmsg library code to look into later:
Length Handling
In pru_rpmsg_receive, the code copies msg->len bytes into data without validating that msg->len fits the caller-provided buffer (*len is only set after the copy). This risks buffer overflow if the caller passes a smaller buffer. Consider passing the max buffer length into the function and clamping/validating before memcpy, or returning an error if too large.
source/rpmsg/pru_rpmsg.c [109-113]
+/* Verify descriptor buffer is large enough for header + payload */
+if (msg_len < (sizeof(struct pru_rpmsg_hdr) + len)) {
+ /* Return the buffer back to used ring to keep vring consistent */
+ pru_virtqueue_add_used_buf(virtqueue, head, msg_len);
+ return PRU_RPMSG_BUF_TOO_SMALL;
+}
+
/* Copy the message payload to the local data buffer provided */
memcpy(data, msg->data, msg->len);
*src = msg->src;
*dst = msg->dst;
*len = msg->len;
OR
In pru_rpmsg_receive, check that the received message length does not exceed the
size of the destination buffer to prevent a buffer overflow.
source/rpmsg/pru_rpmsg.c [149-153]
+/* Ensure caller buffer is large enough for payload */
+if (msg->len > *len) {
+ /* Recycle buffer before returning error */
+ pru_virtqueue_add_used_buf(virtqueue, head, msg_len);
+ pru_virtqueue_kick(virtqueue);
+ return PRU_RPMSG_BUF_TOO_SMALL;
+}
+
/* Copy the message payload to the local data buffer provided */
memcpy(data, msg->data, msg->len);
*src = msg->src;
*dst = msg->dst;
*len = msg->len;
OR
Validate received buffer sizes
*In pru_rpmsg_receive, add validation to ensure the received message length
msg->len is consistent with the available buffer size msg_len and does not
exceed the destination buffer's capacity len to prevent buffer overflows.
source/rpmsg/pru_rpmsg.c [135-153]
int16_t head;
struct pru_rpmsg_hdr *msg;
uint32_t msg_len;
struct pru_virtqueue *virtqueue;
virtqueue = &transport->virtqueue1;
/* Get an available buffer */
head = pru_virtqueue_get_avail_buf(virtqueue, (void **)&msg, &msg_len);
if (head < 0)
return PRU_RPMSG_NO_BUF_AVAILABLE;
+/* Validate descriptor and payload lengths */
+if (msg_len < sizeof(struct pru_rpmsg_hdr) ||
+ msg->len > (msg_len - sizeof(struct pru_rpmsg_hdr)) ||
+ msg->len > *len) {
+ /* Return buffer and report size error */
+ pru_virtqueue_add_used_buf(virtqueue, head, msg_len);
+ pru_virtqueue_kick(virtqueue);
+ return PRU_RPMSG_BUF_TOO_SMALL;
+}
/* Copy the message payload to the local data buffer provided */
memcpy(data, msg->data, msg->len);
*src = msg->src;
*dst = msg->dst;
*len = msg->len;
+/* Add the used buffer */
+if (pru_virtqueue_add_used_buf(virtqueue, head, msg_len) < 0)
+ return PRU_RPMSG_INVALID_HEAD;
+
+/* Kick the ARM host */
+pru_virtqueue_kick(virtqueue);
+
+return PRU_RPMSG_SUCCESS;
+
In pru_rpmsg_send, validate that the available buffer length msg_len is
sufficient for the message header and payload before performing memcpy to
prevent a potential buffer overflow.
source/rpmsg/pru_rpmsg.c [98-111]
if (len > (RPMSG_BUF_SIZE - sizeof(struct pru_rpmsg_hdr)))
return PRU_RPMSG_BUF_TOO_SMALL;
virtqueue = &transport->virtqueue0;
/* Get an available buffer */
head = pru_virtqueue_get_avail_buf(virtqueue, (void **)&msg, &msg_len);
if (head < 0)
return PRU_RPMSG_NO_BUF_AVAILABLE;
+/* Ensure descriptor has enough space for header + payload */
+if (msg_len < (sizeof(struct pru_rpmsg_hdr) + len)) {
+ /* Return the buffer back as used with zero length to avoid stalling */
+ pru_virtqueue_add_used_buf(virtqueue, head, 0);
+ pru_virtqueue_kick(virtqueue);
+ return PRU_RPMSG_BUF_TOO_SMALL;
+}
+
/* Copy local data buffer to the descriptor buffer address */
memcpy(msg->data, data, len);
msg->len = len;
+msg->dst = dst;
+msg->src = src;
+msg->flags = 0;
+msg->reserved = 0;
+/* Add the used buffer with actual message size */
+if (pru_virtqueue_add_used_buf(virtqueue, head, sizeof(struct pru_rpmsg_hdr) + len) < 0)
+ return PRU_RPMSG_INVALID_HEAD;
+
+/* Kick the ARM host */
+pru_virtqueue_kick(virtqueue);
+
+return PRU_RPMSG_SUCCESS;
+
Head Index Validation
pru_virtqueue_add_used_buf checks if (head > num) but valid indices are [0, num-1]. The condition should likely be head >= num or also guard negative head to prevent out-of-range writes to the used ring.
if (head > num)
return PRU_VIRTQUEUE_INVALID_HEAD;
/*
* The virtqueue's vring contains a ring of used buffers. Get a pointer to
* the next entry in that used ring.
*/
used_elem = &used->ring[used->idx++ & (num - 1)];
used_elem->id = head;
used_elem->len = len;
Validate vring head index
In pru_virtqueue_get_avail_buf, validate the descriptor head index to ensure it
is within the valid range before using it to access the descriptor array,
preventing a potential out-of-bounds read.
source/rpmsg/pru_virtqueue.c [87-94]
/*
* Grab the next descriptor number the ARM host is advertising, and
* increment the last available index we've seen.
*/
head = avail->ring[vq->last_avail_idx++ & (vq->vring.num - 1)];
+/* Validate descriptor head */
+if (head < 0 || head >= vq->vring.num) {
+ return PRU_VIRTQUEUE_INVALID_HEAD;
+}
+
desc = vq->vring.desc[head];
*buf = (void *)(uint32_t)desc.addr;
*len = desc.len;
src & dst are uint32_t. however, source/rpmsg/pru_rpmsg.c defines *src & *dst as uint16_t in pru_rpmsg_receive. Should those be updated to uint32_t? Or perhaps should the main.c code call these a different variable name to indicate they are pointers?
int16_t pru_rpmsg_receive(
struct pru_rpmsg_transport *transport,
uint16_t *src,
uint16_t *dst,
void *data,
uint16_t *len
)
Add memory barrier on used index
In pru_virtqueue_add_used_buf, insert a memory barrier before incrementing
used->idx to ensure the used ring entry is fully written, preventing a race
condition and potential data corruption on the host side.
source/rpmsg/pru_virtqueue.c [119]
-used_elem = &used->ring[used->idx++ & (num - 1)];
+used_elem = &used->ring[used->idx & (num - 1)];
+used_elem->id = head;
+used_elem->len = len;
+/* Ensure used entry is visible before updating idx */
+__sync_synchronize();
+used->idx++;
+