diff --git a/Sources/SparkConnect/CaseInsensitiveDictionary.swift b/Sources/SparkConnect/CaseInsensitiveDictionary.swift new file mode 100644 index 0000000..ae6508a --- /dev/null +++ b/Sources/SparkConnect/CaseInsensitiveDictionary.swift @@ -0,0 +1,63 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// +import Foundation + +/// A dictionary in which keys are case insensitive. The input dictionary can be +/// accessed for cases where case-sensitive information is required. +public struct CaseInsensitiveDictionary: Sendable { + public var originalDictionary: [String: Sendable] + private var keyLowerCasedDictionary: [String: Sendable] = [:] + + init(_ originalDictionary: [String: Sendable] = [:]) { + self.originalDictionary = originalDictionary + for (key, value) in originalDictionary { + keyLowerCasedDictionary[key.lowercased()] = value + } + } + + subscript(key: String) -> Sendable? { + get { + return keyLowerCasedDictionary[key.lowercased()] + } + set { + let lowerKey = key.lowercased() + if let newValue = newValue { + keyLowerCasedDictionary[lowerKey] = newValue + } else { + keyLowerCasedDictionary.removeValue(forKey: lowerKey) + } + originalDictionary = originalDictionary.filter { $0.key.lowercased() != lowerKey } + if let newValue = newValue { + originalDictionary[key] = newValue + } + } + } + + public func toDictionary() -> [String: Sendable] { + return originalDictionary + } + + public func toStringDictionary() -> [String: String] { + return originalDictionary.mapValues { String(describing: $0) } + } + + public var count: Int { + return keyLowerCasedDictionary.count + } +} diff --git a/Sources/SparkConnect/DataFrameReader.swift b/Sources/SparkConnect/DataFrameReader.swift index 1041bc3..47f022f 100644 --- a/Sources/SparkConnect/DataFrameReader.swift +++ b/Sources/SparkConnect/DataFrameReader.swift @@ -32,8 +32,7 @@ public actor DataFrameReader: Sendable { var paths: [String] = [] - // TODO: Case-insensitive Map - var extraOptions: [String: String] = [:] + var extraOptions: CaseInsensitiveDictionary = CaseInsensitiveDictionary([:]) let sparkSession: SparkSession @@ -84,7 +83,7 @@ public actor DataFrameReader: Sendable { var dataSource = DataSource() dataSource.format = self.source dataSource.paths = self.paths - dataSource.options = self.extraOptions + dataSource.options = self.extraOptions.toStringDictionary() var read = Read() read.dataSource = dataSource diff --git a/Sources/SparkConnect/DataFrameWriter.swift b/Sources/SparkConnect/DataFrameWriter.swift index 6ebc514..ffb0183 100644 --- a/Sources/SparkConnect/DataFrameWriter.swift +++ b/Sources/SparkConnect/DataFrameWriter.swift @@ -32,8 +32,7 @@ public actor DataFrameWriter: Sendable { var saveMode: String = "default" - // TODO: Case-insensitive Map - var extraOptions: [String: String] = [:] + var extraOptions: CaseInsensitiveDictionary = CaseInsensitiveDictionary() var partitioningColumns: [String]? = nil @@ -146,7 +145,7 @@ public actor DataFrameWriter: Sendable { write.bucketBy = bucketBy } - for option in self.extraOptions { + for option in self.extraOptions.toStringDictionary() { write.options[option.key] = option.value } diff --git a/Tests/SparkConnectTests/CaseInsensitiveDictionaryTests.swift b/Tests/SparkConnectTests/CaseInsensitiveDictionaryTests.swift new file mode 100644 index 0000000..481dcc3 --- /dev/null +++ b/Tests/SparkConnectTests/CaseInsensitiveDictionaryTests.swift @@ -0,0 +1,87 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +import Foundation +import Testing + +@testable import SparkConnect + +/// A test suite for `CaseInsensitiveDictionary` +struct CaseInsensitiveDictionaryTests { + @Test + func empty() async throws { + let dict = CaseInsensitiveDictionary([:]) + #expect(dict.count == 0) + } + + @Test + func originalDictionary() async throws { + let dict = CaseInsensitiveDictionary([ + "key1": "value1", + "KEY1": "VALUE1", + ]) + #expect(dict.count == 1) + #expect(dict.originalDictionary.count == 2) + } + + @Test + func toDictionary() async throws { + let dict = CaseInsensitiveDictionary([ + "key1": "value1", + "KEY1": "VALUE1", + ]) + #expect(dict.toDictionary().count == 2) + } + + @Test + func `subscript`() async throws { + var dict = CaseInsensitiveDictionary([:]) + #expect(dict.count == 0) + + dict["KEY1"] = "value1" + #expect(dict.count == 1) + #expect(dict["key1"] as! String == "value1") + #expect(dict["KEY1"] as! String == "value1") + #expect(dict["KeY1"] as! String == "value1") + + dict["key2"] = false + #expect(dict.count == 2) + #expect(dict["kEy2"] as! Bool == false) + + dict["key3"] = 2025 + #expect(dict.count == 3) + #expect(dict["key3"] as! Int == 2025) + } + + @Test + func updatedOriginalDictionary() async throws { + var dict = CaseInsensitiveDictionary([ + "key1": "value1", + "KEY1": "VALUE1", + ]) + #expect(dict.count == 1) + #expect(dict.originalDictionary.count == 2) + + dict["KEY1"] = "Swift" + #expect(dict["KEY1"] as! String == "Swift") + #expect(dict.count == 1) + #expect(dict.originalDictionary.count == 1) + #expect(dict.toDictionary().count == 1) + } +}