@@ -42,6 +42,7 @@ void mexFunction(int nlhs,
4242 fpopts -> emin = -14 ;
4343 fpopts -> emax = 15 ;
4444 fpopts -> explim = CPFLOAT_EXPRANGE_TARG ;
45+ fpopts -> infinity = CPFLOAT_INF_USE ;
4546 fpopts -> round = CPFLOAT_RND_NE ;
4647 fpopts -> saturation = CPFLOAT_SAT_NO ;
4748 fpopts -> subnormal = CPFLOAT_SUBN_USE ;
@@ -57,6 +58,7 @@ void mexFunction(int nlhs,
5758 /* Parse second argument and populate fpopts structure. */
5859 if (nrhs > 1 ) {
5960 bool is_subn_rnd_default = false;
61+ bool is_inf_no_default = false;
6062 if (!mxIsEmpty (prhs [1 ]) && !mxIsStruct (prhs [1 ])) {
6163 mexErrMsgIdAndTxt ("cpfloat:invalidstruct" ,
6264 "Second argument must be a struct." );
@@ -83,6 +85,7 @@ void mexFunction(int nlhs,
8385 fpopts -> precision = 4 ;
8486 fpopts -> emin = -6 ;
8587 fpopts -> emax = 8 ;
88+ is_inf_no_default = true;
8689 } else if (!strcmp (fpopts -> format , "q52" ) ||
8790 !strcmp (fpopts -> format , "fp8-e5m2" ) ||
8891 !strcmp (fpopts -> format , "E5M2" )) {
@@ -137,6 +140,7 @@ void mexFunction(int nlhs,
137140 mexErrMsgIdAndTxt ("cpfloat:invalidformat" ,
138141 "Invalid floating-point format specified." );
139142 }
143+
140144 /* Set default values to be compatible with MATLAB chop. */
141145 tmp = mxGetField (prhs [1 ], 0 , "subnormal" );
142146 if (tmp != NULL ) {
@@ -147,32 +151,43 @@ void mexFunction(int nlhs,
147151 } else {
148152 if (is_subn_rnd_default )
149153 fpopts -> subnormal = CPFLOAT_SUBN_RND ; /* Default for bfloat16. */
150- else
151- fpopts -> subnormal = CPFLOAT_SUBN_USE ;
152154 }
155+
153156 tmp = mxGetField (prhs [1 ], 0 , "explim" );
154157 if (tmp != NULL ) {
155158 if (mxGetM (tmp ) == 0 && mxGetN (tmp ) == 0 )
156159 fpopts -> explim = 1 ;
157160 else if (mxGetClassID (tmp ) == mxDOUBLE_CLASS )
158161 fpopts -> explim = * ((double * )mxGetData (tmp ));
159162 }
163+
164+ tmp = mxGetField (prhs [1 ], 0 , "infinity" );
165+ if (tmp != NULL ) {
166+ if (mxGetM (tmp ) == 0 && mxGetN (tmp ) == 0 )
167+ fpopts -> infinity = CPFLOAT_INF_USE ;
168+ else if (mxGetClassID (tmp ) == mxDOUBLE_CLASS )
169+ fpopts -> infinity = * ((double * )mxGetData (tmp ));
170+ } else {
171+ if (is_inf_no_default )
172+ fpopts -> infinity = CPFLOAT_INF_NO ; /* Default for E4M5. */
173+ }
174+
160175 tmp = mxGetField (prhs [1 ], 0 , "round" );
161176 if (tmp != NULL ) {
162177 if (mxGetM (tmp ) == 0 && mxGetN (tmp ) == 0 )
163178 fpopts -> round = CPFLOAT_RND_NE ;
164179 else if (mxGetClassID (tmp ) == mxDOUBLE_CLASS )
165180 fpopts -> round = * ((double * )mxGetData (tmp ));
166181 }
182+
167183 tmp = mxGetField (prhs [1 ], 0 , "saturation" );
168184 if (tmp != NULL ) {
169185 if (mxGetM (tmp ) == 0 && mxGetN (tmp ) == 0 )
170186 fpopts -> saturation = CPFLOAT_SAT_NO ;
171187 else if (mxGetClassID (tmp ) == mxDOUBLE_CLASS )
172188 fpopts -> saturation = * ((double * )mxGetData (tmp ));
173- } else {
174- fpopts -> saturation = CPFLOAT_SAT_NO ;
175189 }
190+
176191 tmp = mxGetField (prhs [1 ], 0 , "subnormal" );
177192 if (tmp != NULL ) {
178193 if (mxGetM (tmp ) == 0 && mxGetN (tmp ) == 0 )
@@ -313,11 +328,11 @@ void mexFunction(int nlhs,
313328
314329 /* Allocate and return second output. */
315330 if (nlhs > 1 ) {
316- const char * field_names [] = {"format" , "params" , "explim" ,
331+ const char * field_names [] = {"format" , "params" , "explim" , "infinity" ,
317332 "round" , "saturation" , "subnormal" ,
318333 "flip" , "p" };
319334 mwSize dims [2 ] = {1 , 1 };
320- plhs [1 ] = mxCreateStructArray (2 , dims , 8 , field_names );
335+ plhs [1 ] = mxCreateStructArray (2 , dims , 9 , field_names );
321336 mxSetFieldByNumber (plhs [1 ], 0 , 0 , mxCreateString (fpopts -> format ));
322337
323338 mxArray * outparams = mxCreateDoubleMatrix (1 ,3 ,mxREAL );
@@ -332,30 +347,35 @@ void mexFunction(int nlhs,
332347 outexplimptr [0 ] = fpopts -> explim ;
333348 mxSetFieldByNumber (plhs [1 ], 0 , 2 , outexplim );
334349
350+ mxArray * outinfinity = mxCreateDoubleMatrix (1 , 1 , mxREAL );
351+ double * outinfinityptr = mxGetData (outinfinity );
352+ outinfinityptr [0 ] = fpopts -> infinity ;
353+ mxSetFieldByNumber (plhs [1 ], 0 , 3 , outinfinity );
354+
335355 mxArray * outround = mxCreateDoubleMatrix (1 ,1 ,mxREAL );
336356 double * outroundptr = mxGetData (outround );
337357 outroundptr [0 ] = fpopts -> round ;
338- mxSetFieldByNumber (plhs [1 ], 0 , 3 , outround );
358+ mxSetFieldByNumber (plhs [1 ], 0 , 4 , outround );
339359
340360 mxArray * outsaturation = mxCreateDoubleMatrix (1 ,1 ,mxREAL );
341361 double * outsaturationptr = mxGetData (outsaturation );
342362 outsaturationptr [0 ] = fpopts -> saturation ;
343- mxSetFieldByNumber (plhs [1 ], 0 , 4 , outsaturation );
363+ mxSetFieldByNumber (plhs [1 ], 0 , 5 , outsaturation );
344364
345365 mxArray * outsubnormal = mxCreateDoubleMatrix (1 ,1 ,mxREAL );
346366 double * outsubnormalptr = mxGetData (outsubnormal );
347367 outsubnormalptr [0 ] = fpopts -> subnormal ;
348- mxSetFieldByNumber (plhs [1 ], 0 , 5 , outsubnormal );
368+ mxSetFieldByNumber (plhs [1 ], 0 , 6 , outsubnormal );
349369
350370 mxArray * outflip = mxCreateDoubleMatrix (1 ,1 ,mxREAL );
351371 double * outflipptr = mxGetData (outflip );
352372 outflipptr [0 ] = fpopts -> flip ;
353- mxSetFieldByNumber (plhs [1 ], 0 , 6 , outflip );
373+ mxSetFieldByNumber (plhs [1 ], 0 , 7 , outflip );
354374
355375 mxArray * outp = mxCreateDoubleMatrix (1 ,1 ,mxREAL );
356376 double * outpptr = mxGetData (outp );
357377 outpptr [0 ] = fpopts -> p ;
358- mxSetFieldByNumber (plhs [1 ], 0 , 7 , outp );
378+ mxSetFieldByNumber (plhs [1 ], 0 , 8 , outp );
359379
360380 }
361381 if (nlhs > 2 )
0 commit comments