Skip to content

Commit 0eaf6f8

Browse files
committed
feat: support overriding workingDirectory bash param
1 parent fe02380 commit 0eaf6f8

File tree

1 file changed

+24
-12
lines changed
  • src/main/kotlin/ee/carlrobert/codegpt/agent/tools

1 file changed

+24
-12
lines changed

src/main/kotlin/ee/carlrobert/codegpt/agent/tools/BashTool.kt

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import kotlinx.coroutines.channels.ReceiveChannel
2121
import kotlinx.coroutines.selects.select
2222
import kotlinx.serialization.SerialName
2323
import kotlinx.serialization.Serializable
24+
import java.io.File
2425
import java.io.IOException
2526
import java.nio.file.Paths
2627
import java.util.*
@@ -34,7 +35,7 @@ fun interface BashCommandConfirmationHandler {
3435
class BashTool(
3536
private val project: Project,
3637
private val confirmationHandler: BashCommandConfirmationHandler,
37-
private val sessionId: String = "global",
38+
private val sessionId: String,
3839
private val hookManager: HookManager
3940
) : BaseTool<BashTool.Args, BashTool.Result>(
4041
workingDirectory = project.basePath ?: "",
@@ -72,7 +73,6 @@ class BashTool(
7273
- It is very helpful if you write a clear, concise description of what this command does in 5-10 words.
7374
- If the output exceeds 30000 characters, output will be truncated before being returned to you.
7475
- You can use the `run_in_background` parameter to run the command in the background, which allows you to continue working while the command runs. You can monitor the output using the Bash tool as it becomes available. You do not need to use '&' at the end of the command when using this parameter.
75-
7676
- Avoid using Bash with the `find`, `grep`, `cat`, `head`, `tail`, `sed`, `awk`, or `echo` commands, unless explicitly instructed or when these commands are truly necessary for the task. Instead, always prefer using the dedicated tools for these commands:
7777
- Content search: Use Grep (NOT grep or rg)
7878
- Read files: Use Read (NOT cat/head/tail)
@@ -91,6 +91,7 @@ class BashTool(
9191
<bad-example>
9292
find /src -type f -name "*.kt"
9393
</bad-example>
94+
- Try to maintain your current working directory throughout the session by using absolute paths and avoiding usage of cd. You may use cd if the User explicitly requests it. pytest /foo/bar/tests cd /foo/bar && pytest tests
9495
9596
# Committing changes with git
9697
@@ -166,7 +167,6 @@ class BashTool(
166167
</example>
167168
168169
Important:
169-
- DO NOT use the TodoWrite or Task tools
170170
- Return the PR URL when you're done, so the user can see it
171171
172172
# Other common operations
@@ -183,6 +183,8 @@ class BashTool(
183183
"The command to execute"
184184
)
185185
val command: String,
186+
@property:LLMDescription("Optional working directory. If not specified, defaults to the project base directory.")
187+
val workingDirectory: String? = null,
186188
@property:LLMDescription(
187189
"Optional timeout in milliseconds (max 600000)"
188190
)
@@ -221,6 +223,7 @@ class BashTool(
221223
}
222224

223225
override suspend fun doExecute(args: Args): Result {
226+
val workingDirectory = args.workingDirectory ?: super.workingDirectory
224227
val toolId = ToolRunContext.getToolId(sessionId)
225228
if (project.service<ProxyAISettingsService>()
226229
.isToolInvocationDenied(this, args.command)
@@ -271,15 +274,15 @@ class BashTool(
271274
is ShellCommandConfirmation.Approved -> {
272275
try {
273276
if (args.runInBackground == true) {
274-
val bashId = executeBackground(args.command)
277+
val bashId = executeBackground(workingDirectory, args.command)
275278
Result(
276279
args.command,
277280
null,
278281
"Background process started with ID: $bashId",
279282
bashId
280283
)
281284
} else {
282-
val result = runForegroundWithStreaming(resolvedToolId, args)
285+
val result = runForegroundWithStreaming(resolvedToolId, args, workingDirectory)
283286

284287
val postPayload = mapOf(
285288
"command" to args.command,
@@ -341,7 +344,11 @@ class BashTool(
341344
)
342345
}
343346

344-
private suspend fun runForegroundWithStreaming(toolId: String, args: Args): Result {
347+
private suspend fun runForegroundWithStreaming(
348+
toolId: String,
349+
args: Args,
350+
workingDirectory: String
351+
): Result {
345352
val publisher = ApplicationManager.getApplication().messageBus
346353
.syncPublisher(AgentToolOutputNotifier.AGENT_TOOL_OUTPUT_TOPIC)
347354

@@ -353,7 +360,7 @@ class BashTool(
353360
val shellCommand = buildShellCommand(args.command)
354361
val process = ProcessBuilder(shellCommand)
355362
.apply {
356-
directory(java.io.File(workingDirectory))
363+
directory(File(workingDirectory))
357364
redirectErrorStream(false)
358365
}
359366
.start()
@@ -514,12 +521,12 @@ class BashTool(
514521
raw.truncateToolResult()
515522
}
516523

517-
private fun executeBackground(command: String): String {
524+
private fun executeBackground(workingDirectory: String, command: String): String {
518525
val bashId = UUID.randomUUID().toString()
519526
val shellCommand = buildShellCommand(command)
520527
val process = ProcessBuilder(shellCommand)
521528
.apply {
522-
directory(java.io.File(workingDirectory))
529+
directory(File(workingDirectory))
523530
redirectErrorStream(false)
524531
}
525532
.start()
@@ -531,9 +538,14 @@ class BashTool(
531538
private fun buildShellCommand(command: String): List<String> {
532539
val osName = System.getProperty("os.name").lowercase()
533540
return when {
534-
osName.contains("windows") -> listOf("cmd", "/c", command)
535-
osName.contains("linux") -> listOf("bash", "-c", command)
536-
else -> listOf("sh", "-c", command)
541+
osName.contains("win") -> {
542+
val systemRoot = System.getenv("SystemRoot")
543+
?: System.getenv("WINDIR")
544+
?: "C:\\Windows"
545+
listOf("$systemRoot\\System32\\cmd.exe", "/c", command)
546+
}
547+
548+
else -> listOf("bash", "-c", command)
537549
}
538550
}
539551

0 commit comments

Comments
 (0)