Skip to content

Commit ca60678

Browse files
committed
refactor round 2
1 parent f8f4a68 commit ca60678

File tree

2 files changed

+152
-104
lines changed

2 files changed

+152
-104
lines changed

jailfile.go

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -50,39 +50,25 @@ func parseJailFile(conf Config, f io.Reader) (JailFile, error) {
5050

5151
var jf JailFile
5252
for i := 1; scanner.Scan(); i++ {
53-
line := scanner.Text()
54-
line = strings.TrimSpace(line)
55-
if line == "" || strings.HasPrefix(line, "#") {
56-
continue
57-
}
58-
59-
if !strings.HasPrefix(line, "-") && !strings.HasPrefix(line, "+") {
60-
return jf, NewJailFileParserErr(conf, i, line, rsnMissingPlusOrMinus)
61-
}
53+
originalLine := scanner.Text()
54+
trimmedLine := strings.TrimSpace(originalLine)
6255

63-
rule := strings.TrimSpace(line[1:])
64-
if rule == "" {
65-
return jf, NewJailFileParserErr(conf, i, line, rsnNoMatcher)
56+
if trimmedLine == "" || strings.HasPrefix(trimmedLine, "#") {
57+
continue // Skip empty lines and comments
6658
}
6759

68-
var err error
69-
var m Matcher
70-
if strings.HasPrefix(rule, "'") {
71-
m = NewLiteralMatcher(newMatcher(line, conf.JailFile, i), strings.TrimPrefix(rule, "'"))
72-
} else if strings.HasPrefix(rule, "r'") {
73-
m, err = NewRegexMatcher(newMatcher(line, conf.JailFile, i), strings.TrimPrefix(rule, "r'"))
74-
if err != nil {
75-
return jf, NewJailFileParserErr(conf, i, line, err.Error())
76-
}
77-
} else {
78-
m = NewCmdMatcher(newMatcher(line, conf.JailFile, i), rule, conf.ShellCmd)
60+
matcher, ruleType, err := parseRuleLine(trimmedLine, i, conf)
61+
if err != nil {
62+
// Pass originalLine for more accurate error reporting if needed,
63+
// though NewJailFileParserErr uses the (trimmed) line it receives.
64+
return JailFile{}, err
7965
}
8066

81-
switch line[0] {
67+
switch ruleType {
8268
case '+':
83-
jf.Allow = append(jf.Allow, m)
69+
jf.Allow = append(jf.Allow, matcher)
8470
case '-':
85-
jf.Deny = append(jf.Deny, m)
71+
jf.Deny = append(jf.Deny, matcher)
8672
}
8773
}
8874

@@ -97,6 +83,38 @@ func parseJailFile(conf Config, f io.Reader) (JailFile, error) {
9783
return jf, nil
9884
}
9985

86+
// parseRuleLine processes a single non-empty, non-comment line from the jail file.
87+
// It returns the parsed Matcher, the rule type ('+' or '-'), and any error.
88+
func parseRuleLine(line string, lineNumber int, conf Config) (Matcher, byte, error) {
89+
if !strings.HasPrefix(line, "-") && !strings.HasPrefix(line, "+") {
90+
return nil, 0, NewJailFileParserErr(conf, lineNumber, line, rsnMissingPlusOrMinus)
91+
}
92+
93+
ruleType := line[0]
94+
ruleDefinition := strings.TrimSpace(line[1:])
95+
96+
if ruleDefinition == "" {
97+
return nil, 0, NewJailFileParserErr(conf, lineNumber, line, rsnNoMatcher)
98+
}
99+
100+
var m Matcher
101+
var err error
102+
baseMatcher := newMatcher(line, conf.JailFile, lineNumber) // Pass the full original line for Raw()
103+
104+
if strings.HasPrefix(ruleDefinition, "'") {
105+
m = NewLiteralMatcher(baseMatcher, strings.TrimPrefix(ruleDefinition, "'"))
106+
} else if strings.HasPrefix(ruleDefinition, "r'") {
107+
m, err = NewRegexMatcher(baseMatcher, strings.TrimPrefix(ruleDefinition, "r'"))
108+
if err != nil {
109+
return nil, 0, NewJailFileParserErr(conf, lineNumber, line, err.Error())
110+
}
111+
} else {
112+
m = NewCmdMatcher(baseMatcher, ruleDefinition, conf.ShellCmd)
113+
}
114+
115+
return m, ruleType, nil
116+
}
117+
100118
func errIsAny(target error, errs ...error) bool {
101119
for _, err := range errs {
102120
if errors.Is(target, err) {

main.go

Lines changed: 108 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -74,51 +74,63 @@ func runShell(conf Config, jailFile JailFile) int {
7474
scanner := bufio.NewScanner(os.Stdin)
7575
isRecordMode := conf.RecordFile != ""
7676

77-
fmt.Print("cmdjail> ")
78-
for scanner.Scan() {
77+
for {
78+
fmt.Print("cmdjail> ")
79+
if !scanner.Scan() {
80+
break // End of input (Ctrl+D) or scanner error
81+
}
82+
7983
line := strings.TrimSpace(scanner.Text())
8084
if line == "" {
81-
fmt.Print("cmdjail> ")
8285
continue
8386
}
8487

85-
if err := checkCmdSafety(line, conf.Log); err != nil {
86-
printLogErr(os.Stderr, "%s", err.Error())
87-
fmt.Print("cmdjail> ")
88-
continue
88+
shouldExit, exitCode := processShellCommand(line, conf, jailFile, isRecordMode)
89+
if shouldExit {
90+
return exitCode
8991
}
92+
}
9093

91-
if isRecordMode {
92-
if err := appendRuleToFile(conf.RecordFile, line); err != nil {
93-
printLogErr(os.Stderr, "appending to record file %s: %s", conf.RecordFile, err.Error())
94-
} else {
95-
printLogDebug(os.Stdout, "appended rule to %s: + '%s'", conf.RecordFile, line)
96-
}
94+
if err := scanner.Err(); err != nil {
95+
printLogErr(os.Stderr, "reading from stdin: %v", err)
96+
return 1
97+
}
9798

98-
if line == "exit" || line == "quit" {
99-
return 0
100-
} else {
101-
runCmd(conf.ShellCmd, line)
102-
}
99+
fmt.Println() // Print a newline on exit (e.g., Ctrl+D)
100+
return 0
101+
}
103102

104-
} else {
105-
cmdWasAllowed, _ := evaluateAndRun(line, jailFile, conf.ShellCmd)
103+
// processShellCommand handles a single line of input from the interactive shell.
104+
// It returns whether the shell should exit and the corresponding exit code.
105+
func processShellCommand(line string, conf Config, jailFile JailFile, isRecordMode bool) (shouldExit bool, exitCode int) {
106+
if err := checkCmdSafety(line, conf.Log); err != nil {
107+
printLogErr(os.Stderr, "%s", err.Error())
108+
return false, 0 // Continue shell, error already printed
109+
}
106110

107-
if cmdWasAllowed && (line == "exit" || line == "quit") {
108-
return 0
109-
}
111+
isExitCmd := (line == "exit" || line == "quit")
112+
113+
if isRecordMode {
114+
if err := appendRuleToFile(conf.RecordFile, line); err != nil {
115+
printLogErr(os.Stderr, "appending to record file %s: %s", conf.RecordFile, err.Error())
116+
} else {
117+
printLogDebug(os.Stdout, "appended rule to %s: + '%s'", conf.RecordFile, line)
110118
}
111119

112-
fmt.Print("cmdjail> ")
120+
if isExitCmd {
121+
return true, 0 // Exit shell normally
122+
}
123+
runCmd(conf.ShellCmd, line) // Run the command
124+
return false, 0 // Continue shell
113125
}
114126

115-
if err := scanner.Err(); err != nil {
116-
printLogErr(os.Stderr, "reading from stdin: %v", err)
117-
return 1
127+
// Not record mode
128+
cmdWasAllowed, _ := evaluateAndRun(line, jailFile, conf.ShellCmd)
129+
if cmdWasAllowed && isExitCmd {
130+
return true, 0 // Exit shell normally
118131
}
119132

120-
fmt.Println() // Print a newline on exit (e.g., Ctrl+D)
121-
return 0
133+
return false, 0 // Continue shell
122134
}
123135

124136
func evaluateAndRun(intentCmd string, jailFile JailFile, shellCmd []string) (bool, int) {
@@ -142,9 +154,36 @@ func evaluateAndRun(intentCmd string, jailFile JailFile, shellCmd []string) (boo
142154
func runCheckMode(conf Config, jailFile JailFile) int {
143155
printMsg(os.Stdout, "Jail file '%s' syntax is valid.", conf.JailFile)
144156

145-
var commands []string
146-
var err error
147-
source := "command line"
157+
commands, source, err := loadCommandsForCheckMode(conf)
158+
if err != nil {
159+
// Error already logged by loadCommandsForCheckMode if it's critical
160+
return 1
161+
}
162+
163+
if len(commands) == 0 {
164+
printMsg(os.Stdout, "No commands provided to check. Exiting.")
165+
return 0
166+
}
167+
168+
printMsg(os.Stdout, "\nTesting commands from %s...", source)
169+
blockedCount := 0
170+
for _, cmd := range commands {
171+
result := evaluateCmd(cmd, jailFile)
172+
printCheckCommandResult(result)
173+
if !result.Allowed {
174+
blockedCount++
175+
}
176+
}
177+
178+
printMsg(os.Stdout, "\nCheck complete. %d/%d commands would be blocked.", blockedCount, len(commands))
179+
if blockedCount > 0 {
180+
return 1
181+
}
182+
return 0
183+
}
184+
185+
func loadCommandsForCheckMode(conf Config) (commands []string, source string, err error) {
186+
source = "command line" // Default source
148187

149188
if conf.CheckIntentCmdsFile != "" {
150189
var r io.Reader
@@ -157,50 +196,32 @@ func runCheckMode(conf Config, jailFile JailFile) int {
157196
file, fileErr := os.Open(conf.CheckIntentCmdsFile)
158197
if fileErr != nil {
159198
printLogErr(os.Stderr, "reading test file %s: %v", conf.CheckIntentCmdsFile, fileErr)
160-
return 1
199+
return nil, source, fileErr
161200
}
162201
defer file.Close()
163202
r = file
164203
}
165204
commands, err = readLines(r)
166205
if err != nil {
167206
printLogErr(os.Stderr, "reading test commands from %s: %v", source, err)
168-
return 1
207+
return nil, source, err
169208
}
170209
} else if conf.IntentCmd != "" {
171210
commands = []string{conf.IntentCmd}
172211
}
212+
return commands, source, nil
213+
}
173214

174-
if len(commands) == 0 {
175-
printMsg(os.Stdout, "No commands provided to check. Exiting.")
176-
return 0
177-
}
178-
179-
printMsg(os.Stdout, "\nTesting commands from %s...", source)
180-
blockedCount := 0
181-
for _, cmd := range commands {
182-
result := evaluateCmd(cmd, jailFile)
183-
if result.Allowed {
184-
printMsg(os.Stdout, "\n[ALLOWED] '%s'", cmd)
185-
printMsg(os.Stdout, " Reason: %s", result.Reason)
186-
if result.Matcher != nil {
187-
printMsg(os.Stdout, " Matcher: %s", result.Matcher.Raw())
188-
}
189-
} else {
190-
blockedCount++
191-
printMsg(os.Stdout, "\n[BLOCKED] '%s'", cmd)
192-
printMsg(os.Stdout, " Reason: %s", result.Reason)
193-
if result.Matcher != nil {
194-
printMsg(os.Stdout, " Matcher: %s", result.Matcher.Raw())
195-
}
196-
}
215+
func printCheckCommandResult(result CheckResult) {
216+
if result.Allowed {
217+
printMsg(os.Stdout, "\n[ALLOWED] '%s'", result.Cmd)
218+
} else {
219+
printMsg(os.Stdout, "\n[BLOCKED] '%s'", result.Cmd)
197220
}
198-
199-
printMsg(os.Stdout, "\nCheck complete. %d/%d commands would be blocked.", blockedCount, len(commands))
200-
if blockedCount > 0 {
201-
return 1
221+
printMsg(os.Stdout, " Reason: %s", result.Reason)
222+
if result.Matcher != nil {
223+
printMsg(os.Stdout, " Matcher: %s", result.Matcher.Raw())
202224
}
203-
return 0
204225
}
205226

206227
// readLines reads from a reader and returns its lines as a slice of strings.
@@ -229,24 +250,33 @@ func getConfig() Config {
229250
return conf
230251
}
231252

232-
func getJailFile(conf Config) JailFile {
233-
var jailFileReader io.Reader
234-
var err error
235-
236-
if conf.JailFile == "-" && !conf.Shell {
253+
func openJailFileReader(conf Config) (io.Reader, error) {
254+
if conf.JailFile == "-" && !conf.Shell { // In shell mode, stdin is for commands, not jailfile
237255
printLogDebug(os.Stdout, "reading jail file from: <stdin>")
238-
jailFileReader = os.Stdin
239-
} else {
240-
printLogDebug(os.Stdout, "reading jail file from: %s", conf.JailFile)
241-
jailFileReader, err = os.Open(conf.JailFile)
242-
if err != nil {
243-
if errors.Is(err, fs.ErrNotExist) {
244-
printLogErr(os.Stderr, "finding jail file: %s", conf.JailFile)
245-
} else {
246-
printLogErr(os.Stderr, "opening jail file: %s: %s", conf.JailFile, err.Error())
247-
}
248-
os.Exit(1)
256+
return os.Stdin, nil
257+
}
258+
259+
printLogDebug(os.Stdout, "reading jail file from: %s", conf.JailFile)
260+
file, err := os.Open(conf.JailFile)
261+
if err != nil {
262+
if errors.Is(err, fs.ErrNotExist) {
263+
printLogErr(os.Stderr, "finding jail file: %s", conf.JailFile)
264+
} else {
265+
printLogErr(os.Stderr, "opening jail file: %s: %s", conf.JailFile, err.Error())
249266
}
267+
return nil, err // Return error to be handled by caller
268+
}
269+
return file, nil
270+
}
271+
272+
func getJailFile(conf Config) JailFile {
273+
jailFileReader, err := openJailFileReader(conf)
274+
if err != nil {
275+
os.Exit(1) // Exit if file couldn't be opened
276+
}
277+
// If the reader is a file, ensure it's closed.
278+
if closer, ok := jailFileReader.(io.Closer); ok && jailFileReader != os.Stdin {
279+
defer closer.Close()
250280
}
251281

252282
jailFile, err := parseJailFile(conf, jailFileReader)

0 commit comments

Comments
 (0)