@@ -456,6 +456,8 @@ struct io_ring_ctx {
456
456
struct work_struct exit_work ;
457
457
struct list_head tctx_list ;
458
458
struct completion ref_comp ;
459
+ u32 iowq_limits [2 ];
460
+ bool iowq_limits_set ;
459
461
};
460
462
};
461
463
@@ -1368,11 +1370,6 @@ static void io_req_track_inflight(struct io_kiocb *req)
1368
1370
}
1369
1371
}
1370
1372
1371
- static inline void io_unprep_linked_timeout (struct io_kiocb * req )
1372
- {
1373
- req -> flags &= ~REQ_F_LINK_TIMEOUT ;
1374
- }
1375
-
1376
1373
static struct io_kiocb * __io_prep_linked_timeout (struct io_kiocb * req )
1377
1374
{
1378
1375
if (WARN_ON_ONCE (!req -> link ))
@@ -6983,7 +6980,7 @@ static void __io_queue_sqe(struct io_kiocb *req)
6983
6980
switch (io_arm_poll_handler (req )) {
6984
6981
case IO_APOLL_READY :
6985
6982
if (linked_timeout )
6986
- io_unprep_linked_timeout ( req );
6983
+ io_queue_linked_timeout ( linked_timeout );
6987
6984
goto issue_sqe ;
6988
6985
case IO_APOLL_ABORTED :
6989
6986
/*
@@ -9638,7 +9635,16 @@ static int __io_uring_add_tctx_node(struct io_ring_ctx *ctx)
9638
9635
ret = io_uring_alloc_task_context (current , ctx );
9639
9636
if (unlikely (ret ))
9640
9637
return ret ;
9638
+
9641
9639
tctx = current -> io_uring ;
9640
+ if (ctx -> iowq_limits_set ) {
9641
+ unsigned int limits [2 ] = { ctx -> iowq_limits [0 ],
9642
+ ctx -> iowq_limits [1 ], };
9643
+
9644
+ ret = io_wq_max_workers (tctx -> io_wq , limits );
9645
+ if (ret )
9646
+ return ret ;
9647
+ }
9642
9648
}
9643
9649
if (!xa_load (& tctx -> xa , (unsigned long )ctx )) {
9644
9650
node = kmalloc (sizeof (* node ), GFP_KERNEL );
@@ -10643,7 +10649,9 @@ static int io_unregister_iowq_aff(struct io_ring_ctx *ctx)
10643
10649
10644
10650
static int io_register_iowq_max_workers (struct io_ring_ctx * ctx ,
10645
10651
void __user * arg )
10652
+ __must_hold (& ctx - > uring_lock )
10646
10653
{
10654
+ struct io_tctx_node * node ;
10647
10655
struct io_uring_task * tctx = NULL ;
10648
10656
struct io_sq_data * sqd = NULL ;
10649
10657
__u32 new_count [2 ];
@@ -10674,13 +10682,19 @@ static int io_register_iowq_max_workers(struct io_ring_ctx *ctx,
10674
10682
tctx = current -> io_uring ;
10675
10683
}
10676
10684
10677
- ret = - EINVAL ;
10678
- if (!tctx || !tctx -> io_wq )
10679
- goto err ;
10685
+ BUILD_BUG_ON (sizeof (new_count ) != sizeof (ctx -> iowq_limits ));
10680
10686
10681
- ret = io_wq_max_workers (tctx -> io_wq , new_count );
10682
- if (ret )
10683
- goto err ;
10687
+ memcpy (ctx -> iowq_limits , new_count , sizeof (new_count ));
10688
+ ctx -> iowq_limits_set = true;
10689
+
10690
+ ret = - EINVAL ;
10691
+ if (tctx && tctx -> io_wq ) {
10692
+ ret = io_wq_max_workers (tctx -> io_wq , new_count );
10693
+ if (ret )
10694
+ goto err ;
10695
+ } else {
10696
+ memset (new_count , 0 , sizeof (new_count ));
10697
+ }
10684
10698
10685
10699
if (sqd ) {
10686
10700
mutex_unlock (& sqd -> lock );
@@ -10690,6 +10704,22 @@ static int io_register_iowq_max_workers(struct io_ring_ctx *ctx,
10690
10704
if (copy_to_user (arg , new_count , sizeof (new_count )))
10691
10705
return - EFAULT ;
10692
10706
10707
+ /* that's it for SQPOLL, only the SQPOLL task creates requests */
10708
+ if (sqd )
10709
+ return 0 ;
10710
+
10711
+ /* now propagate the restriction to all registered users */
10712
+ list_for_each_entry (node , & ctx -> tctx_list , ctx_node ) {
10713
+ struct io_uring_task * tctx = node -> task -> io_uring ;
10714
+
10715
+ if (WARN_ON_ONCE (!tctx -> io_wq ))
10716
+ continue ;
10717
+
10718
+ for (i = 0 ; i < ARRAY_SIZE (new_count ); i ++ )
10719
+ new_count [i ] = ctx -> iowq_limits [i ];
10720
+ /* ignore errors, it always returns zero anyway */
10721
+ (void )io_wq_max_workers (tctx -> io_wq , new_count );
10722
+ }
10693
10723
return 0 ;
10694
10724
err :
10695
10725
if (sqd ) {
0 commit comments