diff --git a/subsys/net/ip/tcp.c b/subsys/net/ip/tcp.c index 270675fe27f..f86e782ac9f 100644 --- a/subsys/net/ip/tcp.c +++ b/subsys/net/ip/tcp.c @@ -2719,7 +2719,7 @@ static void tcp_queue_recv_data(struct tcp *conn, struct net_pkt *pkt, } static enum net_verdict tcp_data_received(struct tcp *conn, struct net_pkt *pkt, - size_t *len, bool psh) + size_t *len, bool psh, bool fin) { enum net_verdict ret; @@ -2732,6 +2732,13 @@ static enum net_verdict tcp_data_received(struct tcp *conn, struct net_pkt *pkt, net_stats_update_tcp_seg_recv(conn->iface); conn_ack(conn, *len); + /* In case FIN was received, don't send ACK just yet, FIN,ACK will be + * sent instead. + */ + if (fin) { + return ret; + } + /* Delay ACK response in case of small window or missing PSH, * as described in RFC 813. */ @@ -3143,37 +3150,8 @@ static enum net_verdict tcp_in(struct tcp *conn, struct net_pkt *pkt) } break; - case TCP_ESTABLISHED: - /* full-close */ - if (FL(&fl, &, FIN, th_seq(th) == conn->ack)) { - if (len) { - verdict = tcp_data_get(conn, pkt, &len); - if (verdict == NET_OK) { - /* net_pkt owned by the recv fifo now */ - pkt = NULL; - } - } else { - verdict = NET_OK; - } - - conn_ack(conn, + len + 1); - keep_alive_timer_stop(conn); - - if (net_tcp_seq_cmp(th_ack(th), conn->seq) > 0) { - uint32_t len_acked = th_ack(th) - conn->seq; - - conn_seq(conn, + len_acked); - } - - tcp_out(conn, FIN | ACK); - conn_seq(conn, + 1); - tcp_setup_retransmission(conn); - - tcp_setup_last_ack_timer(conn); - next = TCP_LAST_ACK; - - break; - } + case TCP_ESTABLISHED: { + bool fin = FL(&fl, &, FIN, th_seq(th) == conn->ack); /* Whatever we've received, we know that peer is alive, so reset * the keepalive timer. @@ -3279,11 +3257,21 @@ static enum net_verdict tcp_in(struct tcp *conn, struct net_pkt *pkt) /* We are closing the connection, send a FIN to peer */ if (conn->in_close && conn->send_data_total == 0) { - next = TCP_FIN_WAIT_1; - - k_work_reschedule_for_queue(&tcp_work_q, - &conn->fin_timer, - FIN_TIMEOUT); + if (fin) { + /* If FIN was also present in the processed + * packet, acknowledge that and jump directly + * to TCP_LAST_ACK. + */ + conn_ack(conn, + 1); + next = TCP_LAST_ACK; + tcp_setup_last_ack_timer(conn); + } else { + /* Otherwise, wait for FIN in TCP_FIN_WAIT_1 */ + next = TCP_FIN_WAIT_1; + k_work_reschedule_for_queue(&tcp_work_q, + &conn->fin_timer, + FIN_TIMEOUT); + } tcp_out(conn, FIN | ACK); conn_seq(conn, + 1); @@ -3314,7 +3302,7 @@ static enum net_verdict tcp_in(struct tcp *conn, struct net_pkt *pkt) data_recv: psh = FL(&fl, &, PSH); - verdict = tcp_data_received(conn, pkt, &len, psh); + verdict = tcp_data_received(conn, pkt, &len, psh, fin); if (verdict == NET_OK) { /* net_pkt owned by the recv fifo now */ pkt = NULL; @@ -3358,7 +3346,19 @@ static enum net_verdict tcp_in(struct tcp *conn, struct net_pkt *pkt) k_sem_give(&conn->tx_sem); } + /* Finally, after all Data/ACK processing, check for FIN flag. */ + if (fin) { + keep_alive_timer_stop(conn); + conn_ack(conn, + 1); + tcp_out(conn, FIN | ACK); + conn_seq(conn, + 1); + tcp_setup_retransmission(conn); + tcp_setup_last_ack_timer(conn); + next = TCP_LAST_ACK; + } + break; + } case TCP_CLOSE_WAIT: /* Half-close is not supported, so do nothing here */ break; diff --git a/subsys/net/lib/dns/llmnr_responder.c b/subsys/net/lib/dns/llmnr_responder.c index c8b1500b1b0..7dfa682f53d 100644 --- a/subsys/net/lib/dns/llmnr_responder.c +++ b/subsys/net/lib/dns/llmnr_responder.c @@ -192,23 +192,19 @@ static void setup_dns_hdr(uint8_t *buf, uint16_t answers, uint16_t dns_id) static void add_question(struct net_buf *query, enum dns_rr_type qtype) { - char *dot = query->data + DNS_MSG_HEADER_SIZE; - char *prev = NULL; + char *dot = query->data + DNS_MSG_HEADER_SIZE + 1; + char *prev = query->data + DNS_MSG_HEADER_SIZE; uint16_t offset; - while ((dot = strchr(dot, '.'))) { - if (!prev) { - prev = dot++; - continue; - } + /* For the length of the first label. */ + query->len += 1; + while ((dot = strchr(dot, '.')) != NULL) { *prev = dot - prev - 1; prev = dot++; } - if (prev) { - *prev = strlen(prev) - 1; - } + *prev = strlen(prev + 1); offset = DNS_MSG_HEADER_SIZE + query->len + 1; UNALIGNED_PUT(htons(qtype), (uint16_t *)(query->data+offset)); @@ -245,14 +241,15 @@ static int create_answer(enum dns_rr_type qtype, /* Prepare the response into the query buffer: move the name * query buffer has to get enough free space: dns_hdr + query + answer */ - if ((net_buf_max_len(query) - query->len) < (DNS_MSG_HEADER_SIZE + + if ((net_buf_max_len(query) - query->len) < (DNS_MSG_HEADER_SIZE + 1 + (DNS_QTYPE_LEN + DNS_QCLASS_LEN) * 2 + DNS_TTL_LEN + DNS_RDLENGTH_LEN + addr_len + query->len)) { return -ENOBUFS; } - memmove(query->data + DNS_MSG_HEADER_SIZE, query->data, query->len); + /* +1 for the initial label length */ + memmove(query->data + DNS_MSG_HEADER_SIZE + 1, query->data, query->len); setup_dns_hdr(query->data, 1, dns_id); @@ -488,8 +485,8 @@ static int dns_read(int sock, result->data, ret); /* If the query matches to our hostname, then send reply */ - if (!strncasecmp(hostname, result->data + 1, hostname_len) && - (result->len - 1) >= hostname_len) { + if (!strncasecmp(hostname, result->data, hostname_len) && + (result->len) >= hostname_len) { NET_DBG("%s query to our hostname %s", "LLMNR", hostname); ret = send_response(sock, src_addr, addrlen, result, qtype, diff --git a/subsys/net/lib/dns/mdns_responder.c b/subsys/net/lib/dns/mdns_responder.c index a867cd079a1..e60ab9b7002 100644 --- a/subsys/net/lib/dns/mdns_responder.c +++ b/subsys/net/lib/dns/mdns_responder.c @@ -274,23 +274,19 @@ static void setup_dns_hdr(uint8_t *buf, uint16_t answers) static void add_answer(struct net_buf *query, enum dns_rr_type qtype, uint32_t ttl, uint16_t addr_len, uint8_t *addr) { - char *dot = query->data + DNS_MSG_HEADER_SIZE; - char *prev = NULL; + char *dot = query->data + DNS_MSG_HEADER_SIZE + 1; + char *prev = query->data + DNS_MSG_HEADER_SIZE; uint16_t offset; - while ((dot = strchr(dot, '.'))) { - if (!prev) { - prev = dot++; - continue; - } + /* For the length of the first label. */ + query->len += 1; + while ((dot = strchr(dot, '.')) != NULL) { *prev = dot - prev - 1; prev = dot++; } - if (prev) { - *prev = strlen(prev) - 1; - } + *prev = strlen(prev + 1); /* terminator byte (0x00) */ query->len += 1; @@ -322,14 +318,15 @@ static int create_answer(int sock, /* Prepare the response into the query buffer: move the name * query buffer has to get enough free space: dns_hdr + answer */ - if ((net_buf_max_len(query) - query->len) < (DNS_MSG_HEADER_SIZE + + if ((net_buf_max_len(query) - query->len) < (DNS_MSG_HEADER_SIZE + 1 + DNS_QTYPE_LEN + DNS_QCLASS_LEN + DNS_TTL_LEN + DNS_RDLENGTH_LEN + addr_len)) { return -ENOBUFS; } - memmove(query->data + DNS_MSG_HEADER_SIZE, query->data, query->len); + /* +1 for the initial label length */ + memmove(query->data + DNS_MSG_HEADER_SIZE + 1, query->data, query->len); setup_dns_hdr(query->data, 1); @@ -641,7 +638,7 @@ static int dns_read(int sock, } /* Handle only .local queries */ - lquery = strrchr(result->data + 1, '.'); + lquery = strrchr(result->data, '.'); if (!lquery || memcmp(lquery, (const void *){ ".local" }, 7)) { continue; } @@ -654,9 +651,9 @@ static int dns_read(int sock, * We skip the first dot, and make sure there is dot after * matching hostname. */ - if (!strncasecmp(hostname, result->data + 1, hostname_len) && - (result->len - 1) >= hostname_len && - &(result->data + 1)[hostname_len] == lquery) { + if (!strncasecmp(hostname, result->data, hostname_len) && + (result->len) >= hostname_len && + &result->data[hostname_len] == lquery) { NET_DBG("%s %s %s to our hostname %s%s", "mDNS", family == AF_INET ? "IPv4" : "IPv6", "query", hostname, ".local"); diff --git a/tests/net/socket/getaddrinfo/src/main.c b/tests/net/socket/getaddrinfo/src/main.c index c412cce87f4..25be0205ea8 100644 --- a/tests/net/socket/getaddrinfo/src/main.c +++ b/tests/net/socket/getaddrinfo/src/main.c @@ -96,7 +96,7 @@ static bool check_dns_query(uint8_t *buf, int buf_len) /* In this test we are just checking if the query came to us in correct * form, we are not creating a DNS server implementation here. */ - if (strncmp(result->data + 1, QUERY_HOST, + if (strncmp(result->data, QUERY_HOST, sizeof(QUERY_HOST) - 1)) { net_buf_unref(result); return false; diff --git a/tests/net/tcp/src/main.c b/tests/net/tcp/src/main.c index 2c0cfcd9ef2..db57b70a40e 100644 --- a/tests/net/tcp/src/main.c +++ b/tests/net/tcp/src/main.c @@ -116,6 +116,7 @@ static enum test_case_no { TEST_CLIENT_FIN_ACK_WITH_DATA = 18, TEST_CLIENT_SEQ_VALIDATION = 19, TEST_SERVER_ACK_VALIDATION = 20, + TEST_SERVER_FIN_ACK_AFTER_DATA = 21, } test_case_no; static enum test_state t_state; @@ -142,6 +143,7 @@ static void handle_syn_invalid_ack(sa_family_t af, struct tcphdr *th); static void handle_client_fin_ack_with_data_test(sa_family_t af, struct tcphdr *th); static void handle_client_seq_validation_test(sa_family_t af, struct tcphdr *th); static void handle_server_ack_validation_test(struct net_pkt *pkt); +static void handle_server_fin_ack_after_data_test(sa_family_t af, struct tcphdr *th); static void verify_flags(struct tcphdr *th, uint8_t flags, const char *fun, int line) @@ -494,6 +496,9 @@ static int tester_send(const struct device *dev, struct net_pkt *pkt) case TEST_SERVER_ACK_VALIDATION: handle_server_ack_validation_test(pkt); break; + case TEST_SERVER_FIN_ACK_AFTER_DATA: + handle_server_fin_ack_after_data_test(net_pkt_family(pkt), &th); + break; default: zassert_true(false, "Undefined test case"); } @@ -3002,4 +3007,204 @@ ZTEST(net_tcp, test_server_ack_validation) net_context_put(accepted_ctx); } +#define TEST_FIN_ACK_AFTER_DATA_REQ "request" +#define TEST_FIN_ACK_AFTER_DATA_RSP "test data response" + +/* In this test we check that FIN,ACK packet acknowledging latest data is + * handled correctly by the TCP stack. + */ +static void handle_server_fin_ack_after_data_test(sa_family_t af, struct tcphdr *th) +{ + struct net_pkt *reply = NULL; + + zassert_false(th == NULL && t_state != T_SYN, + "NULL pkt only expected in T_SYN state"); + + switch (t_state) { + case T_SYN: + reply = prepare_syn_packet(af, htons(MY_PORT), htons(PEER_PORT)); + seq++; + t_state = T_SYN_ACK; + break; + case T_SYN_ACK: + test_verify_flags(th, SYN | ACK); + zassert_equal(ntohl(th->th_ack), seq, + "Unexpected ACK in T_SYN_ACK, got %d, expected %d", + ntohl(th->th_ack), seq); + device_initial_seq = ntohl(th->th_seq); + ack = ntohl(th->th_seq) + 1U; + t_state = T_DATA_ACK; + + /* Dummy "request" packet */ + reply = prepare_data_packet(af, htons(MY_PORT), htons(PEER_PORT), + TEST_FIN_ACK_AFTER_DATA_REQ, + sizeof(TEST_FIN_ACK_AFTER_DATA_REQ) - 1); + seq += sizeof(TEST_FIN_ACK_AFTER_DATA_REQ) - 1; + break; + case T_DATA_ACK: + test_verify_flags(th, ACK); + t_state = T_DATA; + zassert_equal(ntohl(th->th_seq), ack, + "Unexpected SEQ in T_DATA_ACK, got %d, expected %d", + get_rel_seq(th), ack); + zassert_equal(ntohl(th->th_ack), seq, + "Unexpected ACK in T_DATA_ACK, got %d, expected %d", + ntohl(th->th_ack), seq); + break; + case T_DATA: + test_verify_flags(th, PSH | ACK); + zassert_equal(ntohl(th->th_seq), ack, + "Unexpected SEQ in T_DATA, got %d, expected %d", + get_rel_seq(th), ack); + zassert_equal(ntohl(th->th_ack), seq, + "Unexpected ACK in T_DATA, got %d, expected %d", + ntohl(th->th_ack), seq); + ack += sizeof(TEST_FIN_ACK_AFTER_DATA_RSP) - 1; + t_state = T_FIN_ACK; + + reply = prepare_fin_ack_packet(af, htons(MY_PORT), htons(PEER_PORT)); + seq++; + break; + case T_FIN_ACK: + test_verify_flags(th, FIN | ACK); + zassert_equal(ntohl(th->th_seq), ack, + "Unexpected SEQ in T_FIN_ACK, got %d, expected %d", + get_rel_seq(th), ack); + zassert_equal(ntohl(th->th_ack), seq, + "Unexpected ACK in T_FIN_ACK, got %d, expected %d", + ntohl(th->th_ack), seq); + + ack++; + t_state = T_CLOSING; + + reply = prepare_ack_packet(af, htons(MY_PORT), htons(PEER_PORT)); + seq++; + break; + case T_CLOSING: + zassert_true(false, "Should not receive anything after final ACK"); + break; + default: + zassert_true(false, "%s unexpected state", __func__); + return; + } + + if (reply != NULL) { + zassert_ok(net_recv_data(net_iface, reply), "%s failed", __func__); + } +} + +/* Receive callback to be installed in the accept handler */ +static void test_fin_ack_after_data_recv_cb(struct net_context *context, + struct net_pkt *pkt, + union net_ip_header *ip_hdr, + union net_proto_header *proto_hdr, + int status, + void *user_data) +{ + zassert_ok(status, "failed to recv the data"); + + if (pkt != NULL) { + uint8_t buf[sizeof(TEST_FIN_ACK_AFTER_DATA_REQ)] = { 0 }; + int data_len = net_pkt_remaining_data(pkt); + + zassert_equal(data_len, sizeof(TEST_FIN_ACK_AFTER_DATA_REQ) - 1, + "Invalid packet length, %d", data_len); + zassert_ok(net_pkt_read(pkt, buf, data_len)); + zassert_mem_equal(buf, TEST_FIN_ACK_AFTER_DATA_REQ, data_len); + + net_pkt_unref(pkt); + } + + test_sem_give(); +} + +static void test_fin_ack_after_data_accept_cb(struct net_context *ctx, + struct sockaddr *addr, + socklen_t addrlen, + int status, + void *user_data) +{ + int ret; + + zassert_ok(status, "failed to accept the conn"); + + /* set callback on newly created context */ + accepted_ctx = ctx; + ret = net_context_recv(ctx, test_fin_ack_after_data_recv_cb, + K_NO_WAIT, NULL); + zassert_ok(ret, "Failed to recv data from peer"); + + /* Ref the context on the app behalf. */ + net_context_ref(ctx); +} + +/* Verify that the TCP stack replies with a valid FIN,ACK after the peer + * acknowledges the latest data in the FIN packet. + * Test case scenario IPv4 + * send SYN, + * expect SYN ACK, + * send ACK with Data, + * expect ACK, + * expect Data, + * send FIN,ACK + * expect FIN,ACK + * send ACK + * any failures cause test case to fail. + */ +ZTEST(net_tcp, test_server_fin_ack_after_data) +{ + struct net_context *ctx; + int ret; + + test_case_no = TEST_SERVER_FIN_ACK_AFTER_DATA; + + t_state = T_SYN; + seq = ack = 0; + + ret = net_context_get(AF_INET, SOCK_STREAM, IPPROTO_TCP, &ctx); + zassert_ok(ret, "Failed to get net_context"); + + net_context_ref(ctx); + + ret = net_context_bind(ctx, (struct sockaddr *)&my_addr_s, + sizeof(struct sockaddr_in)); + zassert_ok(ret, "Failed to bind net_context"); + + /* Put context into listening mode and install accept cb */ + ret = net_context_listen(ctx, 1); + zassert_ok(ret, "Failed to listen on net_context"); + + ret = net_context_accept(ctx, test_fin_ack_after_data_accept_cb, + K_NO_WAIT, NULL); + zassert_ok(ret, "Failed to set accept on net_context"); + + /* Trigger the peer to send SYN */ + handle_server_fin_ack_after_data_test(AF_INET, NULL); + + /* test_fin_ack_after_data_recv_cb will release the semaphore after + * dummy request is read. + */ + test_sem_take(K_MSEC(100), __LINE__); + + /* Send dummy "response" */ + ret = net_context_send(accepted_ctx, TEST_FIN_ACK_AFTER_DATA_RSP, + sizeof(TEST_FIN_ACK_AFTER_DATA_RSP) - 1, NULL, + K_NO_WAIT, NULL); + zassert_equal(ret, sizeof(TEST_FIN_ACK_AFTER_DATA_RSP) - 1, + "Failed to send data to peer %d", ret); + + /* test_fin_ack_after_data_recv_cb will release the semaphore after + * the connection is marked closed. + */ + test_sem_take(K_MSEC(100), __LINE__); + + net_context_put(ctx); + net_context_put(accepted_ctx); + + /* Connection is in TIME_WAIT state, context will be released + * after K_MSEC(CONFIG_NET_TCP_TIME_WAIT_DELAY), so wait for it. + */ + k_sleep(K_MSEC(CONFIG_NET_TCP_TIME_WAIT_DELAY)); +} + ZTEST_SUITE(net_tcp, NULL, presetup, NULL, NULL, NULL);