@@ -336,15 +336,19 @@ func.func @scaling_truncf_vector_f16_to_f6E3M2FN(%arg0 : vector<4xf16>, %arg1: v
336336
337337// -----
338338
339- func.func @scaling_truncf_propagate_rounding_mode (%arg0 : vector <4 xf16 >, %arg1: vector <4 xf 8 E 8 M 0 FNU >) -> vector <4 xf6 E3 M2 FN> {
340- %0 = arith.scaling_truncf %arg0 , %arg1 to_nearest_even : vector <4 xf16 >, vector <4 xf 8 E 8 M 0 FNU > to vector <4 xf6 E3 M2 FN>
339+ func.func @scaling_truncf_propagate_rounding_mode_fast_math (%arg0 : vector <4 xf16 >, %arg1: vector <4 x f16 >) -> vector <4 xf6 E3 M2 FN> {
340+ %0 = arith.scaling_truncf %arg0 , %arg1 to_nearest_even fastmath < fast > : vector <4 xf16 >, vector <4 x f16 > to vector <4 xf6 E3 M2 FN>
341341 return %0 : vector <4 xf6 E3 M2 FN>
342342}
343- // SCHECK-LABEL: @scaling_truncf_propagate_rounding_mode
344- // SCHECK: %[[TRUNCF:.+]] = arith.truncf [[_:%[a-zA-Z0-9_]+]] to_nearest_even : vector<4xf16> to vector<4xf6E3M2FN>
343+ // SCHECK-LABEL: @scaling_truncf_propagate_rounding_mode_fast_math
344+ // SCHECK: %[[SCALEF8:.+]] = arith.truncf %arg1 fastmath<fast> : vector<4xf16> to vector<4xf8E8M0FNU>
345+ // SCHECK: %[[SCALEINTY:.+]] = arith.extf %[[SCALEF8]] fastmath<fast> : vector<4xf8E8M0FNU> to vector<4xf16>
346+ // SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEINTY]] fastmath<fast> : vector<4xf16>
347+ // SCHECK: %[[TRUNCF:.+]] = arith.truncf [[_:%[a-zA-Z0-9_]+]] to_nearest_even fastmath<fast> : vector<4xf16> to vector<4xf6E3M2FN>
345348// SCHECK: return %[[TRUNCF]] : vector<4xf6E3M2FN>
346349
347350// -----
351+
348352func.func @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales (%arg0: f16 , %arg1 : f16 ) -> f4E2M1FN {
349353 %0 = arith.scaling_truncf %arg0 , %arg1 : f16 , f16 to f4E2M1FN
350354 return %0 : f4E2M1FN
@@ -353,6 +357,15 @@ func.func @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales(%arg0: f16, %arg1 : f
353357// SCHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : f16 to f8E8M0FN
354358// SCHECK: return
355359
360+ // -----
361+ func.func @scaling_truncf_vector_f16_to_f4E2M1FN_using_f16_scales (%arg0: vector <4 xf16 >, %arg1 : vector <4 xf16 >) -> vector <4 xf4 E2 M1 FN> {
362+ %0 = arith.scaling_truncf %arg0 , %arg1 : vector <4 xf16 >, vector <4 xf16 > to vector <4 xf4 E2 M1 FN>
363+ return %0 : vector <4 xf4 E2 M1 FN>
364+ }
365+ // SCHECK-LABEL: @scaling_truncf_vector_f16_to_f4E2M1FN_using_f16_scales
366+ // SCHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : vector<4xf16> to vector<4xf8E8M0FNU>
367+ // SCHECK: return
368+
356369// -----
357370
358371func.func @invalid_scaling_truncf_to_f4E2M1FN (%arg0: f16 , %arg1 : f8E5M2FNUZ ) -> f4E2M1FN {
@@ -507,6 +520,34 @@ func.func @scaling_extf_vector_to_bf16(%arg0: vector<4xf4E2M1FN>, %arg1 : vector
507520
508521// -----
509522
523+ func.func @scaling_extf_vector_to_f32_using_f16_scales (%arg0: vector <4 xf4 E2 M1 FN>, %arg1 : vector <4 xf16 >) -> vector <4 xf32 > {
524+ %0 = arith.scaling_extf %arg0 , %arg1 : vector <4 xf4 E2 M1 FN>, vector <4 xf16 > to vector <4 xf32 >
525+ return %0 : vector <4 xf32 >
526+ }
527+
528+ // SCHECK-LABEL: @scaling_extf_vector_to_f32_using_f16_scales
529+ // SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 : vector<4xf16> to vector<4xf8E8M0FNU>
530+ // SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] : vector<4xf8E8M0FNU> to vector<4xf32>
531+ // SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32>
532+ // SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf32>
533+ // SCHECK: return %[[RESULT]]
534+
535+ // -----
536+
537+ func.func @scaling_extf_vector_to_f32_using_f16_scales_fastmath (%arg0: vector <4 xf4 E2 M1 FN>, %arg1 : vector <4 xf16 >) -> vector <4 xf32 > {
538+ %0 = arith.scaling_extf %arg0 , %arg1 fastmath <fast > : vector <4 xf4 E2 M1 FN>, vector <4 xf16 > to vector <4 xf32 >
539+ return %0 : vector <4 xf32 >
540+ }
541+
542+ // SCHECK-LABEL: @scaling_extf_vector_to_f32_using_f16_scales_fastmath
543+ // SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 fastmath<fast> : vector<4xf16> to vector<4xf8E8M0FNU>
544+ // SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] fastmath<fast> : vector<4xf8E8M0FNU> to vector<4xf32>
545+ // SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 fastmath<fast> : vector<4xf4E2M1FN> to vector<4xf32>
546+ // SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] fastmath<fast> : vector<4xf32>
547+ // SCHECK: return %[[RESULT]]
548+
549+ // -----
550+
510551func.func @maxsi (%a: i32 , %b: i32 ) -> i32 {
511552 %result = arith.maxsi %a , %b : i32
512553 return %result : i32
0 commit comments