@@ -188,13 +188,12 @@ static void stack_map_get_build_id_offset(struct bpf_stack_build_id *id_offs,
188188}
189189
190190static struct perf_callchain_entry *
191- get_callchain_entry_for_task (struct task_struct * task , u32 max_depth )
191+ get_callchain_entry_for_task (int * rctx , struct task_struct * task , u32 max_depth )
192192{
193193#ifdef CONFIG_STACKTRACE
194194 struct perf_callchain_entry * entry ;
195- int rctx ;
196195
197- entry = get_callchain_entry (& rctx );
196+ entry = get_callchain_entry (rctx );
198197
199198 if (!entry )
200199 return NULL ;
@@ -216,8 +215,6 @@ get_callchain_entry_for_task(struct task_struct *task, u32 max_depth)
216215 to [i ] = (u64 )(from [i ]);
217216 }
218217
219- put_callchain_entry (rctx );
220-
221218 return entry ;
222219#else /* CONFIG_STACKTRACE */
223220 return NULL ;
@@ -297,6 +294,31 @@ static long __bpf_get_stackid(struct bpf_map *map,
297294 return id ;
298295}
299296
297+ static struct perf_callchain_entry *
298+ bpf_get_perf_callchain (int * rctx , struct pt_regs * regs , bool kernel , bool user ,
299+ int max_stack , bool crosstask )
300+ {
301+ struct perf_callchain_entry_ctx ctx ;
302+ struct perf_callchain_entry * entry ;
303+
304+ entry = get_callchain_entry (rctx );
305+ if (unlikely (!entry ))
306+ return NULL ;
307+
308+ __init_perf_callchain_ctx (& ctx , entry , max_stack , false);
309+ if (kernel )
310+ __get_perf_callchain_kernel (& ctx , regs );
311+ if (user && !crosstask )
312+ __get_perf_callchain_user (& ctx , regs );
313+
314+ return entry ;
315+ }
316+
317+ static void bpf_put_callchain_entry (int rctx )
318+ {
319+ put_callchain_entry (rctx );
320+ }
321+
300322BPF_CALL_3 (bpf_get_stackid , struct pt_regs * , regs , struct bpf_map * , map ,
301323 u64 , flags )
302324{
@@ -305,6 +327,7 @@ BPF_CALL_3(bpf_get_stackid, struct pt_regs *, regs, struct bpf_map *, map,
305327 bool user = flags & BPF_F_USER_STACK ;
306328 struct perf_callchain_entry * trace ;
307329 bool kernel = !user ;
330+ int rctx , ret ;
308331
309332 if (unlikely (flags & ~(BPF_F_SKIP_FIELD_MASK | BPF_F_USER_STACK |
310333 BPF_F_FAST_STACK_CMP | BPF_F_REUSE_STACKID )))
@@ -314,14 +337,15 @@ BPF_CALL_3(bpf_get_stackid, struct pt_regs *, regs, struct bpf_map *, map,
314337 if (max_depth > sysctl_perf_event_max_stack )
315338 max_depth = sysctl_perf_event_max_stack ;
316339
317- trace = get_perf_callchain (regs , kernel , user , max_depth ,
318- false);
319-
340+ trace = bpf_get_perf_callchain (& rctx , regs , kernel , user , max_depth , false);
320341 if (unlikely (!trace ))
321342 /* couldn't fetch the stack trace */
322343 return - EFAULT ;
323344
324- return __bpf_get_stackid (map , trace , flags );
345+ ret = __bpf_get_stackid (map , trace , flags );
346+ bpf_put_callchain_entry (rctx );
347+
348+ return ret ;
325349}
326350
327351const struct bpf_func_proto bpf_get_stackid_proto = {
@@ -415,6 +439,7 @@ static long __bpf_get_stack(struct pt_regs *regs, struct task_struct *task,
415439 bool kernel = !user ;
416440 int err = - EINVAL ;
417441 u64 * ips ;
442+ int rctx ;
418443
419444 if (unlikely (flags & ~(BPF_F_SKIP_FIELD_MASK | BPF_F_USER_STACK |
420445 BPF_F_USER_BUILD_ID )))
@@ -449,17 +474,24 @@ static long __bpf_get_stack(struct pt_regs *regs, struct task_struct *task,
449474 if (trace_in )
450475 trace = trace_in ;
451476 else if (kernel && task )
452- trace = get_callchain_entry_for_task (task , max_depth );
477+ trace = get_callchain_entry_for_task (& rctx , task , max_depth );
453478 else
454- trace = get_perf_callchain (regs , kernel , user , max_depth ,
455- crosstask );
479+ trace = bpf_get_perf_callchain (& rctx , regs , kernel , user , max_depth , crosstask );
456480
457- if (unlikely (!trace ) || trace -> nr < skip ) {
481+ if (unlikely (!trace )) {
458482 if (may_fault )
459483 rcu_read_unlock ();
460484 goto err_fault ;
461485 }
462486
487+ if (trace -> nr < skip ) {
488+ if (may_fault )
489+ rcu_read_unlock ();
490+ if (!trace_in )
491+ bpf_put_callchain_entry (rctx );
492+ goto err_fault ;
493+ }
494+
463495 trace_nr = trace -> nr - skip ;
464496 trace_nr = (trace_nr <= num_elem ) ? trace_nr : num_elem ;
465497 copy_len = trace_nr * elem_size ;
@@ -479,6 +511,9 @@ static long __bpf_get_stack(struct pt_regs *regs, struct task_struct *task,
479511 if (may_fault )
480512 rcu_read_unlock ();
481513
514+ if (!trace_in )
515+ bpf_put_callchain_entry (rctx );
516+
482517 if (user_build_id )
483518 stack_map_get_build_id_offset (buf , trace_nr , user , may_fault );
484519
0 commit comments