7
7
from typing import Iterable
8
8
9
9
from .decode import decode_float
10
- from .round import encode_float , round_float
10
+ from .round import encode_float , round_float , RoundMode
11
11
from .types import FormatInfo
12
12
13
13
@@ -78,7 +78,10 @@ def decode_block(fi: BlockFormatInfo, block: Iterable[int]) -> Iterable[float]:
78
78
79
79
80
80
def encode_block (
81
- fi : BlockFormatInfo , scale : float , vals : Iterable [float ]
81
+ fi : BlockFormatInfo ,
82
+ scale : float ,
83
+ vals : Iterable [float ],
84
+ round : RoundMode = RoundMode .TiesToEven ,
82
85
) -> Iterable [int ]:
83
86
"""
84
87
Encode a :paramref:`block` of bytes into block Format descibed by :paramref:`fi`
@@ -93,22 +96,27 @@ def encode_block(
93
96
fi (BlockFormatInfo): Describes the target block format
94
97
scale (float): Scale to be recorded in the block
95
98
vals (Iterable[float]): Input block
99
+ round (RoundMode): Rounding mode to use, defaults to `TiesToEven`
96
100
97
101
Returns:
98
102
A sequence of ints representing the encoded values.
99
103
100
104
Raises:
101
105
ValueError: The scale overflows the target scale encoding format.
102
106
"""
103
- # TODO: this should not do any multiplication - the scale is to be recorded not applied.
107
+
108
+ # TODO: this should really not do any multiplication -
109
+ # the scale is to be recorded not applied.
104
110
recip_scale = 1 / scale
105
111
scale = 1 / recip_scale
106
112
107
113
if scale > fi .stype .max :
108
114
raise ValueError (f"Scaled { scale } too large for { fi .stype } " )
109
115
116
+ sat = True # Saturate elements if out of range
117
+
110
118
def enc (ty : FormatInfo , x : float ) -> int :
111
- return encode_float (ty , round_float (ty , x ))
119
+ return encode_float (ty , round_float (ty , x , round , sat ))
112
120
113
121
yield enc (fi .stype , scale )
114
122
0 commit comments