@@ -41,6 +41,8 @@ cc_library(
4141 ] + if_cuda ([
4242 ":xla_gpu_device" ,
4343 ":xla_gpu_jit" ,
44+ ":jit_cuda_graph_mode_passes" ,
45+ "//tensorflow/compiler/jit/kernels:cuda_graph_mode_ops" ,
4446 ]),
4547 alwayslink = 1 ,
4648)
@@ -327,6 +329,17 @@ cc_library(
327329 alwayslink = 1 ,
328330)
329331
332+ cc_library (
333+ name = "jit_cuda_graph_mode_passes" ,
334+ srcs = ["jit_cuda_graph_mode_pass_registration.cc" ],
335+ visibility = ["//visibility:public" ],
336+ deps = [
337+ ":cuda_graph_mode_passes" ,
338+ "//tensorflow/core:core_cpu_internal" ,
339+ ] + tf_jit_compilation_passes_extra_deps (),
340+ alwayslink = 1 ,
341+ )
342+
330343cc_library (
331344 name = "xla_kernel_creator" ,
332345 srcs = [
@@ -578,6 +591,93 @@ cc_library(
578591 ],
579592)
580593
594+ cc_library (
595+ name = "cuda_graph_mode_passes" ,
596+ srcs = [
597+ "build_cuda_graph_mode_ops_pass.cc" ,
598+ "clone_constants_for_better_clustering.cc" ,
599+ "cluster_scoping_pass.cc" ,
600+ "deadness_analysis.cc" ,
601+ "deadness_analysis_internal.h" ,
602+ "encapsulate_subgraphs_pass.cc" ,
603+ "encapsulate_cuda_graph_mode_subgraphs_pass.cc" ,
604+ "encapsulate_xla_computations_pass.cc" ,
605+ "extract_outside_compilation_pass.cc" ,
606+ "increase_dynamism_for_auto_jit_pass.cc" ,
607+ "introduce_floating_point_jitter_pass.cc" ,
608+ "mark_for_cuda_graph_mode_pass.cc" ,
609+ "mark_for_cuda_graph_mode_pass_test_helper.cc" ,
610+ "partially_decluster_pass.cc" ,
611+ "report_clustering_info_pass.cc" ,
612+ "async_io_conversion_pass.cc" ,
613+ ],
614+ hdrs = [
615+ "build_cuda_graph_mode_ops_pass.h" ,
616+ "clone_constants_for_better_clustering.h" ,
617+ "cluster_scoping_pass.h" ,
618+ "deadness_analysis.h" ,
619+ "encapsulate_subgraphs_pass.h" ,
620+ "encapsulate_cuda_graph_mode_subgraphs_pass.h" ,
621+ "encapsulate_xla_computations_pass.h" ,
622+ "extract_outside_compilation_pass.h" ,
623+ "increase_dynamism_for_auto_jit_pass.h" ,
624+ "introduce_floating_point_jitter_pass.h" ,
625+ "mark_for_compilation_pass.h" ,
626+ "mark_for_compilation_pass_test_helper.h" ,
627+ "mark_for_cuda_graph_mode_pass.h" ,
628+ "mark_for_cuda_graph_mode_pass_test_helper.h" ,
629+ "partially_decluster_pass.h" ,
630+ "report_clustering_info_pass.h" ,
631+ "async_io_conversion_pass.h" ,
632+ ],
633+ deps = [
634+ "compilability_check_util" ,
635+ ":common" ,
636+ ":device_util" ,
637+ ":encapsulate_util" ,
638+ ":flags" ,
639+ ":resource_operation_safety_analysis" ,
640+ ":shape_inference_helpers" ,
641+ ":xla_activity_listener" ,
642+ ":xla_cluster_util" ,
643+ ":cuda_graph_mode_cluster_util" ,
644+ "//tensorflow/cc:cc_ops" ,
645+ "//tensorflow/cc:functional_ops" ,
646+ "//tensorflow/cc:ops" ,
647+ "//tensorflow/cc:scope" ,
648+ "//tensorflow/cc:scope_internal" ,
649+ "//tensorflow/compiler/jit/graphcycles" ,
650+ "//tensorflow/compiler/jit/ops:xla_ops" ,
651+ "//tensorflow/compiler/jit/ops:async_io_ops" ,
652+ "//tensorflow/compiler/tf2xla:resource_operation_table" ,
653+ "//tensorflow/compiler/tf2xla:side_effect_util" ,
654+ "//tensorflow/compiler/tf2xla:tf2xla_util" ,
655+ "//tensorflow/compiler/tf2xla:xla_compiler" ,
656+ "//tensorflow/compiler/tf2xla/cc:xla_jit_ops" ,
657+ "//tensorflow/compiler/tf2xla/cc:xla_ops" ,
658+ "//tensorflow/compiler/xla:status_macros" ,
659+ "//tensorflow/compiler/xla:statusor" ,
660+ "//tensorflow/compiler/xla:union_find" ,
661+ "//tensorflow/compiler/xla:util" ,
662+ "//tensorflow/core:core_cpu" ,
663+ "//tensorflow/core:core_cpu_internal" ,
664+ "//tensorflow/core:framework" ,
665+ "//tensorflow/core:framework_bounds_check" ,
666+ "//tensorflow/core:graph" ,
667+ "//tensorflow/core:lib" ,
668+ "//tensorflow/core:lib_internal" ,
669+ "//tensorflow/core:protos_all_cc" ,
670+ "//tensorflow/stream_executor/lib" ,
671+ "@com_google_absl//absl/algorithm:container" ,
672+ "@com_google_absl//absl/container:flat_hash_map" ,
673+ "@com_google_absl//absl/container:flat_hash_set" ,
674+ "@com_google_absl//absl/container:inlined_vector" ,
675+ "@com_google_absl//absl/memory" ,
676+ "@com_google_absl//absl/strings" ,
677+ "@com_google_absl//absl/types:optional" ,
678+ ],
679+ )
680+
581681cc_library (
582682 name = "xla_cluster_util" ,
583683 srcs = ["xla_cluster_util.cc" ],
@@ -603,6 +703,31 @@ cc_library(
603703 ],
604704)
605705
706+ cc_library (
707+ name = "cuda_graph_mode_cluster_util" ,
708+ srcs = ["cuda_graph_mode_cluster_util.cc" ],
709+ hdrs = ["cuda_graph_mode_cluster_util.h" ],
710+ deps = [
711+ ":flags" ,
712+ ":xla_activity_proto_cc" ,
713+ "//tensorflow/compiler/jit/graphcycles" ,
714+ "//tensorflow/compiler/xla:status_macros" ,
715+ "//tensorflow/compiler/xla:statusor" ,
716+ "//tensorflow/core:core_cpu" ,
717+ "//tensorflow/core:framework" ,
718+ "//tensorflow/core:framework_bounds_check" ,
719+ "//tensorflow/core:graph" ,
720+ "//tensorflow/core:protos_all_cc" ,
721+ "//tensorflow/stream_executor/lib" ,
722+ "@com_google_absl//absl/algorithm:container" ,
723+ "@com_google_absl//absl/container:flat_hash_map" ,
724+ "@com_google_absl//absl/container:flat_hash_set" ,
725+ "@com_google_absl//absl/strings" ,
726+ "@com_google_absl//absl/types:optional" ,
727+ "@com_google_absl//absl/types:span" ,
728+ ],
729+ )
730+
606731cc_library (
607732 name = "device_util" ,
608733 srcs = ["device_util.cc" ],
@@ -738,6 +863,63 @@ tf_cc_test(
738863 ],
739864)
740865
866+ tf_cc_test (
867+ name = "cuda_graph_mode_passes_test" ,
868+ size = "small" ,
869+ srcs = [
870+ "mark_for_cuda_graph_mode_pass_test.cc" ,
871+ "encapsulate_cuda_graph_mode_subgraphs_pass_test.cc" ,
872+ "build_cuda_graph_mode_ops_pass_test.cc" ,
873+ ],
874+ # TODO(b/141643254) Re-enable msan after fixing use-of-uninitialized-value
875+ # error.
876+ tags = ["nomsan" ] + tf_cuda_tests_tags (),
877+ deps = [
878+ ":common" ,
879+ ":cuda_graph_mode_passes" ,
880+ ":compilation_passes" ,
881+ ":compilation_passes_test_main" ,
882+ ":encapsulate_util" ,
883+ ":flags" ,
884+ ":node_matchers" ,
885+ ":xla_cluster_util" ,
886+ ":cuda_graph_mode_cluster_util" ,
887+ ":xla_cpu_device" ,
888+ ":xla_gpu_device" ,
889+ "//tensorflow/cc:cc_ops" ,
890+ "//tensorflow/cc:cc_ops_internal" ,
891+ "//tensorflow/cc:function_ops" ,
892+ "//tensorflow/cc:functional_ops" ,
893+ "//tensorflow/cc:ops" ,
894+ "//tensorflow/cc:resource_variable_ops" ,
895+ "//tensorflow/cc:scope" ,
896+ "//tensorflow/cc:sendrecv_ops" ,
897+ "//tensorflow/compiler/jit/kernels:xla_ops" ,
898+ "//tensorflow/compiler/jit/kernels:cuda_graph_mode_ops" ,
899+ "//tensorflow/compiler/tf2xla:rearrange_function_argument" ,
900+ "//tensorflow/compiler/tf2xla:side_effect_util" ,
901+ "//tensorflow/compiler/tf2xla:test_util" ,
902+ "//tensorflow/compiler/tf2xla:xla_compiler" ,
903+ "//tensorflow/compiler/tf2xla/cc:xla_jit_ops" ,
904+ "//tensorflow/compiler/tf2xla/cc:xla_ops" ,
905+ "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops" ,
906+ "//tensorflow/compiler/tf2xla/kernels:xla_ops" ,
907+ "//tensorflow/compiler/xla:test" ,
908+ "//tensorflow/core:core_cpu" ,
909+ "//tensorflow/core:framework" ,
910+ "//tensorflow/core:framework_internal" ,
911+ "//tensorflow/core:lib" ,
912+ "//tensorflow/core:protos_all_cc" ,
913+ "//tensorflow/core:session_options" ,
914+ "//tensorflow/core:test" ,
915+ "//tensorflow/core:testlib" ,
916+ "@com_google_absl//absl/container:flat_hash_map" ,
917+ "@com_google_absl//absl/memory" ,
918+ "@com_google_absl//absl/strings" ,
919+ "@com_google_absl//absl/types:span" ,
920+ ],
921+ )
922+
741923tf_cc_test (
742924 name = "xla_cluster_util_test" ,
743925 size = "small" ,
0 commit comments