Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pkg/abi/linux/signal.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,12 @@ func MakeSignalSet(sigs ...Signal) SignalSet {
for i, sig := range sigs {
indices[i] = sig.Index()
}
return SignalSet(bits.Mask64(indices...))
return bits.Mask[SignalSet](indices...)
}

// SignalSetOf returns a SignalSet with a single signal set.
func SignalSetOf(sig Signal) SignalSet {
return SignalSet(bits.MaskOf64(sig.Index()))
return bits.MaskOf[SignalSet](sig.Index())
}

// ForEachSignal invokes f for each signal set in the given mask.
Expand Down
35 changes: 3 additions & 32 deletions pkg/bits/BUILD
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")

package(
default_applicable_licenses = ["//:license"],
Expand All @@ -10,46 +9,18 @@ go_library(
name = "bits",
srcs = [
"bits.go",
"bits32.go",
"bits64.go",
"bits_template.go",
"uint64_arch.go",
"uint64_arch_amd64_asm.s",
"uint64_arch_arm64_asm.s",
"uint64_arch_generic.go",
],
visibility = ["//:sandbox"],
)

go_template(
name = "bits_template",
srcs = ["bits_template.go"],
types = [
"T",
deps = [
"@org_golang_x_exp//constraints:go_default_library",
],
)

go_template_instance(
name = "bits64",
out = "bits64.go",
package = "bits",
suffix = "64",
template = ":bits_template",
types = {
"T": "uint64",
},
)

go_template_instance(
name = "bits32",
out = "bits32.go",
package = "bits",
suffix = "32",
template = ":bits_template",
types = {
"T": "uint32",
},
)

go_test(
name = "bits_test",
size = "small",
Expand Down
17 changes: 8 additions & 9 deletions pkg/bits/bits_template.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,37 +14,36 @@

package bits

// Non-atomic bit operations on a template type T.
import "golang.org/x/exp/constraints"

// T is a required type parameter that must be an integral type.
type T uint64
// Non-atomic bit operations on integral types.

// IsOn returns true if *all* bits set in 'bits' are set in 'mask'.
func IsOn(mask, bits T) bool {
func IsOn[T constraints.Integer](mask, bits T) bool {
return mask&bits == bits
}

// IsAnyOn returns true if *any* bit set in 'bits' is set in 'mask'.
func IsAnyOn(mask, bits T) bool {
func IsAnyOn[T constraints.Integer](mask, bits T) bool {
return mask&bits != 0
}

// Mask returns a T with all of the given bits set.
func Mask(is ...int) T {
func Mask[T constraints.Integer](is ...int) T {
ret := T(0)
for _, i := range is {
ret |= MaskOf(i)
ret |= MaskOf[T](i)
}
return ret
}

// MaskOf is like Mask, but sets only a single bit (more efficiently).
func MaskOf(i int) T {
func MaskOf[T constraints.Integer](i int) T {
return T(1) << T(i)
}

// IsPowerOfTwo returns true if v is power of 2.
func IsPowerOfTwo(v T) bool {
func IsPowerOfTwo[T constraints.Integer](v T) bool {
if v == 0 {
return false
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/bits/uint64_arch.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,6 @@ func ForEachSetBit64(x uint64, f func(i int)) {
for x != 0 {
i := TrailingZeros64(x)
f(i)
x &^= MaskOf64(i)
x &^= MaskOf[uint64](i)
}
}
11 changes: 8 additions & 3 deletions pkg/bits/uint64_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ func TestMostSignificantOne64(t *testing.T) {
}
}

// Mask64 returns a uint64 with all of the given bits set.
func Mask64(is ...int) uint64 {
return Mask[uint64](is...)
}

func TestForEachSetBit64(t *testing.T) {
for _, want := range [][]int{
{},
Expand Down Expand Up @@ -106,10 +111,10 @@ func TestIsOn(t *testing.T) {
{Mask64(1, 63), Mask64(0, 1, 63), true, false},
{Mask64(1, 63), Mask64(0, 62), false, false},
} {
if ok := IsAnyOn64(s.mask, s.bits); ok != s.any {
if ok := IsAnyOn(s.mask, s.bits); ok != s.any {
t.Errorf("IsAnyOn(%#x, %#x) = %v, wanted: %v", s.mask, s.bits, ok, s.any)
}
if ok := IsOn64(s.mask, s.bits); ok != s.all {
if ok := IsOn(s.mask, s.bits); ok != s.all {
t.Errorf("IsOn(%#x, %#x) = %v, wanted: %v", s.mask, s.bits, ok, s.all)
}
}
Expand All @@ -127,7 +132,7 @@ func TestIsPowerOfTwo(t *testing.T) {
{v: 4, want: true},
{v: 5, want: false},
} {
if got := IsPowerOfTwo64(tc.v); got != tc.want {
if got := IsPowerOfTwo(tc.v); got != tc.want {
t.Errorf("IsPowerOfTwo(%d) = %t, want: %t", tc.v, got, tc.want)
}
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/sentry/kernel/auth/capability_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,16 @@ var AllCapabilities = CapabilitySetOf(linux.CAP_LAST_CAP+1) - 1
// CapabilitySetOf returns a CapabilitySet containing only the given
// capability.
func CapabilitySetOf(cp linux.Capability) CapabilitySet {
return CapabilitySet(bits.MaskOf64(int(cp)))
return bits.MaskOf[CapabilitySet](int(cp))
}

// CapabilitySetOfMany returns a CapabilitySet containing the given capabilities.
func CapabilitySetOfMany(cps []linux.Capability) CapabilitySet {
var cs uint64
var cs CapabilitySet
for _, cp := range cps {
cs |= bits.MaskOf64(int(cp))
cs |= bits.MaskOf[CapabilitySet](int(cp))
}
return CapabilitySet(cs)
return cs
}

// Add adds the given capability to the CapabilitySet.
Expand Down
6 changes: 3 additions & 3 deletions pkg/sentry/kernel/syscalls.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func (e *SyscallFlagsTable) UpdateSecCheck(state *seccheck.State) {
defer e.mu.Unlock()
for sysno := uintptr(0); sysno <= sentry.MaxSyscallNum; sysno++ {
oldFlags := e.enable[sysno].Load()
if !bits.IsOn32(oldFlags, syscallPresent) {
if !bits.IsOn(oldFlags, syscallPresent) {
continue
}
flags := oldFlags
Expand Down Expand Up @@ -226,7 +226,7 @@ func (e *SyscallFlagsTable) Enable(bit uint32, s map[uintptr]bool, missingEnable

for num := range e.enable {
val := e.enable[num].Load()
if !bits.IsOn32(val, syscallPresent) {
if !bits.IsOn(val, syscallPresent) {
// Missing.
e.enable[num].Store(missingVal)
continue
Expand All @@ -252,7 +252,7 @@ func (e *SyscallFlagsTable) EnableAll(bit uint32) {

for num := range e.enable {
val := e.enable[num].Load()
if !bits.IsOn32(val, syscallPresent) {
if !bits.IsOn(val, syscallPresent) {
// Missing.
e.enable[num].Store(missingVal)
continue
Expand Down
16 changes: 8 additions & 8 deletions pkg/sentry/kernel/task_syscall.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval u
fe := s.FeatureEnable.Word(sysno)

var straceContext any
if bits.IsAnyOn32(fe, StraceEnableBits) {
if bits.IsAnyOn(fe, StraceEnableBits) {
straceContext = s.Stracer.SyscallEnter(t, sysno, args, fe)
}

if bits.IsAnyOn32(fe, SecCheckRawEnter) {
if bits.IsAnyOn(fe, SecCheckRawEnter) {
info := pb.Syscall{
Sysno: uint64(sysno),
Arg1: args[0].Uint64(),
Expand All @@ -110,7 +110,7 @@ func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval u
return c.RawSyscall(t, fields, &info)
})
}
if bits.IsAnyOn32(fe, SecCheckEnter) {
if bits.IsAnyOn(fe, SecCheckEnter) {
fields := seccheck.Global.GetFieldSet(seccheck.GetPointForSyscall(seccheck.SyscallEnter, sysno))
var ctxData *pb.ContextData
if !fields.Context.Empty() {
Expand All @@ -128,7 +128,7 @@ func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval u
})
}

if bits.IsOn32(fe, ExternalBeforeEnable) && (s.ExternalFilterBefore == nil || s.ExternalFilterBefore(t, sysno, args)) {
if bits.IsOn(fe, ExternalBeforeEnable) && (s.ExternalFilterBefore == nil || s.ExternalFilterBefore(t, sysno, args)) {
t.invokeExternal()
// Ensure we check for stops, then invoke the syscall again.
ctrl = ctrlStopAndReinvokeSyscall
Expand All @@ -150,16 +150,16 @@ func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval u
}
}

if bits.IsOn32(fe, ExternalAfterEnable) && (s.ExternalFilterAfter == nil || s.ExternalFilterAfter(t, sysno, args)) {
if bits.IsOn(fe, ExternalAfterEnable) && (s.ExternalFilterAfter == nil || s.ExternalFilterAfter(t, sysno, args)) {
t.invokeExternal()
// Don't reinvoke the unix.
}

if bits.IsAnyOn32(fe, StraceEnableBits) {
if bits.IsAnyOn(fe, StraceEnableBits) {
s.Stracer.SyscallExit(straceContext, t, sysno, rval, err)
}

if bits.IsAnyOn32(fe, SecCheckRawExit) {
if bits.IsAnyOn(fe, SecCheckRawExit) {
info := pb.Syscall{
Sysno: uint64(sysno),
Arg1: args[0].Uint64(),
Expand All @@ -182,7 +182,7 @@ func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval u
return c.RawSyscall(t, fields, &info)
})
}
if bits.IsAnyOn32(fe, SecCheckExit) {
if bits.IsAnyOn(fe, SecCheckExit) {
fields := seccheck.Global.GetFieldSet(seccheck.GetPointForSyscall(seccheck.SyscallExit, sysno))
var ctxData *pb.ContextData
if !fields.Context.Empty() {
Expand Down
12 changes: 6 additions & 6 deletions pkg/sentry/strace/strace.go
Original file line number Diff line number Diff line change
Expand Up @@ -682,10 +682,10 @@ func (s SyscallMap) SyscallEnter(t *kernel.Task, sysno uintptr, args arch.Syscal
}

var output, eventOutput []string
if bits.IsOn32(flags, kernel.StraceEnableLog) {
if bits.IsOn(flags, kernel.StraceEnableLog) {
output = info.printEnter(t, args)
}
if bits.IsOn32(flags, kernel.StraceEnableEvent) {
if bits.IsOn(flags, kernel.StraceEnableEvent) {
eventOutput = info.sendEnter(t, args)
}

Expand All @@ -706,10 +706,10 @@ func (s SyscallMap) SyscallExit(context any, t *kernel.Task, sysno, rval uintptr
c := context.(*syscallContext)

elapsed := time.Since(c.start)
if bits.IsOn32(c.flags, kernel.StraceEnableLog) {
if bits.IsOn(c.flags, kernel.StraceEnableLog) {
c.info.printExit(t, elapsed, c.logOutput, c.args, rval, err, errno)
}
if bits.IsOn32(c.flags, kernel.StraceEnableEvent) {
if bits.IsOn(c.flags, kernel.StraceEnableEvent) {
c.info.sendExit(t, elapsed, c.eventOutput, c.args, rval, err, errno)
}
}
Expand Down Expand Up @@ -786,10 +786,10 @@ const (

func convertToSyscallFlag(sinks SinkType) uint32 {
ret := uint32(0)
if bits.IsOn32(uint32(sinks), uint32(SinkTypeLog)) {
if bits.IsOn(uint32(sinks), uint32(SinkTypeLog)) {
ret |= kernel.StraceEnableLog
}
if bits.IsOn32(uint32(sinks), uint32(SinkTypeEvent)) {
if bits.IsOn(uint32(sinks), uint32(SinkTypeEvent)) {
ret |= kernel.StraceEnableEvent
}
return ret
Expand Down
2 changes: 1 addition & 1 deletion pkg/sentry/syscalls/linux/sys_stat.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func Statx(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr,
}
// Make sure that only one sync type option is set.
syncType := uint32(flags & linux.AT_STATX_SYNC_TYPE)
if syncType != 0 && !bits.IsPowerOfTwo32(syncType) {
if syncType != 0 && !bits.IsPowerOfTwo(syncType) {
return 0, nil, linuxerr.EINVAL
}
if mask&linux.STATX__RESERVED != 0 {
Expand Down
2 changes: 1 addition & 1 deletion pkg/sentry/vfs/propagation.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func (vfs *VirtualFilesystem) SetMountPropagationAt(ctx context.Context, creds *
recursive := propFlag&linux.MS_REC != 0
propFlag &= propagationFlags
// Check if flags is a power of 2. If not then more than one flag is set.
if !bits.IsPowerOfTwo32(propFlag) {
if !bits.IsPowerOfTwo(propFlag) {
return linuxerr.EINVAL
}
vd, err := vfs.getMountpoint(ctx, creds, pop)
Expand Down
1 change: 1 addition & 0 deletions pkg/tcpip/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ deps_test(

# Other deps.
"@com_github_google_btree//:go_default_library",
"@org_golang_x_exp//constraints:go_default_library",
"@org_golang_x_sys//cpu:go_default_library",
"@org_golang_x_sys//unix:go_default_library",
"@org_golang_x_time//rate:go_default_library",
Expand Down
Loading