diff --git a/main.go b/main.go index 06b09dd..efc1c41 100644 --- a/main.go +++ b/main.go @@ -25,7 +25,6 @@ import ( "time" humanize "github.com/dustin/go-humanize" - psutil "github.com/shirou/gopsutil/mem" ) var ( @@ -100,7 +99,7 @@ func main() { for { workQueue <- struct{}{} go func() { - cmd := exec.Command("memStress", "--size", memSize, "--workers", fmt.Sprintf("%d", workers), + cmd := exec.Command("./memStress", "--size", memSize, "--workers", fmt.Sprintf("%d", workers), "--time", growthTime, "--client", "1") cmd.SysProcAttr = &syscall.SysProcAttr{ Pdeathsig: syscall.SIGTERM, @@ -114,7 +113,11 @@ func main() { time.Sleep(time.Second) } } else { - memInfo, _ := psutil.VirtualMemory() + totalMem, err := getTotalMemory() + if err != nil { + fmt.Println("get total memory err:", err) + return + } var length uint64 if memSize[len(memSize)-1] != '%' { @@ -129,7 +132,7 @@ func main() { if err != nil { fmt.Println(err) } - length = uint64(float64(memInfo.Total) / 100.0 * percentage) + length = uint64(float64(totalMem) / 100.0 * percentage) } timeLine, err := time.ParseDuration(growthTime) diff --git a/util.go b/util.go new file mode 100644 index 0000000..c011770 --- /dev/null +++ b/util.go @@ -0,0 +1,71 @@ +package main + +import ( + "fmt" + "io/ioutil" + "strconv" + "strings" + + psutil "github.com/shirou/gopsutil/mem" +) + +const ( + cgroupV2Path = "/sys/fs/cgroup/memory.max" + cgroupV1Path = "/sys/fs/cgroup/memory/memory.limit_in_bytes" + cgroupNoLimitV1 = 0x7FFFFFFFFFFFF000 +) + +// Read file content and parse as uint64 +func readUintFromFile(path string) (uint64, error) { + data, err := ioutil.ReadFile(path) + if err != nil { + return 0, err + } + content := strings.TrimSpace(string(data)) + return strconv.ParseUint(content, 10, 64) +} + +// Check cgroup v2 memory limit +func getCgroupV2Limit() (uint64, error) { + data, err := ioutil.ReadFile(cgroupV2Path) + if err != nil { + return 0, err + } + content := strings.TrimSpace(string(data)) + if content == "max" { + return 0, fmt.Errorf("cgroup v2: no memory limit set") + } + limit, err := strconv.ParseUint(content, 10, 64) + if err != nil || limit == 0 { + return 0, fmt.Errorf("cgroup v2: invalid memory limit") + } + return limit, nil +} + +// Check cgroup v1 memory limit +func getCgroupV1Limit() (uint64, error) { + limit, err := readUintFromFile(cgroupV1Path) + if err != nil { + return 0, err + } + // 0 or cgroup's "infinity" value means no limit + if limit == 0 || limit >= cgroupNoLimitV1 { + return 0, fmt.Errorf("cgroup v1: no memory limit set") + } + return limit, nil +} + +// Get total memory, prefer cgroup v2 -> cgroup v1 -> host +func getTotalMemory() (uint64, error) { + if limit, err := getCgroupV2Limit(); err == nil { + return limit, nil + } + if limit, err := getCgroupV1Limit(); err == nil { + return limit, nil + } + mem, err := psutil.VirtualMemory() + if err != nil { + return 0, fmt.Errorf("failed to get system memory: %v", err) + } + return mem.Total, nil +}