@@ -527,16 +527,17 @@ ENTRY triton_computation {
527527 any_is (PrimitiveType::F8E4M3FN) && any_is (PrimitiveType::F8E5M2);
528528 }
529529
530- // Crashes due to unsupported/unspecified rounding mode.
531- crashes_on_failure |= (data_type_in == PrimitiveType::F64 &&
532- (data_type_out == PrimitiveType::F8E4M3FN ||
533- data_type_out == PrimitiveType::F8E5M2));
534-
535- // Crashes due to unsupported conversion.
536- crashes_on_failure |= (data_type_out == PrimitiveType::F64 &&
537- (data_type_in == PrimitiveType::F8E4M3FN ||
538- data_type_in == PrimitiveType::F8E5M2));
539-
530+ if (std::holds_alternative<se::CudaComputeCapability>(cc)) {
531+ // Crashes due to unsupported/unspecified rounding mode.
532+ crashes_on_failure |= (data_type_in == PrimitiveType::F64 &&
533+ (data_type_out == PrimitiveType::F8E4M3FN ||
534+ data_type_out == PrimitiveType::F8E5M2));
535+
536+ // Crashes due to unsupported conversion.
537+ crashes_on_failure |= (data_type_out == PrimitiveType::F64 &&
538+ (data_type_in == PrimitiveType::F8E4M3FN ||
539+ data_type_in == PrimitiveType::F8E5M2));
540+ }
540541 RunSupportTest (
541542 std::move (ti), /* output_tile_sizes=*/ {1 , 32 }, cc,
542543 crashes_on_failure ? ExpectedFailMode::kCrash : ExpectedFailMode::kFail );
@@ -577,15 +578,22 @@ ENTRY triton_computation {
577578 data_type, opcode));
578579
579580 ExpectedFailMode fail_mode = ExpectedFailMode::kFail ;
580- if (opcode == HloOpcode::kDivide &&
581- (data_type == PrimitiveType::BF16 || data_type == PrimitiveType::F16 ||
582- data_type == PrimitiveType::F8E5M2 ||
583- data_type == PrimitiveType::F8E4M3FN ||
584- data_type == PrimitiveType::F8E4M3B11FNUZ ||
585- data_type == PrimitiveType::F8E5M2FNUZ ||
586- data_type == PrimitiveType::F8E4M3FNUZ)) {
587- fail_mode = ExpectedFailMode::kCrash ;
588- };
581+ if (std::holds_alternative<se::CudaComputeCapability>(cc)) {
582+ if (opcode == HloOpcode::kDivide &&
583+ (data_type == PrimitiveType::BF16 || data_type == PrimitiveType::F16 ||
584+ data_type == PrimitiveType::F8E5M2 ||
585+ data_type == PrimitiveType::F8E4M3FN)) {
586+ fail_mode = ExpectedFailMode::kCrash ;
587+ }
588+ } else {
589+ if (((opcode == HloOpcode::kMaximum || opcode == HloOpcode::kMinimum ) &&
590+ (data_type == PrimitiveType::F8E5M2 ||
591+ data_type == PrimitiveType::F8E4M3FN ||
592+ data_type == PrimitiveType::F8E5M2FNUZ ||
593+ data_type == PrimitiveType::F8E4M3FNUZ))) {
594+ fail_mode = ExpectedFailMode::kFailOrCrash ;
595+ }
596+ }
589597
590598 RunSupportTest (std::move (ti), /* output_tile_sizes=*/ {1 , 32 }, cc, fail_mode);
591599}
@@ -614,14 +622,21 @@ ENTRY triton_computation {
614622 data_type, opcode));
615623
616624 ExpectedFailMode fail_mode = ExpectedFailMode::kFail ;
617- if (opcode == HloOpcode::kDivide &&
618- (data_type == PrimitiveType::BF16 || data_type == PrimitiveType::F16 ||
619- data_type == PrimitiveType::F8E5M2 ||
620- data_type == PrimitiveType::F8E4M3FN ||
621- data_type == PrimitiveType::F8E4M3B11FNUZ ||
622- data_type == PrimitiveType::F8E5M2FNUZ ||
623- data_type == PrimitiveType::F8E4M3FNUZ)) {
624- fail_mode = ExpectedFailMode::kCrash ;
625+ if (std::holds_alternative<se::CudaComputeCapability>(cc)) {
626+ if (opcode == HloOpcode::kDivide &&
627+ (data_type == PrimitiveType::BF16 || data_type == PrimitiveType::F16 ||
628+ data_type == PrimitiveType::F8E5M2 ||
629+ data_type == PrimitiveType::F8E4M3FN)) {
630+ fail_mode = ExpectedFailMode::kCrash ;
631+ }
632+ } else {
633+ if (((opcode == HloOpcode::kMaximum || opcode == HloOpcode::kMinimum ) &&
634+ (data_type == PrimitiveType::F8E5M2 ||
635+ data_type == PrimitiveType::F8E4M3FN ||
636+ data_type == PrimitiveType::F8E5M2FNUZ ||
637+ data_type == PrimitiveType::F8E4M3FNUZ))) {
638+ fail_mode = ExpectedFailMode::kFailOrCrash ;
639+ }
625640 }
626641
627642 RunSupportTest (std::move (ti), /* output_tile_sizes=*/ {}, cc, fail_mode);
@@ -675,7 +690,20 @@ ENTRY triton_computation {
675690 TF_ASSERT_OK_AND_ASSIGN (
676691 TestedInstruction ti,
677692 ParseTemplateAndGetInstruction (hlo_text, data_type, opcode));
678- RunSupportTest (std::move (ti), /* output_tile_sizes=*/ {1 , 32 }, cc);
693+
694+ bool skip_failure_branch_to_avoid_crash = false ;
695+ if (std::holds_alternative<se::RocmComputeCapability>(cc)) {
696+ skip_failure_branch_to_avoid_crash =
697+ (opcode == HloOpcode::kClamp || opcode == HloOpcode::kSelect ) &&
698+ (data_type == PrimitiveType::F8E5M2 ||
699+ data_type == PrimitiveType::F8E4M3FN ||
700+ data_type == PrimitiveType::F8E5M2FNUZ ||
701+ data_type == PrimitiveType::F8E4M3FNUZ);
702+ }
703+
704+ RunSupportTest (std::move (ti), /* output_tile_sizes=*/ {1 , 32 }, cc,
705+ skip_failure_branch_to_avoid_crash ? ExpectedFailMode::kFailOrCrash
706+ : ExpectedFailMode::kFail );
679707}
680708
681709constexpr std::array kTestedOpsTernaryElementwise = {HloOpcode::kSelect ,
@@ -718,7 +746,9 @@ ENTRY triton_computation {
718746 TestedInstruction ti,
719747 ParseTemplateAndGetInstruction (kHloTestTemplate , data_type, opcode));
720748 bool crashes_on_failure = data_type == PrimitiveType::F8E4M3FN ||
721- data_type == PrimitiveType::F8E5M2;
749+ data_type == PrimitiveType::F8E5M2 ||
750+ data_type == PrimitiveType::F8E5M2FNUZ ||
751+ data_type == PrimitiveType::F8E4M3FNUZ;
722752 RunSupportTest (
723753 std::move (ti), /* output_tile_sizes=*/ {1 }, cc,
724754 crashes_on_failure ? ExpectedFailMode::kCrash : ExpectedFailMode::kFail );
@@ -742,7 +772,7 @@ ENTRY triton_computation {
742772 ParseTemplateAndGetInstruction (kHloTestTemplate , F32,
743773 HloOpcode::kReduce ));
744774 RunSupportTest (std::move (ti), /* output_tile_sizes=*/ {3 , 4 },
745- se::CudaComputeCapability::Ampere ());
775+ CudaAmpereOrRocm ());
746776}
747777
748778TEST_P (
@@ -789,7 +819,9 @@ ENTRY triton_computation {
789819 ParseTemplateAndGetInstruction (kHloTestTemplate , data_type, opcode));
790820
791821 bool crashes_on_failure = data_type == PrimitiveType::F8E4M3FN ||
792- data_type == PrimitiveType::F8E5M2;
822+ data_type == PrimitiveType::F8E5M2 ||
823+ data_type == PrimitiveType::F8E5M2FNUZ ||
824+ data_type == PrimitiveType::F8E4M3FNUZ;
793825 RunSupportTest (
794826 std::move (ti), /* output_tile_sizes=*/ {1 }, cc,
795827 crashes_on_failure ? ExpectedFailMode::kCrash : ExpectedFailMode::kFail );
@@ -825,7 +857,7 @@ ENTRY triton_computation {
825857}
826858
827859TEST_F (ReduceTest, ReduceWithNonConstReduceValueIsSupportedWithTriton) {
828- const se::GpuComputeCapability cc = se::CudaComputeCapability::Ampere ();
860+ const se::GpuComputeCapability cc = CudaAmpereOrRocm ();
829861 const std::string kHloTestTemplate = R"(
830862add {
831863 Arg_0 = $0[] parameter(0)
@@ -905,12 +937,15 @@ ENTRY triton_computation {
905937
906938 // TODO(b/361526623): Reduce the cases where emitter crashes.
907939 ExpectedFailMode fail_mode = ExpectedFailMode::kFail ;
908- if (opcode == HloOpcode::kDivide && (data_type == BF16 || data_type == F16)) {
940+ if (opcode == HloOpcode::kDivide && (data_type == BF16 ||
941+ data_type == F16)) {
909942 fail_mode = ExpectedFailMode::kCrash ;
910943 }
911- if (data_type == F8E4M3FN || data_type == F8E5M2) {
944+ if (data_type == F8E4M3FN || data_type == F8E5M2 || data_type == PrimitiveType::F8E5M2FNUZ ||
945+ data_type == PrimitiveType::F8E4M3FNUZ) {
912946 fail_mode = ExpectedFailMode::kFailOrCrash ;
913947 }
948+
914949 RunSupportTest (std::move (ti), /* output_tile_sizes=*/ {1 }, cc, fail_mode);
915950}
916951
@@ -1854,12 +1889,12 @@ TEST_P(DotTypesTest, Dot) {
18541889 fail_mode = ExpectedFailMode::kFailOrCrash ;
18551890 }
18561891 }
1857- if (absl::c_linear_search (std::vector{F8E5M2FNUZ, F8E4M3FNUZ, F8E4M3FN}, input_type) ||
1858- absl::c_linear_search (std::vector{F8E5M2FNUZ, F8E4M3FNUZ, F8E4M3FN}, result_type ) ||
1859- input_type == F64) {
1860- if (std::holds_alternative<se::RocmComputeCapability>(cc) ) {
1861- // Hits llvm::report_fatal_error during Triton compilation.
1862- fail_mode = ExpectedFailMode::kFailOrCrash ;
1892+ if (std::holds_alternative<se::RocmComputeCapability>(cc)) {
1893+ if ( absl::c_linear_search (std::vector{F8E5M2FNUZ, F8E4M3FNUZ, F8E4M3FN}, input_type ) ||
1894+ absl::c_linear_search (std::vector{F8E5M2FNUZ, F8E4M3FNUZ, F8E4M3FN}, result_type) ||
1895+ input_type == F64 ) {
1896+ // Hits llvm::report_fatal_error during Triton compilation.
1897+ fail_mode = ExpectedFailMode::kFailOrCrash ;
18631898 }
18641899 }
18651900
@@ -2169,49 +2204,6 @@ ENTRY triton_computation {
21692204 CudaAmpereOrRocm ());
21702205}
21712206
2172- TEST_F (DotTest, SparsityConfiguration) {
2173- // Note that support rejects this HLO as u16 is not supported.
2174- const std::string kHloTestTemplate = R"(
2175- flhs {
2176- ROOT result = $0[128,128] parameter(0)
2177- }
2178-
2179- frhs {
2180- ROOT result = $0[256,512] parameter(0)
2181- }
2182-
2183- ENTRY triton_computation {
2184- p0 = $0[128,128] parameter(0)
2185- p1 = $0[256,512] parameter(1)
2186- lhs = $0[128,128] fusion(p0), kind=kCustom, calls=flhs, backend_config={
2187- "fusion_backend_config":{
2188- "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{
2189- "output_tiles":[{"sizes":["16", "64"]}]
2190- }
2191- }
2192- }
2193- rhs = $0[256,512] fusion(p1), kind=kCustom, calls=frhs, backend_config={
2194- "fusion_backend_config":{
2195- "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{
2196- "output_tiles":[{"sizes":["64", "32"]}]
2197- }
2198- }
2199- }
2200- meta = u16[128,16] parameter(2)
2201- ROOT result = $0[128,512] dot(lhs, rhs, meta),
2202- lhs_contracting_dims={1},
2203- rhs_contracting_dims={0},
2204- sparsity=L.1@2:4
2205- }
2206- )" ;
2207- TF_ASSERT_OK_AND_ASSIGN (
2208- TestedInstruction ti,
2209- ParseTemplateAndGetInstruction (kHloTestTemplate , F32, HloOpcode::kDot ,
2210- /* use_nested_gemm_fusions=*/ true ));
2211- RunSupportTest (std::move (ti), /* output_tile_sizes=*/ {16 , 32 },
2212- CudaAmpereOrRocm ());
2213- }
2214-
22152207class DotPrecisionTest
22162208 : public DotTest,
22172209 public ::testing::WithParamInterface<
@@ -2272,14 +2264,9 @@ ENTRY triton_computation {
22722264 if (absl::c_linear_search (std::vector{F8E5M2, F8E4M3FN, S8}, data_type)) {
22732265 fail_mode = ExpectedFailMode::kFailOrCrash ;
22742266 }
2275- if (std::holds_alternative<se::CudaComputeCapability>(cc)) {
2276- if (data_type == F64) {
2277- fail_mode = ExpectedFailMode::kFailOrCrash ;
2278- }
2279- }
22802267 if (std::holds_alternative<se::RocmComputeCapability>(cc)) {
22812268 if (absl::c_linear_search (std::vector{F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN,
2282- F64 }, data_type)) {
2269+ S8, S16, S32, S64 }, data_type)) {
22832270 fail_mode = ExpectedFailMode::kFailOrCrash ;
22842271 }
22852272 }
@@ -2380,7 +2367,7 @@ ENTRY triton_computation {
23802367 if (std::holds_alternative<se::RocmComputeCapability>(cc)) {
23812368 if (absl::c_linear_search (std::vector{F8E4M3FN, F8E5M2FNUZ, F8E4M3FNUZ, F64},
23822369 data_type) ||
2383- (absl::c_linear_search (std::vector{F16, S64, S32, S16, BF16, F32},
2370+ (absl::c_linear_search (std::vector{S64, S32, S16, BF16, F16 , F32},
23842371 data_type) &&
23852372 algorithm == xla::PrecisionConfig::ALG_DOT_F64_F64_F64)) {
23862373 fail_mode = ExpectedFailMode::kFailOrCrash ;
0 commit comments