Skip to content

Commit ad00cc1

Browse files
authored
feat(gpud): clean up --endpoint flag parsing, persist in systemd service file (#681)
Signed-off-by: Gyuho Lee <[email protected]> Signed-off-by: Gyuho Lee <[email protected]>
1 parent 14b8c5c commit ad00cc1

File tree

10 files changed

+276
-38
lines changed

10 files changed

+276
-38
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,5 @@ go.work.sum
2626
/bin/
2727

2828
.DS_Store
29+
30+
coverage.html

cmd/gpud/command/join.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"fmt"
1111
"io"
1212
"net/http"
13+
"net/url"
1314
"os"
1415
"os/exec"
1516
"path/filepath"
@@ -30,6 +31,7 @@ import (
3031
func cmdJoin(cliContext *cli.Context) (retErr error) {
3132
rootCtx, rootCancel := context.WithTimeout(context.Background(), 3*time.Minute)
3233
defer rootCancel()
34+
3335
endpoint := cliContext.String("endpoint")
3436
clusterName := cliContext.String("cluster-name")
3537
provider := cliContext.String("provider")
@@ -174,7 +176,7 @@ func cmdJoin(cliContext *cli.Context) (retErr error) {
174176
}
175177
}
176178
fmt.Println("Please wait while control plane is initializing basic setup for your machine, this may take up to one minute...")
177-
response, err := http.Post(fmt.Sprintf("%s/api/v1/join", endpoint), "application/json", bytes.NewBuffer(rawPayload))
179+
response, err := http.Post(createJoinURL(endpoint), "application/json", bytes.NewBuffer(rawPayload))
178180
if err != nil {
179181
return err
180182
}
@@ -291,3 +293,13 @@ func runCommand(ctx context.Context, script string, result *string) error {
291293
}
292294
return nil
293295
}
296+
297+
// createJoinURL creates a URL for the join endpoint
298+
func createJoinURL(endpoint string) string {
299+
host := endpoint
300+
url, _ := url.Parse(endpoint)
301+
if url.Host != "" {
302+
host = url.Host
303+
}
304+
return fmt.Sprintf("https://%s/api/v1/join", host)
305+
}

cmd/gpud/command/notify.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
"io"
99
"net/http"
10+
"net/url"
1011
"time"
1112

1213
"github.com/urfave/cli"
@@ -60,7 +61,7 @@ func notification(ctx context.Context, endpoint string, req payload) error {
6061
Status string `json:"status"`
6162
}
6263
rawPayload, _ := json.Marshal(&req)
63-
response, err := http.Post(fmt.Sprintf("https://%s/api/v1/notification", endpoint), "application/json", bytes.NewBuffer(rawPayload))
64+
response, err := http.Post(createNotificationURL(endpoint), "application/json", bytes.NewBuffer(rawPayload))
6465
if err != nil {
6566
return err
6667
}
@@ -79,3 +80,13 @@ func notification(ctx context.Context, endpoint string, req payload) error {
7980
}
8081
return nil
8182
}
83+
84+
// createNotificationURL creates a URL for the notification endpoint
85+
func createNotificationURL(endpoint string) string {
86+
host := endpoint
87+
url, _ := url.Parse(endpoint)
88+
if url.Host != "" {
89+
host = url.Host
90+
}
91+
return fmt.Sprintf("https://%s/api/v1/notification", host)
92+
}

cmd/gpud/command/up.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ func cmdUp(cliContext *cli.Context) (retErr error) {
4343
return fmt.Errorf("gpud binary not found at %s (you may run 'cp %s %s' to fix the installation)", systemd.DefaultBinPath, bin, systemd.DefaultBinPath)
4444
}
4545

