|
43 | 43 | import org.bouncycastle.pqc.crypto.mldsa.MLDSAPublicKeyParameters; |
44 | 44 | import org.bouncycastle.pqc.crypto.mlkem.MLKEMParameters; |
45 | 45 | import org.bouncycastle.pqc.crypto.mlkem.MLKEMPrivateKeyParameters; |
| 46 | +import org.bouncycastle.pqc.crypto.mlkem.MLKEMPublicKeyParameters; |
46 | 47 | import org.bouncycastle.pqc.crypto.newhope.NHPrivateKeyParameters; |
47 | 48 | import org.bouncycastle.pqc.crypto.ntru.NTRUParameters; |
48 | 49 | import org.bouncycastle.pqc.crypto.ntru.NTRUPrivateKeyParameters; |
@@ -147,22 +148,12 @@ else if (algOID.on(BCObjectIdentifiers.sphincsPlus) || algOID.on(BCObjectIdentif |
147 | 148 | return new SPHINCSPlusPrivateKeyParameters(spParams, ASN1OctetString.getInstance(obj).getOctets()); |
148 | 149 | } |
149 | 150 | } |
150 | | - else if (Utils.shldsaParams.containsKey(algOID)) |
| 151 | + else if (Utils.slhdsaParams.containsKey(algOID)) |
151 | 152 | { |
152 | 153 | SLHDSAParameters spParams = Utils.slhdsaParamsLookup(algOID); |
| 154 | + ASN1OctetString slhdsaKey = parseOctetString(keyInfo.getPrivateKey(), spParams.getN() * 4); |
153 | 155 |
|
154 | | - ASN1Encodable obj = keyInfo.parsePrivateKey(); |
155 | | - if (obj instanceof ASN1Sequence) |
156 | | - { |
157 | | - SPHINCSPLUSPrivateKey spKey = SPHINCSPLUSPrivateKey.getInstance(obj); |
158 | | - SPHINCSPLUSPublicKey publicKey = spKey.getPublicKey(); |
159 | | - return new SLHDSAPrivateKeyParameters(spParams, spKey.getSkseed(), spKey.getSkprf(), |
160 | | - publicKey.getPkseed(), publicKey.getPkroot()); |
161 | | - } |
162 | | - else |
163 | | - { |
164 | | - return new SLHDSAPrivateKeyParameters(spParams, ASN1OctetString.getInstance(obj).getOctets()); |
165 | | - } |
| 156 | + return new SLHDSAPrivateKeyParameters(spParams, slhdsaKey.getOctets()); |
166 | 157 | } |
167 | 158 | else if (algOID.on(BCObjectIdentifiers.picnic)) |
168 | 159 | { |
@@ -203,10 +194,37 @@ else if (algOID.equals(NISTObjectIdentifiers.id_alg_ml_kem_512) || |
203 | 194 | algOID.equals(NISTObjectIdentifiers.id_alg_ml_kem_768) || |
204 | 195 | algOID.equals(NISTObjectIdentifiers.id_alg_ml_kem_1024)) |
205 | 196 | { |
206 | | - ASN1OctetString kyberKey = ASN1OctetString.getInstance(keyInfo.parsePrivateKey()); |
207 | | - MLKEMParameters kyberParams = Utils.mlkemParamsLookup(algOID); |
| 197 | + ASN1Primitive mlkemKey = parsePrimitiveString(keyInfo.getPrivateKey(), 64); |
| 198 | + MLKEMParameters mlkemParams = Utils.mlkemParamsLookup(algOID); |
208 | 199 |
|
209 | | - return new MLKEMPrivateKeyParameters(kyberParams, kyberKey.getOctets()); |
| 200 | + MLKEMPublicKeyParameters pubParams = null; |
| 201 | + if (keyInfo.getPublicKeyData() != null) |
| 202 | + { |
| 203 | + pubParams = PublicKeyFactory.MLKEMConverter.getPublicKeyParams(mlkemParams, keyInfo.getPublicKeyData()); |
| 204 | + } |
| 205 | + |
| 206 | + if (mlkemKey instanceof ASN1OctetString) |
| 207 | + { |
| 208 | + // TODO This should be explicitly EXPANDED_KEY or SEED (tag already removed) but is length-flexible |
| 209 | + return new MLKEMPrivateKeyParameters(mlkemParams, ((ASN1OctetString)mlkemKey).getOctets(), pubParams); |
| 210 | + } |
| 211 | + else if (mlkemKey instanceof ASN1Sequence) |
| 212 | + { |
| 213 | + ASN1Sequence keySeq = (ASN1Sequence)mlkemKey; |
| 214 | + byte[] seed = ASN1OctetString.getInstance(keySeq.getObjectAt(0)).getOctets(); |
| 215 | + byte[] encoding = ASN1OctetString.getInstance(keySeq.getObjectAt(1)).getOctets(); |
| 216 | + |
| 217 | + // TODO This should only allow seed but is length-flexible |
| 218 | + MLKEMPrivateKeyParameters mlkemPriv = new MLKEMPrivateKeyParameters(mlkemParams, seed, pubParams); |
| 219 | + if (!Arrays.constantTimeAreEqual(mlkemPriv.getEncoded(), encoding)) |
| 220 | + { |
| 221 | + throw new IllegalArgumentException("inconsistent " + mlkemParams.getName() + " private key"); |
| 222 | + } |
| 223 | + |
| 224 | + return mlkemPriv; |
| 225 | + } |
| 226 | + |
| 227 | + throw new IllegalArgumentException("invalid " + mlkemParams.getName() + " private key"); |
210 | 228 | } |
211 | 229 | else if (algOID.on(BCObjectIdentifiers.pqc_kem_ntrulprime)) |
212 | 230 | { |
@@ -235,58 +253,37 @@ else if (algOID.on(BCObjectIdentifiers.pqc_kem_sntruprime)) |
235 | 253 | } |
236 | 254 | else if (Utils.mldsaParams.containsKey(algOID)) |
237 | 255 | { |
238 | | - ASN1Encodable keyObj = keyInfo.parsePrivateKey(); |
239 | | - MLDSAParameters spParams = Utils.mldsaParamsLookup(algOID); |
| 256 | + ASN1Encodable mldsaKey = parsePrimitiveString(keyInfo.getPrivateKey(), 32); |
| 257 | + MLDSAParameters mldsaParams = Utils.mldsaParamsLookup(algOID); |
240 | 258 |
|
241 | | - if (keyObj instanceof ASN1Sequence) |
| 259 | + MLDSAPublicKeyParameters pubParams = null; |
| 260 | + if (keyInfo.getPublicKeyData() != null) |
242 | 261 | { |
243 | | - ASN1Sequence keyEnc = ASN1Sequence.getInstance(keyObj); |
244 | | - |
245 | | - int version = ASN1Integer.getInstance(keyEnc.getObjectAt(0)).intValueExact(); |
246 | | - if (version != 0) |
247 | | - { |
248 | | - throw new IOException("unknown private key version: " + version); |
249 | | - } |
250 | | - |
251 | | - if (keyInfo.getPublicKeyData() != null) |
252 | | - { |
253 | | - MLDSAPublicKeyParameters pubParams = PublicKeyFactory.MLDSAConverter.getPublicKeyParams(spParams, keyInfo.getPublicKeyData()); |
| 262 | + pubParams = PublicKeyFactory.MLDSAConverter.getPublicKeyParams(mldsaParams, keyInfo.getPublicKeyData()); |
| 263 | + } |
254 | 264 |
|
255 | | - return new MLDSAPrivateKeyParameters(spParams, |
256 | | - ASN1BitString.getInstance(keyEnc.getObjectAt(1)).getOctets(), |
257 | | - ASN1BitString.getInstance(keyEnc.getObjectAt(2)).getOctets(), |
258 | | - ASN1BitString.getInstance(keyEnc.getObjectAt(3)).getOctets(), |
259 | | - ASN1BitString.getInstance(keyEnc.getObjectAt(4)).getOctets(), |
260 | | - ASN1BitString.getInstance(keyEnc.getObjectAt(5)).getOctets(), |
261 | | - ASN1BitString.getInstance(keyEnc.getObjectAt(6)).getOctets(), |
262 | | - pubParams.getT1()); // encT1 |
263 | | - } |
264 | | - else |
265 | | - { |
266 | | - return new MLDSAPrivateKeyParameters(spParams, |
267 | | - ASN1BitString.getInstance(keyEnc.getObjectAt(1)).getOctets(), |
268 | | - ASN1BitString.getInstance(keyEnc.getObjectAt(2)).getOctets(), |
269 | | - ASN1BitString.getInstance(keyEnc.getObjectAt(3)).getOctets(), |
270 | | - ASN1BitString.getInstance(keyEnc.getObjectAt(4)).getOctets(), |
271 | | - ASN1BitString.getInstance(keyEnc.getObjectAt(5)).getOctets(), |
272 | | - ASN1BitString.getInstance(keyEnc.getObjectAt(6)).getOctets(), |
273 | | - null); |
274 | | - } |
| 265 | + if (mldsaKey instanceof ASN1OctetString) |
| 266 | + { |
| 267 | + // TODO This should be explicitly EXPANDED_KEY or SEED (tag already removed) but is length-flexible |
| 268 | + return new MLDSAPrivateKeyParameters(mldsaParams, ((ASN1OctetString)mldsaKey).getOctets(), pubParams); |
275 | 269 | } |
276 | | - else if (keyObj instanceof DEROctetString) |
| 270 | + else if (mldsaKey instanceof ASN1Sequence) |
277 | 271 | { |
278 | | - byte[] data = ASN1OctetString.getInstance(keyObj).getOctets(); |
279 | | - if (keyInfo.getPublicKeyData() != null) |
| 272 | + ASN1Sequence keySeq = (ASN1Sequence)mldsaKey; |
| 273 | + byte[] seed = ASN1OctetString.getInstance(keySeq.getObjectAt(0)).getOctets(); |
| 274 | + byte[] encoding = ASN1OctetString.getInstance(keySeq.getObjectAt(1)).getOctets(); |
| 275 | + |
| 276 | + // TODO This should only allow seed but is length-flexible |
| 277 | + MLDSAPrivateKeyParameters mldsaPriv = new MLDSAPrivateKeyParameters(mldsaParams, seed, pubParams); |
| 278 | + if (!Arrays.constantTimeAreEqual(mldsaPriv.getEncoded(), encoding)) |
280 | 279 | { |
281 | | - MLDSAPublicKeyParameters pubParams = PublicKeyFactory.MLDSAConverter.getPublicKeyParams(spParams, keyInfo.getPublicKeyData()); |
282 | | - return new MLDSAPrivateKeyParameters(spParams, data, pubParams); |
| 280 | + throw new IllegalArgumentException("inconsistent " + mldsaParams.getName() + " private key"); |
283 | 281 | } |
284 | | - return new MLDSAPrivateKeyParameters(spParams, data); |
285 | | - } |
286 | | - else |
287 | | - { |
288 | | - throw new IOException("not supported"); |
| 282 | + |
| 283 | + return mldsaPriv; |
289 | 284 | } |
| 285 | + |
| 286 | + throw new IllegalArgumentException("invalid " + mldsaParams.getName() + " private key"); |
290 | 287 | } |
291 | 288 | else if (algOID.equals(BCObjectIdentifiers.dilithium2) |
292 | 289 | || algOID.equals(BCObjectIdentifiers.dilithium3) || algOID.equals(BCObjectIdentifiers.dilithium5)) |
@@ -380,6 +377,66 @@ else if (algOID.equals(PQCObjectIdentifiers.mcElieceCca2)) |
380 | 377 | } |
381 | 378 | } |
382 | 379 |
|
| 380 | + /** |
| 381 | + * So it seems for the new PQC algorithms, there's a couple of approaches to what goes in the OCTET STRING |
| 382 | + */ |
| 383 | + private static ASN1OctetString parseOctetString(ASN1OctetString octStr, int expectedLength) |
| 384 | + throws IOException |
| 385 | + { |
| 386 | + byte[] data = octStr.getOctets(); |
| 387 | + // |
| 388 | + // it's the right length for a RAW encoding, just return it. |
| 389 | + // |
| 390 | + if (data.length == expectedLength) |
| 391 | + { |
| 392 | + return octStr; |
| 393 | + } |
| 394 | + |
| 395 | + // |
| 396 | + // possible internal OCTET STRING, possibly long form with or without the internal OCTET STRING |
| 397 | + ASN1OctetString obj = Utils.parseOctetData(data); |
| 398 | + |
| 399 | + if (obj != null) |
| 400 | + { |
| 401 | + return ASN1OctetString.getInstance(obj); |
| 402 | + } |
| 403 | + |
| 404 | + return octStr; |
| 405 | + } |
| 406 | + |
| 407 | + /** |
| 408 | + * So it seems for the new PQC algorithms, there's a couple of approaches to what goes in the OCTET STRING |
| 409 | + * and in this case there may also be SEQUENCE. |
| 410 | + */ |
| 411 | + private static ASN1Primitive parsePrimitiveString(ASN1OctetString octStr, int expectedLength) |
| 412 | + throws IOException |
| 413 | + { |
| 414 | + byte[] data = octStr.getOctets(); |
| 415 | + // |
| 416 | + // it's the right length for a RAW encoding, just return it. |
| 417 | + // |
| 418 | + if (data.length == expectedLength) |
| 419 | + { |
| 420 | + return octStr; |
| 421 | + } |
| 422 | + |
| 423 | + // |
| 424 | + // possible internal OCTET STRING, possibly long form with or without the internal OCTET STRING |
| 425 | + // or possible SEQUENCE |
| 426 | + ASN1Encodable obj = Utils.parseData(data); |
| 427 | + |
| 428 | + if (obj instanceof ASN1OctetString) |
| 429 | + { |
| 430 | + return ASN1OctetString.getInstance(obj); |
| 431 | + } |
| 432 | + if (obj instanceof ASN1Sequence) |
| 433 | + { |
| 434 | + return ASN1Sequence.getInstance(obj); |
| 435 | + } |
| 436 | + |
| 437 | + return octStr; |
| 438 | + } |
| 439 | + |
383 | 440 | private static short[] convert(byte[] octets) |
384 | 441 | { |
385 | 442 | short[] rv = new short[octets.length / 2]; |
|
0 commit comments