diff --git a/packages/tiny-agents/src/cli.ts b/packages/tiny-agents/src/cli.ts index 2a670ee8d7..ba8d608bfb 100644 --- a/packages/tiny-agents/src/cli.ts +++ b/packages/tiny-agents/src/cli.ts @@ -1,11 +1,13 @@ #!/usr/bin/env node import { parseArgs } from "node:util"; +import * as readline from "node:readline/promises"; +import { stdin, stdout } from "node:process"; import { z } from "zod"; import { PROVIDERS_OR_POLICIES } from "@huggingface/inference"; import { Agent } from "@huggingface/mcp-client"; import { version as packageVersion } from "../package.json"; -import { ServerConfigSchema } from "./lib/types"; -import { debug, error } from "./lib/utils"; +import { InputConfigSchema, ServerConfigSchema } from "./lib/types"; +import { debug, error, ANSI } from "./lib/utils"; import { mainCliLoop } from "./lib/mainCliLoop"; import { loadConfigFrom } from "./lib/loadConfigFrom"; @@ -70,6 +72,7 @@ async function main() { provider: z.enum(PROVIDERS_OR_POLICIES).optional(), endpointUrl: z.string().optional(), apiKey: z.string().optional(), + inputs: z.array(InputConfigSchema).optional(), servers: z.array(ServerConfigSchema), }) .refine((data) => data.provider !== undefined || data.endpointUrl !== undefined, { @@ -85,6 +88,111 @@ async function main() { process.exit(1); } + // Handle inputs (i.e. env variables injection) + if (config.inputs && config.inputs.length > 0) { + const rl = readline.createInterface({ input: stdin, output: stdout }); + + stdout.write(ANSI.BLUE); + stdout.write("Some initial inputs are required by the agent. "); + stdout.write("Please provide a value or leave empty to load from env."); + stdout.write(ANSI.RESET); + stdout.write("\n"); + + for (const inputItem of config.inputs) { + const inputId = inputItem.id; + const description = inputItem.description; + const envSpecialValue = `\${input:${inputId}}`; // Special value to indicate env variable injection + + // Check env variables that will use this input + const inputVars = new Set(); + for (const server of config.servers) { + if (server.type === "stdio" && server.config.env) { + for (const [key, value] of Object.entries(server.config.env)) { + if (value === envSpecialValue) { + inputVars.add(key); + } + } + } + if ((server.type === "http" || server.type === "sse") && server.config.options?.requestInit?.headers) { + for (const [key, value] of Object.entries(server.config.options.requestInit.headers)) { + if (value.includes(envSpecialValue)) { + inputVars.add(key); + } + } + } + } + + if (inputVars.size === 0) { + stdout.write(ANSI.YELLOW); + stdout.write(`Input ${inputId} defined in config but not used by any server.`); + stdout.write(ANSI.RESET); + stdout.write("\n"); + continue; + } + + // Prompt user for input + stdout.write(ANSI.BLUE); + stdout.write(` • ${inputId}`); + stdout.write(ANSI.RESET); + stdout.write(`: ${description}. (default: load from ${Array.from(inputVars).join(", ")}) `); + + const userInput = (await rl.question("")).trim(); + + // Inject user input (or env variable) into servers' env + for (const server of config.servers) { + if (server.type === "stdio" && server.config.env) { + for (const [key, value] of Object.entries(server.config.env)) { + if (value === envSpecialValue) { + if (userInput) { + server.config.env[key] = userInput; + } else { + const valueFromEnv = process.env[key] || ""; + server.config.env[key] = valueFromEnv; + if (valueFromEnv) { + stdout.write(ANSI.GREEN); + stdout.write(`Value successfully loaded from '${key}'`); + stdout.write(ANSI.RESET); + stdout.write("\n"); + } else { + stdout.write(ANSI.YELLOW); + stdout.write(`No value found for '${key}' in environment variables. Continuing.`); + stdout.write(ANSI.RESET); + stdout.write("\n"); + } + } + } + } + } + if ((server.type === "http" || server.type === "sse") && server.config.options?.requestInit?.headers) { + for (const [key, value] of Object.entries(server.config.options.requestInit.headers)) { + if (value.includes(envSpecialValue)) { + if (userInput) { + server.config.options.requestInit.headers[key] = value.replace(envSpecialValue, userInput); + } else { + const valueFromEnv = process.env[key] || ""; + server.config.options.requestInit.headers[key] = value.replace(envSpecialValue, valueFromEnv); + if (valueFromEnv) { + stdout.write(ANSI.GREEN); + stdout.write(`Value successfully loaded from '${key}'`); + stdout.write(ANSI.RESET); + stdout.write("\n"); + } else { + stdout.write(ANSI.YELLOW); + stdout.write(`No value found for '${key}' in environment variables. Continuing.`); + stdout.write(ANSI.RESET); + stdout.write("\n"); + } + } + } + } + } + } + } + + stdout.write("\n"); + rl.close(); + } + const agent = new Agent( config.endpointUrl ? { diff --git a/packages/tiny-agents/src/lib/types.ts b/packages/tiny-agents/src/lib/types.ts index a2c6bf3d35..14f0b479d8 100644 --- a/packages/tiny-agents/src/lib/types.ts +++ b/packages/tiny-agents/src/lib/types.ts @@ -21,6 +21,14 @@ export const ServerConfigSchema = z.discriminatedUnion("type", [ url: z.union([z.string(), z.string().url()]), options: z .object({ + /** + * Customizes HTTP requests to the server. + */ + requestInit: z + .object({ + headers: z.record(z.string()).optional(), + }) + .optional(), /** * Session ID for the connection. This is used to identify the session on the server. * When not provided and connecting to a server that supports session IDs, the server will generate a new session ID. @@ -34,9 +42,29 @@ export const ServerConfigSchema = z.discriminatedUnion("type", [ type: z.literal("sse"), config: z.object({ url: z.union([z.string(), z.string().url()]), - options: z.object({}).optional(), + options: z + .object({ + /** + * Customizes HTTP requests to the server. + */ + requestInit: z + .object({ + headers: z.record(z.string()).optional(), + }) + .optional(), + }) + .optional(), }), }), ]); export type ServerConfig = z.infer; + +export const InputConfigSchema = z.object({ + id: z.string(), + description: z.string(), + type: z.string().optional(), + password: z.boolean().optional(), +}); + +export type InputConfig = z.infer; diff --git a/packages/tiny-agents/src/lib/utils.ts b/packages/tiny-agents/src/lib/utils.ts index 069aa75717..93aa2cf7f4 100644 --- a/packages/tiny-agents/src/lib/utils.ts +++ b/packages/tiny-agents/src/lib/utils.ts @@ -16,4 +16,5 @@ export const ANSI = { GREEN: "\x1b[32m", RED: "\x1b[31m", RESET: "\x1b[0m", + YELLOW: "\x1b[33m", };