46-
if err := systemdInit(); err != nil {
46+
endpoint := cliContext.String("endpoint")
47+
if err := systemdInit(endpoint); err != nil {
4748
fmt.Printf("%s failed to initialize systemd files\n", warningSign)
4849
return err
4950
}
@@ -62,8 +63,8 @@ func cmdUp(cliContext *cli.Context) (retErr error) {
6263
return nil
6364
}
6465

65-
func systemdInit() error {
66-
if err := systemd.CreateDefaultEnvFile(); err != nil {
66+
func systemdInit(endpoint string) error {
67+
if err := systemd.CreateDefaultEnvFile(endpoint); err != nil {
6768
return err
6869
}
6970
systemdUnitFileData := systemd.GPUDService

pkg/gossip/client.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
"io"
99
"net/http"
10+
"net/url"
1011
"os"
1112

1213
apiv1 "github.com/leptonai/gpud/api/v1"
@@ -65,5 +66,10 @@ func sendRequest(ctx context.Context, url string, req apiv1.GossipRequest) (*api
6566

6667
// createURL creates a URL for the gossip endpoint
6768
func createURL(endpoint string) string {
68-
return fmt.Sprintf("https://%s/api/v1/gossip", endpoint)
69+
host := endpoint
70+
url, _ := url.Parse(endpoint)
71+
if url.Host != "" {
72+
host = url.Host
73+
}
74+
return fmt.Sprintf("https://%s/api/v1/gossip", host)
6975
}

pkg/gossip/client_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ func TestCreateURL(t *testing.T) {
1919
endpoint string
2020
expected string
2121
}{
22+
{"https://example.com", "https://example.com/api/v1/gossip"},
2223
{"example.com", "https://example.com/api/v1/gossip"},
2324
{"api.leptonai.com", "https://api.leptonai.com/api/v1/gossip"},
2425
}

pkg/gpud-manager/systemd/systemd.go

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,32 @@ func DefaultBinExists() bool {
2727

2828
// CreateDefaultEnvFile creates the default environment file for gpud systemd service.
2929
// Assume systemdctl is already installed, and runs on the linux system.
30-
func CreateDefaultEnvFile() error {
31-
return writeEnvFile(DefaultEnvFile)
30+
func CreateDefaultEnvFile(endpoint string) error {
31+
return writeEnvFile(DefaultEnvFile, endpoint)
3232
}
3333

3434
const defaultEnvFileContent = `# gpud environment variables are set here
3535
FLAGS="--log-level=info --log-file=/var/log/gpud.log"
3636
`
3737

38-
func writeEnvFile(file string) error {
38+
func createDefaultEnvFileContent(endpoint string) string {
39+
if endpoint == "" {
40+
return defaultEnvFileContent
41+
}
42+
return fmt.Sprintf(`# gpud environment variables are set here
43+
FLAGS="--log-level=info --log-file=/var/log/gpud.log --endpoint=%s"
44+
`, endpoint)
45+
}
46+
47+
func writeEnvFile(file string, endpoint string) error {
3948
if _, err := os.Stat(file); err == nil {
40-
return addLogFileFlagIfExists(file)
49+
return updateFlagsFromExistingEnvFile(file, endpoint)
4150
}
42-
return atomicfile.WriteFile(file, []byte(defaultEnvFileContent), 0644)
51+
return atomicfile.WriteFile(file, []byte(createDefaultEnvFileContent(endpoint)), 0644)
4352
}
4453

45-
func addLogFileFlagIfExists(file string) error {
46-
lines, err := processEnvFileLines(file)
54+
func updateFlagsFromExistingEnvFile(file string, endpoint string) error {
55+
lines, err := processEnvFileLines(file, endpoint)
4756
if err != nil {
4857
return err
4958
}
@@ -52,7 +61,7 @@ func addLogFileFlagIfExists(file string) error {
5261

5362
// processEnvFileLines reads all lines from the environment file and processes each line,
5463
// adding the log-file flag to the FLAGS variable if it doesn't already exist.
55-
func processEnvFileLines(file string) ([]string, error) {
64+
func processEnvFileLines(file string, endpoint string) ([]string, error) {
5665
readFile, err := os.OpenFile(file, os.O_RDONLY, 0644)
5766
if err != nil {
5867
return nil, err
@@ -70,15 +79,22 @@ func processEnvFileLines(file string) ([]string, error) {
7079
continue
7180
}
7281

73-
// FLAGS already contains --log-file flag
74-
if strings.Contains(line, "--log-file=") {
82+
// FLAGS already contains --log-file flag and --endpoint flag
83+
if strings.Contains(line, "--log-file=") && (endpoint != "" && strings.Contains(line, "--endpoint=")) {
7584
lines = append(lines, line)
7685
continue
7786
}
7887

7988
// remove the trailing " character
8089
line = strings.TrimSuffix(line, "\"")
81-
line = fmt.Sprintf("%s --log-file=/var/log/gpud.log\"", line)
90+
91+
if !strings.Contains(line, "--log-file=") {
92+
line = fmt.Sprintf("%s --log-file=/var/log/gpud.log\"", line)
93+
}
94+
95+
if endpoint != "" && !strings.Contains(line, "--endpoint=") {
96+
line = fmt.Sprintf("%s --endpoint=%s\"", line, endpoint)
97+
}
8298

8399
lines = append(lines, line)
84100
}

0 commit comments

Comments
 (0)