Skip to content

Commit be9da6d

Browse files
Fixed triton support_test on 0.7.1.
1 parent 52ecd40 commit be9da6d

File tree

3 files changed

+103
-107
lines changed

3 files changed

+103
-107
lines changed

xla/backends/gpu/codegen/triton/support.cc

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ absl::flat_hash_set<HloOpcode> TritonSupportedUnaryElementwiseOps(
9797

9898
if (element_type != PrimitiveType::F8E5M2 &&
9999
element_type != PrimitiveType::F8E4M3FN &&
100-
element_type != PrimitiveType::F8E4M3B11FNUZ &&
101100
element_type != PrimitiveType::F8E5M2FNUZ &&
102101
element_type != PrimitiveType::F8E4M3FNUZ) {
103102
ret.insert(HloOpcode::kNegate);
@@ -147,9 +146,11 @@ CodegenDecision IsTritonSupportedConversion(
147146
return error_message();
148147
}
149148

150-
bool is_f8_conversion =
151-
any_is(PrimitiveType::F8E4M3FN) && any_is(PrimitiveType::F8E5M2);
152-
bool is_f8 = any_is(PrimitiveType::F8E4M3FN) || any_is(PrimitiveType::F8E5M2);
149+
auto supported_fp8_types = {F8E4M3FN, F8E5M2, F8E4M3FNUZ, F8E5M2FNUZ};
150+
bool is_input_fp8 = absl::c_linear_search(supported_fp8_types, input);
151+
bool is_output_fp8 = absl::c_linear_search(supported_fp8_types, output);
152+
bool is_f8_conversion = is_input_fp8 && is_output_fp8;
153+
bool is_f8 = is_input_fp8 || is_output_fp8;
153154
bool is_f16_or_f32 = any_is(PrimitiveType::F16) ||
154155
any_is(PrimitiveType::BF16) ||
155156
any_is(PrimitiveType::F32);
@@ -179,7 +180,6 @@ absl::flat_hash_set<HloOpcode> TritonSupportedBinaryElementwiseOps(
179180
if (element_type == PrimitiveType::S4 || element_type == PrimitiveType::U16 ||
180181
element_type == PrimitiveType::F8E5M2 ||
181182
element_type == PrimitiveType::F8E4M3FN ||
182-
element_type == PrimitiveType::F8E4M3B11FNUZ ||
183183
element_type == PrimitiveType::F8E5M2FNUZ ||
184184
element_type == PrimitiveType::F8E4M3FNUZ) {
185185
return {};
@@ -217,6 +217,7 @@ absl::flat_hash_set<HloOpcode> TritonSupportedBinaryElementwiseOps(
217217
ret.insert(HloOpcode::kAtan2);
218218
ret.insert(HloOpcode::kPower);
219219
ret.insert(HloOpcode::kRemainder);
220+
ret.insert(HloOpcode::kDivide);
220221
}
221222

222223
return ret;
@@ -231,7 +232,6 @@ absl::flat_hash_set<HloOpcode> TritonSupportedTernaryElementwiseOps(
231232

232233
if (element_type == PrimitiveType::F8E5M2 ||
233234
element_type == PrimitiveType::F8E4M3FN ||
234-
element_type == PrimitiveType::F8E4M3B11FNUZ ||
235235
element_type == PrimitiveType::F8E5M2FNUZ ||
236236
element_type == PrimitiveType::F8E4M3FNUZ) {
237237
return {HloOpcode::kSelect};
@@ -263,8 +263,8 @@ CodegenDecision CanTritonHandleReduce(
263263
if (reduce.shape().element_type() == PrimitiveType::F8E4M3FN ||
264264
reduce.shape().element_type() == PrimitiveType::F8E5M2 ||
265265
reduce.shape().element_type() == PrimitiveType::F8E5M2FNUZ ||
266-
reduce.shape().element_type() == PrimitiveType::F8E4M3FNUZ ||
267-
reduce.shape().element_type() == PrimitiveType::F8E4M3B11FNUZ) {
266+
reduce.shape().element_type() == PrimitiveType::F8E4M3FNUZ /*||
267+
reduce.shape().element_type() == PrimitiveType::F8E4M3B11FNUZ*/) {
268268
return CodegenDecision::Forbid(
269269
"F8E4M3FN and F8E5M2 are not supported for reductions.");
270270
}
@@ -358,15 +358,15 @@ CodegenDecision AreTypesSupportedByAlgUnsetDot(
358358
}
359359
}
360360

361-
if (input_type == F8E4M3B11FNUZ || result_type == F8E4M3B11FNUZ ||
362-
input_type == F64) {
361+
if (input_type == F8E4M3B11FNUZ || result_type == F8E4M3B11FNUZ) {
363362
if (std::holds_alternative<se::RocmComputeCapability>(gpu_version)) {
364363
return CodegenDecision::Forbid(
365364
"Dot operation for F8E4M3B11FNUZ is not supported on ROCM.");
366365
}
367366
}
368367

369-
auto supported_float_types = {BF16, F16, F32, F64, F8E5M2};
368+
auto supported_float_types = {BF16, F16, F32, F8E4M3FN, F8E5M2, F8E4M3FNUZ,
369+
F8E5M2FNUZ};
370370
if (absl::c_linear_search(supported_float_types, input_type)) {
371371
return CodegenDecision::Allow();
372372
}
@@ -375,13 +375,15 @@ CodegenDecision AreTypesSupportedByAlgUnsetDot(
375375
return CodegenDecision::Allow();
376376
}
377377

378-
auto partially_supported_signed_types = {S4, S8, S16, S32, S64};
378+
auto partially_supported_signed_types = {S8, S16, S32, S64};
379379
if (absl::c_linear_search(partially_supported_signed_types, input_type)) {
380-
if (absl::c_linear_search(partially_supported_signed_types, result_type)) {
380+
if ((absl::c_linear_search(partially_supported_signed_types, result_type) &&
381+
!std::holds_alternative<se::RocmComputeCapability>(gpu_version))) {
381382
return CodegenDecision::Forbid(
382383
"Dot operation does not support these signed integer types.");
383384
}
384-
if (primitive_util::IsFloatingPointType(result_type)) {
385+
if (primitive_util::IsFloatingPointType(result_type) &&
386+
!std::holds_alternative<se::RocmComputeCapability>(gpu_version)) {
385387
return CodegenDecision::Forbid(
386388
"Dot operation does not support floating point input and signed "
387389
"integer result types.");
@@ -435,9 +437,9 @@ CodegenDecision AreDotAlgorithmInputAndOutputConversionsSupported(
435437
return forbid("Unsupported BF16 on GPUs before Blackwell");
436438
}
437439

438-
if (allowed_operands_types_or->front() == PrimitiveType::F64 &&
440+
if (algorithm == PrecisionConfig::ALG_DOT_F64_F64_F64 &&
439441
std::holds_alternative<se::RocmComputeCapability>(gpu_version)) {
440-
return forbid("Unsupported result conversion");
442+
return forbid("Unsupported BF16 on Rocm");
441443
}
442444

443445
if (allowed_operands_types_or->size() != 1) {
@@ -679,6 +681,13 @@ CodegenDecision IsTritonSupportedInstructionImpl(
679681
return CodegenDecision::Forbid(
680682
"dynamic slice is supported but not enabled yet");
681683
case HloOpcode::kBitcast:
684+
if (ShapeUtil::ElementsIn(instr.operand(0)->shape()) !=
685+
ShapeUtil::ElementsIn(instr.shape())) {
686+
return CodegenDecision::Forbid(
687+
"only bitcasts with the same number of elements are supported");
688+
}
689+
return CodegenDecision(instr.shape().element_type() != S4,
690+
"S4 is not supported.");
682691
case HloOpcode::kBroadcast:
683692
case HloOpcode::kReshape:
684693
case HloOpcode::kSlice:

xla/backends/gpu/codegen/triton/support_test.cc

Lines changed: 77 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -527,16 +527,17 @@ ENTRY triton_computation {
527527
any_is(PrimitiveType::F8E4M3FN) && any_is(PrimitiveType::F8E5M2);
528528
}
529529

530-
// Crashes due to unsupported/unspecified rounding mode.
531-
crashes_on_failure |= (data_type_in == PrimitiveType::F64 &&
532-
(data_type_out == PrimitiveType::F8E4M3FN ||
533-
data_type_out == PrimitiveType::F8E5M2));
534-
535-
// Crashes due to unsupported conversion.
536-
crashes_on_failure |= (data_type_out == PrimitiveType::F64 &&
537-
(data_type_in == PrimitiveType::F8E4M3FN ||
538-
data_type_in == PrimitiveType::F8E5M2));
539-
530+
if (std::holds_alternative<se::CudaComputeCapability>(cc)) {
531+
// Crashes due to unsupported/unspecified rounding mode.
532+
crashes_on_failure |= (data_type_in == PrimitiveType::F64 &&
533+
(data_type_out == PrimitiveType::F8E4M3FN ||
534+
data_type_out == PrimitiveType::F8E5M2));
535+
536+
// Crashes due to unsupported conversion.
537+
crashes_on_failure |= (data_type_out == PrimitiveType::F64 &&
538+
(data_type_in == PrimitiveType::F8E4M3FN ||
539+
data_type_in == PrimitiveType::F8E5M2));
540+
}
540541
RunSupportTest(
541542
std::move(ti), /*output_tile_sizes=*/{1, 32}, cc,
542543
crashes_on_failure ? ExpectedFailMode::kCrash : ExpectedFailMode::kFail);
@@ -577,15 +578,22 @@ ENTRY triton_computation {
577578
data_type, opcode));
578579

579580
ExpectedFailMode fail_mode = ExpectedFailMode::kFail;
580-
if (opcode == HloOpcode::kDivide &&
581-
(data_type == PrimitiveType::BF16 || data_type == PrimitiveType::F16 ||
582-
data_type == PrimitiveType::F8E5M2 ||
583-
data_type == PrimitiveType::F8E4M3FN ||
584-
data_type == PrimitiveType::F8E4M3B11FNUZ ||
585-
data_type == PrimitiveType::F8E5M2FNUZ ||
586-
data_type == PrimitiveType::F8E4M3FNUZ)) {
587-
fail_mode = ExpectedFailMode::kCrash;
588-
};
581+
if (std::holds_alternative<se::CudaComputeCapability>(cc)) {
582+
if (opcode == HloOpcode::kDivide &&
583+
(data_type == PrimitiveType::BF16 || data_type == PrimitiveType::F16 ||
584+
data_type == PrimitiveType::F8E5M2 ||
585+
data_type == PrimitiveType::F8E4M3FN)) {
586+
fail_mode = ExpectedFailMode::kCrash;
587+
}
588+
} else {
589+
if (((opcode == HloOpcode::kMaximum || opcode == HloOpcode::kMinimum) &&
590+
(data_type == PrimitiveType::F8E5M2 ||
591+
data_type == PrimitiveType::F8E4M3FN ||
592+
data_type == PrimitiveType::F8E5M2FNUZ ||
593+
data_type == PrimitiveType::F8E4M3FNUZ))) {
594+
fail_mode = ExpectedFailMode::kFailOrCrash;
595+
}
596+
}
589597

590598
RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 32}, cc, fail_mode);
591599
}
@@ -614,14 +622,21 @@ ENTRY triton_computation {
614622
data_type, opcode));
615623

616624
ExpectedFailMode fail_mode = ExpectedFailMode::kFail;
617-
if (opcode == HloOpcode::kDivide &&
618-
(data_type == PrimitiveType::BF16 || data_type == PrimitiveType::F16 ||
619-
data_type == PrimitiveType::F8E5M2 ||
620-
data_type == PrimitiveType::F8E4M3FN ||
621-
data_type == PrimitiveType::F8E4M3B11FNUZ ||
622-
data_type == PrimitiveType::F8E5M2FNUZ ||
623-
data_type == PrimitiveType::F8E4M3FNUZ)) {
624-
fail_mode = ExpectedFailMode::kCrash;
625+
if (std::holds_alternative<se::CudaComputeCapability>(cc)) {
626+
if (opcode == HloOpcode::kDivide &&
627+
(data_type == PrimitiveType::BF16 || data_type == PrimitiveType::F16 ||
628+
data_type == PrimitiveType::F8E5M2 ||
629+
data_type == PrimitiveType::F8E4M3FN)) {
630+
fail_mode = ExpectedFailMode::kCrash;
631+
}
632+
} else {
633+
if (((opcode == HloOpcode::kMaximum || opcode == HloOpcode::kMinimum) &&
634+
(data_type == PrimitiveType::F8E5M2 ||
635+
data_type == PrimitiveType::F8E4M3FN ||
636+
data_type == PrimitiveType::F8E5M2FNUZ ||
637+
data_type == PrimitiveType::F8E4M3FNUZ))) {
638+
fail_mode = ExpectedFailMode::kFailOrCrash;
639+
}
625640
}
626641

627642
RunSupportTest(std::move(ti), /*output_tile_sizes=*/{}, cc, fail_mode);
@@ -675,7 +690,20 @@ ENTRY triton_computation {
675690
TF_ASSERT_OK_AND_ASSIGN(
676691
TestedInstruction ti,
677692
ParseTemplateAndGetInstruction(hlo_text, data_type, opcode));
678-
RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 32}, cc);
693+
694+
bool skip_failure_branch_to_avoid_crash = false;
695+
if (std::holds_alternative<se::RocmComputeCapability>(cc)) {
696+
skip_failure_branch_to_avoid_crash =
697+
(opcode == HloOpcode::kClamp || opcode == HloOpcode::kSelect) &&
698+
(data_type == PrimitiveType::F8E5M2 ||
699+
data_type == PrimitiveType::F8E4M3FN ||
700+
data_type == PrimitiveType::F8E5M2FNUZ ||
701+
data_type == PrimitiveType::F8E4M3FNUZ);
702+
}
703+
704+
RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 32}, cc,
705+
skip_failure_branch_to_avoid_crash ? ExpectedFailMode::kFailOrCrash
706+
: ExpectedFailMode::kFail);
679707
}
680708

681709
constexpr std::array kTestedOpsTernaryElementwise = {HloOpcode::kSelect,
@@ -718,7 +746,9 @@ ENTRY triton_computation {
718746
TestedInstruction ti,
719747
ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode));
720748
bool crashes_on_failure = data_type == PrimitiveType::F8E4M3FN ||
721-
data_type == PrimitiveType::F8E5M2;
749+
data_type == PrimitiveType::F8E5M2 ||
750+
data_type == PrimitiveType::F8E5M2FNUZ ||
751+
data_type == PrimitiveType::F8E4M3FNUZ;
722752
RunSupportTest(
723753
std::move(ti), /*output_tile_sizes=*/{1}, cc,
724754
crashes_on_failure ? ExpectedFailMode::kCrash : ExpectedFailMode::kFail);
@@ -742,7 +772,7 @@ ENTRY triton_computation {
742772
ParseTemplateAndGetInstruction(kHloTestTemplate, F32,
743773
HloOpcode::kReduce));
744774
RunSupportTest(std::move(ti), /*output_tile_sizes=*/{3, 4},
745-
se::CudaComputeCapability::Ampere());
775+
CudaAmpereOrRocm());
746776
}
747777

748778
TEST_P(
@@ -789,7 +819,9 @@ ENTRY triton_computation {
789819
ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode));
790820

791821
bool crashes_on_failure = data_type == PrimitiveType::F8E4M3FN ||
792-
data_type == PrimitiveType::F8E5M2;
822+
data_type == PrimitiveType::F8E5M2 ||
823+
data_type == PrimitiveType::F8E5M2FNUZ ||
824+
data_type == PrimitiveType::F8E4M3FNUZ;
793825
RunSupportTest(
794826
std::move(ti), /*output_tile_sizes=*/{1}, cc,
795827
crashes_on_failure ? ExpectedFailMode::kCrash : ExpectedFailMode::kFail);
@@ -825,7 +857,7 @@ ENTRY triton_computation {
825857
}
826858

827859
TEST_F(ReduceTest, ReduceWithNonConstReduceValueIsSupportedWithTriton) {
828-
const se::GpuComputeCapability cc = se::CudaComputeCapability::Ampere();
860+
const se::GpuComputeCapability cc = CudaAmpereOrRocm();
829861
const std::string kHloTestTemplate = R"(
830862
add {
831863
Arg_0 = $0[] parameter(0)
@@ -905,12 +937,15 @@ ENTRY triton_computation {
905937

906938
// TODO(b/361526623): Reduce the cases where emitter crashes.
907939
ExpectedFailMode fail_mode = ExpectedFailMode::kFail;
908-
if (opcode == HloOpcode::kDivide && (data_type == BF16 || data_type == F16)) {
940+
if (opcode == HloOpcode::kDivide && (data_type == BF16 ||
941+
data_type == F16)) {
909942
fail_mode = ExpectedFailMode::kCrash;
910943
}
911-
if (data_type == F8E4M3FN || data_type == F8E5M2) {
944+
if (data_type == F8E4M3FN || data_type == F8E5M2 || data_type == PrimitiveType::F8E5M2FNUZ ||
945+
data_type == PrimitiveType::F8E4M3FNUZ) {
912946
fail_mode = ExpectedFailMode::kFailOrCrash;
913947
}
948+
914949
RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc, fail_mode);
915950
}
916951

@@ -1854,12 +1889,12 @@ TEST_P(DotTypesTest, Dot) {
18541889
fail_mode = ExpectedFailMode::kFailOrCrash;
18551890
}
18561891
}
1857-
if (absl::c_linear_search(std::vector{F8E5M2FNUZ, F8E4M3FNUZ, F8E4M3FN}, input_type) ||
1858-
absl::c_linear_search(std::vector{F8E5M2FNUZ, F8E4M3FNUZ, F8E4M3FN}, result_type) ||
1859-
input_type == F64) {
1860-
if (std::holds_alternative<se::RocmComputeCapability>(cc)) {
1861-
// Hits llvm::report_fatal_error during Triton compilation.
1862-
fail_mode = ExpectedFailMode::kFailOrCrash;
1892+
if (std::holds_alternative<se::RocmComputeCapability>(cc)) {
1893+
if (absl::c_linear_search(std::vector{F8E5M2FNUZ, F8E4M3FNUZ, F8E4M3FN}, input_type) ||
1894+
absl::c_linear_search(std::vector{F8E5M2FNUZ, F8E4M3FNUZ, F8E4M3FN}, result_type) ||
1895+
input_type == F64) {
1896+
// Hits llvm::report_fatal_error during Triton compilation.
1897+
fail_mode = ExpectedFailMode::kFailOrCrash;
18631898
}
18641899
}
18651900

@@ -2169,49 +2204,6 @@ ENTRY triton_computation {
21692204
CudaAmpereOrRocm());
21702205
}
21712206

2172-
TEST_F(DotTest, SparsityConfiguration) {
2173-
// Note that support rejects this HLO as u16 is not supported.
2174-
const std::string kHloTestTemplate = R"(
2175-
flhs {
2176-
ROOT result = $0[128,128] parameter(0)
2177-
}
2178-
2179-
frhs {
2180-
ROOT result = $0[256,512] parameter(0)
2181-
}
2182-
2183-
ENTRY triton_computation {
2184-
p0 = $0[128,128] parameter(0)
2185-
p1 = $0[256,512] parameter(1)
2186-
lhs = $0[128,128] fusion(p0), kind=kCustom, calls=flhs, backend_config={
2187-
"fusion_backend_config":{
2188-
"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{
2189-
"output_tiles":[{"sizes":["16", "64"]}]
2190-
}
2191-
}
2192-
}
2193-
rhs = $0[256,512] fusion(p1), kind=kCustom, calls=frhs, backend_config={
2194-
"fusion_backend_config":{
2195-
"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{
2196-
"output_tiles":[{"sizes":["64", "32"]}]
2197-
}
2198-
}
2199-
}
2200-
meta = u16[128,16] parameter(2)
2201-
ROOT result = $0[128,512] dot(lhs, rhs, meta),
2202-
lhs_contracting_dims={1},
2203-
rhs_contracting_dims={0},
2204-
sparsity=L.1@2:4
2205-
}
2206-
)";
2207-
TF_ASSERT_OK_AND_ASSIGN(
2208-
TestedInstruction ti,
2209-
ParseTemplateAndGetInstruction(kHloTestTemplate, F32, HloOpcode::kDot,
2210-
/* use_nested_gemm_fusions=*/true));
2211-
RunSupportTest(std::move(ti), /*output_tile_sizes=*/{16, 32},
2212-
CudaAmpereOrRocm());
2213-
}
2214-
22152207
class DotPrecisionTest
22162208
: public DotTest,
22172209
public ::testing::WithParamInterface<
@@ -2272,14 +2264,9 @@ ENTRY triton_computation {
22722264
if (absl::c_linear_search(std::vector{F8E5M2, F8E4M3FN, S8}, data_type)) {
22732265
fail_mode = ExpectedFailMode::kFailOrCrash;
22742266
}
2275-
if (std::holds_alternative<se::CudaComputeCapability>(cc)) {
2276-
if (data_type == F64) {
2277-
fail_mode = ExpectedFailMode::kFailOrCrash;
2278-
}
2279-
}
22802267
if (std::holds_alternative<se::RocmComputeCapability>(cc)) {
22812268
if (absl::c_linear_search(std::vector{F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN,
2282-
F64}, data_type)) {
2269+
S8, S16, S32, S64}, data_type)) {
22832270
fail_mode = ExpectedFailMode::kFailOrCrash;
22842271
}
22852272
}
@@ -2380,7 +2367,7 @@ ENTRY triton_computation {
23802367
if (std::holds_alternative<se::RocmComputeCapability>(cc)) {
23812368
if (absl::c_linear_search(std::vector{F8E4M3FN, F8E5M2FNUZ, F8E4M3FNUZ, F64},
23822369
data_type) ||
2383-
(absl::c_linear_search(std::vector{F16, S64, S32, S16, BF16, F32},
2370+
(absl::c_linear_search(std::vector{S64, S32, S16, BF16, F16, F32},
23842371
data_type) &&
23852372
algorithm == xla::PrecisionConfig::ALG_DOT_F64_F64_F64)) {
23862373
fail_mode = ExpectedFailMode::kFailOrCrash;

xla/service/gpu/gpu_device_info_for_tests.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ stream_executor::DeviceDescription TestGpuDeviceInfo::AMDMI210DeviceInfo() {
8080
b.set_threads_per_block_limit(1024);
8181
b.set_threads_per_warp(64);
8282
b.set_shared_memory_per_block(64 * 1024);
83-
b.set_shared_memory_per_block_optin(0);
83+
b.set_shared_memory_per_block_optin(64 * 1024);
8484
b.set_shared_memory_per_core(64 * 1024);
8585
b.set_threads_per_core_limit(2048);
8686
b.set_core_count(104);

0 commit comments

Comments
 (0)