Skip to content

Commit 9ae6e80

Browse files
committed
Add support for formats without infinities
1 parent 99665da commit 9ae6e80

File tree

6 files changed

+222
-158
lines changed

6 files changed

+222
-158
lines changed

mex/cpfloat.c

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ void mexFunction(int nlhs,
4242
fpopts->emin = -14;
4343
fpopts->emax = 15;
4444
fpopts->explim = CPFLOAT_EXPRANGE_TARG;
45+
fpopts->infinity = CPFLOAT_INF_USE;
4546
fpopts->round = CPFLOAT_RND_NE;
4647
fpopts->saturation = CPFLOAT_SAT_NO;
4748
fpopts->subnormal = CPFLOAT_SUBN_USE;
@@ -57,6 +58,7 @@ void mexFunction(int nlhs,
5758
/* Parse second argument and populate fpopts structure. */
5859
if (nrhs > 1) {
5960
bool is_subn_rnd_default = false;
61+
bool is_inf_no_default = false;
6062
if(!mxIsEmpty(prhs[1]) && !mxIsStruct(prhs[1])) {
6163
mexErrMsgIdAndTxt("cpfloat:invalidstruct",
6264
"Second argument must be a struct.");
@@ -83,6 +85,7 @@ void mexFunction(int nlhs,
8385
fpopts->precision = 4;
8486
fpopts->emin = -6;
8587
fpopts->emax = 8;
88+
is_inf_no_default = true;
8689
} else if (!strcmp(fpopts->format, "q52") ||
8790
!strcmp(fpopts->format, "fp8-e5m2") ||
8891
!strcmp(fpopts->format, "E5M2")) {
@@ -137,6 +140,7 @@ void mexFunction(int nlhs,
137140
mexErrMsgIdAndTxt("cpfloat:invalidformat",
138141
"Invalid floating-point format specified.");
139142
}
143+
140144
/* Set default values to be compatible with MATLAB chop. */
141145
tmp = mxGetField(prhs[1], 0, "subnormal");
142146
if (tmp != NULL) {
@@ -147,32 +151,43 @@ void mexFunction(int nlhs,
147151
} else {
148152
if (is_subn_rnd_default)
149153
fpopts->subnormal = CPFLOAT_SUBN_RND; /* Default for bfloat16. */
150-
else
151-
fpopts->subnormal = CPFLOAT_SUBN_USE;
152154
}
155+
153156
tmp = mxGetField(prhs[1], 0, "explim");
154157
if (tmp != NULL) {
155158
if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0)
156159
fpopts->explim = 1;
157160
else if (mxGetClassID(tmp) == mxDOUBLE_CLASS)
158161
fpopts->explim = *((double *)mxGetData(tmp));
159162
}
163+
164+
tmp = mxGetField(prhs[1], 0, "infinity");
165+
if (tmp != NULL) {
166+
if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0)
167+
fpopts->infinity = CPFLOAT_INF_USE;
168+
else if (mxGetClassID(tmp) == mxDOUBLE_CLASS)
169+
fpopts->infinity = *((double *)mxGetData(tmp));
170+
} else {
171+
if (is_inf_no_default)
172+
fpopts->infinity = CPFLOAT_INF_NO; /* Default for E4M5. */
173+
}
174+
160175
tmp = mxGetField(prhs[1], 0, "round");
161176
if (tmp != NULL) {
162177
if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0)
163178
fpopts->round = CPFLOAT_RND_NE;
164179
else if (mxGetClassID(tmp) == mxDOUBLE_CLASS)
165180
fpopts->round = *((double *)mxGetData(tmp));
166181
}
182+
167183
tmp = mxGetField(prhs[1], 0, "saturation");
168184
if (tmp != NULL) {
169185
if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0)
170186
fpopts->saturation = CPFLOAT_SAT_NO;
171187
else if (mxGetClassID(tmp) == mxDOUBLE_CLASS)
172188
fpopts->saturation = *((double *)mxGetData(tmp));
173-
} else {
174-
fpopts->saturation = CPFLOAT_SAT_NO;
175189
}
190+
176191
tmp = mxGetField(prhs[1], 0, "subnormal");
177192
if (tmp != NULL) {
178193
if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0)
@@ -313,11 +328,11 @@ void mexFunction(int nlhs,
313328

314329
/* Allocate and return second output. */
315330
if (nlhs > 1) {
316-
const char* field_names[] = {"format", "params", "explim",
331+
const char* field_names[] = {"format", "params", "explim", "infinity",
317332
"round", "saturation", "subnormal",
318333
"flip", "p"};
319334
mwSize dims[2] = {1, 1};
320-
plhs[1] = mxCreateStructArray(2, dims, 8, field_names);
335+
plhs[1] = mxCreateStructArray(2, dims, 9, field_names);
321336
mxSetFieldByNumber(plhs[1], 0, 0, mxCreateString(fpopts->format));
322337

323338
mxArray *outparams = mxCreateDoubleMatrix(1,3,mxREAL);
@@ -332,30 +347,35 @@ void mexFunction(int nlhs,
332347
outexplimptr[0] = fpopts->explim;
333348
mxSetFieldByNumber(plhs[1], 0, 2, outexplim);
334349

350+
mxArray *outinfinity = mxCreateDoubleMatrix(1, 1, mxREAL);
351+
double *outinfinityptr = mxGetData(outinfinity);
352+
outinfinityptr[0] = fpopts->infinity;
353+
mxSetFieldByNumber(plhs[1], 0, 3, outinfinity);
354+
335355
mxArray *outround = mxCreateDoubleMatrix(1,1,mxREAL);
336356
double *outroundptr = mxGetData(outround);
337357
outroundptr[0] = fpopts->round;
338-
mxSetFieldByNumber(plhs[1], 0, 3, outround);
358+
mxSetFieldByNumber(plhs[1], 0, 4, outround);
339359

340360
mxArray *outsaturation = mxCreateDoubleMatrix(1,1,mxREAL);
341361
double *outsaturationptr = mxGetData(outsaturation);
342362
outsaturationptr[0] = fpopts->saturation;
343-
mxSetFieldByNumber(plhs[1], 0, 4, outsaturation);
363+
mxSetFieldByNumber(plhs[1], 0, 5, outsaturation);
344364

345365
mxArray *outsubnormal = mxCreateDoubleMatrix(1,1,mxREAL);
346366
double *outsubnormalptr = mxGetData(outsubnormal);
347367
outsubnormalptr[0] = fpopts->subnormal;
348-
mxSetFieldByNumber(plhs[1], 0, 5, outsubnormal);
368+
mxSetFieldByNumber(plhs[1], 0, 6, outsubnormal);
349369

350370
mxArray *outflip = mxCreateDoubleMatrix(1,1,mxREAL);
351371
double *outflipptr = mxGetData(outflip);
352372
outflipptr[0] = fpopts->flip;
353-
mxSetFieldByNumber(plhs[1], 0, 6, outflip);
373+
mxSetFieldByNumber(plhs[1], 0, 7, outflip);
354374

355375
mxArray *outp = mxCreateDoubleMatrix(1,1,mxREAL);
356376
double *outpptr = mxGetData(outp);
357377
outpptr[0] = fpopts->p;
358-
mxSetFieldByNumber(plhs[1], 0, 7, outp);
378+
mxSetFieldByNumber(plhs[1], 0, 8, outp);
359379

360380
}
361381
if (nlhs > 2)

mex/cpfloat.m

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@
4545
% this field is set to 0, and the exponent range of the format specified in
4646
% FPOPTS.format otherwise. The default value for this field is 1.
4747
%
48+
% * The scalar FPOPTS.infinity specifies whether infinities are supported. The
49+
% target floating-point format will support infinities if this field is set
50+
% to 1, and they will be replaced by NaNs otherwise. The default value for
51+
% this field is 0 if the target format is 'E4M3' and 1 otherwise.
52+
%
4853
% * The scalar FPOPTS.round specifies the rounding mode. Possible values are:
4954
% -1 for round-to-nearest with ties-to-away;
5055
% 0 for round-to-nearest with ties-to-zero;

src/cpfloat_definitions.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
* defines the enumerated types
1010
*
1111
* + @ref cpfloat_explim_t,
12+
* + @ref cpfloat_infinity_t,
1213
* + @ref cpfloat_rounding_t,
1314
* + @ref cpfloat_saturation_t,
1415
* + @ref cpfloat_softerr_t,
@@ -63,6 +64,16 @@ typedef enum {
6364
CPFLOAT_EXPRANGE_TARG = 1
6465
} cpfloat_explim_t;
6566

67+
/**
68+
* @brief Infinity support modes available in CPFloat.
69+
*/
70+
typedef enum {
71+
/** Use infinities in target format. */
72+
CPFLOAT_INF_NO = 0,
73+
/** Replace infinities with NaNs in target format. */
74+
CPFLOAT_INF_USE = 1,
75+
} cpfloat_infinity_t;
76+
6677
/**
6778
* @brief Rounding modes available in CPFloat.
6879
*/
@@ -234,6 +245,14 @@ typedef struct {
234245
* `CPFLOAT_EXPRANGE_STOR`.
235246
*/
236247
cpfloat_explim_t explim;
248+
/**
249+
* @brief Support for infinities in target format.
250+
*
251+
* @details If this field is set to `CPFLOAT_INF_USE`, the target format
252+
* supports signed infinities. If the field is set to `CPFLOAT_INF_NO`,
253+
* infinities are replaced with a quiet NaN.
254+
*/
255+
cpfloat_infinity_t infinity;
237256
/**
238257
* @brief Rounding mode to be used for the conversion.
239258
*

src/cpfloat_template.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ typedef struct {
160160
cpfloat_precision_t precision;
161161
cpfloat_exponent_t emin;
162162
cpfloat_exponent_t emax;
163+
cpfloat_infinity_t infinity;
163164
cpfloat_rounding_t round;
164165
cpfloat_saturation_t saturation;
165166
cpfloat_subnormal_t subnormal;
@@ -322,13 +323,19 @@ static inline FPPARAMS COMPUTE_GLOBAL_PARAMS(const optstruct *fpopts,
322323

323324
FPTYPE xmax = ldexp(1., emax) * (2-ldexp(1., 1-precision));
324325
FPTYPE xbnd = ldexp(1., emax) * (2-ldexp(1., -precision));
325-
FPTYPE ofvalue = (fpopts->saturation == CPFLOAT_SAT_USE) ? xmax : INFINITY;
326+
/*
327+
* Here, fpopts->saturation takes precedence over fpopts->infinity. Therefore,
328+
* when saturation arithmetic is used, infinities are not produced even when
329+
* the target format supports them.
330+
*/
331+
FPTYPE ofvalue = (fpopts->saturation == CPFLOAT_SAT_USE) ? xmax :
332+
(fpopts->infinity == CPFLOAT_INF_USE ? INFINITY : NAN);
326333

327334
/* Bitmasks. */
328335
INTTYPE leadmask = FULLMASK << (DEFPREC-precision); /* To keep. */
329336
INTTYPE trailmask = leadmask ^ FULLMASK; /* To discard. */
330337

331-
FPPARAMS params = {precision, emin, emax, fpopts->round,
338+
FPPARAMS params = {precision, emin, emax, fpopts->infinity, fpopts->round,
332339
fpopts->saturation, fpopts->subnormal,
333340
ftzthreshold, ofvalue, xmin, xmax, xbnd,
334341
leadmask, trailmask, NULL, NULL};
@@ -656,7 +663,9 @@ static inline void UPDATE_LOCAL_PARAMS(const FPTYPE *A,
656663
numelem, p, lp) \
657664
PARALLEL_STRING(PARALLEL) \
658665
{ \
659-
if (p->emax == DEFEMAX && p->saturation == CPFLOAT_SAT_NO) { \
666+
if (p->emax == DEFEMAX \
667+
&& p->saturation == CPFLOAT_SAT_NO \
668+
&& p->infinity == CPFLOAT_INF_USE) { \
660669
FOR_STRING(PARALLEL) \
661670
for (size_t i=0; i<numelem; i++) { \
662671
DEPARENTHESIZE_MAYBE(PREPROC) \

test/cpfloat_test.m

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,18 @@
126126
assert_eq(fp.subnormal,1)
127127
assert_eq(fp.params, [53 -1022 1023])
128128

129+
clear fp
130+
fp.format = 'E4M3';
131+
[~,options] = cpfloat(pi,fp);
132+
assert_eq(options.format,'E4M3')
133+
assert_eq(options.infinity,0)
134+
assert_eq(options.params, [4 -6 8])
135+
[~,fp] = cpfloat;
136+
assert_eq(fp.format,'E4M3')
137+
assert_eq(fp.infinity,0)
138+
assert_eq(fp.params, [4 -6 8])
139+
140+
129141
clear fp
130142
fp.format = 'bfloat16';
131143
[~,options] = cpfloat(pi,fp);
@@ -323,6 +335,9 @@
323335
end
324336

325337
% Infinities tests.
338+
[~,fpopts] = cpfloat;
339+
prev_infinity = fpopts.infinity;
340+
options.infinity = 1;
326341
options.saturation = 0;
327342
for j = 1:6
328343
options.round = j;
@@ -400,6 +415,7 @@
400415
c = cpfloat(x,options);
401416
c_expected = [0 0 x(3:5) inf 1 1];
402417
assert_eq(c,c_expected)
418+
options.infinity = prev_infinity;
403419

404420
% Smallest normal number and spacing between the subnormal numbers.
405421
y = xmin; delta = xmin*2^(1-p);

0 commit comments

Comments
 (0)