Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Sources/ApplicationProtobuf/PirConversion.swift
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ extension Apple_SwiftHomomorphicEncryption_Pir_V1_PirParameters {
entrySizeInBytes: Int(entrySize),
dimensions: dimensions.map(Int.init),
batchSize: Int(batchSize),
evaluationKeyConfig: evaluationKeyConfig.native())
evaluationKeyConfig: evaluationKeyConfig.native(),
encodingEntrySize: encodingEntrySize)
}

/// Converts the protobuf object into a native type.
Expand Down
10 changes: 7 additions & 3 deletions Sources/ApplicationProtobuf/PirConversionApi.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2024-2025 Apple Inc. and the Swift Homomorphic Encryption project authors
// Copyright 2024-2026 Apple Inc. and the Swift Homomorphic Encryption project authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -50,14 +50,18 @@ extension Apple_SwiftHomomorphicEncryption_Api_Pir_V1_PIRShardConfig {
/// - Parameters:
/// - batchSize: Number of queries in a batch.
/// - evaluationKeyConfig: Evaluation key configuration
/// - encodingEntrySize: Whether to encode the size.
/// - Returns: The converted native type.
public func native(batchSize: Int, evaluationKeyConfig: EvaluationKeyConfig) -> IndexPirParameter {
public func native(batchSize: Int, evaluationKeyConfig: EvaluationKeyConfig, encodingEntrySize: Bool)
-> IndexPirParameter
{
IndexPirParameter(
entryCount: Int(numEntries),
entrySizeInBytes: Int(entrySize),
dimensions: dimensions.map(Int.init),
batchSize: batchSize,
evaluationKeyConfig: evaluationKeyConfig)
evaluationKeyConfig: evaluationKeyConfig,
encodingEntrySize: encodingEntrySize)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
// For information on using the generated types, please see the documentation:
// https://github.com/apple/swift-protobuf/

// Copyright 2024-2025 Apple Inc. and the Swift Homomorphic Encryption project authors
// Copyright 2024-2026 Apple Inc. and the Swift Homomorphic Encryption project authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -196,6 +196,12 @@ public struct Apple_SwiftHomomorphicEncryption_Pir_V1_PirParameters: @unchecked
set {_uniqueStorage()._keyCompressionStrategy = newValue}
}

/// Whether to encode the entry size as part of the Index PIR response.
public var encodingEntrySize: Bool {
get {return _storage._encodingEntrySize}
set {_uniqueStorage()._encodingEntrySize = newValue}
}

public var unknownFields = SwiftProtobuf.UnknownStorage()

public init() {}
Expand Down Expand Up @@ -358,7 +364,7 @@ extension Apple_SwiftHomomorphicEncryption_Pir_V1_SymmetricPirConfigType: SwiftP

extension Apple_SwiftHomomorphicEncryption_Pir_V1_PirParameters: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
public static let protoMessageName: String = _protobuf_package + ".PirParameters"
public static let _protobuf_nameMap = SwiftProtobuf._NameMap(bytecode: "\0\u{3}encryption_parameters\0\u{3}num_entries\0\u{3}entry_size\0\u{1}dimensions\0\u{3}keyword_pir_params\0\u{1}algorithm\0\u{3}batch_size\0\u{3}evaluation_key_config\0\u{3}key_compression_strategy\0\u{c}\u{a}\u{1}\u{c}\u{b}\u{1}\u{c}\u{c}\u{1}")
public static let _protobuf_nameMap = SwiftProtobuf._NameMap(bytecode: "\0\u{3}encryption_parameters\0\u{3}num_entries\0\u{3}entry_size\0\u{1}dimensions\0\u{3}keyword_pir_params\0\u{1}algorithm\0\u{3}batch_size\0\u{3}evaluation_key_config\0\u{3}key_compression_strategy\0\u{4}\u{4}encoding_entry_size\0\u{c}\u{a}\u{1}\u{c}\u{b}\u{1}\u{c}\u{c}\u{1}")

