Skip to content

Commit 5b40a31

Browse files
committed
feat: add x/runtime AutoGOMAXPROCs, PrintStackTrace, SetCurrentUser; add singleflight
1 parent 765770d commit 5b40a31

File tree

7 files changed

+477
-0
lines changed

7 files changed

+477
-0
lines changed

x/runtime/maxprocs.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package runtime
2+
3+
import (
4+
"os"
5+
"runtime"
6+
"runtime/debug"
7+
"strconv"
8+
)
9+
10+
const MaxMemory = 512 << 20 // 512MB
11+
12+
func AutoGOMAXPROCS() (int, int) {
13+
// set the max proc = preCore / 2
14+
maxThreads := 1000
15+
if envStr, ok := os.LookupEnv("APP_MAXTHREADS"); ok {
16+
if v, err := strconv.Atoi(envStr); err == nil {
17+
maxThreads = v
18+
}
19+
}
20+
21+
debug.SetMaxThreads(maxThreads)
22+
debug.SetMemoryLimit(MaxMemory)
23+
return runtime.GOMAXPROCS(runtime.NumCPU()), maxThreads
24+
}

x/runtime/maxprocs_test.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package runtime_test
2+
3+
import (
4+
"os"
5+
"testing"
6+
7+
"github.com/omalloc/contrib/x/runtime"
8+
)
9+
10+
func TestAutoGOMAXPROCS(t *testing.T) {
11+
tests := []struct {
12+
name string
13+
envMaxThreads string
14+
wantMaxThreads int
15+
}{
16+
{
17+
name: "default_value",
18+
envMaxThreads: "",
19+
wantMaxThreads: 1000,
20+
},
21+
{
22+
name: "custom_max_threads",
23+
envMaxThreads: "500",
24+
wantMaxThreads: 500,
25+
},
26+
{
27+
name: "invalid_env_value",
28+
envMaxThreads: "invalid",
29+
wantMaxThreads: 1000,
30+
},
31+
}
32+
33+
for _, tt := range tests {
34+
t.Run(tt.name, func(t *testing.T) {
35+
if tt.envMaxThreads != "" {
36+
os.Setenv("APP_MAXTHREADS", tt.envMaxThreads)
37+
defer os.Unsetenv("APP_MAXTHREADS")
38+
}
39+
40+
_, gotMaxThreads := runtime.AutoGOMAXPROCS()
41+
if gotMaxThreads != tt.wantMaxThreads {
42+
t.Errorf("AutoGOMAXPROCS() maxThreads = %v, want %v", gotMaxThreads, tt.wantMaxThreads)
43+
}
44+
})
45+
}
46+
}

x/runtime/recovered.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package runtime
2+
3+
import (
4+
"fmt"
5+
"runtime"
6+
"strings"
7+
)
8+
9+
func PrintStackTrace(skip int) string {
10+
// Capture the stack trace
11+
pc := make([]uintptr, 10)
12+
n := runtime.Callers(skip, pc)
13+
frames := runtime.CallersFrames(pc[:n])
14+
15+
// Iterate over the frames and print them
16+
sb := strings.Builder{}
17+
for {
18+
frame, more := frames.Next()
19+
if strings.HasPrefix(frame.Function, "runtime.panic") || strings.HasPrefix(frame.Function, "runtime.gopanic") {
20+
sb.WriteString("panic: ")
21+
continue
22+
}
23+
sb.WriteString(fmt.Sprintf("%s\n\t%s:%d\n", frame.Function, frame.File, frame.Line))
24+
if !more {
25+
break
26+
}
27+
}
28+
29+
return sb.String()
30+
}

