Skip to content

Commit 5b560fd

Browse files
committed
add rsync flag option to copy files using rsync
Signed-off-by: olalekan odukoya <[email protected]>
1 parent 93ef4ce commit 5b560fd

File tree

3 files changed

+218
-40
lines changed

3 files changed

+218
-40
lines changed

cmd/limactl/copy.go

Lines changed: 214 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,20 @@ Example: limactl copy default:/etc/os-release .
3030
Not to be confused with 'limactl clone'.
3131
`
3232

33+
type copyTool string
34+
35+
const (
36+
rsync copyTool = "rsync"
37+
scp copyTool = "scp"
38+
)
39+
40+
type copyPath struct {
41+
instanceName string
42+
path string
43+
isRemote bool
44+
instance *store.Instance
45+
}
46+
3347
func newCopyCommand() *cobra.Command {
3448
copyCommand := &cobra.Command{
3549
Use: "copy SOURCE ... TARGET",
@@ -58,13 +72,6 @@ func copyAction(cmd *cobra.Command, args []string) error {
5872
return err
5973
}
6074

61-
arg0, err := exec.LookPath("scp")
62-
if err != nil {
63-
return err
64-
}
65-
instances := make(map[string]*store.Instance)
66-
scpFlags := []string{}
67-
scpArgs := []string{}
6875
debug, err := cmd.Flags().GetBool("debug")
6976
if err != nil {
7077
return err
@@ -74,87 +81,254 @@ func copyAction(cmd *cobra.Command, args []string) error {
7481
verbose = true
7582
}
7683

77-
if verbose {
78-
scpFlags = append(scpFlags, "-v")
79-
} else {
80-
scpFlags = append(scpFlags, "-q")
84+
copyPaths, err := parseArgs(args)
85+
if err != nil {
86+
return err
8187
}
8288

83-
if recursive {
84-
scpFlags = append(scpFlags, "-r")
89+
cpTool, toolPath, err := selectCopyTool(copyPaths)
90+
if err != nil {
91+
return err
8592
}
86-
// this assumes that ssh and scp come from the same place, but scp has no -V
87-
legacySSH := sshutil.DetectOpenSSHVersion("ssh").LessThan(*semver.New("8.0.0"))
93+
94+
logrus.Infof("using copy tool %q", toolPath)
95+
96+
var copyCmd *exec.Cmd
97+
switch cpTool {
98+
case scp:
99+
copyCmd, err = scpCommand(toolPath, copyPaths, verbose, recursive)
100+
case rsync:
101+
copyCmd, err = rsyncCommand(toolPath, copyPaths, verbose, recursive)
102+
default:
103+
err = fmt.Errorf("invalid copy tool %q", cpTool)
104+
}
105+
if err != nil {
106+
return err
107+
}
108+
109+
copyCmd.Stdin = cmd.InOrStdin()
110+
copyCmd.Stdout = cmd.OutOrStdout()
111+
copyCmd.Stderr = cmd.ErrOrStderr()
112+
logrus.Debugf("executing %v (may take a long time)", copyCmd)
113+
114+
// TODO: use syscall.Exec directly (results in losing tty?)
115+
return copyCmd.Run()
116+
}
117+
118+
func parseArgs(args []string) ([]*copyPath, error) {
119+
var copyPaths []*copyPath
120+
88121
for _, arg := range args {
122+
cp := &copyPath{}
123+
89124
if runtime.GOOS == "windows" {
90125
if filepath.IsAbs(arg) {
126+
var err error
91127
arg, err = ioutilx.WindowsSubsystemPath(arg)
92128
if err != nil {
93-
return err
129+
return nil, err
94130
}
95131
} else {
96132
arg = filepath.ToSlash(arg)
97133
}
98134
}
99-
path := strings.Split(arg, ":")
100-
switch len(path) {
135+
136+
parts := strings.SplitN(arg, ":", 2)
137+
switch len(parts) {
101138
case 1:
102-
scpArgs = append(scpArgs, arg)
139+
cp.path = arg
140+
cp.isRemote = false
103141
case 2:
104-
instName := path[0]
105-
inst, err := store.Inspect(instName)
142+
cp.instanceName = parts[0]
143+
cp.path = parts[1]
144+
cp.isRemote = true
145+
146+
inst, err := store.Inspect(cp.instanceName)
106147
if err != nil {
107148
if errors.Is(err, os.ErrNotExist) {
108-
return fmt.Errorf("instance %q does not exist, run `limactl create %s` to create a new instance", instName, instName)
149+
return nil, fmt.Errorf("instance %q does not exist, run `limactl create %s` to create a new instance", cp.instanceName, cp.instanceName)
109150
}
110-
return err
151+
return nil, err
111152
}
112153
if inst.Status == store.StatusStopped {
113-
return fmt.Errorf("instance %q is stopped, run `limactl start %s` to start the instance", instName, instName)
154+
return nil, fmt.Errorf("instance %q is stopped, run `limactl start %s` to start the instance", cp.instanceName, cp.instanceName)
114155
}
156+
cp.instance = inst
157+
default:
158+
return nil, fmt.Errorf("path %q contains multiple colons", arg)
159+
}
160+
161+
copyPaths = append(copyPaths, cp)
162+
}
163+
164+
return copyPaths, nil
165+
}
166+
167+
func selectCopyTool(copyPaths []*copyPath) (copyTool, string, error) {
168+
if rsyncPath, err := exec.LookPath("rsync"); err == nil {
169+
if rsyncAvailableOnGuests(copyPaths) {
170+
return rsync, rsyncPath, nil
171+
}
172+
logrus.Debugf("rsync not available on guest(s), falling back to scp")
173+
} else {
174+
logrus.Debugf("rsync not found on host, falling back to scp: %v", err)
175+
}
176+
177+
scpPath, err := exec.LookPath("scp")
178+
if err != nil {
179+
return "", "", fmt.Errorf("neither rsync nor scp found on host: %w", err)
180+
}
181+
182+
return scp, scpPath, nil
183+
}
184+
185+
func rsyncAvailableOnGuests(copyPaths []*copyPath) bool {
186+
instances := make(map[string]*store.Instance)
187+
188+
for _, cp := range copyPaths {
189+
if cp.isRemote {
190+
instances[cp.instanceName] = cp.instance
191+
}
192+
}
193+
194+
for instName, inst := range instances {
195+
if !checkRsyncOnGuest(inst) {
196+
logrus.Debugf("rsync not available on instance %q", instName)
197+
return false
198+
}
199+
}
200+
201+
return true
202+
}
203+
204+
func checkRsyncOnGuest(inst *store.Instance) bool {
205+
sshOpts, err := sshutil.SSHOpts("ssh", inst.Dir, *inst.Config.User.Name, false, false, false, false)
206+
if err != nil {
207+
logrus.Debugf("failed to get SSH options for rsync check: %v", err)
208+
return false
209+
}
210+
211+
sshArgs := sshutil.SSHArgsFromOpts(sshOpts)
212+
checkCmd := exec.Command("ssh")
213+
checkCmd.Args = append(checkCmd.Args, sshArgs...)
214+
checkCmd.Args = append(checkCmd.Args,
215+
"-p", fmt.Sprintf("%d", inst.SSHLocalPort),
216+
fmt.Sprintf("%[email protected]", *inst.Config.User.Name),
217+
"command -v rsync >/dev/null 2>&1",
218+
)
219+
220+
err = checkCmd.Run()
221+
return err == nil
222+
}
223+
224+
func scpCommand(command string, copyPaths []*copyPath, verbose, recursive bool) (*exec.Cmd, error) {
225+
instances := make(map[string]*store.Instance)
226+
scpFlags := []string{}
227+
scpArgs := []string{}
228+
229+
if verbose {
230+
scpFlags = append(scpFlags, "-v")
231+
} else {
232+
scpFlags = append(scpFlags, "-q")
233+
}
234+
235+
if recursive {
236+
scpFlags = append(scpFlags, "-r")
237+
}
238+
239+
// this assumes that ssh and scp come from the same place, but scp has no -V
240+
legacySSH := sshutil.DetectOpenSSHVersion("ssh").LessThan(*semver.New("8.0.0"))
241+
242+
for _, cp := range copyPaths {
243+
if cp.isRemote {
115244
if legacySSH {
116-
scpFlags = append(scpFlags, "-P", fmt.Sprintf("%d", inst.SSHLocalPort))
117-
scpArgs = append(scpArgs, fmt.Sprintf("%[email protected]:%s", *inst.Config.User.Name, path[1]))
245+
scpFlags = append(scpFlags, "-P", fmt.Sprintf("%d", cp.instance.SSHLocalPort))
246+
scpArgs = append(scpArgs, fmt.Sprintf("%[email protected]:%s", *cp.instance.Config.User.Name, cp.path))
118247
} else {
119-
scpArgs = append(scpArgs, fmt.Sprintf("scp://%[email protected]:%d/%s", *inst.Config.User.Name, inst.SSHLocalPort, path[1]))
248+
scpArgs = append(scpArgs, fmt.Sprintf("scp://%[email protected]:%d/%s", *cp.instance.Config.User.Name, cp.instance.SSHLocalPort, cp.path))
120249
}
121-
instances[instName] = inst
122-
default:
123-
return fmt.Errorf("path %q contains multiple colons", arg)
250+
instances[cp.instanceName] = cp.instance
251+
} else {
252+
scpArgs = append(scpArgs, cp.path)
124253
}
125254
}
255+
126256
if legacySSH && len(instances) > 1 {
127-
return errors.New("more than one (instance) host is involved in this command, this is only supported for openSSH v8.0 or higher")
257+
return nil, errors.New("more than one (instance) host is involved in this command, this is only supported for openSSH v8.0 or higher")
128258
}
259+
129260
scpFlags = append(scpFlags, "-3", "--")
130261
scpArgs = append(scpFlags, scpArgs...)
131262

132263
var sshOpts []string
264+
var err error
133265
if len(instances) == 1 {
134266
// Only one (instance) host is involved; we can use the instance-specific
135267
// arguments such as ControlPath. This is preferred as we can multiplex
136268
// sessions without re-authenticating (MaxSessions permitting).
137269
for _, inst := range instances {
138270
sshOpts, err = sshutil.SSHOpts("ssh", inst.Dir, *inst.Config.User.Name, false, false, false, false)
139271
if err != nil {
140-
return err
272+
return nil, err
141273
}
142274
}
143275
} else {
144276
// Copying among multiple hosts; we can't pass in host-specific options.
145277
sshOpts, err = sshutil.CommonOpts("ssh", false)
146278
if err != nil {
147-
return err
279+
return nil, err
148280
}
149281
}
150282
sshArgs := sshutil.SSHArgsFromOpts(sshOpts)
151283

152-
sshCmd := exec.Command(arg0, append(sshArgs, scpArgs...)...)
153-
sshCmd.Stdin = cmd.InOrStdin()
154-
sshCmd.Stdout = cmd.OutOrStdout()
155-
sshCmd.Stderr = cmd.ErrOrStderr()
156-
logrus.Debugf("executing scp (may take a long time): %+v", sshCmd.Args)
284+
return exec.Command(command, append(sshArgs, scpArgs...)...), nil
285+
}
157286

158-
// TODO: use syscall.Exec directly (results in losing tty?)
159-
return sshCmd.Run()
287+
func rsyncCommand(command string, copyPaths []*copyPath, verbose, recursive bool) (*exec.Cmd, error) {
288+
rsyncFlags := []string{"-a"}
289+
290+
if verbose {
291+
rsyncFlags = append(rsyncFlags, "-v", "--progress")
292+
} else {
293+
rsyncFlags = append(rsyncFlags, "-q")
294+
}
295+
296+
if recursive {
297+
rsyncFlags = append(rsyncFlags, "-r")
298+
}
299+
300+
rsyncArgs := make([]string, 0, len(rsyncFlags)+len(copyPaths))
301+
rsyncArgs = append(rsyncArgs, rsyncFlags...)
302+
303+
var sshCmd string
304+
var remoteInstance *store.Instance
305+
306+
for _, cp := range copyPaths {
307+
if cp.isRemote {
308+
if remoteInstance == nil {
309+
remoteInstance = cp.instance
310+
sshOpts, err := sshutil.SSHOpts("ssh", cp.instance.Dir, *cp.instance.Config.User.Name, false, false, false, false)
311+
if err != nil {
312+
return nil, err
313+
}
314+
315+
sshArgs := sshutil.SSHArgsFromOpts(sshOpts)
316+
sshCmd = fmt.Sprintf("ssh -p %d %s", cp.instance.SSHLocalPort, strings.Join(sshArgs, " "))
317+
}
318+
}
319+
}
320+
321+
if sshCmd != "" {
322+
rsyncArgs = append(rsyncArgs, "-e", sshCmd)
323+
}
324+
325+
for _, cp := range copyPaths {
326+
if cp.isRemote {
327+
rsyncArgs = append(rsyncArgs, fmt.Sprintf("%[email protected]:%s", *cp.instance.Config.User.Name, cp.path))
328+
} else {
329+
rsyncArgs = append(rsyncArgs, cp.path)
330+
}
331+
}
332+
333+
return exec.Command(command, rsyncArgs...), nil
160334
}

pkg/cidata/cidata.TEMPLATE.d/user-data

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ package_upgrade: true
1111
package_reboot_if_required: true
1212
{{- end }}
1313

14+
packages:
15+
- rsync
16+
1417
{{- if or .RosettaEnabled (and .Mounts (or (eq .MountType "9p") (eq .MountType "virtiofs"))) }}
1518
mounts:
1619
{{- if .RosettaEnabled }}{{/* Mount the rosetta volume before systemd-binfmt.service(8) starts */}}

pkg/hostagent/hostagent.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ func (a *HostAgent) startHostAgentRoutines(ctx context.Context) error {
433433
if err := a.waitForRequirements("essential", a.essentialRequirements()); err != nil {
434434
errs = append(errs, err)
435435
}
436+
436437
if *a.instConfig.SSH.ForwardAgent {
437438
faScript := `#!/bin/bash
438439
set -eux -o pipefail

0 commit comments

Comments
 (0)