diff --git a/include/linux/tnum.h b/include/linux/tnum.h index c52b862dad45b..63987f442b4a9 100644 --- a/include/linux/tnum.h +++ b/include/linux/tnum.h @@ -125,5 +125,6 @@ static inline bool tnum_subreg_is_const(struct tnum a) { return !(tnum_subreg(a)).mask; } - +/* Returns the smallest member of t larger than z. */ +u64 tnum_step(struct tnum t, u64 z); #endif /* _LINUX_TNUM_H */ diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c index f8e70e9c3998d..5c12d7d9ba225 100644 --- a/kernel/bpf/tnum.c +++ b/kernel/bpf/tnum.c @@ -253,3 +253,55 @@ struct tnum tnum_const_subreg(struct tnum a, u32 value) { return tnum_with_subreg(a, tnum_const(value)); } + +/* Given tnum t, and a number z such that tmin <= z < tmax, where tmin + * is the smallest member of the t (= t.value) and tmax is the largest + * member of t (= t.value | t.mask) returns the smallest member of t + * larger than z. + * + * For example, + * t = x11100x0 + * z = 11110001 (241) + * result = 11110010 (242) + * + * Note: if this function is called with z >= tmax, it just returns + * early with tmax; if this function is called with z < tmin, the + * algorithm already returns tmin. + */ +u64 tnum_step(struct tnum t, u64 z) +{ + u64 tmax, j, p, q, r, s, v, u, w, res; + u8 k; + + tmax = t.value | t.mask; + + /* if z >= largest member of t, return largest member of t */ + if (z >= tmax) + return tmax; + + /* keep t's known bits, and match all unknown bits to z */ + j = t.value | z & t.mask; + + if (j > z) { + p = ~z & t.value & ~t.mask; + k = fls64(p); /* k is the most-significant 0-to-1 flip */ + q = U64_MAX << k; + r = q & z; /* positions > k matched to z */ + s = ~q & t.value; /* positions <= k matched to t.value */ + v = r | s; + res = v; + } else { + p = z & ~t.value & ~t.mask; + k = fls64(p); /* k is the most-significant 1-to-0 flip */ + q = U64_MAX << k; + r = q & t.mask & z; /* unknown positions > k, matched to z */ + s = q & ~t.mask; /* known positions > k, set to 1 */ + v = r | s; + /* add 1 to unknown positions > k to make value greater than z */ + u = v + (1ULL << k); + /* extract bits in unknown positions > k from u, rest from t.value */ + w = u & t.mask | t.value; + res = w; + } + return res; +}