x/runtime/recovered_test.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package runtime_test
2+
3+
import (
4+
"strings"
5+
"testing"
6+
7+
"github.com/omalloc/contrib/x/runtime"
8+
)
9+
10+
func TestPrintStackTrace(t *testing.T) {
11+
// 测试基本堆栈跟踪输出
12+
trace := runtime.PrintStackTrace(1)
13+
14+
// 验证输出包含当前测试函数名
15+
if !strings.Contains(trace, "TestPrintStackTrace") {
16+
t.Errorf("堆栈跟踪中应该包含测试函数名 'TestPrintStackTrace',但实际输出为:\n%s", trace)
17+
}
18+
19+
// 验证输出包含文件名和行号
20+
if !strings.Contains(trace, "recovered_test.go") {
21+
t.Errorf("堆栈跟踪中应该包含测试文件名 'recovered_test.go',但实际输出为:\n%s", trace)
22+
}
23+
24+
// 验证输出格式
25+
lines := strings.Split(strings.TrimSpace(trace), "\n")
26+
if len(lines) < 2 {
27+
t.Errorf("堆栈跟踪应至少包含两行,但实际输出为:\n%s", trace)
28+
}
29+
30+
// 测试不同的 skip 值
31+
trace2 := runtime.PrintStackTrace(2)
32+
if trace == trace2 {
33+
t.Error("不同的 skip 值应产生不同的堆栈跟踪")
34+
}
35+
}
36+
37+
// 用于测试 panic 情况的辅助函数
38+
func causePanic() string {
39+
panic("测试 panic")
40+
}
41+
42+
func TestPrintStackTraceWithPanic(t *testing.T) {
43+
defer func() {
44+
if r := recover(); r != nil {
45+
trace := runtime.PrintStackTrace(1)
46+
47+
// 验证输出包含 panic 前缀
48+
if !strings.Contains(trace, "panic:") {
49+
t.Errorf("panic 堆栈跟踪应包含 'panic:' 前缀,但实际输出为:\n%s", trace)
50+
}
51+
52+
// 验证输出包含触发 panic 的函数名
53+
if !strings.Contains(trace, "causePanic") {
54+
t.Errorf("堆栈跟踪应包含 'causePanic' 函数,但实际输出为:\n%s", trace)
55+
}
56+
}
57+
}()
58+
59+
causePanic()
60+
}

x/runtime/user.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package runtime
2+
3+
import (
4+
"fmt"
5+
"os/user"
6+
"strconv"
7+
"syscall"
8+
)
9+
10+
func SetCurrentUser(username string) error {
11+
if username != "" {
12+
current, err := user.Current()
13+
if err != nil {
14+
return err
15+
}
16+
17+
if current.Username == username {
18+
return nil
19+
}
20+
21+
wantedUser, err := user.Lookup(username)
22+
if err != nil {
23+
return err
24+
}
25+
26+
uid, err := strconv.Atoi(wantedUser.Uid)
27+
if err != nil {
28+
return fmt.Errorf("error converting UID [%s] to int: %s", wantedUser.Uid, err)
29+
}
30+
31+
gid, err := strconv.Atoi(wantedUser.Gid)
32+
if err != nil {
33+
return fmt.Errorf("error converting GID [%s] to int: %s", wantedUser.Gid, err)
34+
}
35+
36+
if err = syscall.Setgid(gid); err != nil {
37+
return fmt.Errorf("setting group id: %s", err)
38+
}
39+
40+
if err = syscall.Setuid(uid); err != nil {
41+
return fmt.Errorf("setting user id: %s", err)
42+
}
43+
}
44+
45+
return nil
46+
}

x/runtime/user_test.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package runtime_test
2+
3+
import (
4+
"os/user"
5+
"testing"
6+
7+
"github.com/omalloc/contrib/x/runtime"
8+
)
9+
10+
func TestSetCurrentUser(t *testing.T) {
11+
tests := []struct {
12+
name string
13+
username string
14+
wantErr bool
15+
}{
16+
{
17+
name: "空用户名应该直接返回nil",
18+
username: "",
19+
wantErr: false,
20+
},
21+
{
22+
name: "当前用户名相同应该返回nil",
23+
username: getCurrentUsername(t),
24+
wantErr: false,
25+
},
26+
{
27+
name: "不存在的用户名应该返回错误",
28+
username: "nonexistentuser123456789",
29+
wantErr: true,
30+
},
31+
{
32+
name: "切换用户为daemon",
33+
username: "daemon",
34+
wantErr: true,
35+
},
36+
}
37+
38+
for _, tt := range tests {
39+
t.Run(tt.name, func(t *testing.T) {
40+
err := runtime.SetCurrentUser(tt.username)
41+
if (err != nil) != tt.wantErr {
42+
t.Errorf("SetCurrentUser() error = %v, wantErr %v", err, tt.wantErr)
43+
}
44+
})
45+
}
46+
}
47+
48+
// 辅助函数:获取当前用户名
49+
func getCurrentUsername(t *testing.T) string {
50+
current, err := user.Current()
51+
if err != nil {
52+
t.Fatalf("无法获取当前用户: %v", err)
53+
}
54+
return current.Username
55+
}

0 commit comments

Comments
 (0)