Skip to content

Commit 63f6b9a

Browse files
authored
[Runtime] Support CUDA Graph execution in JIT mode. (#531)
1. Supports using a JIT-based CUDA graph execution of TF operations. 2. The auto-clustering is enabled by default.
1 parent b1036b8 commit 63f6b9a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+8718
-183
lines changed

tensorflow/compiler/jit/BUILD

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ cc_library(
4141
] + if_cuda([
4242
":xla_gpu_device",
4343
":xla_gpu_jit",
44+
":jit_cuda_graph_mode_passes",
45+
"//tensorflow/compiler/jit/kernels:cuda_graph_mode_ops",
4446
]),
4547
alwayslink = 1,
4648
)
@@ -327,6 +329,17 @@ cc_library(
327329
alwayslink = 1,
328330
)
329331

332+
cc_library(
333+
name = "jit_cuda_graph_mode_passes",
334+
srcs = ["jit_cuda_graph_mode_pass_registration.cc"],
335+
visibility = ["//visibility:public"],
336+
deps = [
337+
":cuda_graph_mode_passes",
338+
"//tensorflow/core:core_cpu_internal",
339+
] + tf_jit_compilation_passes_extra_deps(),
340+
alwayslink = 1,
341+
)
342+
330343
cc_library(
331344
name = "xla_kernel_creator",
332345
srcs = [
@@ -578,6 +591,93 @@ cc_library(
578591
],
579592
)
580593

594+
cc_library(
595+
name = "cuda_graph_mode_passes",
596+
srcs = [
597+
"build_cuda_graph_mode_ops_pass.cc",
598+
"clone_constants_for_better_clustering.cc",
599+
"cluster_scoping_pass.cc",
600+
"deadness_analysis.cc",
601+
"deadness_analysis_internal.h",
602+
"encapsulate_subgraphs_pass.cc",
603+
"encapsulate_cuda_graph_mode_subgraphs_pass.cc",
604+
"encapsulate_xla_computations_pass.cc",
605+
"extract_outside_compilation_pass.cc",
606+
"increase_dynamism_for_auto_jit_pass.cc",
607+
"introduce_floating_point_jitter_pass.cc",
608+
"mark_for_cuda_graph_mode_pass.cc",
609+
"mark_for_cuda_graph_mode_pass_test_helper.cc",
610+
"partially_decluster_pass.cc",
611+
"report_clustering_info_pass.cc",
612+
"async_io_conversion_pass.cc",
613+
],
614+
hdrs = [
615+
"build_cuda_graph_mode_ops_pass.h",
616+
"clone_constants_for_better_clustering.h",
617+
"cluster_scoping_pass.h",
618+
"deadness_analysis.h",
619+
"encapsulate_subgraphs_pass.h",
620+
"encapsulate_cuda_graph_mode_subgraphs_pass.h",
621+
"encapsulate_xla_computations_pass.h",
622+
"extract_outside_compilation_pass.h",
623+
"increase_dynamism_for_auto_jit_pass.h",
624+
"introduce_floating_point_jitter_pass.h",
625+
"mark_for_compilation_pass.h",
626+
"mark_for_compilation_pass_test_helper.h",
627+
"mark_for_cuda_graph_mode_pass.h",
628+
"mark_for_cuda_graph_mode_pass_test_helper.h",
629+
"partially_decluster_pass.h",
630+
"report_clustering_info_pass.h",
631+
"async_io_conversion_pass.h",
632+
],
633+
deps = [
634+
"compilability_check_util",
635+
":common",
636+
":device_util",
637+
":encapsulate_util",
638+
":flags",
639+
":resource_operation_safety_analysis",
640+
":shape_inference_helpers",
641+
":xla_activity_listener",
642+
":xla_cluster_util",
643+
":cuda_graph_mode_cluster_util",
644+
"//tensorflow/cc:cc_ops",
645+
"//tensorflow/cc:functional_ops",
646+
"//tensorflow/cc:ops",
647+
"//tensorflow/cc:scope",
648+
"//tensorflow/cc:scope_internal",
649+
"//tensorflow/compiler/jit/graphcycles",
650+
"//tensorflow/compiler/jit/ops:xla_ops",
651+
"//tensorflow/compiler/jit/ops:async_io_ops",
652+
"//tensorflow/compiler/tf2xla:resource_operation_table",
653+
"//tensorflow/compiler/tf2xla:side_effect_util",
654+
"//tensorflow/compiler/tf2xla:tf2xla_util",
655+
"//tensorflow/compiler/tf2xla:xla_compiler",
656+
"//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
657+
"//tensorflow/compiler/tf2xla/cc:xla_ops",
658+
"//tensorflow/compiler/xla:status_macros",
659+
"//tensorflow/compiler/xla:statusor",
660+
"//tensorflow/compiler/xla:union_find",
661+
"//tensorflow/compiler/xla:util",
662+
"//tensorflow/core:core_cpu",
663+
"//tensorflow/core:core_cpu_internal",
664+
"//tensorflow/core:framework",
665+
"//tensorflow/core:framework_bounds_check",
666+
"//tensorflow/core:graph",
667+
"//tensorflow/core:lib",
668+
"//tensorflow/core:lib_internal",
669+
"//tensorflow/core:protos_all_cc",
670+
"//tensorflow/stream_executor/lib",
671+
"@com_google_absl//absl/algorithm:container",
672+
"@com_google_absl//absl/container:flat_hash_map",
673+
"@com_google_absl//absl/container:flat_hash_set",
674+
"@com_google_absl//absl/container:inlined_vector",
675+
"@com_google_absl//absl/memory",
676+
"@com_google_absl//absl/strings",
677+
"@com_google_absl//absl/types:optional",
678+
],
679+
)
680+
581681
cc_library(
582682
name = "xla_cluster_util",
583683
srcs = ["xla_cluster_util.cc"],
@@ -603,6 +703,31 @@ cc_library(
603703
],
604704
)
605705

706+
cc_library(
707+
name = "cuda_graph_mode_cluster_util",
708+
srcs = ["cuda_graph_mode_cluster_util.cc"],
709+
hdrs = ["cuda_graph_mode_cluster_util.h"],
710+
deps = [
711+
":flags",
712+
":xla_activity_proto_cc",
713+
"//tensorflow/compiler/jit/graphcycles",
714+
"//tensorflow/compiler/xla:status_macros",
715+
"//tensorflow/compiler/xla:statusor",
716+
"//tensorflow/core:core_cpu",
717+
"//tensorflow/core:framework",
718+
"//tensorflow/core:framework_bounds_check",
719+
"//tensorflow/core:graph",
720+
"//tensorflow/core:protos_all_cc",
721+
"//tensorflow/stream_executor/lib",
722+
"@com_google_absl//absl/algorithm:container",
723+
"@com_google_absl//absl/container:flat_hash_map",
724+
"@com_google_absl//absl/container:flat_hash_set",
725+
"@com_google_absl//absl/strings",
726+
"@com_google_absl//absl/types:optional",
727+
"@com_google_absl//absl/types:span",
728+
],
729+
)
730+
606731
cc_library(
607732
name = "device_util",
608733
srcs = ["device_util.cc"],
@@ -738,6 +863,63 @@ tf_cc_test(
738863
],
739864
)
740865

866+
tf_cc_test(
867+
name = "cuda_graph_mode_passes_test",
868+
size = "small",
869+
srcs = [
870+
"mark_for_cuda_graph_mode_pass_test.cc",
871+
"encapsulate_cuda_graph_mode_subgraphs_pass_test.cc",
872+
"build_cuda_graph_mode_ops_pass_test.cc",
873+
],
874+
# TODO(b/141643254) Re-enable msan after fixing use-of-uninitialized-value
875+
# error.
876+
tags = ["nomsan"] + tf_cuda_tests_tags(),
877+
deps = [
878+
":common",
879+
":cuda_graph_mode_passes",
880+
":compilation_passes",
881+
":compilation_passes_test_main",
882+
":encapsulate_util",
883+
":flags",
884+
":node_matchers",
885+
":xla_cluster_util",
886+
":cuda_graph_mode_cluster_util",
887+
":xla_cpu_device",
888+
":xla_gpu_device",
889+
"//tensorflow/cc:cc_ops",
890+
"//tensorflow/cc:cc_ops_internal",
891+
"//tensorflow/cc:function_ops",
892+
"//tensorflow/cc:functional_ops",
893+
"//tensorflow/cc:ops",
894+
"//tensorflow/cc:resource_variable_ops",
895+
"//tensorflow/cc:scope",
896+
"//tensorflow/cc:sendrecv_ops",
897+
"//tensorflow/compiler/jit/kernels:xla_ops",
898+
"//tensorflow/compiler/jit/kernels:cuda_graph_mode_ops",
899+
"//tensorflow/compiler/tf2xla:rearrange_function_argument",
900+
"//tensorflow/compiler/tf2xla:side_effect_util",
901+
"//tensorflow/compiler/tf2xla:test_util",
902+
"//tensorflow/compiler/tf2xla:xla_compiler",
903+
"//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
904+
"//tensorflow/compiler/tf2xla/cc:xla_ops",
905+
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
906+
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
907+
"//tensorflow/compiler/xla:test",
908+
"//tensorflow/core:core_cpu",
909+
"//tensorflow/core:framework",
910+
"//tensorflow/core:framework_internal",
911+
"//tensorflow/core:lib",
912+
"//tensorflow/core:protos_all_cc",
913+
"//tensorflow/core:session_options",
914+
"//tensorflow/core:test",
915+
"//tensorflow/core:testlib",
916+
"@com_google_absl//absl/container:flat_hash_map",
917+
"@com_google_absl//absl/memory",
918+
"@com_google_absl//absl/strings",
919+
"@com_google_absl//absl/types:span",
920+
],
921+
)
922+
741923
tf_cc_test(
742924
name = "xla_cluster_util_test",
743925
size = "small",

0 commit comments

Comments
 (0)