|
15 | 15 | package download
|
16 | 16 |
|
17 | 17 | import (
|
| 18 | + "archive/tar" |
| 19 | + "archive/zip" |
18 | 20 | "bytes"
|
| 21 | + "compress/gzip" |
19 | 22 | "io"
|
20 | 23 | "io/ioutil"
|
21 | 24 | "os"
|
@@ -452,3 +455,182 @@ func Test_extractArchive(t *testing.T) {
|
452 | 455 | })
|
453 | 456 | }
|
454 | 457 | }
|
| 458 | + |
| 459 | +func Test_suspiciousPath(t *testing.T) { |
| 460 | + tests := []struct { |
| 461 | + path string |
| 462 | + shouldErr bool |
| 463 | + }{ |
| 464 | + { |
| 465 | + path: `/foo`, |
| 466 | + shouldErr: true, |
| 467 | + }, |
| 468 | + { |
| 469 | + path: `\foo`, |
| 470 | + shouldErr: true, |
| 471 | + }, |
| 472 | + { |
| 473 | + path: `//foo`, |
| 474 | + shouldErr: true, |
| 475 | + }, |
| 476 | + { |
| 477 | + path: `/\foo`, |
| 478 | + shouldErr: true, |
| 479 | + }, |
| 480 | + { |
| 481 | + path: `\\foo`, |
| 482 | + shouldErr: true, |
| 483 | + }, |
| 484 | + { |
| 485 | + path: `./foo`, |
| 486 | + }, |
| 487 | + { |
| 488 | + path: `././foo`, |
| 489 | + }, |
| 490 | + { |
| 491 | + path: `.//foo`, |
| 492 | + }, |
| 493 | + { |
| 494 | + path: `../foo`, |
| 495 | + shouldErr: true, |
| 496 | + }, |
| 497 | + { |
| 498 | + path: `a/../foo`, |
| 499 | + shouldErr: true, |
| 500 | + }, |
| 501 | + { |
| 502 | + path: `a/././foo`, |
| 503 | + }, |
| 504 | + } |
| 505 | + |
| 506 | + for _, tt := range tests { |
| 507 | + t.Run(tt.path, func(t *testing.T) { |
| 508 | + err := suspiciousPath(tt.path) |
| 509 | + if tt.shouldErr && err == nil { |
| 510 | + t.Errorf("Expected suspiciousPath to fail") |
| 511 | + } |
| 512 | + if !tt.shouldErr && err != nil { |
| 513 | + t.Errorf("Expected suspiciousPath not to fail, got %s", err) |
| 514 | + } |
| 515 | + }) |
| 516 | + } |
| 517 | +} |
| 518 | + |
| 519 | +func Test_extractMaliciousArchive(t *testing.T) { |
| 520 | + const testContent = "some file content" |
| 521 | + |
| 522 | + tests := []struct { |
| 523 | + name string |
| 524 | + path string |
| 525 | + }{ |
| 526 | + { |
| 527 | + name: "absolute file", |
| 528 | + path: "/foo", |
| 529 | + }, |
| 530 | + { |
| 531 | + name: "contains ..", |
| 532 | + path: "a/../foo", |
| 533 | + }, |
| 534 | + } |
| 535 | + |
| 536 | + for _, tt := range tests { |
| 537 | + t.Run("tar.gz "+tt.name, func(t *testing.T) { |
| 538 | + tmpDir, cleanup := testutil.NewTempDir(t) |
| 539 | + defer cleanup() |
| 540 | + |
| 541 | + // do not use filepath.Join here, because it calls filepath.Clean on the result |
| 542 | + reader, err := tarGZArchiveForTesting(map[string]string{tt.path: testContent}) |
| 543 | + if err != nil { |
| 544 | + t.Fatal(err) |
| 545 | + } |
| 546 | + |
| 547 | + err = extractTARGZ(tmpDir.Root(), reader, reader.Size()) |
| 548 | + if err == nil { |
| 549 | + t.Errorf("Expected extractTARGZ to fail") |
| 550 | + } else if !strings.HasPrefix(err.Error(), "refusing to unpack archive") { |
| 551 | + t.Errorf("Found the wrong error: %s", err) |
| 552 | + } |
| 553 | + }) |
| 554 | + } |
| 555 | + |
| 556 | + for _, tt := range tests { |
| 557 | + t.Run("zip "+tt.name, func(t *testing.T) { |
| 558 | + tmpDir, cleanup := testutil.NewTempDir(t) |
| 559 | + defer cleanup() |
| 560 | + |
| 561 | + // do not use filepath.Join here, because it calls filepath.Clean on the result |
| 562 | + reader, err := zipArchiveReaderForTesting(map[string]string{tt.path: testContent}) |
| 563 | + if err != nil { |
| 564 | + t.Fatal(err) |
| 565 | + } |
| 566 | + |
| 567 | + err = extractZIP(tmpDir.Root(), reader, reader.Size()) |
| 568 | + if err == nil { |
| 569 | + t.Errorf("Expected extractZIP to fail") |
| 570 | + } else if !strings.HasPrefix(err.Error(), "refusing to unpack archive") { |
| 571 | + t.Errorf("Found the wrong error: %s", err) |
| 572 | + } |
| 573 | + }) |
| 574 | + } |
| 575 | +} |
| 576 | + |
| 577 | +// tarGZArchiveForTesting creates an in-memory zip archive with entries from |
| 578 | +// the files map, where keys are the paths and values are the contents. |
| 579 | +// For example, to create an empty file `a` and another file `b/c`: |
| 580 | +// tarGZArchiveForTesting(map[string]string{ |
| 581 | +// "a": "", |
| 582 | +// "b/c": "nested content", |
| 583 | +// }) |
| 584 | +func tarGZArchiveForTesting(files map[string]string) (*bytes.Reader, error) { |
| 585 | + archiveBuffer := &bytes.Buffer{} |
| 586 | + gzArchiveBuffer := gzip.NewWriter(archiveBuffer) |
| 587 | + tw := tar.NewWriter(gzArchiveBuffer) |
| 588 | + for path, content := range files { |
| 589 | + header := &tar.Header{ |
| 590 | + Name: path, |
| 591 | + Size: int64(len(content)), |
| 592 | + Mode: 0600, |
| 593 | + } |
| 594 | + if err := tw.WriteHeader(header); err != nil { |
| 595 | + return nil, err |
| 596 | + } |
| 597 | + if _, err := tw.Write([]byte(content)); err != nil { |
| 598 | + return nil, err |
| 599 | + } |
| 600 | + |
| 601 | + } |
| 602 | + if err := tw.Close(); err != nil { |
| 603 | + return nil, err |
| 604 | + } |
| 605 | + if err := gzArchiveBuffer.Close(); err != nil { |
| 606 | + return nil, err |
| 607 | + } |
| 608 | + return bytes.NewReader(archiveBuffer.Bytes()), nil |
| 609 | +} |
| 610 | + |
| 611 | +// zipArchiveReaderForTesting creates an in-memory zip archive with entries from |
| 612 | +// the files map, where keys are the paths and values are the contents. Note that |
| 613 | +// entries with empty content just create a directory. The zip spec requires that |
| 614 | +// parent directories are explicitly listed in the archive, so this must be done |
| 615 | +// for nested entries. For example, to create a file at `a/b/c`, you must pass: |
| 616 | +// map[string]string{"a": "", "a/b": "", "a/b/c": "nested content"} |
| 617 | +func zipArchiveReaderForTesting(files map[string]string) (*bytes.Reader, error) { |
| 618 | + archiveBuffer := &bytes.Buffer{} |
| 619 | + zw := zip.NewWriter(archiveBuffer) |
| 620 | + for path, content := range files { |
| 621 | + f, err := zw.Create(path) |
| 622 | + if err != nil { |
| 623 | + return nil, err |
| 624 | + } |
| 625 | + if content == "" { |
| 626 | + continue |
| 627 | + } |
| 628 | + if _, err := f.Write([]byte(content)); err != nil { |
| 629 | + return nil, err |
| 630 | + } |
| 631 | + } |
| 632 | + if err := zw.Close(); err != nil { |
| 633 | + return nil, err |
| 634 | + } |
| 635 | + return bytes.NewReader(archiveBuffer.Bytes()), nil |
| 636 | +} |
0 commit comments