Skip to content

Commit c729ff7

Browse files
akuegelGoogle-ML-Automation
authored andcommitted
Move triton codegen to xla/backends/gpu/codegen/triton
PiperOrigin-RevId: 715389092
1 parent 2940811 commit c729ff7

File tree

76 files changed

+1666
-217
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+1666
-217
lines changed

xla/service/gpu/fusions/triton/BUILD renamed to xla/backends/gpu/codegen/triton/BUILD

Lines changed: 64 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -105,18 +105,18 @@ cc_library(
105105
)
106106

107107
cc_library(
108-
name = "triton_fusion_emitter",
108+
name = "fusion_emitter",
109109
srcs = if_gpu_is_configured(
110-
["triton_fusion_emitter.cc"],
111-
["triton_fusion_emitter_stub.cc"],
110+
["fusion_emitter.cc"],
111+
["fusion_emitter_stub.cc"],
112112
),
113-
hdrs = ["triton_fusion_emitter.h"],
113+
hdrs = ["fusion_emitter.h"],
114114
deps = [
115115
":compilation_pipeline",
116116
":emitter_helpers",
117+
":fusion_emitter_legacy_matmul",
117118
":passes",
118-
":triton_fusion_emitter_legacy_matmul",
119-
":triton_support",
119+
":support",
120120
":xla_triton",
121121
":xla_triton_passes",
122122
"//xla:autotuning_proto_cc",
@@ -209,12 +209,12 @@ cc_library(
209209
)
210210

211211
cc_library(
212-
name = "triton_fusion_emitter_legacy_matmul",
212+
name = "fusion_emitter_legacy_matmul",
213213
srcs = if_gpu_is_configured(
214-
["triton_fusion_emitter_legacy_matmul.cc"],
215-
["triton_fusion_emitter_legacy_matmul_stub.cc"],
214+
["fusion_emitter_legacy_matmul.cc"],
215+
["fusion_emitter_legacy_matmul_stub.cc"],
216216
),
217-
hdrs = ["triton_fusion_emitter_legacy_matmul.h"],
217+
hdrs = ["fusion_emitter_legacy_matmul.h"],
218218
deps = [
219219
":emitter_helpers",
220220
":xla_triton",
@@ -271,16 +271,16 @@ cc_library(
271271
)
272272

273273
cc_library(
274-
name = "triton_fusion_emitter_stub_for_testing",
274+
name = "fusion_emitter_stub_for_testing",
275275
srcs = [
276276
"compilation_pipeline_stub.cc",
277-
"triton_fusion_emitter_legacy_matmul_stub.cc",
278-
"triton_fusion_emitter_stub.cc",
277+
"fusion_emitter_legacy_matmul_stub.cc",
278+
"fusion_emitter_stub.cc",
279279
],
280280
hdrs = [
281281
"compilation_pipeline.h",
282-
"triton_fusion_emitter.h",
283-
"triton_fusion_emitter_legacy_matmul.h",
282+
"fusion_emitter.h",
283+
"fusion_emitter_legacy_matmul.h",
284284
],
285285
deps = [
286286
"//xla:autotuning_proto_cc",
@@ -307,10 +307,10 @@ cc_library(
307307
)
308308

309309
xla_cc_test(
310-
name = "triton_fusion_emitter_stub_test",
311-
srcs = ["triton_fusion_emitter_stub_test.cc"],
310+
name = "fusion_emitter_stub_test",
311+
srcs = ["fusion_emitter_stub_test.cc"],
312312
deps = [
313-
":triton_fusion_emitter_stub_for_testing",
313+
":fusion_emitter_stub_for_testing",
314314
"//xla:literal",
315315
"//xla:literal_util",
316316
"//xla/hlo/ir:hlo",
@@ -417,7 +417,7 @@ cc_library(
417417
)
418418

419419
td_library(
420-
name = "xla_triton_td_files",
420+
name = "xla_td_files",
421421
srcs = glob(["*.td"]),
422422
includes = ["."],
423423
deps = [
@@ -441,7 +441,7 @@ gentbl_cc_library(
441441
],
442442
tblgen = "@llvm-project//mlir:mlir-tblgen",
443443
td_file = "xla_triton_dialect.td",
444-
deps = [":xla_triton_td_files"],
444+
deps = [":xla_td_files"],
445445
)
446446

447447
gentbl_cc_library(
@@ -460,7 +460,7 @@ gentbl_cc_library(
460460
tblgen = "@llvm-project//mlir:mlir-tblgen",
461461
td_file = "xla_triton_ops.td",
462462
deps = [
463-
":xla_triton_td_files",
463+
":xla_td_files",
464464
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
465465
"@llvm-project//mlir:OpBaseTdFiles",
466466
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
@@ -490,7 +490,7 @@ gentbl_cc_library(
490490
tblgen = "@llvm-project//mlir:mlir-tblgen",
491491
td_file = "xla_triton_attrs.td",
492492
deps = [
493-
":xla_triton_td_files",
493+
":xla_td_files",
494494
"@triton//:td_files",
495495
],
496496
)
@@ -516,11 +516,11 @@ cc_library(
516516
)
517517

518518
xla_test(
519-
name = "triton_fusion_emitter_deviceless_test",
520-
srcs = ["triton_fusion_emitter_deviceless_test.cc"],
519+
name = "fusion_emitter_deviceless_test",
520+
srcs = ["fusion_emitter_deviceless_test.cc"],
521521
backends = ["gpu"],
522522
deps = [
523-
":triton_fusion_emitter",
523+
":fusion_emitter",
524524
"//xla/hlo/ir:hlo",
525525
"//xla/hlo/testlib:filecheck",
526526
"//xla/service/gpu:gpu_device_info_for_tests",
@@ -539,8 +539,8 @@ xla_test(
539539
)
540540

541541
xla_test(
542-
name = "triton_fusion_emitter_device_legacy_test",
543-
srcs = if_gpu_is_configured(["triton_fusion_emitter_device_legacy_test.cc"]),
542+
name = "fusion_emitter_device_legacy_test",
543+
srcs = if_gpu_is_configured(["fusion_emitter_device_legacy_test.cc"]),
544544
# TODO(b/372714955): Fix the memory leak!
545545
backend_args = if_google(
546546
{
@@ -560,8 +560,8 @@ xla_test(
560560
"no_mac",
561561
],
562562
deps = [
563-
":triton_fusion_emitter",
564-
":triton_test_utils",
563+
":fusion_emitter",
564+
":test_utils",
565565
"//xla:autotuning_proto_cc",
566566
"//xla:error_spec",
567567
"//xla:xla_proto_cc",
@@ -592,8 +592,8 @@ xla_test(
592592
)
593593

594594
xla_test(
595-
name = "triton_fusion_emitter_int4_device_test",
596-
srcs = if_gpu_is_configured(["triton_fusion_emitter_int4_device_test.cc"]),
595+
name = "fusion_emitter_int4_device_test",
596+
srcs = if_gpu_is_configured(["fusion_emitter_int4_device_test.cc"]),
597597
# TODO(b/372714955): Fix the memory leak!
598598
backend_args = if_google(
599599
{
@@ -651,7 +651,7 @@ xla_test(
651651
],
652652
deps = [
653653
":kernel_name_tracer",
654-
":triton_test_utils",
654+
":test_utils",
655655
"//xla:autotuning_proto_cc",
656656
"//xla:error_spec",
657657
"//xla:literal",
@@ -682,8 +682,8 @@ xla_test(
682682
)
683683

684684
xla_test(
685-
name = "triton_fusion_emitter_device_test",
686-
srcs = if_gpu_is_configured(["triton_fusion_emitter_device_test.cc"]),
685+
name = "fusion_emitter_device_test",
686+
srcs = if_gpu_is_configured(["fusion_emitter_device_test.cc"]),
687687
backends = [
688688
"gpu_a100",
689689
"gpu_h100",
@@ -694,8 +694,8 @@ xla_test(
694694
"no_mac",
695695
],
696696
deps = [
697-
":triton_fusion_emitter",
698-
":triton_test_utils",
697+
":fusion_emitter",
698+
":test_utils",
699699
"//xla:autotuning_proto_cc",
700700
"//xla:error_spec",
701701
"//xla:shape_util",
@@ -757,12 +757,12 @@ cc_library(
757757
)
758758

759759
cc_library(
760-
name = "triton_test_utils",
760+
name = "test_utils",
761761
testonly = True,
762-
srcs = ["triton_test_utils.cc"],
763-
hdrs = ["triton_test_utils.h"],
762+
srcs = ["test_utils.cc"],
763+
hdrs = ["test_utils.h"],
764764
deps = [
765-
":triton_fusion_emitter",
765+
":fusion_emitter",
766766
"//xla:shape_util",
767767
"//xla:status_macros",
768768
"//xla/hlo/ir:hlo",
@@ -795,10 +795,10 @@ cc_library(
795795
)
796796

797797
xla_cc_test(
798-
name = "triton_fusion_emitter_mem_utils_test",
799-
srcs = if_cuda_is_configured(["triton_fusion_emitter_mem_utils_test.cc"]),
798+
name = "fusion_emitter_mem_utils_test",
799+
srcs = if_cuda_is_configured(["fusion_emitter_mem_utils_test.cc"]),
800800
deps = [
801-
":triton_fusion_emitter",
801+
":fusion_emitter",
802802
"//xla/hlo/ir:hlo",
803803
"//xla/hlo/utils:hlo_traversal",
804804
"//xla/service/gpu:gpu_device_info_for_tests",
@@ -825,8 +825,8 @@ xla_cc_test(
825825
)
826826

827827
xla_test(
828-
name = "triton_fusion_emitter_large_test",
829-
srcs = if_gpu_is_configured(["triton_fusion_emitter_large_test.cc"]),
828+
name = "fusion_emitter_large_test",
829+
srcs = if_gpu_is_configured(["fusion_emitter_large_test.cc"]),
830830
backends = [
831831
"gpu_a100",
832832
"gpu_h100",
@@ -853,8 +853,8 @@ xla_test(
853853
)
854854

855855
xla_test(
856-
name = "triton_fusion_emitter_parametrized_test",
857-
srcs = if_gpu_is_configured(["triton_fusion_emitter_parametrized_test.cc"]),
856+
name = "fusion_emitter_parametrized_test",
857+
srcs = if_gpu_is_configured(["fusion_emitter_parametrized_test.cc"]),
858858
backends = [
859859
"gpu_a100",
860860
"gpu_h100",
@@ -864,8 +864,8 @@ xla_test(
864864
shard_count = 10,
865865
tags = ["no_mac"],
866866
deps = [
867-
":triton_support",
868-
":triton_test_utils",
867+
":support",
868+
":test_utils",
869869
"//xla:comparison_util",
870870
"//xla:error_spec",
871871
"//xla:xla_data_proto_cc",
@@ -881,14 +881,14 @@ xla_test(
881881
)
882882

883883
cc_library(
884-
name = "triton_support",
884+
name = "support",
885885
srcs = [
886-
"triton_support.cc",
887-
"triton_support_legacy.cc",
886+
"support.cc",
887+
"support_legacy.cc",
888888
],
889889
hdrs = [
890-
"triton_support.h",
891-
"triton_support_legacy.h",
890+
"support.h",
891+
"support_legacy.h",
892892
],
893893
deps = [
894894
"//xla:shape_util",
@@ -911,16 +911,16 @@ cc_library(
911911
)
912912

913913
xla_cc_test(
914-
name = "triton_support_test",
915-
srcs = ["triton_support_test.cc"],
914+
name = "support_test",
915+
srcs = ["support_test.cc"],
916916
shard_count = 25,
917917
# TODO(b/353912594): this test does not need to run on GPU, but it is broken on CPU in OSS.
918918
# Force it to run on GPU temporarily in order to get important OSS coverage.
919919
tags = ["gpu"],
920920
deps = [
921-
":triton_fusion_emitter",
922-
":triton_support",
923-
":triton_test_utils",
921+
":fusion_emitter",
922+
":support",
923+
":test_utils",
924924
"//xla:shape_util",
925925
"//xla:xla_data_proto_cc",
926926
"//xla:xla_proto_cc",
@@ -941,8 +941,8 @@ xla_cc_test(
941941
)
942942

943943
xla_test(
944-
name = "triton_support_legacy_test",
945-
srcs = if_gpu_is_configured(["triton_support_legacy_test.cc"]),
944+
name = "support_legacy_test",
945+
srcs = if_gpu_is_configured(["support_legacy_test.cc"]),
946946
backends = [
947947
"gpu_a100",
948948
"gpu_h100",
@@ -951,10 +951,10 @@ xla_test(
951951
],
952952
tags = ["no_mac"],
953953
deps = [
954+
":fusion_emitter",
954955
":kernel_name_tracer",
955-
":triton_fusion_emitter",
956-
":triton_support",
957-
":triton_test_utils",
956+
":support",
957+
":test_utils",
958958
"//xla:error_spec",
959959
"//xla:shape_util",
960960
"//xla:xla_data_proto_cc",

xla/service/gpu/fusions/triton/compilation_pipeline.h renamed to xla/backends/gpu/codegen/triton/compilation_pipeline.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
#ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_COMPILATION_PIPELINE_H_
17-
#define XLA_SERVICE_GPU_FUSIONS_TRITON_COMPILATION_PIPELINE_H_
16+
#ifndef XLA_BACKENDS_GPU_CODEGEN_TRITON_COMPILATION_PIPELINE_H_
17+
#define XLA_BACKENDS_GPU_CODEGEN_TRITON_COMPILATION_PIPELINE_H_
1818

1919
#include <string>
2020

@@ -48,4 +48,4 @@ absl::Status CreateTritonPipeline(
4848
} // namespace gpu
4949
} // namespace xla
5050

51-
#endif // XLA_SERVICE_GPU_FUSIONS_TRITON_COMPILATION_PIPELINE_H_
51+
#endif // XLA_BACKENDS_GPU_CODEGEN_TRITON_COMPILATION_PIPELINE_H_

xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc renamed to xla/backends/gpu/codegen/triton/compilation_pipeline_cuda.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ limitations under the License.
2525
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
2626
#include "mlir/Pass/PassManager.h"
2727
#include "mlir/Transforms/Passes.h"
28-
#include "xla/service/gpu/fusions/triton/xla_triton_passes.h"
28+
#include "xla/backends/gpu/codegen/triton/xla_triton_passes.h"
2929
#include "xla/service/gpu/llvm_gpu_backend/nvptx_libdevice_path.h"
3030
#include "xla/service/hlo_module_config.h"
3131
#include "xla/stream_executor/device_description.h"

xla/service/gpu/fusions/triton/compilation_pipeline_stub.cc renamed to xla/backends/gpu/codegen/triton/compilation_pipeline_stub.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ limitations under the License.
1717

1818
#include "absl/status/status.h"
1919
#include "mlir/Pass/PassManager.h"
20-
#include "xla/service/gpu/fusions/triton/compilation_pipeline.h"
20+
#include "xla/backends/gpu/codegen/triton/compilation_pipeline.h"
2121

2222
namespace xla {
2323
namespace gpu {

xla/service/gpu/fusions/triton/dot_algorithms_test.cc renamed to xla/backends/gpu/codegen/triton/dot_algorithms_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ limitations under the License.
3838
#include "absl/strings/str_replace.h"
3939
#include "absl/strings/string_view.h"
4040
#include "xla/autotuning.pb.h"
41+
#include "xla/backends/gpu/codegen/triton/kernel_name_tracer.h"
42+
#include "xla/backends/gpu/codegen/triton/test_utils.h"
4143
#include "xla/error_spec.h"
4244
#include "xla/hlo/ir/hlo_computation.h"
4345
#include "xla/hlo/ir/hlo_instruction.h"
@@ -47,8 +49,6 @@ limitations under the License.
4749
#include "xla/literal_util.h"
4850
#include "xla/service/dump.h"
4951
#include "xla/service/gpu/backend_configs.pb.h"
50-
#include "xla/service/gpu/fusions/triton/kernel_name_tracer.h"
51-
#include "xla/service/gpu/fusions/triton/triton_test_utils.h"
5252
#include "xla/service/gpu/tests/gpu_codegen_test.h"
5353
#include "xla/service/hlo_module_config.h"
5454
#include "xla/stream_executor/device_description.h"

xla/service/gpu/fusions/triton/emitter_helpers.cc renamed to xla/backends/gpu/codegen/triton/emitter_helpers.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
#include "xla/service/gpu/fusions/triton/emitter_helpers.h"
16+
#include "xla/backends/gpu/codegen/triton/emitter_helpers.h"
1717

1818
#include <cstdint>
1919
#include <variant>

0 commit comments

Comments
 (0)