@@ -385,6 +385,47 @@ def get_continuation(width, line_number, is_soft_wrap):
385385 def show_suggestion_tip ():
386386 return iterations < 2
387387
388+ def output_res (res , start ):
389+ result_count = 0
390+ mutating = False
391+ for title , cur , headers , status in res :
392+ logger .debug ("headers: %r" , headers )
393+ logger .debug ("rows: %r" , cur )
394+ logger .debug ("status: %r" , status )
395+ threshold = 1000
396+ if is_select (status ) and cur and cur .rowcount > threshold :
397+ self .echo (
398+ "The result set has more than {} rows." .format (threshold ),
399+ fg = "red" ,
400+ )
401+ if not confirm ("Do you want to continue?" ):
402+ self .echo ("Aborted!" , err = True , fg = "red" )
403+ break
404+
405+ if self .auto_vertical_output :
406+ max_width = self .prompt_app .output .get_size ().columns
407+ else :
408+ max_width = None
409+
410+ formatted = self .format_output (title , cur , headers , special .is_expanded_output (), max_width )
411+
412+ t = time () - start
413+ try :
414+ if result_count > 0 :
415+ self .echo ("" )
416+ try :
417+ self .output (formatted , status )
418+ except KeyboardInterrupt :
419+ pass
420+ self .echo ("Time: %0.03fs" % t )
421+ except KeyboardInterrupt :
422+ pass
423+
424+ start = time ()
425+ result_count += 1
426+ mutating = mutating or is_mutating (status )
427+ return mutating
428+
388429 def one_iteration (text = None ):
389430 if text is None :
390431 try :
@@ -402,6 +443,24 @@ def one_iteration(text=None):
402443 self .echo (str (e ), err = True , fg = "red" )
403444 return
404445
446+ if special .is_llm_command (text ):
447+ try :
448+ start = time ()
449+ cur = self .sqlexecute .conn .cursor ()
450+ context , sql = special .handle_llm (text , cur )
451+ if context :
452+ click .echo (context )
453+ text = self .prompt_app .prompt (default = sql )
454+ except KeyboardInterrupt :
455+ return
456+ except special .FinishIteration as e :
457+ return output_res (e .results , start ) if e .results else None
458+ except RuntimeError as e :
459+ logger .error ("sql: %r, error: %r" , text , e )
460+ logger .error ("traceback: %r" , traceback .format_exc ())
461+ self .echo (str (e ), err = True , fg = "red" )
462+ return
463+
405464 if not text .strip ():
406465 return
407466
@@ -415,9 +474,6 @@ def one_iteration(text=None):
415474 self .echo ("Wise choice!" )
416475 return
417476
418- # Keep track of whether or not the query is mutating. In case
419- # of a multi-statement query, the overall query is considered
420- # mutating if any one of the component statements is mutating
421477 mutating = False
422478
423479 try :
@@ -434,44 +490,11 @@ def one_iteration(text=None):
434490 res = sqlexecute .run (text )
435491 self .formatter .query = text
436492 successful = True
437- result_count = 0
438- for title , cur , headers , status in res :
439- logger .debug ("headers: %r" , headers )
440- logger .debug ("rows: %r" , cur )
441- logger .debug ("status: %r" , status )
442- threshold = 1000
443- if is_select (status ) and cur and cur .rowcount > threshold :
444- self .echo (
445- "The result set has more than {} rows." .format (threshold ),
446- fg = "red" ,
447- )
448- if not confirm ("Do you want to continue?" ):
449- self .echo ("Aborted!" , err = True , fg = "red" )
450- break
451-
452- if self .auto_vertical_output :
453- max_width = self .prompt_app .output .get_size ().columns
454- else :
455- max_width = None
456-
457- formatted = self .format_output (title , cur , headers , special .is_expanded_output (), max_width )
458-
459- t = time () - start
460- try :
461- if result_count > 0 :
462- self .echo ("" )
463- try :
464- self .output (formatted , status )
465- except KeyboardInterrupt :
466- pass
467- self .echo ("Time: %0.03fs" % t )
468- except KeyboardInterrupt :
469- pass
470-
471- start = time ()
472- result_count += 1
473- mutating = mutating or is_mutating (status )
474493 special .unset_once_if_written ()
494+ # Keep track of whether or not the query is mutating. In case
495+ # of a multi-statement query, the overall query is considered
496+ # mutating if any one of the component statements is mutating
497+ mutating = output_res (res , start )
475498 special .unset_pipe_once_if_written ()
476499 except EOFError as e :
477500 raise e
0 commit comments