Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions Sources/tart/Commands/Clone.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@ struct Clone: AsyncParsableCommand {
@Flag(help: .hidden)
var deduplicate: Bool = false

@Option(help: .hidden)
var proxy: String?

@Option(help: .hidden)
var caCert: String?

@Option(help: .hidden)
var maxRetries: UInt = 5

func validate() throws {
if newName.contains("/") {
throw ValidationError("<new-name> should be a local name")
Expand All @@ -47,8 +56,8 @@ struct Clone: AsyncParsableCommand {

if let remoteName = try? RemoteName(sourceName), !ociStorage.exists(remoteName) {
// Pull the VM in case it's OCI-based and doesn't exist locally yet
let registry = try Registry(host: remoteName.host, namespace: remoteName.namespace, insecure: insecure)
try await ociStorage.pull(remoteName, registry: registry, concurrency: concurrency, deduplicate: deduplicate)
let registry = try Registry(host: remoteName.host, namespace: remoteName.namespace, insecure: insecure, proxy: proxy, caCert: caCert)
try await ociStorage.pull(remoteName, registry: registry, concurrency: concurrency, deduplicate: deduplicate, maxRetries: maxRetries)
}

let sourceVM = try VMStorageHelper.open(sourceName)
Expand Down
13 changes: 11 additions & 2 deletions Sources/tart/Commands/Pull.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ struct Pull: AsyncParsableCommand {
@Flag(help: .hidden)
var deduplicate: Bool = false

@Option(help: .hidden)
var proxy: String?

@Option(help: .hidden)
var caCert: String?

@Option(help: .hidden)
var maxRetries: UInt = 5

func validate() throws {
if concurrency < 1 {
throw ValidationError("network concurrency cannot be less than 1")
Expand All @@ -42,10 +51,10 @@ struct Pull: AsyncParsableCommand {
}

let remoteName = try RemoteName(remoteName)
let registry = try Registry(host: remoteName.host, namespace: remoteName.namespace, insecure: insecure)
let registry = try Registry(host: remoteName.host, namespace: remoteName.namespace, insecure: insecure, proxy: proxy, caCert: caCert)

defaultLogger.appendNewLine("pulling \(remoteName)...")

try await VMStorageOCI().pull(remoteName, registry: registry, concurrency: concurrency, deduplicate: deduplicate)
try await VMStorageOCI().pull(remoteName, registry: registry, concurrency: concurrency, deduplicate: deduplicate, maxRetries: maxRetries)
}
}
138 changes: 118 additions & 20 deletions Sources/tart/Fetcher.swift
Original file line number Diff line number Diff line change
@@ -1,26 +1,61 @@
import Foundation

fileprivate var urlSession: URLSession = {
let config = URLSessionConfiguration.default

// Harbor expects a CSRF token to be present if the HTTP client
// carries a session cookie between its requests[1] and fails if
// it was not present[2].
//
// To fix that, we disable the automatic cookies carry in URLSession.
//
// [1]: https://github.com/goharbor/harbor/blob/a4c577f9ec4f18396207a5e686433a6ba203d4ef/src/server/middleware/csrf/csrf.go#L78
// [2]: https://github.com/cirruslabs/tart/issues/295
config.httpShouldSetCookies = false

return URLSession(configuration: config)
}()

class Fetcher {
static func fetch(_ request: URLRequest, viaFile: Bool = false) async throws -> (AsyncThrowingStream<Data, Error>, HTTPURLResponse) {
let task = urlSession.dataTask(with: request)
let urlSession: URLSession
let caCert: SecCertificate?

init(proxy: String? = nil, caCert: String? = nil) throws {
// Configure URLSession
let config = URLSessionConfiguration.default

// Harbor expects a CSRF token to be present if the HTTP client
// carries a session cookie between its requests[1] and fails if
// it was not present[2].
//
// To fix that, we disable the automatic cookies carry in URLSession.
//
// [1]: https://github.com/goharbor/harbor/blob/a4c577f9ec4f18396207a5e686433a6ba203d4ef/src/server/middleware/csrf/csrf.go#L78
// [2]: https://github.com/cirruslabs/tart/issues/295
config.httpShouldSetCookies = false

if let proxy {
let (host, port) = try Self.parseProxy(proxy)

config.connectionProxyDictionary = [
kCFNetworkProxiesHTTPEnable: true,
kCFNetworkProxiesHTTPProxy: host,
kCFNetworkProxiesHTTPPort: port,

kCFNetworkProxiesHTTPSEnable: true,
kCFNetworkProxiesHTTPSProxy: host,
kCFNetworkProxiesHTTPSPort: port,
]
}

self.urlSession = URLSession(configuration: config)

// Load CA certificate, if any
if let caCert {
let caCertString = try String(contentsOf: URL(filePath: caCert), encoding:. utf8)

let caCertBase64Lines = caCertString.components(separatedBy: .newlines).filter { line in
!line.hasPrefix("-----BEGIN") && !line.hasPrefix("-----END")
}

let delegate = Delegate()
guard let caCertData = Data(base64Encoded: caCertBase64Lines.joined()) else {
throw RuntimeError.FailedToLoadCACertificate("failed to parse Base64-encoded PEM data")
}

self.caCert = SecCertificateCreateWithData(nil, caCertData as CFData)!
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it an additional one? It seems it's possible to pass a certificate without proxy and then I wonder if regular requests will work?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not, but can be made so by calling SecTrustSetAnchorCertificatesOnly().

} else {
self.caCert = nil
}
}

func fetch(_ request: URLRequest, viaFile: Bool = false) async throws -> (AsyncThrowingStream<Data, Error>, HTTPURLResponse) {
let task = self.urlSession.dataTask(with: request)

let delegate = Delegate(caCert: self.caCert)
task.delegate = delegate

let stream = AsyncThrowingStream<Data, Error> { continuation in
Expand All @@ -34,15 +69,78 @@ class Fetcher {

return (stream, response as! HTTPURLResponse)
}

private static func parseProxy(_ proxy: String) throws -> (String, Int) {
// Assume that the scheme is specified
var url = URL(string: proxy)

// Fall back to HTTP scheme when not specified
if url?.scheme == nil {
url = URL(string: "http://\(proxy)")
}

guard let url else {
throw RuntimeError.InvalidProxyString
}

guard let host = url.host() else {
throw RuntimeError.InvalidProxyString
}

guard let port = url.port else {
throw RuntimeError.InvalidProxyString
}

return (host, port)
}
}

fileprivate class Delegate: NSObject, URLSessionDataDelegate {
fileprivate class Delegate: NSObject, URLSessionDelegate, URLSessionDataDelegate {
let caCert: SecCertificate?
var responseContinuation: CheckedContinuation<URLResponse, Error>?
var streamContinuation: AsyncThrowingStream<Data, Error>.Continuation?

private var buffer: Data = Data()
private let bufferFlushSize = 16 * 1024 * 1024

init(caCert: SecCertificate?) {
self.caCert = caCert
}

func urlSession(
_ session: URLSession,
didReceive challenge: URLAuthenticationChallenge,
completionHandler: @escaping @Sendable (URLSession.AuthChallengeDisposition, URLCredential?) -> Void
) {
if let caCert {
// Ensure that we're performing server trust authentication
guard challenge.protectionSpace.authenticationMethod == NSURLAuthenticationMethodServerTrust,
let serverTrust = challenge.protectionSpace.serverTrust else {
completionHandler(.performDefaultHandling, nil)

return
}

// Set the provided CA certificate as the only anchor
if SecTrustSetAnchorCertificates(serverTrust, [caCert] as CFArray) != errSecSuccess {
completionHandler(.cancelAuthenticationChallenge, nil)

return
}

// Evaluate the trust
if SecTrustEvaluateWithError(serverTrust, nil) {
completionHandler(.useCredential, URLCredential(trust: serverTrust))
} else {
completionHandler(.rejectProtectionSpace, nil)
}

return
}

completionHandler(.performDefaultHandling, nil)
}

func urlSession(
_ session: URLSession,
dataTask: URLSessionDataTask,
Expand Down
2 changes: 1 addition & 1 deletion Sources/tart/OCI/Layerizer/Disk.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ import Foundation

protocol Disk {
static func push(diskURL: URL, registry: Registry, chunkSizeMb: Int, concurrency: UInt, progress: Progress) async throws -> [OCIManifestLayer]
static func pull(registry: Registry, diskLayers: [OCIManifestLayer], diskURL: URL, concurrency: UInt, progress: Progress, localLayerCache: LocalLayerCache?, deduplicate: Bool) async throws
static func pull(registry: Registry, diskLayers: [OCIManifestLayer], diskURL: URL, concurrency: UInt, progress: Progress, localLayerCache: LocalLayerCache?, deduplicate: Bool, maxRetries: UInt) async throws
}
2 changes: 1 addition & 1 deletion Sources/tart/OCI/Layerizer/DiskV1.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class DiskV1: Disk {
return pushedLayers
}

static func pull(registry: Registry, diskLayers: [OCIManifestLayer], diskURL: URL, concurrency: UInt, progress: Progress, localLayerCache: LocalLayerCache? = nil, deduplicate: Bool = false) async throws {
static func pull(registry: Registry, diskLayers: [OCIManifestLayer], diskURL: URL, concurrency: UInt, progress: Progress, localLayerCache: LocalLayerCache? = nil, deduplicate: Bool = false, maxRetries: UInt = 5) async throws {
if !FileManager.default.createFile(atPath: diskURL.path, contents: nil) {
throw OCIError.FailedToCreateVmFile
}
Expand Down
4 changes: 2 additions & 2 deletions Sources/tart/OCI/Layerizer/DiskV2.swift
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class DiskV2: Disk {
}
}

static func pull(registry: Registry, diskLayers: [OCIManifestLayer], diskURL: URL, concurrency: UInt, progress: Progress, localLayerCache: LocalLayerCache? = nil, deduplicate: Bool = false) async throws {
static func pull(registry: Registry, diskLayers: [OCIManifestLayer], diskURL: URL, concurrency: UInt, progress: Progress, localLayerCache: LocalLayerCache? = nil, deduplicate: Bool = false, maxRetries: UInt = 5) async throws {
// Support resumable pulls
let pullResumed = FileManager.default.fileExists(atPath: diskURL.path)

Expand Down Expand Up @@ -210,7 +210,7 @@ class DiskV2: Disk {

var rangeStart: Int64 = 0

try await retry(maxAttempts: 5) {
try await retry(maxAttempts: Int(maxRetries)) {
try await registry.pullBlob(diskLayer.digest, rangeStart: rangeStart) { data in
try filter.write(data)

Expand Down
12 changes: 9 additions & 3 deletions Sources/tart/OCI/Registry.swift
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class Registry {
let namespace: String
let credentialsProviders: [CredentialsProvider]
let authenticationKeeper = AuthenticationKeeper()
let fetcher: Fetcher

var host: String? {
guard let host = baseURL.host else { return nil }
Expand All @@ -128,17 +129,22 @@ class Registry {

init(baseURL: URL,
namespace: String,
credentialsProviders: [CredentialsProvider] = [EnvironmentCredentialsProvider(), DockerConfigCredentialsProvider(), KeychainCredentialsProvider()]
credentialsProviders: [CredentialsProvider] = [EnvironmentCredentialsProvider(), DockerConfigCredentialsProvider(), KeychainCredentialsProvider()],
proxy: String? = nil,
caCert: String? = nil
) throws {
self.baseURL = baseURL
self.namespace = namespace
self.credentialsProviders = credentialsProviders
self.fetcher = try Fetcher(proxy: proxy, caCert: caCert)
}

convenience init(
host: String,
namespace: String,
insecure: Bool = false,
proxy: String? = nil,
caCert: String? = nil,
credentialsProviders: [CredentialsProvider] = [EnvironmentCredentialsProvider(), DockerConfigCredentialsProvider(), KeychainCredentialsProvider()]
) throws {
let proto = insecure ? "http" : "https"
Expand All @@ -154,7 +160,7 @@ class Registry {
throw RuntimeError.ImproperlyFormattedHost(host, hint)
}

try self.init(baseURL: baseURL, namespace: namespace, credentialsProviders: credentialsProviders)
try self.init(baseURL: baseURL, namespace: namespace, credentialsProviders: credentialsProviders, proxy: proxy, caCert: caCert)
}

func ping() async throws {
Expand Down Expand Up @@ -448,6 +454,6 @@ class Registry {
request.setValue("Tart/\(CI.version) (\(DeviceInfo.os); \(DeviceInfo.model))",
forHTTPHeaderField: "User-Agent")

return try await Fetcher.fetch(request, viaFile: viaFile)
return try await self.fetcher.fetch(request, viaFile: viaFile)
}
}
4 changes: 2 additions & 2 deletions Sources/tart/VM.swift
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class VM: NSObject, VZVirtualMachineDelegate, ObservableObject {
// Check if we already have this IPSW in cache
var headRequest = URLRequest(url: remoteURL)
headRequest.httpMethod = "HEAD"
let (_, headResponse) = try await Fetcher.fetch(headRequest, viaFile: false)
let (_, headResponse) = try await Fetcher().fetch(headRequest, viaFile: false)

if let hash = headResponse.value(forHTTPHeaderField: "x-amz-meta-digest-sha256") {
let ipswLocation = try IPSWCache().locationFor(fileName: "sha256:\(hash).ipsw")
Expand All @@ -100,7 +100,7 @@ class VM: NSObject, VZVirtualMachineDelegate, ObservableObject {
defaultLogger.appendNewLine("Fetching \(remoteURL.lastPathComponent)...")

let request = URLRequest(url: remoteURL)
let (channel, response) = try await Fetcher.fetch(request, viaFile: true)
let (channel, response) = try await Fetcher().fetch(request, viaFile: true)

let temporaryLocation = try Config().tartTmpDir.appendingPathComponent(UUID().uuidString + ".ipsw")

Expand Down
12 changes: 10 additions & 2 deletions Sources/tart/VMDirectory+OCI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@ enum OCIError: Error {
}

extension VMDirectory {
func pullFromRegistry(registry: Registry, manifest: OCIManifest, concurrency: UInt, localLayerCache: LocalLayerCache?, deduplicate: Bool) async throws {
func pullFromRegistry(
registry: Registry,
manifest: OCIManifest,
concurrency: UInt,
localLayerCache: LocalLayerCache?,
deduplicate: Bool,
maxRetries: UInt
) async throws {
// Pull VM's config file layer and re-serialize it into a config file
let configLayers = manifest.layers.filter {
$0.mediaType == configMediaType
Expand Down Expand Up @@ -55,7 +62,8 @@ extension VMDirectory {
try await diskImplType.pull(registry: registry, diskLayers: layers, diskURL: diskURL,
concurrency: concurrency, progress: progress,
localLayerCache: localLayerCache,
deduplicate: deduplicate)
deduplicate: deduplicate,
maxRetries: maxRetries)
} catch let error where error is FilterError {
throw RuntimeError.PullFailed("failed to decompress disk: \(error.localizedDescription)")
}
Expand Down
6 changes: 6 additions & 0 deletions Sources/tart/VMStorageHelper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ enum RuntimeError : Error {
case SuspendFailed(_ message: String)
case PullFailed(_ message: String)
case VirtualMachineLimitExceeded(_ hint: String)
case InvalidProxyString
case FailedToLoadCACertificate(_ message: String)
}

protocol HasExitCode {
Expand Down Expand Up @@ -136,6 +138,10 @@ extension RuntimeError : CustomStringConvertible {
return message
case .VirtualMachineLimitExceeded(let hint):
return "The number of VMs exceeds the system limit\(hint)"
case .InvalidProxyString:
return "Invalid proxy string, should be in the form of host:port"
case .FailedToLoadCACertificate(let message):
return "Failed to load CA certificate: \(message)"
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions Sources/tart/VMStorageOCI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class VMStorageOCI: PrunableStorage {
try list().filter { (_, _, isSymlink) in !isSymlink }.map { (_, vmDir, _) in vmDir }
}

func pull(_ name: RemoteName, registry: Registry, concurrency: UInt, deduplicate: Bool) async throws {
func pull(_ name: RemoteName, registry: Registry, concurrency: UInt, deduplicate: Bool, maxRetries: UInt) async throws {
SentrySDK.configureScope { scope in
scope.setContext(value: ["imageName": name.description], key: "OCI")
}
Expand Down Expand Up @@ -196,7 +196,7 @@ class VMStorageOCI: PrunableStorage {
}

try await withTaskCancellationHandler(operation: {
try await retry(maxAttempts: 5) {
try await retry(maxAttempts: Int(maxRetries)) {
// Choose the best base image which has the most deduplication ratio
let localLayerCache = try await chooseLocalLayerCache(name, manifest, registry)

Expand All @@ -210,7 +210,7 @@ class VMStorageOCI: PrunableStorage {
}
}

try await tmpVMDir.pullFromRegistry(registry: registry, manifest: manifest, concurrency: concurrency, localLayerCache: localLayerCache, deduplicate: deduplicate)
try await tmpVMDir.pullFromRegistry(registry: registry, manifest: manifest, concurrency: concurrency, localLayerCache: localLayerCache, deduplicate: deduplicate, maxRetries: maxRetries)
} recoverFromFailure: { error in
if error is URLError {
print("Error pulling image: \"\(error.localizedDescription)\", attempting to re-try...")
Expand Down