Skip to content

Commit f4e9e0e

Browse files
howlettakpm00
authored andcommitted
mm/mempolicy: fix use-after-free of VMA iterator
set_mempolicy_home_node() iterates over a list of VMAs and calls mbind_range() on each VMA, which also iterates over the singular list of the VMA passed in and potentially splits the VMA. Since the VMA iterator is not passed through, set_mempolicy_home_node() may now point to a stale node in the VMA tree. This can result in a UAF as reported by syzbot. Avoid the stale maple tree node by passing the VMA iterator through to the underlying call to split_vma(). mbind_range() is also overly complicated, since there are two calling functions and one already handles iterating over the VMAs. Simplify mbind_range() to only handle merging and splitting of the VMAs. Align the new loop in do_mbind() and existing loop in set_mempolicy_home_node() to use the reduced mbind_range() function. This allows for a single location of the range calculation and avoids constantly looking up the previous VMA (since this is a loop over the VMAs). Link: https://lore.kernel.org/linux-mm/[email protected]/ Fixes: 66850be ("mm/mempolicy: use vma iterator & maple state instead of vma linked list") Signed-off-by: Liam R. Howlett <[email protected]> Reported-by: [email protected] Link: https://lkml.kernel.org/r/[email protected] Tested-by: [email protected] Cc: <[email protected]> Signed-off-by: Andrew Morton <[email protected]>
1 parent 4737edb commit f4e9e0e

File tree

1 file changed

+49
-55
lines changed

1 file changed

+49
-55
lines changed

mm/mempolicy.c

Lines changed: 49 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -790,61 +790,50 @@ static int vma_replace_policy(struct vm_area_struct *vma,
790790
return err;
791791
}
792792

793-
/* Step 2: apply policy to a range and do splits. */
794-
static int mbind_range(struct mm_struct *mm, unsigned long start,
795-
unsigned long end, struct mempolicy *new_pol)
793+
/* Split or merge the VMA (if required) and apply the new policy */
794+
static int mbind_range(struct vma_iterator *vmi, struct vm_area_struct *vma,
795+
struct vm_area_struct **prev, unsigned long start,
796+
unsigned long end, struct mempolicy *new_pol)
796797
{
797-
VMA_ITERATOR(vmi, mm, start);
798-
struct vm_area_struct *prev;
799-
struct vm_area_struct *vma;
800-
int err = 0;
798+
struct vm_area_struct *merged;
799+
unsigned long vmstart, vmend;
801800
pgoff_t pgoff;
801+
int err;
802802

803-
prev = vma_prev(&vmi);
804-
vma = vma_find(&vmi, end);
805-
if (WARN_ON(!vma))
803+
vmend = min(end, vma->vm_end);
804+
if (start > vma->vm_start) {
805+
*prev = vma;
806+
vmstart = start;
807+
} else {
808+
vmstart = vma->vm_start;
809+
}
810+
811+
if (mpol_equal(vma_policy(vma), new_pol))
806812
return 0;
807813

808-
if (start > vma->vm_start)
809-
prev = vma;
810-
811-
do {
812-
unsigned long vmstart = max(start, vma->vm_start);
813-
unsigned long vmend = min(end, vma->vm_end);
814-
815-
if (mpol_equal(vma_policy(vma), new_pol))
816-
goto next;
817-
818-
pgoff = vma->vm_pgoff +
819-
((vmstart - vma->vm_start) >> PAGE_SHIFT);
820-
prev = vma_merge(&vmi, mm, prev, vmstart, vmend, vma->vm_flags,
821-
vma->anon_vma, vma->vm_file, pgoff,
822-
new_pol, vma->vm_userfaultfd_ctx,
823-
anon_vma_name(vma));
824-
if (prev) {
825-
vma = prev;
826-
goto replace;
827-
}
828-
if (vma->vm_start != vmstart) {
829-
err = split_vma(&vmi, vma, vmstart, 1);
830-
if (err)
831-
goto out;
832-
}
833-
if (vma->vm_end != vmend) {
834-
err = split_vma(&vmi, vma, vmend, 0);
835-
if (err)
836-
goto out;
837-
}
838-
replace:
839-
err = vma_replace_policy(vma, new_pol);
814+
pgoff = vma->vm_pgoff + ((vmstart - vma->vm_start) >> PAGE_SHIFT);
815+
merged = vma_merge(vmi, vma->vm_mm, *prev, vmstart, vmend, vma->vm_flags,
816+
vma->anon_vma, vma->vm_file, pgoff, new_pol,
817+
vma->vm_userfaultfd_ctx, anon_vma_name(vma));
818+
if (merged) {
819+
*prev = merged;
820+
return vma_replace_policy(merged, new_pol);
821+
}
822+
823+
if (vma->vm_start != vmstart) {
824+
err = split_vma(vmi, vma, vmstart, 1);
840825
if (err)
841-
goto out;
842-
next:
843-
prev = vma;
844-
} for_each_vma_range(vmi, vma, end);
826+
return err;
827+
}
845828

846-
out:
847-
return err;
829+
if (vma->vm_end != vmend) {
830+
err = split_vma(vmi, vma, vmend, 0);
831+
if (err)
832+
return err;
833+
}
834+
835+
*prev = vma;
836+
return vma_replace_policy(vma, new_pol);
848837
}
849838

