@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
1313limitations under the License.
1414==============================================================================*/
1515
16- #include < array>
1716#include < cmath>
1817#include < cstdint>
1918#include < memory>
@@ -22,16 +21,13 @@ limitations under the License.
2221#include < utility>
2322#include < vector>
2423
25- #include " absl/container/flat_hash_map.h"
26- #include " absl/functional/any_invocable.h"
2724#include " absl/log/check.h"
2825#include " absl/log/log.h"
2926#include " absl/status/status.h"
3027#include " absl/status/statusor.h"
3128#include " absl/strings/match.h"
3229#include " absl/strings/str_cat.h"
3330#include " absl/strings/str_format.h"
34- #include " absl/strings/str_replace.h"
3531#include " absl/strings/string_view.h"
3632#include " absl/types/span.h"
3733#include " xla/array.h"
@@ -81,13 +77,6 @@ bool IsAsync(const HloInstruction* inst) {
8177
8278class CollectiveOpsTestE2E : public CollectiveOpsE2ETestBase {
8379 public:
84- CollectiveOpsTestE2E () {
85- replacements_[kF8E4M3DatatypePlaceholder ] =
86- Capability ().IsCuda () ? " f8e4m3fn" : " f8e4m3fnuz" ;
87- replacements_[kF8E5M2DatatypePlaceholder ] =
88- Capability ().IsCuda () ? " f8e5m2" : " f8e5m2fnuz" ;
89- }
90-
9180 bool HasFp8Support () {
9281 if (Capability ().IsCuda ()) {
9382 return Capability ().cuda_compute_capability ()->IsAtLeast (8 , 9 );
@@ -123,13 +112,6 @@ class CollectiveOpsTestE2E : public CollectiveOpsE2ETestBase {
123112 EXPECT_EQ (gemm_op->custom_call_target (), " __cublas$lt$matmul$f8" );
124113 }
125114 }
126-
127- protected:
128- absl::flat_hash_map<absl::string_view, absl::string_view> replacements_;
129-
130- private:
131- static constexpr const char * kF8E4M3DatatypePlaceholder {" <<F8E4M3>>" };
132- static constexpr const char * kF8E5M2DatatypePlaceholder {" <<F8E5M2>>" };
133115};
134116
135117class AsyncCollectiveOps : public CollectiveOpsWithFlagsBase ,
@@ -1439,9 +1421,8 @@ ENTRY main {
14391421
14401422 // Disable the dot merger pass which can prevent the creation of FP8 GEMM
14411423 // Custom Calls.
1442- CollectiveOpsCompareWindowedNonWindowed (
1443- absl::StrReplaceAll (kModuleReplicatedStr , replacements_),
1444- /* disable_dot_merger=*/ true );
1424+ CollectiveOpsCompareWindowedNonWindowed (kModuleReplicatedStr ,
1425+ /* disable_dot_merger=*/ true );
14451426
14461427 // Verify the creation of FP8 GEMM Custom Calls on Hopper and newer
14471428 // architectures.
@@ -1451,8 +1432,7 @@ ENTRY main {
14511432 opts.set_xla_gpu_graph_min_graph_size (200 );
14521433 opts.set_xla_gpu_enable_triton_gemm (false );
14531434 opts.add_xla_disable_hlo_passes (" dot-merger" );
1454- CollectiveOpsVerifyF8Matmul (
1455- absl::StrReplaceAll (kModuleReplicatedStr , replacements_), opts);
1435+ CollectiveOpsVerifyF8Matmul (kModuleReplicatedStr , opts);
14561436}
14571437
14581438TEST_F (CollectiveOpsTestE2EWindowedNonWindowed,
@@ -1693,8 +1673,7 @@ ENTRY entry {
16931673 GetModuleConfigForTest (/* replica_count=*/ kNumReplicas );
16941674 auto opts = GetDebugOptionsForTest ();
16951675 opts.set_xla_gpu_enable_triton_gemm (false );
1696- CollectiveOpsVerifyF8Matmul (
1697- absl::StrReplaceAll (kModuleReplicatedStr , replacements_), opts);
1676+ CollectiveOpsVerifyF8Matmul (kModuleReplicatedStr , opts);
16981677}
16991678
17001679// E2E tests comparing the results with and without pipelining of collectives.
0 commit comments