Skip to content

Commit 297bbdb

Browse files
tamirdgvisor-bot
authored andcommitted
bits: use normal Go generics
FUTURE_COPYBARA_INTEGRATE_REVIEW=#12482 from tamird:bits-generics 23c4f2e PiperOrigin-RevId: 855373016
1 parent e0a2f60 commit 297bbdb

File tree

12 files changed

+46
-70
lines changed

12 files changed

+46
-70
lines changed

pkg/abi/linux/signal.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,12 @@ func MakeSignalSet(sigs ...Signal) SignalSet {
130130
for i, sig := range sigs {
131131
indices[i] = sig.Index()
132132
}
133-
return SignalSet(bits.Mask64(indices...))
133+
return bits.Mask[SignalSet](indices...)
134134
}
135135

136136
// SignalSetOf returns a SignalSet with a single signal set.
137137
func SignalSetOf(sig Signal) SignalSet {
138-
return SignalSet(bits.MaskOf64(sig.Index()))
138+
return bits.MaskOf[SignalSet](sig.Index())
139139
}
140140

141141
// ForEachSignal invokes f for each signal set in the given mask.

pkg/bits/BUILD

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
load("//tools:defs.bzl", "go_library", "go_test")
2-
load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
32

43
package(
54
default_applicable_licenses = ["//:license"],
@@ -10,46 +9,18 @@ go_library(
109
name = "bits",
1110
srcs = [
1211
"bits.go",
13-
"bits32.go",
14-
"bits64.go",
12+
"bits_template.go",
1513
"uint64_arch.go",
1614
"uint64_arch_amd64_asm.s",
1715
"uint64_arch_arm64_asm.s",
1816
"uint64_arch_generic.go",
1917
],
2018
visibility = ["//:sandbox"],
21-
)
22-
23-
go_template(
24-
name = "bits_template",
25-
srcs = ["bits_template.go"],
26-
types = [
27-
"T",
19+
deps = [
20+
"@org_golang_x_exp//constraints:go_default_library",
2821
],
2922
)
3023

31-
go_template_instance(
32-
name = "bits64",
33-
out = "bits64.go",
34-
package = "bits",
35-
suffix = "64",
36-
template = ":bits_template",
37-
types = {
38-
"T": "uint64",
39-
},
40-
)
41-
42-
go_template_instance(
43-
name = "bits32",
44-
out = "bits32.go",
45-
package = "bits",
46-
suffix = "32",
47-
template = ":bits_template",
48-
types = {
49-
"T": "uint32",
50-
},
51-
)
52-
5324
go_test(
5425
name = "bits_test",
5526
size = "small",

pkg/bits/bits_template.go

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,37 +14,36 @@
1414

1515
package bits
1616

17-
// Non-atomic bit operations on a template type T.
17+
import "golang.org/x/exp/constraints"
1818

19-
// T is a required type parameter that must be an integral type.
20-
type T uint64
19+
// Non-atomic bit operations on integral types.
2120

2221
// IsOn returns true if *all* bits set in 'bits' are set in 'mask'.
23-
func IsOn(mask, bits T) bool {
22+
func IsOn[T constraints.Integer](mask, bits T) bool {
2423
return mask&bits == bits
2524
}
2625

2726
// IsAnyOn returns true if *any* bit set in 'bits' is set in 'mask'.
28-
func IsAnyOn(mask, bits T) bool {
27+
func IsAnyOn[T constraints.Integer](mask, bits T) bool {
2928
return mask&bits != 0
3029
}
3130

3231
// Mask returns a T with all of the given bits set.
33-
func Mask(is ...int) T {
32+
func Mask[T constraints.Integer](is ...int) T {
3433
ret := T(0)
3534
for _, i := range is {
36-
ret |= MaskOf(i)
35+
ret |= MaskOf[T](i)
3736
}
3837
return ret
3938
}
4039

4140
// MaskOf is like Mask, but sets only a single bit (more efficiently).
42-
func MaskOf(i int) T {
41+
func MaskOf[T constraints.Integer](i int) T {
4342
return T(1) << T(i)
4443
}
4544

4645
// IsPowerOfTwo returns true if v is power of 2.
47-
func IsPowerOfTwo(v T) bool {
46+
func IsPowerOfTwo[T constraints.Integer](v T) bool {
4847
if v == 0 {
4948
return false
5049
}

pkg/bits/uint64_arch.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,6 @@ func ForEachSetBit64(x uint64, f func(i int)) {
3232
for x != 0 {
3333
i := TrailingZeros64(x)
3434
f(i)
35-
x &^= MaskOf64(i)
35+
x &^= MaskOf[uint64](i)
3636
}
3737
}

pkg/bits/uint64_test.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ func TestMostSignificantOne64(t *testing.T) {
6565
}
6666
}
6767

68+
// Mask64 returns a uint64 with all of the given bits set.
69+
func Mask64(is ...int) uint64 {
70+
return Mask[uint64](is...)
71+
}
72+
6873
func TestForEachSetBit64(t *testing.T) {
6974
for _, want := range [][]int{
7075
{},
@@ -106,10 +111,10 @@ func TestIsOn(t *testing.T) {
106111
{Mask64(1, 63), Mask64(0, 1, 63), true, false},
107112
{Mask64(1, 63), Mask64(0, 62), false, false},
108113
} {
109-
if ok := IsAnyOn64(s.mask, s.bits); ok != s.any {
114+
if ok := IsAnyOn(s.mask, s.bits); ok != s.any {
110115
t.Errorf("IsAnyOn(%#x, %#x) = %v, wanted: %v", s.mask, s.bits, ok, s.any)
111116
}
112-
if ok := IsOn64(s.mask, s.bits); ok != s.all {
117+
if ok := IsOn(s.mask, s.bits); ok != s.all {
113118
t.Errorf("IsOn(%#x, %#x) = %v, wanted: %v", s.mask, s.bits, ok, s.all)
114119
}
115120
}
@@ -127,7 +132,7 @@ func TestIsPowerOfTwo(t *testing.T) {
127132
{v: 4, want: true},
128133
{v: 5, want: false},
129134
} {
130-
if got := IsPowerOfTwo64(tc.v); got != tc.want {
135+
if got := IsPowerOfTwo(tc.v); got != tc.want {
131136
t.Errorf("IsPowerOfTwo(%d) = %t, want: %t", tc.v, got, tc.want)
132137
}
133138
}

pkg/sentry/kernel/auth/capability_set.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,16 @@ var AllCapabilities = CapabilitySetOf(linux.CAP_LAST_CAP+1) - 1
3131
// CapabilitySetOf returns a CapabilitySet containing only the given
3232
// capability.
3333
func CapabilitySetOf(cp linux.Capability) CapabilitySet {
34-
return CapabilitySet(bits.MaskOf64(int(cp)))
34+
return bits.MaskOf[CapabilitySet](int(cp))
3535
}
3636

3737
// CapabilitySetOfMany returns a CapabilitySet containing the given capabilities.
3838
func CapabilitySetOfMany(cps []linux.Capability) CapabilitySet {
39-
var cs uint64
39+
var cs CapabilitySet
4040
for _, cp := range cps {
41-
cs |= bits.MaskOf64(int(cp))
41+
cs |= bits.MaskOf[CapabilitySet](int(cp))
4242
}
43-
return CapabilitySet(cs)
43+
return cs
4444
}
4545

4646
// Add adds the given capability to the CapabilitySet.

pkg/sentry/kernel/syscalls.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ func (e *SyscallFlagsTable) UpdateSecCheck(state *seccheck.State) {
165165
defer e.mu.Unlock()
166166
for sysno := uintptr(0); sysno <= sentry.MaxSyscallNum; sysno++ {
167167
oldFlags := e.enable[sysno].Load()
168-
if !bits.IsOn32(oldFlags, syscallPresent) {
168+
if !bits.IsOn(oldFlags, syscallPresent) {
169169
continue
170170
}
171171
flags := oldFlags
@@ -226,7 +226,7 @@ func (e *SyscallFlagsTable) Enable(bit uint32, s map[uintptr]bool, missingEnable
226226

227227
for num := range e.enable {
228228
val := e.enable[num].Load()
229-
if !bits.IsOn32(val, syscallPresent) {
229+
if !bits.IsOn(val, syscallPresent) {
230230
// Missing.
231231
e.enable[num].Store(missingVal)
232232
continue
@@ -252,7 +252,7 @@ func (e *SyscallFlagsTable) EnableAll(bit uint32) {
252252

253253
for num := range e.enable {
254254
val := e.enable[num].Load()
255-
if !bits.IsOn32(val, syscallPresent) {
255+
if !bits.IsOn(val, syscallPresent) {
256256
// Missing.
257257
e.enable[num].Store(missingVal)
258258
continue

pkg/sentry/kernel/task_syscall.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,11 @@ func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval u
8787
fe := s.FeatureEnable.Word(sysno)
8888

8989
var straceContext any
90-
if bits.IsAnyOn32(fe, StraceEnableBits) {
90+
if bits.IsAnyOn(fe, StraceEnableBits) {
9191
straceContext = s.Stracer.SyscallEnter(t, sysno, args, fe)
9292
}
9393

94-
if bits.IsAnyOn32(fe, SecCheckRawEnter) {
94+
if bits.IsAnyOn(fe, SecCheckRawEnter) {
9595
info := pb.Syscall{
9696
Sysno: uint64(sysno),
9797
Arg1: args[0].Uint64(),
@@ -110,7 +110,7 @@ func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval u
110110
return c.RawSyscall(t, fields, &info)
111111
})
112112
}
113-
if bits.IsAnyOn32(fe, SecCheckEnter) {
113+
if bits.IsAnyOn(fe, SecCheckEnter) {
114114
fields := seccheck.Global.GetFieldSet(seccheck.GetPointForSyscall(seccheck.SyscallEnter, sysno))
115115
var ctxData *pb.ContextData
116116
if !fields.Context.Empty() {
@@ -128,7 +128,7 @@ func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval u
128128
})
129129
}
130130

131-
if bits.IsOn32(fe, ExternalBeforeEnable) && (s.ExternalFilterBefore == nil || s.ExternalFilterBefore(t, sysno, args)) {
131+
if bits.IsOn(fe, ExternalBeforeEnable) && (s.ExternalFilterBefore == nil || s.ExternalFilterBefore(t, sysno, args)) {
132132
t.invokeExternal()
133133
// Ensure we check for stops, then invoke the syscall again.
134134
ctrl = ctrlStopAndReinvokeSyscall
@@ -150,16 +150,16 @@ func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval u
150150
}
151151
}
152152

153-
if bits.IsOn32(fe, ExternalAfterEnable) && (s.ExternalFilterAfter == nil || s.ExternalFilterAfter(t, sysno, args)) {
153+
if bits.IsOn(fe, ExternalAfterEnable) && (s.ExternalFilterAfter == nil || s.ExternalFilterAfter(t, sysno, args)) {
154154
t.invokeExternal()
155155
// Don't reinvoke the unix.
156156
}
157157

158-
if bits.IsAnyOn32(fe, StraceEnableBits) {
158+
if bits.IsAnyOn(fe, StraceEnableBits) {
159159
s.Stracer.SyscallExit(straceContext, t, sysno, rval, err)
160160
}
161161

162-
if bits.IsAnyOn32(fe, SecCheckRawExit) {
162+
if bits.IsAnyOn(fe, SecCheckRawExit) {
163163
info := pb.Syscall{
164164
Sysno: uint64(sysno),
165165
Arg1: args[0].Uint64(),
@@ -182,7 +182,7 @@ func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval u
182182
return c.RawSyscall(t, fields, &info)
183183
})
184184
}
185-
if bits.IsAnyOn32(fe, SecCheckExit) {
185+
if bits.IsAnyOn(fe, SecCheckExit) {
186186
fields := seccheck.Global.GetFieldSet(seccheck.GetPointForSyscall(seccheck.SyscallExit, sysno))
187187
var ctxData *pb.ContextData
188188
if !fields.Context.Empty() {

pkg/sentry/strace/strace.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -682,10 +682,10 @@ func (s SyscallMap) SyscallEnter(t *kernel.Task, sysno uintptr, args arch.Syscal
682682
}
683683

684684
var output, eventOutput []string
685-
if bits.IsOn32(flags, kernel.StraceEnableLog) {
685+
if bits.IsOn(flags, kernel.StraceEnableLog) {
686686
output = info.printEnter(t, args)
687687
}
688-
if bits.IsOn32(flags, kernel.StraceEnableEvent) {
688+
if bits.IsOn(flags, kernel.StraceEnableEvent) {
689689
eventOutput = info.sendEnter(t, args)
690690
}
691691

@@ -706,10 +706,10 @@ func (s SyscallMap) SyscallExit(context any, t *kernel.Task, sysno, rval uintptr
706706
c := context.(*syscallContext)
707707

708708
elapsed := time.Since(c.start)
709-
if bits.IsOn32(c.flags, kernel.StraceEnableLog) {
709+
if bits.IsOn(c.flags, kernel.StraceEnableLog) {
710710
c.info.printExit(t, elapsed, c.logOutput, c.args, rval, err, errno)
711711
}
712-
if bits.IsOn32(c.flags, kernel.StraceEnableEvent) {
712+
if bits.IsOn(c.flags, kernel.StraceEnableEvent) {
713713
c.info.sendExit(t, elapsed, c.eventOutput, c.args, rval, err, errno)
714714
}
715715
}
@@ -786,10 +786,10 @@ const (
786786

787787
func convertToSyscallFlag(sinks SinkType) uint32 {
788788
ret := uint32(0)
789-
if bits.IsOn32(uint32(sinks), uint32(SinkTypeLog)) {
789+
if bits.IsOn(uint32(sinks), uint32(SinkTypeLog)) {
790790
ret |= kernel.StraceEnableLog
791791
}
792-
if bits.IsOn32(uint32(sinks), uint32(SinkTypeEvent)) {
792+
if bits.IsOn(uint32(sinks), uint32(SinkTypeEvent)) {
793793
ret |= kernel.StraceEnableEvent
794794
}
795795
return ret

pkg/sentry/syscalls/linux/sys_stat.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ func Statx(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr,
165165
}
166166
// Make sure that only one sync type option is set.
167167
syncType := uint32(flags & linux.AT_STATX_SYNC_TYPE)
168-
if syncType != 0 && !bits.IsPowerOfTwo32(syncType) {
168+
if syncType != 0 && !bits.IsPowerOfTwo(syncType) {
169169
return 0, nil, linuxerr.EINVAL
170170
}
171171
if mask&linux.STATX__RESERVED != 0 {

0 commit comments

Comments
 (0)