1- // Copyright 2024-2025 Apple Inc. and the Swift Homomorphic Encryption project authors
1+ // Copyright 2024-2026 Apple Inc. and the Swift Homomorphic Encryption project authors
22//
33// Licensed under the Apache License, Version 2.0 (the "License");
44// you may not use this file except in compliance with the License.
@@ -35,9 +35,9 @@ public enum MulPir<Scheme: HeScheme>: IndexPirProtocol {
3535 public static func generateParameter( config: IndexPirConfig ,
3636 with context: Scheme . Context ) -> IndexPirParameter
3737 {
38- let entrySizeInBytes = config. entrySizeInBytes
39- let perChunkPlaintextCount = if entrySizeInBytes <= context. bytesPerPlaintext {
40- config. entryCount. dividingCeil ( context. bytesPerPlaintext / entrySizeInBytes , variableTime: true )
38+ let encodedEntrySize = config. encodedEntrySize
39+ let perChunkPlaintextCount = if encodedEntrySize <= context. bytesPerPlaintext {
40+ config. entryCount. dividingCeil ( context. bytesPerPlaintext / encodedEntrySize , variableTime: true )
4141 } else {
4242 config. entryCount
4343 }
@@ -74,9 +74,10 @@ public enum MulPir<Scheme: HeScheme>: IndexPirProtocol {
7474 keyCompression: config. keyCompression)
7575 return IndexPirParameter (
7676 entryCount: config. entryCount,
77- entrySizeInBytes: entrySizeInBytes,
77+ entrySizeInBytes: config . entrySizeInBytes,
7878 dimensions: dimensions, batchSize: config. batchSize,
79- evaluationKeyConfig: evalKeyConfig)
79+ evaluationKeyConfig: evalKeyConfig,
80+ encodingEntrySize: config. encodingEntrySize)
8081 }
8182
8283 @inlinable
@@ -134,9 +135,13 @@ public final class MulPirClient<PirUtil: PirUtilProtocol>: IndexPirClient {
134135
135136 @usableFromInline var entrySizeInBytes : Int { parameter. entrySizeInBytes }
136137
138+ @usableFromInline var encodingEntrySize : Bool { parameter. encodingEntrySize }
139+
140+ @usableFromInline var encodedEntrySize : Int { parameter. encodedEntrySize }
141+
137142 @usableFromInline var entryChunksPerPlaintext : Int {
138- if context. bytesPerPlaintext >= entrySizeInBytes {
139- return context. bytesPerPlaintext / entrySizeInBytes
143+ if context. bytesPerPlaintext >= encodedEntrySize {
144+ return context. bytesPerPlaintext / encodedEntrySize
140145 }
141146 return 1
142147 }
@@ -214,12 +219,12 @@ extension MulPirClient {
214219
215220extension MulPirClient {
216221 var expectedResponseCiphertextCount : Int {
217- entrySizeInBytes . dividingCeil ( context. bytesPerPlaintext, variableTime: true )
222+ encodedEntrySize . dividingCeil ( context. bytesPerPlaintext, variableTime: true )
218223 }
219224
220225 private func computeResponseRangeInBytes( at index: Int ) -> Range < Int > {
221226 let position = index % entryChunksPerPlaintext
222- return position * entrySizeInBytes ..< ( position + 1 ) * entrySizeInBytes
227+ return position * encodedEntrySize ..< ( position + 1 ) * encodedEntrySize
223228 }
224229
225230 /// Decrypts an encrypted response.
@@ -248,7 +253,12 @@ extension MulPirClient {
248253 bitsPerCoeff: context. plaintextModulus. log2)
249254 }
250255
251- return Array ( bytes [ computeResponseRangeInBytes ( at: entryIndex) ] )
256+ let responseBytes = bytes [ computeResponseRangeInBytes ( at: entryIndex) ]
257+ if encodingEntrySize {
258+ let ( entrySize, bytesConsumed) : ( UInt32 , Int ) = try VarInt . decode ( responseBytes)
259+ return Array ( responseBytes [ ( responseBytes. startIndex + bytesConsumed) ... ] . prefix ( Int ( entrySize) ) )
260+ }
261+ return Array ( responseBytes)
252262 }
253263 }
254264
@@ -336,7 +346,7 @@ public final class MulPirServer<PirUtil: PirUtilProtocol>: IndexPirServer {
336346
337347 @inlinable
338348 package static func chunkCount( parameter: IndexPirParameter , context: Scheme . Context ) -> Int {
339- parameter. entrySizeInBytes . dividingCeil ( context. bytesPerPlaintext, variableTime: true )
349+ parameter. encodedEntrySize . dividingCeil ( context. bytesPerPlaintext, variableTime: true )
340350 }
341351}
342352
@@ -437,12 +447,12 @@ extension MulPirServer {
437447 throw PirError
438448 . invalidDatabaseEntryCount ( entryCount: database. count, expected: parameter. entryCount)
439449 }
440- let maximumElementSize = database. map ( \. count) . max ( ) ?? 0
441- guard maximumElementSize <= parameter. entrySizeInBytes else {
450+ let maxEntrySize = database. map ( \. count) . max ( ) ?? 0
451+ guard maxEntrySize <= parameter. entrySizeInBytes else {
442452 throw PirError
443- . invalidDatabaseEntrySize ( maximumEntrySize: maximumElementSize , expected: parameter. entrySizeInBytes)
453+ . invalidDatabaseEntrySize ( maximumEntrySize: maxEntrySize , expected: parameter. entrySizeInBytes)
444454 }
445- let chunkCount = parameter. entrySizeInBytes . dividingCeil ( context. bytesPerPlaintext, variableTime: true )
455+ let chunkCount = parameter. encodedEntrySize . dividingCeil ( context. bytesPerPlaintext, variableTime: true )
446456 if chunkCount > 1 {
447457 return try await processSplitLargeEntries ( database: database, with: context, using: parameter)
448458 }
@@ -457,14 +467,22 @@ extension MulPirServer {
457467 {
458468 let chunkCount = Self . chunkCount ( parameter: parameter, context: context)
459469 var plaintexts : [ [ Plaintext < Scheme , Eval > ? ] ] = try await . init( database. async . map { entry in
460- try await . init( stride ( from: 0 , to: parameter. entrySizeInBytes, by: context. bytesPerPlaintext) . async
470+ let encoded = VarInt . encode ( UInt32 ( entry. count) )
471+ let entryEncodingSize = if parameter. encodingEntrySize { encoded. count } else { 0 }
472+ return try await . init( stride ( from: 0 , to: parameter. encodedEntrySize, by: context. bytesPerPlaintext) . async
461473 . map { startIndex in
462- let endIndex = min ( startIndex + context. bytesPerPlaintext, entry. count)
474+ let entryStartIndex = startIndex - entryEncodingSize
475+ let endIndex = min ( entryStartIndex + context. bytesPerPlaintext, entry. count)
463476 // Avoid computing on padding plaintexts
464- guard startIndex < endIndex else {
477+ guard entryStartIndex < endIndex else {
465478 return nil
466479 }
467- let bytes = Array ( entry [ startIndex..< endIndex] )
480+ let bytes = if startIndex == 0 , parameter. encodingEntrySize {
481+ encoded + entry[ 0 ..< endIndex]
482+ } else {
483+ Array ( entry [ entryStartIndex..< endIndex] )
484+ }
485+
468486 let coefficients : [ Scheme . Scalar ] = try CoefficientPacking . bytesToCoefficients (
469487 bytes: bytes,
470488 bitsPerCoeff: context. plaintextModulus. log2,
@@ -504,12 +522,16 @@ extension MulPirServer {
504522 assert ( database. count == parameter. entryCount)
505523 let flatDatabase : [ UInt8 ] = database. flatMap { entry in
506524 var entry = entry
507- let pad = parameter. entrySizeInBytes - entry. count
525+ if parameter. encodingEntrySize {
526+ let encoded = VarInt . encode ( UInt32 ( entry. count) )
527+ entry = encoded + entry
528+ }
529+ let pad = parameter. encodedEntrySize - entry. count
508530 entry. append ( contentsOf: repeatElement ( 0 , count: pad) )
509531 return entry
510532 }
511- let entriesPerPlaintext = context. bytesPerPlaintext / parameter. entrySizeInBytes
512- let bytesPerPlaintext = entriesPerPlaintext * parameter. entrySizeInBytes
533+ let entriesPerPlaintext = context. bytesPerPlaintext / parameter. encodedEntrySize
534+ let bytesPerPlaintext = entriesPerPlaintext * parameter. encodedEntrySize
513535 let plaintextIndices = stride ( from: 0 , to: flatDatabase. count, by: bytesPerPlaintext)
514536 var plaintexts : [ Plaintext < Scheme , Eval > ? ] = try await . init( plaintextIndices. async
515537 . map { startIndex in
0 commit comments