@@ -156,15 +156,6 @@ std::vector<xla::PrimitiveType> AllOpSupportedTypes(HloOpcode opcode) {
156156 return result;
157157}
158158
159- std::vector<xla::PrimitiveType> AllIntegralDataTypes () {
160- std::vector<xla::PrimitiveType> result;
161- absl::c_copy_if (AllXlaDataTypes (), std::back_inserter (result),
162- [&](PrimitiveType data_type) {
163- return primitive_util::IsIntegralType (data_type);
164- });
165- return result;
166- }
167-
168159std::vector<PrecisionConfig::Algorithm> AllPrecisionAlgorithms () {
169160 std::vector<PrecisionConfig::Algorithm> algorithms;
170161 const tsl::protobuf::EnumDescriptor* algorithm_descriptor =
@@ -3099,54 +3090,6 @@ INSTANTIATE_TEST_SUITE_P(SortSuite, SortTest,
30993090 AllTestCombinationsForOpcodes ({HloOpcode::kSort }),
31003091 TritonSupportTestTypeAndOpcodeAndDeviceToString);
31013092
3102- using DynamicSliceTest = TritonSupportTestWithTypeAndDeviceParam;
3103-
3104- TEST_P (DynamicSliceTest, OperandTypes) {
3105- auto [data_type, cc] = GetParam ();
3106- const std::string kHloTestTemplate = R"(
3107- ENTRY triton_computation {
3108- operand = $0[256,256] parameter(0)
3109- start_1 = s32[] parameter(1)
3110- start_2 = s32[] constant(0)
3111- ROOT dynamic_slice_op = $0[32,256] dynamic-slice(operand, start_1, start_2),
3112- dynamic_slice_sizes={32,256}
3113- })" ;
3114- TF_ASSERT_OK_AND_ASSIGN (TestedInstruction ti, ParseTemplateAndGetInstruction (
3115- kHloTestTemplate , data_type,
3116- HloOpcode::kDynamicSlice ));
3117- RunSupportTest (std::move (ti), /* output_tile_sizes=*/ {2 , 4 }, cc);
3118- }
3119-
3120- INSTANTIATE_TEST_SUITE_P (
3121- DynamicSliceSuite, DynamicSliceTest,
3122- ::testing::Combine (::testing::ValuesIn(AllXlaDataTypes()),
3123- ::testing::ValuesIn(AllDevicesToTest())),
3124- TritonSupportTestTypeAndDeviceToString);
3125-
3126- using DynamicSliceOffsetTypesTest = TritonSupportTestWithTypeAndDeviceParam;
3127-
3128- TEST_P (DynamicSliceOffsetTypesTest, DynamicSlice2D) {
3129- auto [data_type, cc] = GetParam ();
3130- const std::string kHloTestTemplate = R"(
3131- ENTRY triton_computation {
3132- operand = f32[256,256] parameter(0)
3133- start_1 = $0[] parameter(1)
3134- start_2 = $0[] parameter(2)
3135- ROOT dynamic_slice_op = f32[32,64] dynamic-slice(operand, start_1, start_2),
3136- dynamic_slice_sizes={32,64}
3137- })" ;
3138- TF_ASSERT_OK_AND_ASSIGN (TestedInstruction ti, ParseTemplateAndGetInstruction (
3139- kHloTestTemplate , data_type,
3140- HloOpcode::kDynamicSlice ));
3141- RunSupportTest (std::move (ti), /* output_tile_sizes=*/ {2 , 4 }, cc);
3142- }
3143-
3144- INSTANTIATE_TEST_SUITE_P (
3145- DynamicSliceOffsetTypesSuite, DynamicSliceOffsetTypesTest,
3146- ::testing::Combine (::testing::ValuesIn(AllIntegralDataTypes()),
3147- ::testing::ValuesIn(AllDevicesToTest())),
3148- TritonSupportTestTypeAndDeviceToString);
3149-
31503093using RecvOpsTest = TritonSupportTestWithTypeAndDeviceParam;
31513094
31523095TEST_P (RecvOpsTest, RecvAndRecvDone) {
@@ -3534,6 +3477,7 @@ constexpr std::array kUnsupportedOps = {
35343477 // clang-format off
35353478 // go/keep-sorted start
35363479 HloOpcode::kDynamicReshape ,
3480+ HloOpcode::kDynamicSlice ,
35373481 HloOpcode::kDynamicUpdateSlice ,
35383482 HloOpcode::kGather ,
35393483 HloOpcode::kRaggedDot ,
@@ -3593,7 +3537,6 @@ absl::flat_hash_set<HloOpcode> AllTestedOpcodes() {
35933537 ret.emplace (HloOpcode::kCustomCall );
35943538 ret.emplace (HloOpcode::kDomain );
35953539 ret.emplace (HloOpcode::kDot );
3596- ret.emplace (HloOpcode::kDynamicSlice );
35973540 ret.emplace (HloOpcode::kFft );
35983541 ret.emplace (HloOpcode::kFusion );
35993542 ret.emplace (HloOpcode::kGetDimensionSize );
0 commit comments