|
11 | 11 |
|
12 | 12 | from .bots import load_bot |
13 | 13 | from .common import PROGRAM, Config, UnreachableError, ensure_state_home |
14 | | -from .drafter import Drafter |
| 14 | +from .drafter import Accept, Drafter |
15 | 15 | from .editor import open_editor |
16 | 16 | from .prompt import Template, TemplatedPrompt, find_template, templates_table |
17 | 17 | from .store import Store |
@@ -64,6 +64,12 @@ def callback( |
64 | 64 | add_command("show-prompts", short="P", help="show prompt history") |
65 | 65 | add_command("show-templates", short="T", help="show template information") |
66 | 66 |
|
| 67 | + parser.add_option( |
| 68 | + "-a", |
| 69 | + "--accept", |
| 70 | + help="apply generated changes", |
| 71 | + action="count", |
| 72 | + ) |
67 | 73 | parser.add_option( |
68 | 74 | "-b", |
69 | 75 | "--bot", |
@@ -171,67 +177,69 @@ def main() -> None: # noqa: PLR0912 PLR0915 |
171 | 177 | logging.basicConfig(level=config.log_level, filename=str(log_path)) |
172 | 178 |
|
173 | 179 | drafter = Drafter.create(store=Store.persistent(), path=opts.root) |
174 | | - command = getattr(opts, "command", "generate") |
175 | | - if command == "generate": |
176 | | - bot_config = None |
177 | | - if opts.bot: |
178 | | - bot_configs = [c for c in config.bots if c.name == opts.bot] |
179 | | - if len(bot_configs) != 1: |
180 | | - raise ValueError(f"Found {len(bot_configs)} matching bots") |
181 | | - bot_config = bot_configs[0] |
182 | | - elif config.bots: |
183 | | - bot_config = config.bots[0] |
184 | | - bot = load_bot(bot_config) |
185 | | - |
186 | | - prompt: str | TemplatedPrompt |
187 | | - editable = opts.edit |
188 | | - if args: |
189 | | - prompt = TemplatedPrompt.parse(args[0], *args[1:]) |
190 | | - elif opts.edit: |
191 | | - editable = False |
192 | | - prompt = edit( |
193 | | - text=drafter.latest_draft_prompt() or _PROMPT_PLACEHOLDER |
| 180 | + match getattr(opts, "command", "generate"): |
| 181 | + case "generate": |
| 182 | + bot_config = None |
| 183 | + if opts.bot: |
| 184 | + bot_configs = [c for c in config.bots if c.name == opts.bot] |
| 185 | + if len(bot_configs) != 1: |
| 186 | + raise ValueError(f"Found {len(bot_configs)} matching bots") |
| 187 | + bot_config = bot_configs[0] |
| 188 | + elif config.bots: |
| 189 | + bot_config = config.bots[0] |
| 190 | + bot = load_bot(bot_config) |
| 191 | + |
| 192 | + prompt: str | TemplatedPrompt |
| 193 | + editable = opts.edit |
| 194 | + if args: |
| 195 | + prompt = TemplatedPrompt.parse(args[0], *args[1:]) |
| 196 | + elif opts.edit: |
| 197 | + editable = False |
| 198 | + prompt = edit( |
| 199 | + text=drafter.latest_draft_prompt() or _PROMPT_PLACEHOLDER |
| 200 | + ) |
| 201 | + else: |
| 202 | + prompt = sys.stdin.read() |
| 203 | + |
| 204 | + accept = Accept(opts.accept or 0) |
| 205 | + name = drafter.generate_draft( |
| 206 | + prompt, |
| 207 | + bot, |
| 208 | + accept=accept, |
| 209 | + bot_name=opts.bot, |
| 210 | + prompt_transform=open_editor if editable else None, |
| 211 | + tool_visitors=[ToolPrinter()], |
| 212 | + reset=config.auto_reset if opts.reset is None else opts.reset, |
| 213 | + sync=opts.sync, |
194 | 214 | ) |
195 | | - else: |
196 | | - prompt = sys.stdin.read() |
197 | | - |
198 | | - name = drafter.generate_draft( |
199 | | - prompt, |
200 | | - bot, |
201 | | - bot_name=opts.bot, |
202 | | - prompt_transform=open_editor if editable else None, |
203 | | - tool_visitors=[ToolPrinter()], |
204 | | - reset=config.auto_reset if opts.reset is None else opts.reset, |
205 | | - sync=opts.sync, |
206 | | - ) |
207 | | - print(f"Refined {name}.") |
208 | | - elif command == "finalize": |
209 | | - name = drafter.finalize_draft(delete=opts.delete) |
210 | | - print(f"Finalized {name}.") |
211 | | - elif command == "show-drafts": |
212 | | - table = drafter.history_table(args[0] if args else None) |
213 | | - if table: |
214 | | - print(table.to_json() if opts.json else table) |
215 | | - elif command == "show-prompts": |
216 | | - raise NotImplementedError() # TODO: Implement |
217 | | - elif command == "show-templates": |
218 | | - if args: |
219 | | - name = args[0] |
220 | | - tpl = find_template(name) |
221 | | - if opts.edit: |
222 | | - if tpl: |
223 | | - edit(path=tpl.local_path(), text=tpl.source) |
| 215 | + print(f"Generated change in {name}.") |
| 216 | + case "finalize": |
| 217 | + name = drafter.finalize_draft(delete=opts.delete) |
| 218 | + print(f"Finalized {name}.") |
| 219 | + case "show-drafts": |
| 220 | + table = drafter.history_table(args[0] if args else None) |
| 221 | + if table: |
| 222 | + print(table.to_json() if opts.json else table) |
| 223 | + case "show-prompts": |
| 224 | + raise NotImplementedError() # TODO: Implement |
| 225 | + case "show-templates": |
| 226 | + if args: |
| 227 | + name = args[0] |
| 228 | + tpl = find_template(name) |
| 229 | + if opts.edit: |
| 230 | + if tpl: |
| 231 | + edit(path=tpl.local_path(), text=tpl.source) |
| 232 | + else: |
| 233 | + edit(path=Template.local_path_for(name)) |
224 | 234 | else: |
225 | | - edit(path=Template.local_path_for(name)) |
| 235 | + if not tpl: |
| 236 | + raise ValueError(f"No template named {name!r}") |
| 237 | + print(tpl.source) |
226 | 238 | else: |
227 | | - if not tpl: |
228 | | - raise ValueError(f"No template named {name!r}") |
229 | | - print(tpl.source) |
230 | | - else: |
231 | | - table = templates_table() |
232 | | - print(table.to_json() if opts.json else table) |
233 | | - else: |
234 | | - raise UnreachableError() |
| 239 | + table = templates_table() |
| 240 | + print(table.to_json() if opts.json else table) |
| 241 | + case _: |
| 242 | + raise UnreachableError() |
235 | 243 |
|
236 | 244 |
|
237 | 245 | if __name__ == "__main__": |
|
0 commit comments