|
43 | 43 | import org.bouncycastle.pqc.crypto.mldsa.MLDSAPublicKeyParameters;
|
44 | 44 | import org.bouncycastle.pqc.crypto.mlkem.MLKEMParameters;
|
45 | 45 | import org.bouncycastle.pqc.crypto.mlkem.MLKEMPrivateKeyParameters;
|
| 46 | +import org.bouncycastle.pqc.crypto.mlkem.MLKEMPublicKeyParameters; |
46 | 47 | import org.bouncycastle.pqc.crypto.newhope.NHPrivateKeyParameters;
|
47 | 48 | import org.bouncycastle.pqc.crypto.ntru.NTRUParameters;
|
48 | 49 | import org.bouncycastle.pqc.crypto.ntru.NTRUPrivateKeyParameters;
|
@@ -147,22 +148,12 @@ else if (algOID.on(BCObjectIdentifiers.sphincsPlus) || algOID.on(BCObjectIdentif
|
147 | 148 | return new SPHINCSPlusPrivateKeyParameters(spParams, ASN1OctetString.getInstance(obj).getOctets());
|
148 | 149 | }
|
149 | 150 | }
|
150 |
| - else if (Utils.shldsaParams.containsKey(algOID)) |
| 151 | + else if (Utils.slhdsaParams.containsKey(algOID)) |
151 | 152 | {
|
152 | 153 | SLHDSAParameters spParams = Utils.slhdsaParamsLookup(algOID);
|
| 154 | + ASN1OctetString slhdsaKey = parseOctetString(keyInfo.getPrivateKey(), spParams.getN() * 4); |
153 | 155 |
|
154 |
| - ASN1Encodable obj = keyInfo.parsePrivateKey(); |
155 |
| - if (obj instanceof ASN1Sequence) |
156 |
| - { |
157 |
| - SPHINCSPLUSPrivateKey spKey = SPHINCSPLUSPrivateKey.getInstance(obj); |
158 |
| - SPHINCSPLUSPublicKey publicKey = spKey.getPublicKey(); |
159 |
| - return new SLHDSAPrivateKeyParameters(spParams, spKey.getSkseed(), spKey.getSkprf(), |
160 |
| - publicKey.getPkseed(), publicKey.getPkroot()); |
161 |
| - } |
162 |
| - else |
163 |
| - { |
164 |
| - return new SLHDSAPrivateKeyParameters(spParams, ASN1OctetString.getInstance(obj).getOctets()); |
165 |
| - } |
| 156 | + return new SLHDSAPrivateKeyParameters(spParams, slhdsaKey.getOctets()); |
166 | 157 | }
|
167 | 158 | else if (algOID.on(BCObjectIdentifiers.picnic))
|
168 | 159 | {
|
@@ -203,10 +194,37 @@ else if (algOID.equals(NISTObjectIdentifiers.id_alg_ml_kem_512) ||
|
203 | 194 | algOID.equals(NISTObjectIdentifiers.id_alg_ml_kem_768) ||
|
204 | 195 | algOID.equals(NISTObjectIdentifiers.id_alg_ml_kem_1024))
|
205 | 196 | {
|
206 |
| - ASN1OctetString kyberKey = ASN1OctetString.getInstance(keyInfo.parsePrivateKey()); |
207 |
| - MLKEMParameters kyberParams = Utils.mlkemParamsLookup(algOID); |
| 197 | + ASN1Primitive mlkemKey = parsePrimitiveString(keyInfo.getPrivateKey(), 64); |
| 198 | + MLKEMParameters mlkemParams = Utils.mlkemParamsLookup(algOID); |
208 | 199 |
|
209 |
| - return new MLKEMPrivateKeyParameters(kyberParams, kyberKey.getOctets()); |
| 200 | + MLKEMPublicKeyParameters pubParams = null; |
| 201 | + if (keyInfo.getPublicKeyData() != null) |
| 202 | + { |
| 203 | + pubParams = PublicKeyFactory.MLKEMConverter.getPublicKeyParams(mlkemParams, keyInfo.getPublicKeyData()); |
| 204 | + } |
| 205 | + |
| 206 | + if (mlkemKey instanceof ASN1OctetString) |
| 207 | + { |
| 208 | + // TODO This should be explicitly EXPANDED_KEY or SEED (tag already removed) but is length-flexible |
| 209 | + return new MLKEMPrivateKeyParameters(mlkemParams, ((ASN1OctetString)mlkemKey).getOctets(), pubParams); |
| 210 | + } |
| 211 | + else if (mlkemKey instanceof ASN1Sequence) |
| 212 | + { |
| 213 | + ASN1Sequence keySeq = (ASN1Sequence)mlkemKey; |
| 214 | + byte[] seed = ASN1OctetString.getInstance(keySeq.getObjectAt(0)).getOctets(); |
| 215 | + byte[] encoding = ASN1OctetString.getInstance(keySeq.getObjectAt(1)).getOctets(); |
| 216 | + |
| 217 | + // TODO This should only allow seed but is length-flexible |
| 218 | + MLKEMPrivateKeyParameters mlkemPriv = new MLKEMPrivateKeyParameters(mlkemParams, seed, pubParams); |
| 219 | + if (!Arrays.constantTimeAreEqual(mlkemPriv.getEncoded(), encoding)) |
| 220 | + { |
| 221 | + throw new IllegalArgumentException("inconsistent " + mlkemParams.getName() + " private key"); |
| 222 | + } |
| 223 | + |
| 224 | + return mlkemPriv; |
| 225 | + } |
| 226 | + |
| 227 | + throw new IllegalArgumentException("invalid " + mlkemParams.getName() + " private key"); |
210 | 228 | }
|
211 | 229 | else if (algOID.on(BCObjectIdentifiers.pqc_kem_ntrulprime))
|
212 | 230 | {
|
@@ -235,58 +253,37 @@ else if (algOID.on(BCObjectIdentifiers.pqc_kem_sntruprime))
|
235 | 253 | }
|
236 | 254 | else if (Utils.mldsaParams.containsKey(algOID))
|
237 | 255 | {
|
238 |
| - ASN1Encodable keyObj = keyInfo.parsePrivateKey(); |
239 |
| - MLDSAParameters spParams = Utils.mldsaParamsLookup(algOID); |
| 256 | + ASN1Encodable mldsaKey = parsePrimitiveString(keyInfo.getPrivateKey(), 32); |
| 257 | + MLDSAParameters mldsaParams = Utils.mldsaParamsLookup(algOID); |
240 | 258 |
|
241 |
| - if (keyObj instanceof ASN1Sequence) |
| 259 | + MLDSAPublicKeyParameters pubParams = null; |
| 260 | + if (keyInfo.getPublicKeyData() != null) |
242 | 261 | {
|
243 |
| - ASN1Sequence keyEnc = ASN1Sequence.getInstance(keyObj); |
244 |
| - |
245 |
| - int version = ASN1Integer.getInstance(keyEnc.getObjectAt(0)).intValueExact(); |
246 |
| - if (version != 0) |
247 |
| - { |
248 |
| - throw new IOException("unknown private key version: " + version); |
249 |
| - } |
250 |
| - |
251 |
| - if (keyInfo.getPublicKeyData() != null) |
252 |
| - { |
253 |
| - MLDSAPublicKeyParameters pubParams = PublicKeyFactory.MLDSAConverter.getPublicKeyParams(spParams, keyInfo.getPublicKeyData()); |
| 262 | + pubParams = PublicKeyFactory.MLDSAConverter.getPublicKeyParams(mldsaParams, keyInfo.getPublicKeyData()); |
| 263 | + } |
254 | 264 |
|
255 |
| - return new MLDSAPrivateKeyParameters(spParams, |
256 |
| - ASN1BitString.getInstance(keyEnc.getObjectAt(1)).getOctets(), |
257 |
| - ASN1BitString.getInstance(keyEnc.getObjectAt(2)).getOctets(), |
258 |
| - ASN1BitString.getInstance(keyEnc.getObjectAt(3)).getOctets(), |
259 |
| - ASN1BitString.getInstance(keyEnc.getObjectAt(4)).getOctets(), |
260 |
| - ASN1BitString.getInstance(keyEnc.getObjectAt(5)).getOctets(), |
261 |
| - ASN1BitString.getInstance(keyEnc.getObjectAt(6)).getOctets(), |
262 |
| - pubParams.getT1()); // encT1 |
263 |
| - } |
264 |
| - else |
265 |
| - { |
266 |
| - return new MLDSAPrivateKeyParameters(spParams, |
267 |
| - ASN1BitString.getInstance(keyEnc.getObjectAt(1)).getOctets(), |
268 |
| - ASN1BitString.getInstance(keyEnc.getObjectAt(2)).getOctets(), |
269 |
| - ASN1BitString.getInstance(keyEnc.getObjectAt(3)).getOctets(), |
270 |
| - ASN1BitString.getInstance(keyEnc.getObjectAt(4)).getOctets(), |
271 |
| - ASN1BitString.getInstance(keyEnc.getObjectAt(5)).getOctets(), |
272 |
| - ASN1BitString.getInstance(keyEnc.getObjectAt(6)).getOctets(), |
273 |
| - null); |
274 |
| - } |
| 265 | + if (mldsaKey instanceof ASN1OctetString) |
| 266 | + { |
| 267 | + // TODO This should be explicitly EXPANDED_KEY or SEED (tag already removed) but is length-flexible |
| 268 | + return new MLDSAPrivateKeyParameters(mldsaParams, ((ASN1OctetString)mldsaKey).getOctets(), pubParams); |
275 | 269 | }
|
276 |
| - else if (keyObj instanceof DEROctetString) |
| 270 | + else if (mldsaKey instanceof ASN1Sequence) |
277 | 271 | {
|
278 |
| - byte[] data = ASN1OctetString.getInstance(keyObj).getOctets(); |
279 |
| - if (keyInfo.getPublicKeyData() != null) |
| 272 | + ASN1Sequence keySeq = (ASN1Sequence)mldsaKey; |
| 273 | + byte[] seed = ASN1OctetString.getInstance(keySeq.getObjectAt(0)).getOctets(); |
| 274 | + byte[] encoding = ASN1OctetString.getInstance(keySeq.getObjectAt(1)).getOctets(); |
| 275 | + |
| 276 | + // TODO This should only allow seed but is length-flexible |
| 277 | + MLDSAPrivateKeyParameters mldsaPriv = new MLDSAPrivateKeyParameters(mldsaParams, seed, pubParams); |
| 278 | + if (!Arrays.constantTimeAreEqual(mldsaPriv.getEncoded(), encoding)) |
280 | 279 | {
|
281 |
| - MLDSAPublicKeyParameters pubParams = PublicKeyFactory.MLDSAConverter.getPublicKeyParams(spParams, keyInfo.getPublicKeyData()); |
282 |
| - return new MLDSAPrivateKeyParameters(spParams, data, pubParams); |
| 280 | + throw new IllegalArgumentException("inconsistent " + mldsaParams.getName() + " private key"); |
283 | 281 | }
|
284 |
| - return new MLDSAPrivateKeyParameters(spParams, data); |
285 |
| - } |
286 |
| - else |
287 |
| - { |
288 |
| - throw new IOException("not supported"); |
| 282 | + |
| 283 | + return mldsaPriv; |
289 | 284 | }
|
| 285 | + |
| 286 | + throw new IllegalArgumentException("invalid " + mldsaParams.getName() + " private key"); |
290 | 287 | }
|
291 | 288 | else if (algOID.equals(BCObjectIdentifiers.dilithium2)
|
292 | 289 | || algOID.equals(BCObjectIdentifiers.dilithium3) || algOID.equals(BCObjectIdentifiers.dilithium5))
|
@@ -380,6 +377,66 @@ else if (algOID.equals(PQCObjectIdentifiers.mcElieceCca2))
|
380 | 377 | }
|
381 | 378 | }
|
382 | 379 |
|
| 380 | + /** |
| 381 | + * So it seems for the new PQC algorithms, there's a couple of approaches to what goes in the OCTET STRING |
| 382 | + */ |
| 383 | + private static ASN1OctetString parseOctetString(ASN1OctetString octStr, int expectedLength) |
| 384 | + throws IOException |
| 385 | + { |
| 386 | + byte[] data = octStr.getOctets(); |
| 387 | + // |
| 388 | + // it's the right length for a RAW encoding, just return it. |
| 389 | + // |
| 390 | + if (data.length == expectedLength) |
| 391 | + { |
| 392 | + return octStr; |
| 393 | + } |
| 394 | + |
| 395 | + // |
| 396 | + // possible internal OCTET STRING, possibly long form with or without the internal OCTET STRING |
| 397 | + ASN1OctetString obj = Utils.parseOctetData(data); |
| 398 | + |
| 399 | + if (obj != null) |
| 400 | + { |
| 401 | + return ASN1OctetString.getInstance(obj); |
| 402 | + } |
| 403 | + |
| 404 | + return octStr; |
| 405 | + } |
| 406 | + |
| 407 | + /** |
| 408 | + * So it seems for the new PQC algorithms, there's a couple of approaches to what goes in the OCTET STRING |
| 409 | + * and in this case there may also be SEQUENCE. |
| 410 | + */ |
| 411 | + private static ASN1Primitive parsePrimitiveString(ASN1OctetString octStr, int expectedLength) |
| 412 | + throws IOException |
| 413 | + { |
| 414 | + byte[] data = octStr.getOctets(); |
| 415 | + // |
| 416 | + // it's the right length for a RAW encoding, just return it. |
| 417 | + // |
| 418 | + if (data.length == expectedLength) |
| 419 | + { |
| 420 | + return octStr; |
| 421 | + } |
| 422 | + |
| 423 | + // |
| 424 | + // possible internal OCTET STRING, possibly long form with or without the internal OCTET STRING |
| 425 | + // or possible SEQUENCE |
| 426 | + ASN1Encodable obj = Utils.parseData(data); |
| 427 | + |
| 428 | + if (obj instanceof ASN1OctetString) |
| 429 | + { |
| 430 | + return ASN1OctetString.getInstance(obj); |
| 431 | + } |
| 432 | + if (obj instanceof ASN1Sequence) |
| 433 | + { |
| 434 | + return ASN1Sequence.getInstance(obj); |
| 435 | + } |
| 436 | + |
| 437 | + return octStr; |
| 438 | + } |
| 439 | + |
383 | 440 | private static short[] convert(byte[] octets)
|
384 | 441 | {
|
385 | 442 | short[] rv = new short[octets.length / 2];
|
|
0 commit comments