Skip to content

Commit faf2dd1

Browse files
mingulovdavidvincze
authored andcommitted
imgtool: fixed keys/general.py to pass existing unittests
keys.KeyClass._emit is able to use 'file' parameter not as a file but some object (not only sys.stdout but io.StringIO, like by tests). Fixed all explicit checks for sys.stdio usage in favor of io.TextIOBase, also improve a single unit test to cover also all the changed methods. Signed-off-by: Denis Mingulov <[email protected]>
1 parent 1202604 commit faf2dd1

File tree

2 files changed

+56
-21
lines changed

2 files changed

+56
-21
lines changed

scripts/imgtool/keys/general.py

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,39 @@
22

33
# SPDX-License-Identifier: Apache-2.0
44

5+
import binascii
6+
import io
7+
import os
58
import sys
69
from cryptography.hazmat.primitives.hashes import Hash, SHA256
710

811
AUTOGEN_MESSAGE = "/* Autogenerated by imgtool.py, do not edit. */"
912

1013

14+
class FileHandler(object):
15+
def __init__(self, file, *args, **kwargs):
16+
self.file_in = file
17+
self.args = args
18+
self.kwargs = kwargs
19+
20+
def __enter__(self):
21+
if isinstance(self.file_in, (str, bytes, os.PathLike)):
22+
self.file = open(self.file_in, *self.args, **self.kwargs)
23+
else:
24+
self.file = self.file_in
25+
return self.file
26+
27+
def __exit__(self, *args):
28+
if self.file != self.file_in:
29+
self.file.close()
30+
31+
1132
class KeyClass(object):
1233
def _emit(self, header, trailer, encoded_bytes, indent, file=sys.stdout,
1334
len_format=None):
14-
if file and file is not sys.stdout:
15-
with open(file, 'w') as file:
16-
self._emit_to_output(header, trailer, encoded_bytes, indent,
17-
file, len_format)
18-
else:
35+
with FileHandler(file, 'w') as file:
1936
self._emit_to_output(header, trailer, encoded_bytes, indent,
20-
sys.stdout, len_format)
37+
file, len_format)
2138

2239
def _emit_to_output(self, header, trailer, encoded_bytes, indent, file,
2340
len_format):
@@ -33,6 +50,16 @@ def _emit_to_output(self, header, trailer, encoded_bytes, indent, file,
3350
if len_format is not None:
3451
print(len_format.format(len(encoded_bytes)), file=file)
3552

53+
def _emit_raw(self, encoded_bytes, file):
54+
with FileHandler(file, 'wb') as file:
55+
try:
56+
# file.buffer is not part of the TextIOBase API
57+
# and may not exist in some implementations.
58+
file.buffer.write(encoded_bytes)
59+
except AttributeError:
60+
# raw binary data, can be for example io.BytesIO
61+
file.write(encoded_bytes)
62+
3663
def emit_c_public(self, file=sys.stdout):
3764
self._emit(
3865
header="const unsigned char {}_pub_key[] = {{"
@@ -58,20 +85,12 @@ def emit_c_public_hash(self, file=sys.stdout):
5885
file=file)
5986

6087
def emit_raw_public(self, file=sys.stdout):
61-
if file and file is not sys.stdout:
62-
with open(file, 'wb') as file:
63-
file.write(self.get_public_bytes())
64-
else:
65-
sys.stdout.buffer.write(self.get_public_bytes())
88+
self._emit_raw(self.get_public_bytes(), file=file)
6689

6790
def emit_raw_public_hash(self, file=sys.stdout):
6891
digest = Hash(SHA256())
6992
digest.update(self.get_public_bytes())
70-
if file and file is not sys.stdout:
71-
with open(file, 'wb') as file:
72-
file.write(digest.finalize())
73-
else:
74-
sys.stdout.buffer.write(digest.finalize())
93+
self._emit_raw(digest.finalize(), file=file)
7594

7695
def emit_rust_public(self, file=sys.stdout):
7796
self._emit(
@@ -83,11 +102,8 @@ def emit_rust_public(self, file=sys.stdout):
83102
file=file)
84103

85104
def emit_public_pem(self, file=sys.stdout):
86-
if file and file is not sys.stdout:
87-
with open(file, 'w') as file:
88-
print(str(self.get_public_pem(), 'utf-8'), file=file, end='')
89-
else:
90-
print(str(self.get_public_pem(), 'utf-8'), file=sys.stdout, end='')
105+
with FileHandler(file, 'w') as file:
106+
print(str(self.get_public_pem(), 'utf-8'), file=file, end='')
91107

92108
def emit_private(self, minimal, format, file=sys.stdout):
93109
self._emit(

scripts/imgtool/keys/rsa_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,34 @@ def test_emit(self):
6363
for key_size in RSA_KEY_SIZES:
6464
k = RSA.generate(key_size=key_size)
6565

66+
pubpem = io.StringIO()
67+
k.emit_public_pem(pubpem)
68+
self.assertIn("BEGIN PUBLIC KEY", pubpem.getvalue())
69+
self.assertIn("END PUBLIC KEY", pubpem.getvalue())
70+
6671
ccode = io.StringIO()
6772
k.emit_c_public(ccode)
6873
self.assertIn("rsa_pub_key", ccode.getvalue())
6974
self.assertIn("rsa_pub_key_len", ccode.getvalue())
7075

76+
hashccode = io.StringIO()
77+
k.emit_c_public_hash(hashccode)
78+
self.assertIn("rsa_pub_key_hash", hashccode.getvalue())
79+
self.assertIn("rsa_pub_key_hash_len", hashccode.getvalue())
80+
7181
rustcode = io.StringIO()
7282
k.emit_rust_public(rustcode)
7383
self.assertIn("RSA_PUB_KEY", rustcode.getvalue())
7484

85+
# raw data - bytes
86+
pubraw = io.BytesIO()
87+
k.emit_raw_public(pubraw)
88+
self.assertTrue(len(pubraw.getvalue()) > 0)
89+
90+
hashraw = io.BytesIO()
91+
k.emit_raw_public_hash(hashraw)
92+
self.assertTrue(len(hashraw.getvalue()) > 0)
93+
7594
def test_emit_pub(self):
7695
"""Basic sanity check on the code emitters, from public key."""
7796
pubname = self.tname("public.pem")

0 commit comments

Comments
 (0)