Skip to content
This repository was archived by the owner on Feb 24, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ public enum NetworkProtectionSubfeature: String, Equatable, PrivacySubfeature {
/// Enforce routes for the VPN to fix TunnelVision
/// https://app.asana.com/0/72649045549333/1208617860225199/f
case enforceRoutes

/// Risky Domain Protection for VPN
/// https://app.asana.com/0/1204186595873227/1206489252288889
case riskyDomainsProtection
}

public enum SyncSubfeature: String, PrivacySubfeature {
Expand Down
38 changes: 32 additions & 6 deletions Sources/NetworkProtection/NetworkProtectionDeviceManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public enum NetworkProtectionServerSelectionMethod: CustomDebugStringConvertible
}

public enum NetworkProtectionDNSSettings: Codable, Equatable, CustomStringConvertible {
case `default`
case ddg(blockRiskyDomains: Bool)
case custom([String])

public var usesCustomDNS: Bool {
Expand All @@ -55,7 +55,7 @@ public enum NetworkProtectionDNSSettings: Codable, Equatable, CustomStringConver

public var description: String {
switch self {
case .default: return "DuckDuckGo"
case .ddg: return "DuckDuckGo"
case .custom(let servers): return servers.joined(separator: ", ")
}
}
Expand Down Expand Up @@ -278,8 +278,12 @@ public actor NetworkProtectionDeviceManager: NetworkProtectionDeviceManagement {

let dns: [DNSServer]
switch dnsSettings {
case .default:
dns = [DNSServer(address: server.serverInfo.internalIP.ipAddress)]
case .ddg(let blockRiskyDomains):
var ipAddress: IPAddress = server.serverInfo.internalIP.ipAddress
if blockRiskyDomains {
ipAddress = ipAddress.computeBlockRiskyDomainsDnsOrSame()
}
dns = [DNSServer(address: ipAddress)]
case .custom(let servers):
dns = servers
.compactMap { IPv4Address($0) }
Expand All @@ -290,8 +294,6 @@ public actor NetworkProtectionDeviceManager: NetworkProtectionDeviceManagement {
dnsServers: dns,
excludeLocalNetworks: excludeLocalNetworks)

Logger.networkProtection.log("Routing table information:\nL Included Routes: \(routingTableResolver.includedRoutes, privacy: .public)\nL Excluded Routes: \(routingTableResolver.excludedRoutes, privacy: .public)")

let interface = InterfaceConfiguration(privateKey: interfacePrivateKey,
addresses: [interfaceAddressRange],
includedRoutes: routingTableResolver.includedRoutes,
Expand Down Expand Up @@ -333,3 +335,27 @@ public actor NetworkProtectionDeviceManager: NetworkProtectionDeviceManagement {
}
}
}

extension IPAddress {
/// Returns a new IP address by left-shifting the last octet of the IPv4 address.
///
/// if the new address cannot be created, the original address is returned.
func computeBlockRiskyDomainsDnsOrSame() -> Self {
// Extracts the last byte
let data = self.rawValue
var bytes = [UInt8](data)
guard let lastOctet = bytes.last else { return self }

// Perform a left-shift on the last octet.
// We cast to UInt16 to avoid overflow, then mask with 0xFF to ensure the result fits in 8 bits.
let shiftedOctet = UInt8((UInt16(lastOctet) << 1) & 0xFF)

// Update the last element with the shifted value.
bytes[bytes.count - 1] = shiftedOctet

// Attempt to create a new IPAddress with the updated raw data, preserving the interface.
// If creation fails, return the original address.
let newData = Data(bytes)
return Self(newData, self.interface) ?? self
}
}
2 changes: 1 addition & 1 deletion Sources/NetworkProtection/PacketTunnelProvider.swift
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
case .useExisting:
break
case .reset:
settings.dnsSettings = .default
settings.dnsSettings = .ddg(blockRiskyDomains: settings.isBlockRiskyDomainsOn)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,30 @@ extension UserDefaults {
final class StorableDNSSettings: NSObject, Codable {
let usesCustomDNS: Bool
let dnsServers: [String]
let isBlockRiskyDomainsOn: Bool

init(usesCustomDNS: Bool = false, dnsServers: [String] = []) {
init(usesCustomDNS: Bool = false, dnsServers: [String] = [], isBlockRiskyDomainsOn: Bool = true) {
self.usesCustomDNS = usesCustomDNS
self.dnsServers = dnsServers
self.isBlockRiskyDomainsOn = isBlockRiskyDomainsOn
}

private enum CodingKeys: String, CodingKey {
case usesCustomDNS, dnsServers, isBlockRiskyDomainsOn
}

required init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
self.usesCustomDNS = try container.decodeIfPresent(Bool.self, forKey: .usesCustomDNS) ?? false
self.dnsServers = try container.decodeIfPresent([String].self, forKey: .dnsServers) ?? []
self.isBlockRiskyDomainsOn = try container.decodeIfPresent(Bool.self, forKey: .isBlockRiskyDomainsOn) ?? true
}

func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
try container.encode(usesCustomDNS, forKey: .usesCustomDNS)
try container.encode(dnsServers, forKey: .dnsServers)
try container.encode(isBlockRiskyDomainsOn, forKey: .isBlockRiskyDomainsOn)
}
}

Expand All @@ -35,7 +55,7 @@ extension UserDefaults {
}

private static func dnsSettingsFromStorageValue(_ value: StorableDNSSettings) -> NetworkProtectionDNSSettings {
guard value.usesCustomDNS, !value.dnsServers.isEmpty else { return .default }
guard value.usesCustomDNS, !value.dnsServers.isEmpty else { return .ddg(blockRiskyDomains: value.isBlockRiskyDomainsOn) }
return .custom(value.dnsServers)
}

Expand All @@ -53,21 +73,31 @@ extension UserDefaults {
}
}

var isBlockRiskyDomainsOn: Bool {
dnsSettingStorageValue.isBlockRiskyDomainsOn
}

var customDnsServers: [String] {
dnsSettingStorageValue.dnsServers
}

var dnsSettings: NetworkProtectionDNSSettings {
get {
Self.dnsSettingsFromStorageValue(dnsSettingStorageValue)
}

set {
switch newValue {
case .default:
dnsSettingStorageValue = StorableDNSSettings()
case .ddg(let isBlockRiskyDomainsOn):
let dnsServers = dnsSettingStorageValue.dnsServers
dnsSettingStorageValue = StorableDNSSettings(dnsServers: dnsServers, isBlockRiskyDomainsOn: isBlockRiskyDomainsOn)
case .custom(let dnsServers):
let hosts = dnsServers.compactMap(\.toIPv4Host)
let isBlockRiskyDomainsOn = dnsSettingStorageValue.isBlockRiskyDomainsOn
if hosts.isEmpty {
dnsSettingStorageValue = StorableDNSSettings()
dnsSettingStorageValue = StorableDNSSettings(isBlockRiskyDomainsOn: isBlockRiskyDomainsOn)
} else {
dnsSettingStorageValue = StorableDNSSettings(usesCustomDNS: true, dnsServers: hosts)
dnsSettingStorageValue = StorableDNSSettings(usesCustomDNS: true, dnsServers: hosts, isBlockRiskyDomainsOn: isBlockRiskyDomainsOn)
}
}
}
Expand All @@ -80,6 +110,6 @@ extension UserDefaults {
}

func resetDNSSettings() {
dnsSettings = .default
dnsSettings = .ddg(blockRiskyDomains: isBlockRiskyDomainsOn)
}
}
8 changes: 8 additions & 0 deletions Sources/NetworkProtection/Settings/VPNSettings.swift
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,14 @@ public final class VPNSettings {
defaults.dnsSettingsPublisher
}

public var isBlockRiskyDomainsOn: Bool {
defaults.isBlockRiskyDomainsOn
}

public var customDnsServers: [String] {
defaults.customDnsServers
}

public var dnsSettings: NetworkProtectionDNSSettings {
get {
defaults.dnsSettings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import Foundation
import XCTest
@testable import NetworkProtection
@testable import NetworkProtectionTestUtils
import Network

final class NetworkProtectionDeviceManagerTests: XCTestCase {
var tokenHandler: SubscriptionTokenHandlingMock!
Expand Down Expand Up @@ -204,16 +205,56 @@ final class NetworkProtectionDeviceManagerTests: XCTestCase {
XCTAssertEqual(firstKey, secondKey) // Check that the key did NOT change, even though we tried to regenerate it
XCTAssertNotNil(networkClient.spyRegister)
}

func testDNSConfigurationWhenProtectionIsActive() async {
// GIVEN
let server = NetworkProtectionServer.mockRegisteredServer
networkClient.stubRegister = .success([server])
let protectionActive = true
let configuration: NetworkProtectionDeviceManager.GenerateTunnelConfigurationResult
let expectedIPAddress = IPv4Address("10.11.12.2")!

// WHEN
do {
configuration = try await manager.generateTunnelConfiguration(selectionMethod: .automatic, regenerateKey: false, protectionActive: protectionActive)
} catch {
XCTFail("Unexpected error \(error.localizedDescription)")
return
}

// THEN
XCTAssertEqual(configuration.0.interface.dns.first?.address.rawValue, expectedIPAddress.rawValue)
}

func testDNSConfigurationWhenProtectionIsNotActive() async {
// GIVEN
let server = NetworkProtectionServer.mockRegisteredServer
networkClient.stubRegister = .success([server])
let protectionActive = false
let configuration: NetworkProtectionDeviceManager.GenerateTunnelConfigurationResult
let expectedIPAddress = IPv4Address("10.11.12.1")!

// WHEN
do {
configuration = try await manager.generateTunnelConfiguration(selectionMethod: .automatic, regenerateKey: false, protectionActive: protectionActive)
} catch {
XCTFail("Unexpected error \(error.localizedDescription)")
return
}

// THEN
XCTAssertEqual(configuration.0.interface.dns.first?.address.rawValue, expectedIPAddress.rawValue)
}
}

extension NetworkProtectionDeviceManager {

func generateTunnelConfiguration(selectionMethod: NetworkProtectionServerSelectionMethod,
regenerateKey: Bool) async throws -> NetworkProtectionDeviceManager.GenerateTunnelConfigurationResult {
regenerateKey: Bool, protectionActive: Bool = false) async throws -> NetworkProtectionDeviceManager.GenerateTunnelConfigurationResult {
try await generateTunnelConfiguration(
resolvedSelectionMethod: selectionMethod,
excludeLocalNetworks: false,
dnsSettings: .default,
dnsSettings: .ddg(blockRiskyDomains: protectionActive),
regenerateKey: regenerateKey
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ final class FailureRecoveryHandlerTests: XCTestCase {
await failureRecoveryHandler.attemptRecovery(
to: server,
excludeLocalNetworks: expectedExcludeLocalNetworks,
dnsSettings: .default
dnsSettings: .ddg(blockRiskyDomains: false)
) {_ in }
guard let spyGenerateTunnelConfiguration = deviceManager.spyGenerateTunnelConfiguration else {
XCTFail("attemptRecovery not called")
Expand Down Expand Up @@ -122,7 +122,7 @@ final class FailureRecoveryHandlerTests: XCTestCase {
await failureRecoveryHandler.attemptRecovery(
to: .mockRegisteredServer,
excludeLocalNetworks: false,
dnsSettings: .default
dnsSettings: .ddg(blockRiskyDomains: false)
) {_ in }

XCTAssertEqual(startedCount, 1)
Expand Down Expand Up @@ -307,7 +307,7 @@ final class FailureRecoveryHandlerTests: XCTestCase {
await failureRecoveryHandler.attemptRecovery(
to: .mockRegisteredServer,
excludeLocalNetworks: false,
dnsSettings: .default
dnsSettings: .ddg(blockRiskyDomains: false)
) {_ in }
}

Expand All @@ -323,7 +323,7 @@ final class FailureRecoveryHandlerTests: XCTestCase {
await failureRecoveryHandler.attemptRecovery(
to: .mockRegisteredServer,
excludeLocalNetworks: false,
dnsSettings: .default
dnsSettings: .ddg(blockRiskyDomains: false)
) { _ in
let underlyingError = NSError(domain: "test", code: 1)
throw WireGuardAdapterError.startWireGuardBackend(underlyingError)
Expand All @@ -342,7 +342,7 @@ final class FailureRecoveryHandlerTests: XCTestCase {

var newConfigResult: NetworkProtectionDeviceManagement.GenerateTunnelConfigurationResult?

await failureRecoveryHandler.attemptRecovery(to: lastServer, excludeLocalNetworks: false, dnsSettings: .default) { configResult in
await failureRecoveryHandler.attemptRecovery(to: lastServer, excludeLocalNetworks: false, dnsSettings: .ddg(blockRiskyDomains: false)) { configResult in
newConfigResult = configResult
}
return newConfigResult
Expand Down
Loading