850839
/* Set the process memory policy */
@@ -1259,6 +1248,8 @@ static long do_mbind(unsigned long start, unsigned long len,
12591248
nodemask_t *nmask, unsigned long flags)
12601249
{
12611250
struct mm_struct *mm = current->mm;
1251+
struct vm_area_struct *vma, *prev;
1252+
struct vma_iterator vmi;
12621253
struct mempolicy *new;
12631254
unsigned long end;
12641255
int err;
@@ -1328,7 +1319,13 @@ static long do_mbind(unsigned long start, unsigned long len,
13281319
goto up_out;
13291320
}
13301321

1331-
err = mbind_range(mm, start, end, new);
1322+
vma_iter_init(&vmi, mm, start);
1323+
prev = vma_prev(&vmi);
1324+
for_each_vma_range(vmi, vma, end) {
1325+
err = mbind_range(&vmi, vma, &prev, start, end, new);
1326+
if (err)
1327+
break;
1328+
}
13321329

13331330
if (!err) {
13341331
int nr_failed = 0;
@@ -1489,10 +1486,8 @@ SYSCALL_DEFINE4(set_mempolicy_home_node, unsigned long, start, unsigned long, le
14891486
unsigned long, home_node, unsigned long, flags)
14901487
{
14911488
struct mm_struct *mm = current->mm;
1492-
struct vm_area_struct *vma;
1489+
struct vm_area_struct *vma, *prev;
14931490
struct mempolicy *new, *old;
1494-
unsigned long vmstart;
1495-
unsigned long vmend;
14961491
unsigned long end;
14971492
int err = -ENOENT;
14981493
VMA_ITERATOR(vmi, mm, start);
@@ -1521,6 +1516,7 @@ SYSCALL_DEFINE4(set_mempolicy_home_node, unsigned long, start, unsigned long, le
15211516
if (end == start)
15221517
return 0;
15231518
mmap_write_lock(mm);
1519+
prev = vma_prev(&vmi);
15241520
for_each_vma_range(vmi, vma, end) {
15251521
/*
15261522
* If any vma in the range got policy other than MPOL_BIND
@@ -1541,9 +1537,7 @@ SYSCALL_DEFINE4(set_mempolicy_home_node, unsigned long, start, unsigned long, le
15411537
}
15421538

15431539
new->home_node = home_node;
1544-
vmstart = max(start, vma->vm_start);
1545-
vmend = min(end, vma->vm_end);
1546-
err = mbind_range(mm, vmstart, vmend, new);
1540+
err = mbind_range(&vmi, vma, &prev, start, end, new);
15471541
mpol_put(new);
15481542
if (err)
15491543
break;

0 commit comments

Comments
 (0)