Skip to content

bpf: improve the general precision of tnum_mul #5758

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: bpf-next_base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/linux/tnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ struct tnum tnum_mul(struct tnum a, struct tnum b);
/* Return a tnum representing numbers satisfying both @a and @b */
struct tnum tnum_intersect(struct tnum a, struct tnum b);

/* Returns a tnum representing numbers satisfying either @a or @b */
struct tnum tnum_union(struct tnum t1, struct tnum t2);

/* Return @a with all but the lowest @size bytes cleared */
struct tnum tnum_cast(struct tnum a, u8 size);

Expand Down
42 changes: 29 additions & 13 deletions kernel/bpf/tnum.c
Original file line number Diff line number Diff line change
Expand Up @@ -116,31 +116,39 @@ struct tnum tnum_xor(struct tnum a, struct tnum b)
return TNUM(v & ~mu, mu);
}

/* Generate partial products by multiplying each bit in the multiplier (tnum a)
* with the multiplicand (tnum b), and add the partial products after
* appropriately bit-shifting them. Instead of directly performing tnum addition
* on the generated partial products, equivalenty, decompose each partial
* product into two tnums, consisting of the value-sum (acc_v) and the
* mask-sum (acc_m) and then perform tnum addition on them. The following paper
* explains the algorithm in more detail: https://arxiv.org/abs/2105.05398.
/* Perform long multiplication, iterating through the trits in a.
* Inside `else if (a.mask & 1)`, instead of simply multiplying b with LSB(a)'s
* uncertainty and accumulating directly, we find two possible partial products
* (one for LSB(a) = 0 and another for LSB(a) = 1), and add their union to the
* accumulator. This addresses an issue pointed out in an open question ("How
* can we incorporate correlation in unknown bits across partial products?")
* left by Harishankar et al. (https://arxiv.org/abs/2105.05398), improving
* the general precision significantly.
*/
struct tnum tnum_mul(struct tnum a, struct tnum b)
{
u64 acc_v = a.value * b.value;
struct tnum acc_m = TNUM(0, 0);
struct tnum acc = TNUM(0, 0);

while (a.value || a.mask) {
/* LSB of tnum a is a certain 1 */
if (a.value & 1)
acc_m = tnum_add(acc_m, TNUM(0, b.mask));
acc = tnum_add(acc, b);
/* LSB of tnum a is uncertain */
else if (a.mask & 1)
acc_m = tnum_add(acc_m, TNUM(0, b.value | b.mask));
else if (a.mask & 1) {
/* acc += tnum_union(acc_0, acc_1), where acc_0 and
* acc_1 are partial accumulators for cases
* LSB(a) = certain 0 and LSB(a) = certain 1.
* acc_0 = acc + 0 * b = acc.
* acc_1 = acc + 1 * b = tnum_add(acc, b).
*/

acc = tnum_union(acc, tnum_add(acc, b));
}
/* Note: no case for LSB is certain 0 */
a = tnum_rshift(a, 1);
b = tnum_lshift(b, 1);
}
return tnum_add(TNUM(acc_v, 0), acc_m);
return acc;
}

/* Note that if a and b disagree - i.e. one has a 'known 1' where the other has
Expand All @@ -155,6 +163,14 @@ struct tnum tnum_intersect(struct tnum a, struct tnum b)
return TNUM(v & ~mu, mu);
}

struct tnum tnum_union(struct tnum a, struct tnum b)
{
u64 v = a.value & b.value;
u64 mu = (a.value ^ b.value) | a.mask | b.mask;

return TNUM(v & ~mu, mu);
}

struct tnum tnum_cast(struct tnum a, u8 size)
{
a.value &= (1ULL << (size * 8)) - 1;
Expand Down
Loading