@@ -1073,31 +1073,26 @@ typedef void (fintrinsic_op1)(unsigned, jl_value_t*, void*, void*);
10731073static inline jl_value_t * jl_fintrinsic_1 (jl_value_t * ty , jl_value_t * a , const char * name , fintrinsic_op1 * bfloatop , fintrinsic_op1 * halfop , fintrinsic_op1 * floatop , fintrinsic_op1 * doubleop )
10741074{
10751075 jl_task_t * ct = jl_current_task ;
1076- if (!jl_is_primitivetype (jl_typeof (a )))
1076+ jl_datatype_t * aty = (jl_datatype_t * )jl_typeof (a );
1077+ if (!jl_is_primitivetype (aty ))
10771078 jl_errorf ("%s: value is not a primitive type" , name );
10781079 if (!jl_is_primitivetype (ty ))
10791080 jl_errorf ("%s: type is not a primitive type" , name );
10801081 unsigned sz2 = jl_datatype_size (ty );
10811082 jl_value_t * newv = jl_gc_alloc (ct -> ptls , sz2 , ty );
10821083 void * pa = jl_data_ptr (a ), * pr = jl_data_ptr (newv );
1083- unsigned sz = jl_datatype_size (jl_typeof (a ));
1084- switch (sz ) {
1085- /* choose the right size c-type operation based on the input */
1086- case 2 :
1087- if (jl_typeof (a ) == (jl_value_t * )jl_float16_type )
1088- halfop (sz2 * host_char_bit , ty , pa , pr );
1089- else /*if (jl_typeof(a) == (jl_value_t*)jl_bfloat16_type)*/
1090- bfloatop (sz2 * host_char_bit , ty , pa , pr );
1091- break ;
1092- case 4 :
1084+
1085+ if (aty == jl_float16_type )
1086+ halfop (sz2 * host_char_bit , ty , pa , pr );
1087+ else if (aty == jl_bfloat16_type )
1088+ bfloatop (sz2 * host_char_bit , ty , pa , pr );
1089+ else if (aty == jl_float32_type )
10931090 floatop (sz2 * host_char_bit , ty , pa , pr );
1094- break ;
1095- case 8 :
1091+ else if (aty == jl_float64_type )
10961092 doubleop (sz2 * host_char_bit , ty , pa , pr );
1097- break ;
1098- default :
1099- jl_errorf ("%s: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64" , name );
1100- }
1093+ else
1094+ jl_errorf ("%s: runtime floating point intrinsics require both arguments to be Float16, BFloat16, Float32, or Float64" , name );
1095+
11011096 return newv ;
11021097}
11031098
@@ -1273,30 +1268,24 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b) \
12731268{ \
12741269 jl_task_t *ct = jl_current_task; \
12751270 jl_value_t *ty = jl_typeof(a); \
1271+ jl_datatype_t *aty = (jl_datatype_t *)ty; \
12761272 if (jl_typeof(b) != ty) \
12771273 jl_error(#name ": types of a and b must match"); \
12781274 if (!jl_is_primitivetype(ty)) \
12791275 jl_error(#name ": values are not primitive types"); \
12801276 int sz = jl_datatype_size(ty); \
12811277 jl_value_t *newv = jl_gc_alloc(ct->ptls, sz, ty); \
12821278 void *pa = jl_data_ptr(a), *pb = jl_data_ptr(b), *pr = jl_data_ptr(newv); \
1283- switch (sz) { \
1284- /* choose the right size c-type operation */ \
1285- case 2 : \
1286- if ((jl_datatype_t * )ty == jl_float16_type ) \
1287- jl_ ##name ##16(16, pa, pb, pr); \
1288- else /*if ((jl_datatype_t*)ty == jl_bfloat16_type)*/ \
1289- jl_ ##name ##bf16(16, pa, pb, pr); \
1290- break; \
1291- case 4: \
1279+ if (aty == jl_float16_type) \
1280+ jl_##name##16(16, pa, pb, pr); \
1281+ else if (aty == jl_bfloat16_type) \
1282+ jl_##name##bf16(16, pa, pb, pr); \
1283+ else if (aty == jl_float32_type) \
12921284 jl_##name##32(32, pa, pb, pr); \
1293- break; \
1294- case 8: \
1285+ else if (aty == jl_float64_type) \
12951286 jl_##name##64(64, pa, pb, pr); \
1296- break; \
1297- default: \
1298- jl_error(#name ": runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64"); \
1299- } \
1287+ else \
1288+ jl_error(#name ": runtime floating point intrinsics require both arguments to be Float16, BFloat16, Float32, or Float64"); \
13001289 return newv; \
13011290}
13021291
@@ -1308,30 +1297,24 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b) \
13081297JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b) \
13091298{ \
13101299 jl_value_t *ty = jl_typeof(a); \
1300+ jl_datatype_t *aty = (jl_datatype_t *)ty; \
13111301 if (jl_typeof(b) != ty) \
13121302 jl_error(#name ": types of a and b must match"); \
13131303 if (!jl_is_primitivetype(ty)) \
13141304 jl_error(#name ": values are not primitive types"); \
13151305 void *pa = jl_data_ptr(a), *pb = jl_data_ptr(b); \
1316- int sz = jl_datatype_size(ty); \
13171306 int cmp; \
1318- switch (sz) { \
1319- /* choose the right size c-type operation */ \
1320- case 2 : \
1321- if ((jl_datatype_t * )ty == jl_float16_type ) \
1322- cmp = jl_ ##name ##16(16, pa, pb); \
1323- else /*if ((jl_datatype_t*)ty == jl_bfloat16_type)*/ \
1324- cmp = jl_ ##name ##bf16(16, pa, pb); \
1325- break; \
1326- case 4: \
1307+ if (aty == jl_float16_type) \
1308+ cmp = jl_##name##16(16, pa, pb); \
1309+ else if (aty == jl_bfloat16_type) \
1310+ cmp = jl_##name##bf16(16, pa, pb); \
1311+ else if (aty == jl_float32_type) \
13271312 cmp = jl_##name##32(32, pa, pb); \
1328- break; \
1329- case 8: \
1313+ else if (aty == jl_float64_type) \
13301314 cmp = jl_##name##64(64, pa, pb); \
1331- break; \
1332- default: \
1333- jl_error(#name ": runtime floating point intrinsics are not implemented for bit sizes other than 32 and 64"); \
1334- } \
1315+ else \
1316+ jl_error(#name ": runtime floating point intrinsics require both arguments to be Float16, BFloat16, Float32, or Float64"); \
1317+ \
13351318 return cmp ? jl_true : jl_false; \
13361319}
13371320
@@ -1344,30 +1327,24 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b, jl_value_t *c)
13441327{ \
13451328 jl_task_t *ct = jl_current_task; \
13461329 jl_value_t *ty = jl_typeof(a); \
1330+ jl_datatype_t *aty = (jl_datatype_t *)ty; \
13471331 if (jl_typeof(b) != ty || jl_typeof(c) != ty) \
13481332 jl_error(#name ": types of a, b, and c must match"); \
13491333 if (!jl_is_primitivetype(ty)) \
13501334 jl_error(#name ": values are not primitive types"); \
13511335 int sz = jl_datatype_size(ty); \
13521336 jl_value_t *newv = jl_gc_alloc(ct->ptls, sz, ty); \
13531337 void *pa = jl_data_ptr(a), *pb = jl_data_ptr(b), *pc = jl_data_ptr(c), *pr = jl_data_ptr(newv); \
1354- switch (sz) { \
1355- /* choose the right size c-type operation */ \
1356- case 2 : \
1357- if ((jl_datatype_t * )ty == jl_float16_type ) \
1338+ if (aty == jl_float16_type) \
13581339 jl_##name##16(16, pa, pb, pc, pr); \
1359- else /* if ((jl_datatype_t*)ty == jl_bfloat16_type)*/ \
1340+ else if (aty == jl_bfloat16_type) \
13601341 jl_##name##bf16(16, pa, pb, pc, pr); \
1361- break; \
1362- case 4: \
1342+ else if (aty == jl_float32_type) \
13631343 jl_##name##32(32, pa, pb, pc, pr); \
1364- break; \
1365- case 8: \
1344+ else if (aty == jl_float64_type) \
13661345 jl_##name##64(64, pa, pb, pc, pr); \
1367- break; \
1368- default: \
1369- jl_error(#name ": runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64"); \
1370- } \
1346+ else \
1347+ jl_error(#name ": runtime floating point intrinsics require both arguments to be Float16, BFloat16, Float32, or Float64"); \
13711348 return newv; \
13721349}
13731350
@@ -1661,7 +1638,7 @@ static inline void fptrunc(jl_datatype_t *aty, void *pa, jl_datatype_t *ty, void
16611638 fptrunc_convert (float64 , bfloat16 );
16621639 fptrunc_convert (float64 , float32 );
16631640 else
1664- jl_error ("fptrunc: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64 " );
1641+ jl_error ("fptrunc: runtime floating point intrinsics require both arguments to be Float16, BFloat16, Float32, or Float64 " );
16651642#undef fptrunc_convert
16661643}
16671644
@@ -1685,7 +1662,7 @@ static inline void fpext(jl_datatype_t *aty, void *pa, jl_datatype_t *ty, void *
16851662 fpext_convert (bfloat16 , float64 );
16861663 fpext_convert (float32 , float64 );
16871664 else
1688- jl_error ("fptrunc: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64 " );
1665+ jl_error ("fptrunc: runtime floating point intrinsics require both arguments to be Float16, BFloat16, Float32, or Float64 " );
16891666#undef fpext_convert
16901667}
16911668
0 commit comments