diff --git a/client.go b/client.go index 65d84758..11c36a8b 100644 --- a/client.go +++ b/client.go @@ -892,6 +892,19 @@ func (cl *Client) hasExtension(ext *sshfx.ExtensionPair) bool { return cl.exts[ext.Name] == ext.Data } +// StatVFS retrieves VFS statistics from a remote host. +// +// It implements the statvfs@openssh.com SSH_FXP_EXTENDED feature from +// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL +func (cl *Client) StatVFS(path string) (*openssh.StatVFSExtendedReplyPacket, error) { + resp, err := getPacket[*openssh.StatVFSExtendedReplyPacket](context.Background(), nil, cl, + &openssh.StatVFSExtendedPacket{ + Path: path, + }, + ) + return valOrPathError("statvfs", path, resp, err) +} + // Link creates newname as a hard link to oldname file. // // If the server did not announce support for the "hardlink@openssh.com" extension, diff --git a/localfs/localfs_integration_test.go b/localfs/localfs_integration_test.go index 77538f32..38b7cfc8 100644 --- a/localfs/localfs_integration_test.go +++ b/localfs/localfs_integration_test.go @@ -751,6 +751,32 @@ func TestReadFrom(t *testing.T) { } } +func TestStatVFS(t *testing.T) { + if !*testServerImpl { + t.Skip("not testing against localfs server implementation") + } + + if _, ok := any(handler).(sftp.StatVFSServerHandler); !ok { + t.Skip("handler does not implement statvfs") + } + + dir := t.TempDir() + + targetNotExist := filepath.Join(dir, "statvfs-does-not-exist") + + _, err := cl.StatVFS(toRemotePath(targetNotExist)) + if !errors.Is(err, fs.ErrNotExist) { + t.Fatalf("unexpected error, got %v, should be fs.ErrNotFound", err) + } + + resp, err := cl.StatVFS(toRemotePath(dir)) + if err != nil { + t.Fatal(err) + } + + t.Logf("%+v", resp) +} + var benchBuf []byte func benchHelperWriteTo(b *testing.B, length int) { diff --git a/server.go b/server.go index 691743a1..d9bd472a 100644 --- a/server.go +++ b/server.go @@ -344,7 +344,12 @@ func Hijack[REQ sshfx.Packet](srv *Server, fn func(context.Context, REQ) error) // This is really only useful for supporting newer versions of the SFTP standard. func HijackWithResponse[REQ, RESP sshfx.Packet](srv *Server, fn func(context.Context, REQ) (RESP, error)) error { wrap := wrapHandler(func(ctx context.Context, req sshfx.Packet) (sshfx.Packet, error) { - return fn(ctx, req.(REQ)) + resp, err := fn(ctx, req.(REQ)) + if err != nil { + // We have to convert maybe typed-zero to untyped-nil. + return nil, err + } + return resp, nil }) var pkt REQ @@ -515,6 +520,7 @@ func (srv *Server) handle(req sshfx.Packet, hint []byte, maxDataLen uint32) (ssh if len(srv.hijacks) > 0 { if fn := srv.hijacks[req.Type()]; fn != nil { + // Hijack takes care of wrapping the getter into an untyped-nil on error. return get(srv, req, fn) } } @@ -595,7 +601,13 @@ func (srv *Server) handle(req sshfx.Packet, hint []byte, maxDataLen uint32) (ssh case *openssh.StatVFSExtendedPacket: if statvfser, ok := srv.Handler.(StatVFSServerHandler); ok { - return get(srv, req, statvfser.StatVFS) + resp, err := get(srv, req, statvfser.StatVFS) + if err != nil { + // We have to convert typed-nil to untyped-nil. + return nil, err + } + + return resp, nil } case interface{ GetHandle() string }: @@ -610,7 +622,13 @@ func (srv *Server) handle(req sshfx.Packet, hint []byte, maxDataLen uint32) (ssh case *openssh.FStatVFSExtendedPacket: if statvfser, ok := file.(StatVFSFileHandler); ok { - return statvfser.StatVFS() + resp, err := statvfser.StatVFS() + if err != nil { + // We have to convert typed-nil to untyped-nil. + return nil, err + } + + return resp, nil } if statvfser, ok := srv.Handler.(StatVFSServerHandler); ok { @@ -618,7 +636,13 @@ func (srv *Server) handle(req sshfx.Packet, hint []byte, maxDataLen uint32) (ssh Path: file.Name(), } - return get(srv, req, statvfser.StatVFS) + resp, err := get(srv, req, statvfser.StatVFS) + if err != nil { + // We have to convert typed-nil to untyped-nil. + return nil, err + } + + return resp, nil } } } @@ -701,7 +725,7 @@ func (srv *Server) handle(req sshfx.Packet, hint []byte, maxDataLen uint32) (ssh } hint = slices.Grow(hint[:0], int(req.Length))[:req.Length] - + n, err := file.ReadAt(hint, int64(req.Offset)) if err != nil { // We cannot return results AND a status like SSH_FX_EOF, @@ -729,7 +753,8 @@ func (srv *Server) handle(req sshfx.Packet, hint []byte, maxDataLen uint32) (ssh return nil, io.ErrShortWrite } - return nil, nil + // explicitly return statusOK here, rather than both nil. + return statusOK, nil case *sshfx.FStatPacket: attrs, err := file.Stat()