fileprivate class _StorageClass {
var _encryptionParameters: HomomorphicEncryptionProtobuf.Apple_SwiftHomomorphicEncryption_V1_EncryptionParameters? = nil
Expand All @@ -370,6 +376,7 @@ extension Apple_SwiftHomomorphicEncryption_Pir_V1_PirParameters: SwiftProtobuf.M
var _batchSize: UInt64 = 0
var _evaluationKeyConfig: HomomorphicEncryptionProtobuf.Apple_SwiftHomomorphicEncryption_V1_EvaluationKeyConfig? = nil
var _keyCompressionStrategy: Apple_SwiftHomomorphicEncryption_Pir_V1_KeyCompressionStrategy = .unspecified
var _encodingEntrySize: Bool = false

// This property is used as the initial default value for new instances of the type.
// The type itself is protecting the reference to its storage via CoW semantics.
Expand All @@ -389,6 +396,7 @@ extension Apple_SwiftHomomorphicEncryption_Pir_V1_PirParameters: SwiftProtobuf.M
_batchSize = source._batchSize
_evaluationKeyConfig = source._evaluationKeyConfig
_keyCompressionStrategy = source._keyCompressionStrategy
_encodingEntrySize = source._encodingEntrySize
}
}

Expand Down Expand Up @@ -416,6 +424,7 @@ extension Apple_SwiftHomomorphicEncryption_Pir_V1_PirParameters: SwiftProtobuf.M
case 7: try { try decoder.decodeSingularUInt64Field(value: &_storage._batchSize) }()
case 8: try { try decoder.decodeSingularMessageField(value: &_storage._evaluationKeyConfig) }()
case 9: try { try decoder.decodeSingularEnumField(value: &_storage._keyCompressionStrategy) }()
case 13: try { try decoder.decodeSingularBoolField(value: &_storage._encodingEntrySize) }()
default: break
}
}
Expand Down Expand Up @@ -455,6 +464,9 @@ extension Apple_SwiftHomomorphicEncryption_Pir_V1_PirParameters: SwiftProtobuf.M
if _storage._keyCompressionStrategy != .unspecified {
try visitor.visitSingularEnumField(value: _storage._keyCompressionStrategy, fieldNumber: 9)
}
if _storage._encodingEntrySize != false {
try visitor.visitSingularBoolField(value: _storage._encodingEntrySize, fieldNumber: 13)
}
}
try unknownFields.traverse(visitor: &visitor)
}
Expand All @@ -473,6 +485,7 @@ extension Apple_SwiftHomomorphicEncryption_Pir_V1_PirParameters: SwiftProtobuf.M
if _storage._batchSize != rhs_storage._batchSize {return false}
if _storage._evaluationKeyConfig != rhs_storage._evaluationKeyConfig {return false}
if _storage._keyCompressionStrategy != rhs_storage._keyCompressionStrategy {return false}
if _storage._encodingEntrySize != rhs_storage._encodingEntrySize {return false}
return true
}
if !storagesAreEqual {return false}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ public enum PirKeyCompressionStrategy: String, CaseIterable, Codable, CodingKeyR
/// Configuration for an Index PIR database.
public struct IndexPirConfig: Hashable, Codable, Sendable {
/// Number of entries in the database.
public let entryCount: Int
/// Byte size of each entry in the database.
public let entrySizeInBytes: Int
public var entryCount: Int
/// Byte size of the largest entry in the database.
public var entrySizeInBytes: Int
/// Number of dimensions in the database.
public let dimensionCount: Int
/// Number of indices in a query to the database.
Expand All @@ -54,23 +54,38 @@ public struct IndexPirConfig: Hashable, Codable, Sendable {
public let unevenDimensions: Bool
/// Evaluation key compression.
public let keyCompression: PirKeyCompressionStrategy
/// Whether to encode the entry size.
public var encodingEntrySize: Bool

/// Size of the largest entry in bytes after encoding.
public var encodedEntrySize: Int {
if encodingEntrySize {
// VarInt is monotonic, i.e. the largest entry will always have the largest encoded entry size.
// So we can take an upper bound here.
VarInt.encodedSize(UInt32(entrySizeInBytes)) + entrySizeInBytes
} else {
entrySizeInBytes
}
}

/// Initializes an ``IndexPirConfig``.
/// - Parameters:
/// - entryCount: Number of entries in the database.
/// - entrySizeInBytes: Byte size of each entry in the database.
/// - entrySizeInBytes: Byte size of the largest entry in the database.
/// - dimensionCount: Number of dimensions in database.
/// - batchSize: Number of indices in a query to the database.
/// - unevenDimensions: Whether or not to enable `uneven dimensions` optimization.
/// - keyCompression: Evaluation key compression.
/// - encodingEntrySize: Whether or not to encode each entry's size.
/// - Throws: Error upon invalid configuration parameters.
public init(
entryCount: Int,
entrySizeInBytes: Int,
dimensionCount: Int,
batchSize: Int,
unevenDimensions: Bool,
keyCompression: PirKeyCompressionStrategy) throws
keyCompression: PirKeyCompressionStrategy,
encodingEntrySize: Bool) throws
{
let validDimensionsCount = [1, 2]
guard validDimensionsCount.contains(dimensionCount) else {
Expand All @@ -82,6 +97,7 @@ public struct IndexPirConfig: Hashable, Codable, Sendable {
self.batchSize = batchSize
self.unevenDimensions = unevenDimensions
self.keyCompression = keyCompression
self.encodingEntrySize = encodingEntrySize
}
}

Expand All @@ -91,14 +107,27 @@ public struct IndexPirConfig: Hashable, Codable, Sendable {
public struct IndexPirParameter: Hashable, Codable, Sendable {
/// Number of entries in the database.
public let entryCount: Int
/// Byte size of each entry in the database.
/// Byte size of the largest entry in the database, excluding any encoding of the entry size.
public let entrySizeInBytes: Int
/// Number of plaintexts in each dimension of the database.
public let dimensions: [Int]
/// Number of indices in a query to the database.
public let batchSize: Int
/// Evaluation key configuration.
public let evaluationKeyConfig: EvaluationKeyConfig
/// Whether to encode the entry size.
public var encodingEntrySize: Bool

/// Size of the largest entry in bytes after encoding.
public var encodedEntrySize: Int {
if encodingEntrySize {
// VarInt is monotonic, i.e. the largest entry will always have the largest encoded entry size.
// So we can take an upper bound here.
VarInt.encodedSize(UInt32(entrySizeInBytes)) + entrySizeInBytes
} else {
entrySizeInBytes
}
}

/// The number of dimensions in the database.
@usableFromInline package var dimensionCount: Int { dimensions.count }
Expand All @@ -108,22 +137,25 @@ public struct IndexPirParameter: Hashable, Codable, Sendable {
/// Initializes an ``IndexPirParameter``.
/// - Parameters:
/// - entryCount: Number of entries in the database.
/// - entrySizeInBytes: Byte size of each entry in the database.
/// - entrySizeInBytes: Byte size of the largest entry in the database, without encoding entry size.
/// - dimensions: Number of plaintexts in each dimension of the database.
/// - batchSize: Number of indices in a query to the database.
/// - evaluationKeyConfig: Evaluation key configuration.
/// - encodingEntrySize: Whether to encode the entry size.
public init(
entryCount: Int,
entrySizeInBytes: Int,
dimensions: [Int],
batchSize: Int,
evaluationKeyConfig: EvaluationKeyConfig)
evaluationKeyConfig: EvaluationKeyConfig,
encodingEntrySize: Bool)
{
self.entryCount = entryCount
self.entrySizeInBytes = entrySizeInBytes
self.dimensions = dimensions
self.batchSize = batchSize
self.evaluationKeyConfig = evaluationKeyConfig
self.encodingEntrySize = encodingEntrySize
}
}

Expand Down
68 changes: 45 additions & 23 deletions Sources/PrivateInformationRetrieval/IndexPir/MulPir.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2024-2025 Apple Inc. and the Swift Homomorphic Encryption project authors
// Copyright 2024-2026 Apple Inc. and the Swift Homomorphic Encryption project authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,9 +35,9 @@ public enum MulPir<Scheme: HeScheme>: IndexPirProtocol {
public static func generateParameter(config: IndexPirConfig,
with context: Scheme.Context) -> IndexPirParameter
{
let entrySizeInBytes = config.entrySizeInBytes
let perChunkPlaintextCount = if entrySizeInBytes <= context.bytesPerPlaintext {
config.entryCount.dividingCeil(context.bytesPerPlaintext / entrySizeInBytes, variableTime: true)
let encodedEntrySize = config.encodedEntrySize
let perChunkPlaintextCount = if encodedEntrySize <= context.bytesPerPlaintext {
config.entryCount.dividingCeil(context.bytesPerPlaintext / encodedEntrySize, variableTime: true)
} else {
config.entryCount
}
Expand Down Expand Up @@ -74,9 +74,10 @@ public enum MulPir<Scheme: HeScheme>: IndexPirProtocol {
keyCompression: config.keyCompression)
return IndexPirParameter(
entryCount: config.entryCount,
entrySizeInBytes: entrySizeInBytes,
entrySizeInBytes: config.entrySizeInBytes,
dimensions: dimensions, batchSize: config.batchSize,
evaluationKeyConfig: evalKeyConfig)
evaluationKeyConfig: evalKeyConfig,
encodingEntrySize: config.encodingEntrySize)
}

@inlinable
Expand Down Expand Up @@ -134,9 +135,13 @@ public final class MulPirClient<PirUtil: PirUtilProtocol>: IndexPirClient {

@usableFromInline var entrySizeInBytes: Int { parameter.entrySizeInBytes }

@usableFromInline var encodingEntrySize: Bool { parameter.encodingEntrySize }

@usableFromInline var encodedEntrySize: Int { parameter.encodedEntrySize }

@usableFromInline var entryChunksPerPlaintext: Int {
if context.bytesPerPlaintext >= entrySizeInBytes {
return context.bytesPerPlaintext / entrySizeInBytes
if context.bytesPerPlaintext >= encodedEntrySize {
return context.bytesPerPlaintext / encodedEntrySize
}
return 1
}
Expand Down Expand Up @@ -214,12 +219,12 @@ extension MulPirClient {

extension MulPirClient {
var expectedResponseCiphertextCount: Int {
entrySizeInBytes.dividingCeil(context.bytesPerPlaintext, variableTime: true)
encodedEntrySize.dividingCeil(context.bytesPerPlaintext, variableTime: true)
}

private func computeResponseRangeInBytes(at index: Int) -> Range<Int> {
let position = index % entryChunksPerPlaintext
return position * entrySizeInBytes..<(position + 1) * entrySizeInBytes
return position * encodedEntrySize..<(position + 1) * encodedEntrySize
}

/// Decrypts an encrypted response.
Expand Down Expand Up @@ -248,7 +253,12 @@ extension MulPirClient {
bitsPerCoeff: context.plaintextModulus.log2)
}

return Array(bytes[computeResponseRangeInBytes(at: entryIndex)])
let responseBytes = bytes[computeResponseRangeInBytes(at: entryIndex)]
if encodingEntrySize {
let (entrySize, bytesConsumed): (UInt32, Int) = try VarInt.decode(responseBytes)
return Array(responseBytes[(responseBytes.startIndex + bytesConsumed)...].prefix(Int(entrySize)))
}
return Array(responseBytes)
}
}

Expand Down Expand Up @@ -336,7 +346,7 @@ public final class MulPirServer<PirUtil: PirUtilProtocol>: IndexPirServer {

@inlinable
package static func chunkCount(parameter: IndexPirParameter, context: Scheme.Context) -> Int {
parameter.entrySizeInBytes.dividingCeil(context.bytesPerPlaintext, variableTime: true)
parameter.encodedEntrySize.dividingCeil(context.bytesPerPlaintext, variableTime: true)
}
}

Expand Down Expand Up @@ -437,12 +447,12 @@ extension MulPirServer {
throw PirError
.invalidDatabaseEntryCount(entryCount: database.count, expected: parameter.entryCount)
}
let maximumElementSize = database.map(\.count).max() ?? 0
guard maximumElementSize <= parameter.entrySizeInBytes else {
let maxEntrySize = database.map(\.count).max() ?? 0
guard maxEntrySize <= parameter.entrySizeInBytes else {
throw PirError
.invalidDatabaseEntrySize(maximumEntrySize: maximumElementSize, expected: parameter.entrySizeInBytes)
.invalidDatabaseEntrySize(maximumEntrySize: maxEntrySize, expected: parameter.entrySizeInBytes)
}
let chunkCount = parameter.entrySizeInBytes.dividingCeil(context.bytesPerPlaintext, variableTime: true)
let chunkCount = parameter.encodedEntrySize.dividingCeil(context.bytesPerPlaintext, variableTime: true)
if chunkCount > 1 {
return try await processSplitLargeEntries(database: database, with: context, using: parameter)
}
Expand All @@ -457,14 +467,22 @@ extension MulPirServer {
{
let chunkCount = Self.chunkCount(parameter: parameter, context: context)
var plaintexts: [[Plaintext<Scheme, Eval>?]] = try await .init(database.async.map { entry in
try await .init(stride(from: 0, to: parameter.entrySizeInBytes, by: context.bytesPerPlaintext).async
let encoded = VarInt.encode(UInt32(entry.count))
let entryEncodingSize = if parameter.encodingEntrySize { encoded.count } else { 0 }
return try await .init(stride(from: 0, to: parameter.encodedEntrySize, by: context.bytesPerPlaintext).async
.map { startIndex in
let endIndex = min(startIndex + context.bytesPerPlaintext, entry.count)
let entryStartIndex = startIndex - entryEncodingSize
let endIndex = min(entryStartIndex + context.bytesPerPlaintext, entry.count)
// Avoid computing on padding plaintexts
guard startIndex < endIndex else {
guard entryStartIndex < endIndex else {
return nil
}
let bytes = Array(entry[startIndex..<endIndex])
let bytes = if startIndex == 0, parameter.encodingEntrySize {
encoded + entry[0..<endIndex]
} else {
Array(entry[entryStartIndex..<endIndex])
}

let coefficients: [Scheme.Scalar] = try CoefficientPacking.bytesToCoefficients(
bytes: bytes,
bitsPerCoeff: context.plaintextModulus.log2,
Expand Down Expand Up @@ -504,12 +522,16 @@ extension MulPirServer {
assert(database.count == parameter.entryCount)
let flatDatabase: [UInt8] = database.flatMap { entry in
var entry = entry
let pad = parameter.entrySizeInBytes - entry.count
if parameter.encodingEntrySize {
let encoded = VarInt.encode(UInt32(entry.count))
entry = encoded + entry
}
let pad = parameter.encodedEntrySize - entry.count
entry.append(contentsOf: repeatElement(0, count: pad))
return entry
}
let entriesPerPlaintext = context.bytesPerPlaintext / parameter.entrySizeInBytes
let bytesPerPlaintext = entriesPerPlaintext * parameter.entrySizeInBytes
let entriesPerPlaintext = context.bytesPerPlaintext / parameter.encodedEntrySize
let bytesPerPlaintext = entriesPerPlaintext * parameter.encodedEntrySize
let plaintextIndices = stride(from: 0, to: flatDatabase.count, by: bytesPerPlaintext)
var plaintexts: [Plaintext<Scheme, Eval>?] = try await .init(plaintextIndices.async
.map { startIndex in
Expand Down
Loading
Loading