Skip to content

Commit 8585e54

Browse files
committed
refactor: interpreter to use di and add coverage
1 parent c66c538 commit 8585e54

15 files changed

+641
-381
lines changed

cmd/agent_smith/main.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"strings"
77

88
"github.com/RewstApp/agent-smith-go/internal/agent"
9+
"github.com/RewstApp/agent-smith-go/internal/interpreter"
910
"github.com/RewstApp/agent-smith-go/internal/utils"
1011
)
1112

@@ -36,6 +37,7 @@ func main() {
3637
// Create providers
3738
sys := agent.NewSystemInfoProvider()
3839
domain := agent.NewDomainInfoProvider()
40+
executor := interpreter.NewExecutor()
3941

4042
uninstallContext, err := newUninstallContext(os.Args[1:])
4143
if err == nil {
@@ -51,7 +53,7 @@ func main() {
5153
return
5254
}
5355

54-
serviceContext, err := newServiceContext(os.Args[1:], sys, domain)
56+
serviceContext, err := newServiceContext(os.Args[1:], sys, domain, executor)
5557
if err == nil {
5658
// Run service routine
5759
runService(serviceContext)

cmd/agent_smith/service.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ func (svc *serviceContext) Execute(stop <-chan struct{}, running chan<- struct{}
198198
notifier.Notify("AgentReceivedMessage:" + string(msg.Payload()))
199199

200200
// Execute the message
201-
resultBytes := message.Execute(ctx, device, logger, svc.Sys, svc.Domain)
201+
resultBytes := message.Execute(svc.Executor, ctx, device, logger, svc.Sys, svc.Domain)
202202

203203
// Skip if there is no post_id specified
204204
if message.PostId == "" {

cmd/agent_smith/service_context.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77

88
"github.com/RewstApp/agent-smith-go/internal/agent"
9+
"github.com/RewstApp/agent-smith-go/internal/interpreter"
910
)
1011

1112
type serviceContext struct {
@@ -15,9 +16,11 @@ type serviceContext struct {
1516

1617
Sys agent.SystemInfoProvider
1718
Domain agent.DomainInfoProvider
19+
20+
Executor interpreter.Executor
1821
}
1922

20-
func newServiceContext(args []string, sys agent.SystemInfoProvider, domain agent.DomainInfoProvider) (*serviceContext, error) {
23+
func newServiceContext(args []string, sys agent.SystemInfoProvider, domain agent.DomainInfoProvider, executor interpreter.Executor) (*serviceContext, error) {
2124
var params serviceContext
2225

2326
fs := flag.NewFlagSet("config", flag.ContinueOnError)
@@ -45,6 +48,7 @@ func newServiceContext(args []string, sys agent.SystemInfoProvider, domain agent
4548

4649
params.Sys = sys
4750
params.Domain = domain
51+
params.Executor = executor
4852

4953
return &params, nil
5054
}

cmd/agent_smith/service_context_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ func TestNewServiceContext(t *testing.T) {
1010
configFile := "/file/config"
1111
logFile := "/file/log"
1212

13-
result, _ := newServiceContext([]string{"--org-id", orgId, "--config-file", configFile, "--log-file", logFile}, nil, nil)
13+
result, _ := newServiceContext([]string{"--org-id", orgId, "--config-file", configFile, "--log-file", logFile}, nil, nil, nil)
1414

1515
if result.OrgId != orgId {
1616
t.Errorf("expected %v, got %v", orgId, result.OrgId)
@@ -43,7 +43,7 @@ func TestNewServiceContext(t *testing.T) {
4343
}
4444

4545
for _, errorTest := range errorTests {
46-
_, err := newServiceContext(errorTest.args, nil, nil)
46+
_, err := newServiceContext(errorTest.args, nil, nil, nil)
4747

4848
if err == nil || !strings.Contains(err.Error(), errorTest.message) {
4949
t.Errorf("expected error %s, got %v", errorTest.message, err.Error())
Original file line numberDiff line numberDiff line change
@@ -1,119 +1,132 @@
1-
package interpreter
2-
3-
import (
4-
"bytes"
5-
"context"
6-
"encoding/base64"
7-
"fmt"
8-
"os"
9-
"os/exec"
10-
"strings"
11-
12-
"github.com/RewstApp/agent-smith-go/internal/agent"
13-
"github.com/RewstApp/agent-smith-go/internal/utils"
14-
"github.com/RewstApp/agent-smith-go/internal/version"
15-
"github.com/hashicorp/go-hclog"
16-
"golang.org/x/text/encoding/unicode"
17-
"golang.org/x/text/transform"
18-
)
19-
20-
const powershellVersionCheckCommand = "\"$($PSVersionTable.PSVersion.Major).$($PSVersionTable.PSVersion.Minor)\""
21-
22-
var utf8BOM = []byte{0xEF, 0xBB, 0xBF}
23-
24-
func executeUsingPowershell(ctx context.Context, message *Message, device agent.Device, logger hclog.Logger, usePwsh bool) []byte {
25-
// Parse the commands
26-
commandBytes, err := base64.StdEncoding.DecodeString(message.Commands)
27-
if err != nil {
28-
return errorResultBytes(err)
29-
}
30-
31-
// Decode using UTF16LE
32-
decoder := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM).NewDecoder()
33-
commands, _, err := transform.String(decoder, string(commandBytes))
34-
if err != nil {
35-
return errorResultBytes(err)
36-
}
37-
38-
// Run the command in the system using powershell
39-
shell := "powershell"
40-
if usePwsh {
41-
shell = "pwsh"
42-
}
43-
44-
if logger.IsDebug() {
45-
cmd := exec.CommandContext(ctx, shell, "-Command", powershellVersionCheckCommand)
46-
combinedOutputBytes, err := cmd.CombinedOutput()
47-
combinedOutput := string(combinedOutputBytes)
48-
if err != nil {
49-
logger.Error("Shell version check failed", "error", err, "combined_output", combinedOutput)
50-
}
51-
52-
version := strings.TrimSpace(combinedOutput)
53-
54-
logger.Debug("Shell version", "shell", shell, "version", version)
55-
logger.Debug("Commands to execute", "commands", commands)
56-
}
57-
58-
if logger.IsDebug() {
59-
cmd := exec.CommandContext(ctx, "whoami")
60-
combinedOutputBytes, err := cmd.CombinedOutput()
61-
combinedOutput := string(combinedOutputBytes)
62-
if err != nil {
63-
logger.Error("Whoami check failed", "error", err, "combined_output", combinedOutput)
64-
}
65-
66-
logger.Debug("Whomai", "user", combinedOutput)
67-
}
68-
69-
// Save commands to temporary file
70-
scriptsDir := agent.GetScriptsDirectory(device.RewstOrgId)
71-
err = utils.CreateFolderIfMissing(scriptsDir)
72-
if err != nil {
73-
return errorResultBytes(err)
74-
}
75-
76-
tempfile, err := os.CreateTemp(scriptsDir, "exec-*.ps1")
77-
if err != nil {
78-
return errorResultBytes(err)
79-
}
80-
81-
_, err = tempfile.Write(utf8BOM)
82-
if err != nil {
83-
logger.Error("Failed to write BOM", "error", err)
84-
return errorResultBytes(err)
85-
}
86-
87-
_, err = tempfile.WriteString(commands)
88-
if err != nil {
89-
logger.Error("Failed to write command file", "error", err)
90-
return errorResultBytes(err)
91-
}
92-
93-
logger.Info("Command saved to", "message_id", message.PostId, "path", tempfile.Name())
94-
95-
// Close the temporary file
96-
tempfile.Close()
97-
98-
var stdoutBuf, stderrBuf bytes.Buffer
99-
cmd := exec.CommandContext(ctx, shell, "-File", tempfile.Name())
100-
cmd.Stdout = &stdoutBuf
101-
cmd.Stderr = &stderrBuf
102-
cmd.Env = os.Environ()
103-
cmd.Env = append(cmd.Env, fmt.Sprintf("AGENT_SMITH_VERSION=%s", version.Version[1:]))
104-
105-
err = cmd.Run()
106-
if err != nil {
107-
logger.Error("Command failed", "error", err)
108-
logger.Debug("Command completed with outputs", "error", stderrBuf.String(), "info", stdoutBuf.String())
109-
return resultBytes(&result{Error: stderrBuf.String(), Output: stdoutBuf.String()})
110-
}
111-
112-
// Remove successfully executed temporary filename
113-
defer os.Remove(tempfile.Name())
114-
115-
logger.Info("Command completed", "message_id", message.PostId, "exit_code", cmd.ProcessState.ExitCode())
116-
logger.Debug("Command completed with outputs", "error", stderrBuf.String(), "info", stdoutBuf.String())
117-
118-
return resultBytes(&result{Error: stderrBuf.String(), Output: stdoutBuf.String()})
119-
}
1+
package interpreter
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/base64"
7+
"fmt"
8+
"os"
9+
"os/exec"
10+
"strings"
11+
12+
"github.com/RewstApp/agent-smith-go/internal/agent"
13+
"github.com/RewstApp/agent-smith-go/internal/utils"
14+
"github.com/RewstApp/agent-smith-go/internal/version"
15+
"github.com/hashicorp/go-hclog"
16+
"golang.org/x/text/encoding/unicode"
17+
"golang.org/x/text/transform"
18+
)
19+
20+
var utf8BOM = []byte{0xEF, 0xBB, 0xBF}
21+
22+
type baseExecutor struct {
23+
Shell string
24+
ShellVersionCheckCommand string
25+
WriteUtf8BOM bool
26+
BuildExecuteCommandArgs BuildExecuteCommandArgsFunc
27+
BuildExecuteFileArgs BuildExecuteFileArgsFunc
28+
}
29+
30+
func (e *baseExecutor) Execute(ctx context.Context, message *Message, device agent.Device, logger hclog.Logger, sys agent.SystemInfoProvider, domain agent.DomainInfoProvider) []byte {
31+
// Parse the commands
32+
commandBytes, err := base64.StdEncoding.DecodeString(message.Commands)
33+
if err != nil {
34+
return errorResultBytes(err)
35+
}
36+
37+
// Decode using UTF16LE
38+
decoder := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM).NewDecoder()
39+
commands, _, err := transform.String(decoder, string(commandBytes))
40+
if err != nil {
41+
return errorResultBytes(err)
42+
}
43+
44+
// Run the command in the system using powershell
45+
if logger.IsDebug() {
46+
cmd := exec.CommandContext(ctx, e.Shell, e.BuildExecuteCommandArgs(e.ShellVersionCheckCommand)...)
47+
combinedOutputBytes, err := cmd.CombinedOutput()
48+
combinedOutput := string(combinedOutputBytes)
49+
if err != nil {
50+
logger.Error("Shell version check failed", "error", err, "combined_output", combinedOutput)
51+
}
52+
53+
version := strings.TrimSpace(combinedOutput)
54+
55+
logger.Debug("Shell version", "shell", e.Shell, "version", version)
56+
logger.Debug("Commands to execute", "commands", commands)
57+
}
58+
59+
if logger.IsDebug() {
60+
cmd := exec.CommandContext(ctx, e.Shell, e.BuildExecuteCommandArgs("whoami")...)
61+
combinedOutputBytes, err := cmd.CombinedOutput()
62+
combinedOutput := string(combinedOutputBytes)
63+
if err != nil {
64+
logger.Error("Whoami check failed", "error", err, "combined_output", combinedOutput)
65+
}
66+
67+
logger.Debug("Whomai", "user", combinedOutput)
68+
}
69+
70+
// Save commands to temporary file
71+
scriptsDir := agent.GetScriptsDirectory(device.RewstOrgId)
72+
err = utils.CreateFolderIfMissing(scriptsDir)
73+
if err != nil {
74+
return errorResultBytes(err)
75+
}
76+
77+
tempfile, err := os.CreateTemp(scriptsDir, "exec-*.ps1")
78+
if err != nil {
79+
return errorResultBytes(err)
80+
}
81+
82+
if e.WriteUtf8BOM {
83+
_, err = tempfile.Write(utf8BOM)
84+
if err != nil {
85+
logger.Error("Failed to write BOM", "error", err)
86+
return errorResultBytes(err)
87+
}
88+
}
89+
90+
_, err = tempfile.WriteString(commands)
91+
if err != nil {
92+
logger.Error("Failed to write command file", "error", err)
93+
return errorResultBytes(err)
94+
}
95+
96+
logger.Info("Command saved to", "message_id", message.PostId, "path", tempfile.Name())
97+
98+
// Close the temporary file
99+
tempfile.Close()
100+
101+
var stdoutBuf, stderrBuf bytes.Buffer
102+
cmd := exec.CommandContext(ctx, e.Shell, e.BuildExecuteFileArgs(tempfile.Name())...)
103+
cmd.Stdout = &stdoutBuf
104+
cmd.Stderr = &stderrBuf
105+
cmd.Env = os.Environ()
106+
cmd.Env = append(cmd.Env, fmt.Sprintf("AGENT_SMITH_VERSION=%s", version.Version[1:]))
107+
108+
err = cmd.Run()
109+
if err != nil {
110+
logger.Error("Command failed", "error", err)
111+
logger.Debug("Command completed with outputs", "error", stderrBuf.String(), "info", stdoutBuf.String())
112+
return resultBytes(&result{Error: stderrBuf.String(), Output: stdoutBuf.String()})
113+
}
114+
115+
// Remove successfully executed temporary filename
116+
defer os.Remove(tempfile.Name())
117+
118+
logger.Info("Command completed", "message_id", message.PostId, "exit_code", cmd.ProcessState.ExitCode())
119+
logger.Debug("Command completed with outputs", "error", stderrBuf.String(), "info", stdoutBuf.String())
120+
121+
return resultBytes(&result{Error: stderrBuf.String(), Output: stdoutBuf.String()})
122+
}
123+
124+
func NewBaseExecutor(shell string, shellVersionCheckCommand string, writeUtf8BOM bool, buildExecuteCommandArgs BuildExecuteCommandArgsFunc, buildExecuteFileArgs BuildExecuteFileArgsFunc) Executor {
125+
return &baseExecutor{
126+
Shell: shell,
127+
ShellVersionCheckCommand: shellVersionCheckCommand,
128+
WriteUtf8BOM: writeUtf8BOM,
129+
BuildExecuteCommandArgs: buildExecuteCommandArgs,
130+
BuildExecuteFileArgs: buildExecuteFileArgs,
131+
}
132+
}

0 commit comments

Comments
 (0)