Skip to content

Commit 86149b4

Browse files
mmhalMartin KaFai Lau
authored andcommitted
selftests/bpf: Introduce __attribute__((cleanup)) in create_pair()
Rewrite function to have (unneeded) socket descriptors automatically close()d when leaving the scope. Make sure the "ownership" of fds is correctly passed via take_fd(); i.e. descriptor returned to caller will remain valid. Reviewed-by: Jakub Sitnicki <[email protected]> Tested-by: Jakub Sitnicki <[email protected]> Suggested-by: Jakub Sitnicki <[email protected]> Signed-off-by: Michal Luczaj <[email protected]> Link: https://lore.kernel.org/r/[email protected] Signed-off-by: Martin KaFai Lau <[email protected]>
1 parent c9c70b2 commit 86149b4

File tree

1 file changed

+36
-25
lines changed

1 file changed

+36
-25
lines changed

tools/testing/selftests/bpf/prog_tests/sockmap_helpers.h

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,17 @@
1717

1818
#define __always_unused __attribute__((__unused__))
1919

20+
/* include/linux/cleanup.h */
21+
#define __get_and_null(p, nullvalue) \
22+
({ \
23+
__auto_type __ptr = &(p); \
24+
__auto_type __val = *__ptr; \
25+
*__ptr = nullvalue; \
26+
__val; \
27+
})
28+
29+
#define take_fd(fd) __get_and_null(fd, -EBADF)
30+
2031
#define _FAIL(errnum, fmt...) \
2132
({ \
2233
error_at_line(0, (errnum), __func__, __LINE__, fmt); \
@@ -182,6 +193,14 @@
182193
__ret; \
183194
})
184195

196+
static inline void close_fd(int *fd)
197+
{
198+
if (*fd >= 0)
199+
xclose(*fd);
200+
}
201+
202+
#define __close_fd __attribute__((cleanup(close_fd)))
203+
185204
static inline int poll_connect(int fd, unsigned int timeout_sec)
186205
{
187206
struct timeval timeout = { .tv_sec = timeout_sec };
@@ -369,72 +388,64 @@ static inline int socket_loopback(int family, int sotype)
369388

370389
static inline int create_pair(int family, int sotype, int *p0, int *p1)
371390
{
391+
__close_fd int s, c = -1, p = -1;
372392
struct sockaddr_storage addr;
373393
socklen_t len = sizeof(addr);
374-
int s, c, p, err;
394+
int err;
375395

376396
s = socket_loopback(family, sotype);
377397
if (s < 0)
378398
return s;
379399

380400
err = xgetsockname(s, sockaddr(&addr), &len);
381401
if (err)
382-
goto close_s;
402+
return err;
383403

384404
c = xsocket(family, sotype, 0);
385-
if (c < 0) {
386-
err = c;
387-
goto close_s;
388-
}
405+
if (c < 0)
406+
return c;
389407

390408
err = connect(c, sockaddr(&addr), len);
391409
if (err) {
392410
if (errno != EINPROGRESS) {
393411
FAIL_ERRNO("connect");
394-
goto close_c;
412+
return err;
395413
}
396414

397415
err = poll_connect(c, IO_TIMEOUT_SEC);
398416
if (err) {
399417
FAIL_ERRNO("poll_connect");
400-
goto close_c;
418+
return err;
401419
}
402420
}
403421

404422
switch (sotype & SOCK_TYPE_MASK) {
405423
case SOCK_DGRAM:
406424
err = xgetsockname(c, sockaddr(&addr), &len);
407425
if (err)
408-
goto close_c;
426+
return err;
409427

410428
err = xconnect(s, sockaddr(&addr), len);
411-
if (!err) {
412-
*p0 = s;
413-
*p1 = c;
429+
if (err)
414430
return err;
415-
}
431+
432+
*p0 = take_fd(s);
416433
break;
417434
case SOCK_STREAM:
418435
case SOCK_SEQPACKET:
419436
p = xaccept_nonblock(s, NULL, NULL);
420-
if (p >= 0) {
421-
*p0 = p;
422-
*p1 = c;
423-
goto close_s;
424-
}
437+
if (p < 0)
438+
return p;
425439

426-
err = p;
440+
*p0 = take_fd(p);
427441
break;
428442
default:
429443
FAIL("Unsupported socket type %#x", sotype);
430-
err = -EOPNOTSUPP;
444+
return -EOPNOTSUPP;
431445
}
432446

433-
close_c:
434-
close(c);
435-
close_s:
436-
close(s);
437-
return err;
447+
*p1 = take_fd(c);
448+
return 0;
438449
}
439450

440451
static inline int create_socket_pairs(int family, int sotype, int *c0, int *c1,

0 commit comments

Comments
 (0)