diff --git a/commands.go b/commands.go index 2e88ebb..4640392 100644 --- a/commands.go +++ b/commands.go @@ -3,6 +3,7 @@ package sqlbless import ( "context" "database/sql" + "errors" "fmt" "io" "strings" @@ -14,7 +15,7 @@ import ( "github.com/hymkor/sqlbless/internal/misc" ) -func doSelect(ctx context.Context, ss *session, query string, v *spread.Viewer) error { +func doSelect(ctx context.Context, ss *session, query string, v *spread.Viewer, pilot commandIn) error { var rows *sql.Rows var err error if ss.tx != nil { @@ -34,9 +35,15 @@ func doSelect(ctx context.Context, ss *session, query string, v *spread.Viewer) v = newViewer(ss) } if ss.automatic() { - v.Pilot = misc.CsviNoOperation{} + v.Pilot = &misc.CsviNoOperation{} + } else if a, ok := pilot.AutoPilotForCsvi(); ok { + v.Pilot = a } - return v.View(ctx, query, _rows, ss.termOut) + err = v.View(ctx, query, _rows, ss.termOut) + if errors.Is(err, io.EOF) { + return nil + } + return err } type canExec interface { @@ -144,7 +151,7 @@ func doDescTables(ctx context.Context, ss *session, commandIn commandIn) error { rc, err := handler(e) if err == nil && rc.Quit && name != "" { action = func() error { - return doDescColumns(ctx, ss, name) + return doDescColumns(ctx, ss, name, commandIn) } } return rc, err @@ -156,7 +163,7 @@ func doDescTables(ctx context.Context, ss *session, commandIn commandIn) error { if ss.Debug { fmt.Println(query) } - err := doSelect(ctx, ss, query, v) + err := doSelect(ctx, ss, query, v, commandIn) if err == nil && name != "" { fmt.Fprintln(ss.termErr) misc.Echo(ss.spool, name) @@ -165,7 +172,7 @@ func doDescTables(ctx context.Context, ss *session, commandIn commandIn) error { return err } -func doDescColumns(ctx context.Context, ss *session, table string) error { +func doDescColumns(ctx context.Context, ss *session, table string, commandIn commandIn) error { if ss.Dialect.SQLForColumns == "" { return fmt.Errorf("desc table: %w", ErrNotSupported) } @@ -173,7 +180,7 @@ func doDescColumns(ctx context.Context, ss *session, table string) error { if ss.Debug { fmt.Println(query) } - return doSelect(ctx, ss, query, newViewer(ss)) + return doSelect(ctx, ss, query, newViewer(ss), commandIn) } func doDesc(ctx context.Context, ss *session, table string, commandIn commandIn) error { @@ -181,5 +188,5 @@ func doDesc(ctx context.Context, ss *session, table string, commandIn commandIn) if table == "" { return doDescTables(ctx, ss, commandIn) } - return doDescColumns(ctx, ss, table) + return doDescColumns(ctx, ss, table, commandIn) } diff --git a/internal/misc/csvinop.go b/internal/misc/csvinop.go index dfb5a89..e7e8158 100644 --- a/internal/misc/csvinop.go +++ b/internal/misc/csvinop.go @@ -11,28 +11,34 @@ type GetKeyAndSize interface { Size() (int, int, error) } -type CsviNoOperation struct{} - -func (CsviNoOperation) Size() (int, int, error) { - return 80, 25, nil +type CsviNoOperation struct { + text []string } -func (CsviNoOperation) Calibrate() error { - return nil +func (*CsviNoOperation) Size() (int, int, error) { + return 80, 25, nil } -func (CsviNoOperation) GetKey() (string, error) { - return "q", nil +func (c *CsviNoOperation) GetKey() (string, error) { + if len(c.text) <= 0 { + c.text = []string{">", "q", "y", ""} + } + v := c.text[0] + if v == "" { + return "", io.EOF + } + c.text = c.text[1:] + return v, nil } -func (CsviNoOperation) ReadLine(io.Writer, string, string, candidate.Candidate) (string, error) { +func (*CsviNoOperation) ReadLine(io.Writer, string, string, candidate.Candidate) (string, error) { return "", nil } -func (CsviNoOperation) GetFilename(io.Writer, string, string) (string, error) { +func (*CsviNoOperation) GetFilename(io.Writer, string, string) (string, error) { return "", nil } -func (CsviNoOperation) Close() error { +func (*CsviNoOperation) Close() error { return nil } diff --git a/loop.go b/loop.go index 6d8524a..137afcd 100644 --- a/loop.go +++ b/loop.go @@ -98,7 +98,7 @@ func (ss *session) Loop(ctx context.Context, commandIn commandIn) error { } lines, err := commandIn.Read(ctx) if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { if ss.tx != nil && !commandIn.CanCloseInTransaction() { fmt.Fprintln(ss.termErr, ErrTransactionIsNotClosed.Error()) continue @@ -161,7 +161,7 @@ func (ss *session) Loop(ctx context.Context, commandIn commandIn) error { case "SELECT": misc.Echo(ss.spool, query) - err = doSelect(ctx, ss, query, nil) + err = doSelect(ctx, ss, query, nil, commandIn) case "ROLLBACK": misc.Echo(ss.spool, query) arg, _ = misc.CutField(arg) @@ -216,7 +216,7 @@ func (ss *session) Loop(ctx context.Context, commandIn commandIn) error { default: misc.Echo(ss.spool, query) if q := ss.Dialect.IsQuerySQL; q != nil && q(query) { - err = doSelect(ctx, ss, query, nil) + err = doSelect(ctx, ss, query, nil, commandIn) } else { if ss.tx == nil { _, err = ss.conn.ExecContext(ctx, query) diff --git a/release_note_en.md b/release_note_en.md index 2cadbdc..5b24e9c 100644 --- a/release_note_en.md +++ b/release_note_en.md @@ -5,6 +5,9 @@ - Support `SAVEPOINT` as a TCL command (#11) - Support `ROLLBACK TO` (or `ROLLBACK TRANSACTION`) as a TCL command (#11) - Require `;` after `ROLLBACK` to prevent accidental execution (#11) +- Fix: Correct handling of `io.EOF` during script execution (#12) +- Suppress output of empty lines and leading/trailing spaces in script output (#12) +- Fix: `CSVI` launched by `SELECT` in a script now terminates automatically with `>`, `q`, and `y` (#12) v0.25.0 ======= diff --git a/release_note_ja.md b/release_note_ja.md index df5a8a0..d3a4a0e 100644 --- a/release_note_ja.md +++ b/release_note_ja.md @@ -5,6 +5,9 @@ - `SAVEPOINT` を TCL コマンドとしてサポート (#11) - `ROLLBACK TO`(もしくは `ROLLBACK TRANSACTION`)を TCL コマンドとしてサポート (#11) - 誤操作による実行を防ぐため、`ROLLBACK` には `;` を必須とした (#11) +* スクリプト中の `SELECT` で起動した `CSVI` は、`>`, `q`, `y` の操作で自動的に終了するようにした。 (#12) +* スクリプト実行時の出力で、空行や行頭・行末の空白が出力されないようにした。 (#12) +* スクリプト実行時に `io.EOF` が誤ってエラーとして扱われていた問題を修正した。 (#12) v0.25.0 ======= diff --git a/script.go b/script.go index 38e75cd..dbd871d 100644 --- a/script.go +++ b/script.go @@ -3,6 +3,7 @@ package sqlbless import ( "bufio" "context" + "errors" "fmt" "io" "os" @@ -29,18 +30,25 @@ func (script *scriptIn) GetKey() (string, error) { } func (script *scriptIn) AutoPilotForCsvi() (csvi.Pilot, bool) { - return nil, false + return &misc.CsviNoOperation{}, true } func (script *scriptIn) Read(context.Context) ([]string, error) { + if script.br == nil { + return nil, io.EOF + } var buffer strings.Builder quoted := 0 for { ch, _, err := script.br.ReadRune() - if err != nil { + if errors.Is(err, io.EOF) { code := buffer.String() - fmt.Fprintln(script.echo, code) - return []string{code}, err + fmt.Fprintln(script.echo, strings.TrimSpace(code)) + script.br = nil + return []string{code}, nil + } + if err != nil { + return nil, err } if ch == '\r' { continue @@ -55,8 +63,7 @@ func (script *scriptIn) Read(context.Context) ([]string, error) { code := buffer.String() term := script.term if _, ok := misc.HasTerm(code, term); ok { - println(code) - fmt.Fprintln(script.echo, code) + fmt.Fprintln(script.echo, strings.TrimSpace(code)) return []string{code}, nil } }