diff --git a/cmd/limactl/copy.go b/cmd/limactl/copy.go index 6204d158969..07fdda832fb 100644 --- a/cmd/limactl/copy.go +++ b/cmd/limactl/copy.go @@ -30,6 +30,20 @@ Example: limactl copy default:/etc/os-release . Not to be confused with 'limactl clone'. ` +type copyTool string + +const ( + rsync copyTool = "rsync" + scp copyTool = "scp" +) + +type copyPath struct { + instanceName string + path string + isRemote bool + instance *store.Instance +} + func newCopyCommand() *cobra.Command { copyCommand := &cobra.Command{ Use: "copy SOURCE ... TARGET", @@ -58,13 +72,6 @@ func copyAction(cmd *cobra.Command, args []string) error { return err } - arg0, err := exec.LookPath("scp") - if err != nil { - return err - } - instances := make(map[string]*store.Instance) - scpFlags := []string{} - scpArgs := []string{} debug, err := cmd.Flags().GetBool("debug") if err != nil { return err @@ -74,62 +81,187 @@ func copyAction(cmd *cobra.Command, args []string) error { verbose = true } - if verbose { - scpFlags = append(scpFlags, "-v") - } else { - scpFlags = append(scpFlags, "-q") + copyPaths, err := parseArgs(args) + if err != nil { + return err } - if recursive { - scpFlags = append(scpFlags, "-r") + cpTool, toolPath, err := selectCopyTool(copyPaths) + if err != nil { + return err } - // this assumes that ssh and scp come from the same place, but scp has no -V - legacySSH := sshutil.DetectOpenSSHVersion("ssh").LessThan(*semver.New("8.0.0")) + + logrus.Infof("using copy tool %q", toolPath) + + var copyCmd *exec.Cmd + switch cpTool { + case scp: + copyCmd, err = scpCommand(toolPath, copyPaths, verbose, recursive) + case rsync: + copyCmd, err = rsyncCommand(toolPath, copyPaths, verbose, recursive) + default: + err = fmt.Errorf("invalid copy tool %q", cpTool) + } + if err != nil { + return err + } + + copyCmd.Stdin = cmd.InOrStdin() + copyCmd.Stdout = cmd.OutOrStdout() + copyCmd.Stderr = cmd.ErrOrStderr() + logrus.Debugf("executing %v (may take a long time)", copyCmd) + + // TODO: use syscall.Exec directly (results in losing tty?) + return copyCmd.Run() +} + +func parseArgs(args []string) ([]*copyPath, error) { + var copyPaths []*copyPath + for _, arg := range args { + cp := ©Path{} + if runtime.GOOS == "windows" { if filepath.IsAbs(arg) { + var err error arg, err = ioutilx.WindowsSubsystemPath(arg) if err != nil { - return err + return nil, err } } else { arg = filepath.ToSlash(arg) } } - path := strings.Split(arg, ":") - switch len(path) { + + parts := strings.SplitN(arg, ":", 2) + switch len(parts) { case 1: - scpArgs = append(scpArgs, arg) + cp.path = arg + cp.isRemote = false case 2: - instName := path[0] - inst, err := store.Inspect(instName) + cp.instanceName = parts[0] + cp.path = parts[1] + cp.isRemote = true + + inst, err := store.Inspect(cp.instanceName) if err != nil { if errors.Is(err, os.ErrNotExist) { - return fmt.Errorf("instance %q does not exist, run `limactl create %s` to create a new instance", instName, instName) + return nil, fmt.Errorf("instance %q does not exist, run `limactl create %s` to create a new instance", cp.instanceName, cp.instanceName) } - return err + return nil, err } if inst.Status == store.StatusStopped { - return fmt.Errorf("instance %q is stopped, run `limactl start %s` to start the instance", instName, instName) + return nil, fmt.Errorf("instance %q is stopped, run `limactl start %s` to start the instance", cp.instanceName, cp.instanceName) } + cp.instance = inst + default: + return nil, fmt.Errorf("path %q contains multiple colons", arg) + } + + copyPaths = append(copyPaths, cp) + } + + return copyPaths, nil +} + +func selectCopyTool(copyPaths []*copyPath) (copyTool, string, error) { + if rsyncPath, err := exec.LookPath("rsync"); err == nil { + if rsyncAvailableOnGuests(copyPaths) { + return rsync, rsyncPath, nil + } + logrus.Debugf("rsync not available on guest(s), falling back to scp") + } else { + logrus.Debugf("rsync not found on host, falling back to scp: %v", err) + } + + scpPath, err := exec.LookPath("scp") + if err != nil { + return "", "", fmt.Errorf("neither rsync nor scp found on host: %w", err) + } + + return scp, scpPath, nil +} + +func rsyncAvailableOnGuests(copyPaths []*copyPath) bool { + instances := make(map[string]*store.Instance) + + for _, cp := range copyPaths { + if cp.isRemote { + instances[cp.instanceName] = cp.instance + } + } + + for instName, inst := range instances { + if !checkRsyncOnGuest(inst) { + logrus.Debugf("rsync not available on instance %q", instName) + return false + } + } + + return true +} + +func checkRsyncOnGuest(inst *store.Instance) bool { + sshOpts, err := sshutil.SSHOpts("ssh", inst.Dir, *inst.Config.User.Name, false, false, false, false) + if err != nil { + logrus.Debugf("failed to get SSH options for rsync check: %v", err) + return false + } + + sshArgs := sshutil.SSHArgsFromOpts(sshOpts) + checkCmd := exec.Command("ssh") + checkCmd.Args = append(checkCmd.Args, sshArgs...) + checkCmd.Args = append(checkCmd.Args, + "-p", fmt.Sprintf("%d", inst.SSHLocalPort), + fmt.Sprintf("%s@127.0.0.1", *inst.Config.User.Name), + "command -v rsync >/dev/null 2>&1", + ) + + err = checkCmd.Run() + return err == nil +} + +func scpCommand(command string, copyPaths []*copyPath, verbose, recursive bool) (*exec.Cmd, error) { + instances := make(map[string]*store.Instance) + scpFlags := []string{} + scpArgs := []string{} + + if verbose { + scpFlags = append(scpFlags, "-v") + } else { + scpFlags = append(scpFlags, "-q") + } + + if recursive { + scpFlags = append(scpFlags, "-r") + } + + // this assumes that ssh and scp come from the same place, but scp has no -V + legacySSH := sshutil.DetectOpenSSHVersion("ssh").LessThan(*semver.New("8.0.0")) + + for _, cp := range copyPaths { + if cp.isRemote { if legacySSH { - scpFlags = append(scpFlags, "-P", fmt.Sprintf("%d", inst.SSHLocalPort)) - scpArgs = append(scpArgs, fmt.Sprintf("%s@127.0.0.1:%s", *inst.Config.User.Name, path[1])) + scpFlags = append(scpFlags, "-P", fmt.Sprintf("%d", cp.instance.SSHLocalPort)) + scpArgs = append(scpArgs, fmt.Sprintf("%s@127.0.0.1:%s", *cp.instance.Config.User.Name, cp.path)) } else { - scpArgs = append(scpArgs, fmt.Sprintf("scp://%s@127.0.0.1:%d/%s", *inst.Config.User.Name, inst.SSHLocalPort, path[1])) + scpArgs = append(scpArgs, fmt.Sprintf("scp://%s@127.0.0.1:%d/%s", *cp.instance.Config.User.Name, cp.instance.SSHLocalPort, cp.path)) } - instances[instName] = inst - default: - return fmt.Errorf("path %q contains multiple colons", arg) + instances[cp.instanceName] = cp.instance + } else { + scpArgs = append(scpArgs, cp.path) } } + if legacySSH && len(instances) > 1 { - return errors.New("more than one (instance) host is involved in this command, this is only supported for openSSH v8.0 or higher") + return nil, errors.New("more than one (instance) host is involved in this command, this is only supported for openSSH v8.0 or higher") } + scpFlags = append(scpFlags, "-3", "--") scpArgs = append(scpFlags, scpArgs...) var sshOpts []string + var err error if len(instances) == 1 { // Only one (instance) host is involved; we can use the instance-specific // arguments such as ControlPath. This is preferred as we can multiplex @@ -137,24 +269,66 @@ func copyAction(cmd *cobra.Command, args []string) error { for _, inst := range instances { sshOpts, err = sshutil.SSHOpts("ssh", inst.Dir, *inst.Config.User.Name, false, false, false, false) if err != nil { - return err + return nil, err } } } else { // Copying among multiple hosts; we can't pass in host-specific options. sshOpts, err = sshutil.CommonOpts("ssh", false) if err != nil { - return err + return nil, err } } sshArgs := sshutil.SSHArgsFromOpts(sshOpts) - sshCmd := exec.Command(arg0, append(sshArgs, scpArgs...)...) - sshCmd.Stdin = cmd.InOrStdin() - sshCmd.Stdout = cmd.OutOrStdout() - sshCmd.Stderr = cmd.ErrOrStderr() - logrus.Debugf("executing scp (may take a long time): %+v", sshCmd.Args) + return exec.Command(command, append(sshArgs, scpArgs...)...), nil +} - // TODO: use syscall.Exec directly (results in losing tty?) - return sshCmd.Run() +func rsyncCommand(command string, copyPaths []*copyPath, verbose, recursive bool) (*exec.Cmd, error) { + rsyncFlags := []string{"-a"} + + if verbose { + rsyncFlags = append(rsyncFlags, "-v", "--progress") + } else { + rsyncFlags = append(rsyncFlags, "-q") + } + + if recursive { + rsyncFlags = append(rsyncFlags, "-r") + } + + rsyncArgs := make([]string, 0, len(rsyncFlags)+len(copyPaths)) + rsyncArgs = append(rsyncArgs, rsyncFlags...) + + var sshCmd string + var remoteInstance *store.Instance + + for _, cp := range copyPaths { + if cp.isRemote { + if remoteInstance == nil { + remoteInstance = cp.instance + sshOpts, err := sshutil.SSHOpts("ssh", cp.instance.Dir, *cp.instance.Config.User.Name, false, false, false, false) + if err != nil { + return nil, err + } + + sshArgs := sshutil.SSHArgsFromOpts(sshOpts) + sshCmd = fmt.Sprintf("ssh -p %d %s", cp.instance.SSHLocalPort, strings.Join(sshArgs, " ")) + } + } + } + + if sshCmd != "" { + rsyncArgs = append(rsyncArgs, "-e", sshCmd) + } + + for _, cp := range copyPaths { + if cp.isRemote { + rsyncArgs = append(rsyncArgs, fmt.Sprintf("%s@127.0.0.1:%s", *cp.instance.Config.User.Name, cp.path)) + } else { + rsyncArgs = append(rsyncArgs, cp.path) + } + } + + return exec.Command(command, rsyncArgs...), nil } diff --git a/pkg/cidata/cidata.TEMPLATE.d/user-data b/pkg/cidata/cidata.TEMPLATE.d/user-data index e2e13045396..b2c50265b09 100644 --- a/pkg/cidata/cidata.TEMPLATE.d/user-data +++ b/pkg/cidata/cidata.TEMPLATE.d/user-data @@ -11,6 +11,9 @@ package_upgrade: true package_reboot_if_required: true {{- end }} +packages: + - rsync + {{- if or .RosettaEnabled (and .Mounts (or (eq .MountType "9p") (eq .MountType "virtiofs"))) }} mounts: {{- if .RosettaEnabled }}{{/* Mount the rosetta volume before systemd-binfmt.service(8) starts */}} diff --git a/pkg/hostagent/hostagent.go b/pkg/hostagent/hostagent.go index 98d0922a0bd..e643f80f71d 100644 --- a/pkg/hostagent/hostagent.go +++ b/pkg/hostagent/hostagent.go @@ -433,6 +433,7 @@ func (a *HostAgent) startHostAgentRoutines(ctx context.Context) error { if err := a.waitForRequirements("essential", a.essentialRequirements()); err != nil { errs = append(errs, err) } + if *a.instConfig.SSH.ForwardAgent { faScript := `#!/bin/bash set -eux -o pipefail