Skip to content

Commit fdd293a

Browse files
WillFroomGoogle-ML-Automation
authored andcommitted
[XLA:GPU] Unconditionally emit func.func from triton emitter.
PiperOrigin-RevId: 820175948
1 parent 189c85b commit fdd293a

File tree

2 files changed

+10
-39
lines changed

2 files changed

+10
-39
lines changed

xla/backends/gpu/codegen/triton/fusion_emitter.cc

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1902,35 +1902,6 @@ void AppendFuncArgType(absl::Span<const int64_t> dims, Type ir_type,
19021902
static_cast<unsigned>(mlir::NVVM::NVVMMemorySpace::Global)));
19031903
}
19041904

1905-
// Legacy emitter works with tt.func. New emitter works with func.func.
1906-
// TODO(b/393299275): Remove legacy optionality once migration is complete.
1907-
mlir::FunctionOpInterface CreateFuncOp(EmitterLocOpBuilder b,
1908-
absl::string_view fn_name,
1909-
absl::string_view fusion_kind,
1910-
SmallVector<Type>& fn_arg_types) {
1911-
if (fusion_kind != kTritonGemmFusionKind) {
1912-
return b.create<mlir::func::FuncOp>(fn_name,
1913-
b.getFunctionType(fn_arg_types, {}));
1914-
}
1915-
auto func = b.create<ttir::FuncOp>(
1916-
fn_name, b.getFunctionType(fn_arg_types, mlir::TypeRange()));
1917-
auto divisibility_attr = b.getI32IntegerAttr(16);
1918-
for (int i = 0; i < func.getNumArguments(); ++i) {
1919-
func.setArgAttr(i, "tt.divisibility", divisibility_attr);
1920-
}
1921-
return func;
1922-
}
1923-
1924-
// Legacy emitter works with tt.return. New emitter works with func.return.
1925-
// TODO(b/393299275): Remove legacy optionality once migration is complete.
1926-
void EmitReturnOp(EmitterLocOpBuilder b, absl::string_view fusion_kind) {
1927-
if (fusion_kind == kTritonGemmFusionKind) {
1928-
b.create<ttir::ReturnOp>();
1929-
} else {
1930-
b.create<mlir::func::ReturnOp>();
1931-
}
1932-
}
1933-
19341905
absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> CreateTritonModule(
19351906
absl::string_view fn_name, const HloFusionInstruction* fusion,
19361907
const se::DeviceDescription& device_info,
@@ -2276,8 +2247,8 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> EmitXTileModule(
22762247
AppendFuncArgType(s.shape.dimensions(), triton_ty, fn_arg_types);
22772248
}
22782249

2279-
mlir::FunctionOpInterface fn =
2280-
CreateFuncOp(b, fn_name, fusion_kind, fn_arg_types);
2250+
mlir::FunctionOpInterface fn = b.create<mlir::func::FuncOp>(
2251+
fn_name, b.getFunctionType(fn_arg_types, {}));
22812252

22822253
fn.addEntryBlock();
22832254
b.setInsertionPointToStart(&fn.front());
@@ -2306,7 +2277,7 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> EmitXTileModule(
23062277
return Internal("Unsupported fusion kind: %s", fusion_kind);
23072278
}
23082279

2309-
EmitReturnOp(b, fusion_kind);
2280+
b.create<mlir::func::ReturnOp>();
23102281

23112282
return triton_module;
23122283
}

xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_test.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ ENTRY e {
264264
})";
265265
TF_EXPECT_OK(
266266
CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_gemm_r", R"(
267-
CHECK: tt.func @triton_fn(%[[LHS:.*]]: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %[[RHS:.*]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %[[OUT:.*]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
267+
CHECK: func.func @triton_fn(%[[LHS:.*]]: !tt.ptr<i8>, %[[RHS:.*]]: !tt.ptr<f32>, %[[OUT:.*]]: !tt.ptr<f32>) {
268268
CHECK-DAG: %[[ZERO_KN:.*]] = arith.constant dense<0.000000e+00> : tensor<32x64xf32>
269269
CHECK-DAG: %[[ZERO_MK:.*]] = arith.constant dense<0.000000e+00> : tensor<16x32xf32>
270270
CHECK-DAG: %[[ZERO_MN:.*]] = arith.constant dense<0.000000e+00> : tensor<16x64xf32>
@@ -327,7 +327,7 @@ CHECK: }
327327
CHECK: %[[OUT_PTR:.*]] = tt.make_tensor_ptr %[[OUT]], [%[[C80]], %[[SIZE_M]]], [%[[SIZE_M]], %[[C1]]], [%[[C0]], %[[C0]]] {order = array<i32: 1, 0>} : <tensor<16x64xf32>>
328328
CHECK: %[[OUT_OFFSET:.*]] = tt.advance %[[OUT_PTR]], [%[[TILE_OFFSET_M_LHS]], %[[TILE_OFFSET_N_RHS]]] : <tensor<16x64xf32>>
329329
CHECK: tt.store %[[OUT_OFFSET]], %[[FOR]]#2 {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<16x64xf32>>
330-
CHECK: tt.return
330+
CHECK: return
331331
CHECK: }
332332
)"));
333333
}
@@ -356,7 +356,7 @@ ENTRY e {
356356

357357
TF_EXPECT_OK(
358358
CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"(
359-
CHECK: tt.func @triton_fn(%[[LHS:.*]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %[[RHS:.*]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %[[OUT:.*]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
359+
CHECK: func.func @triton_fn(%[[LHS:.*]]: !tt.ptr<f32>, %[[RHS:.*]]: !tt.ptr<f32>, %[[OUT:.*]]: !tt.ptr<f32>) {
360360
CHECK-DAG: %[[ZERO_KN:.*]] = arith.constant dense<0.000000e+00> : tensor<32x16xf32>
361361
CHECK-DAG: %[[ZERO_MK:.*]] = arith.constant dense<0.000000e+00> : tensor<16x32xf32>
362362
CHECK-DAG: %[[ZERO_MN:.*]] = arith.constant dense<0.000000e+00> : tensor<16x16xf32>
@@ -417,7 +417,7 @@ CHECK: }
417417
CHECK: %[[OUT_PTR:.*]] = tt.make_tensor_ptr %[[OUT]], [%[[SIZE_M]], %[[C1]]], [%[[C1]], %[[C1]]], [%[[C0]], %[[C0]]] {order = array<i32: 1, 0>} : <tensor<16x16xf32>>
418418
CHECK: %[[OUT_OFFSET:.*]] = tt.advance %[[OUT_PTR]], [%[[TILE_OFFSET_M_LHS]], %[[TILE_OFFSET_N_RHS]]] : <tensor<16x16xf32>>
419419
CHECK: tt.store %[[OUT_OFFSET]], %[[FOR]]#2 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<16x16xf32>>
420-
CHECK: tt.return
420+
CHECK: return
421421
CHECK: }
422422
)"));
423423
}
@@ -491,7 +491,7 @@ ENTRY e {
491491

492492
TF_EXPECT_OK(
493493
CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_gemm", R"(
494-
CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr<f32>
494+
CHECK: func.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr<f32>
495495
CHECK-SAME: %[[P1:[^:]*]]: !tt.ptr<f32>
496496
CHECK-SAME: %[[P2:[^:]*]]: !tt.ptr<f32>
497497
CHECK-DAG: %[[ARG_PTR:.*]] = arith.select %[[CONCAT_COND:.*]], %[[P1]], %[[P2]]
@@ -538,7 +538,7 @@ ENTRY e {
538538

539539
ASSERT_THAT(
540540
CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_gemm", R"(
541-
CHECK: tt.func @triton_fn({{[^,]*}}, %[[DYNAMIC_SLICE_INPUT:[^:]*]]: !tt.ptr<f32> {{[^,]*}}, %[[START_INDEX0_PTR:[^:]*]]: !tt.ptr<i32>
541+
CHECK: func.func @triton_fn({{[^,]*}}, %[[DYNAMIC_SLICE_INPUT:[^:]*]]: !tt.ptr<f32>, %[[START_INDEX0_PTR:[^:]*]]: !tt.ptr<i32>
542542
CHECK-DAG: %[[C0_i32:.*]] = arith.constant 0 : i32
543543
CHECK-DAG: %[[C1_i64:.*]] = arith.constant 1 : i64
544544
CHECK-DAG: %[[C2_i64:.*]] = arith.constant 2 : i64
@@ -1230,7 +1230,7 @@ ENTRY e {
12301230
})";
12311231
TF_EXPECT_OK(
12321232
CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_gemm_r", R"(
1233-
CHECK: tt.func @triton_fn
1233+
CHECK: func.func @triton_fn
12341234
CHECK-DAG: %[[ZERO:.*]] = arith.constant dense<0>
12351235
CHECK-DAG: %[[FMIN:.*]] = arith.constant dense<-1.280000e+02>
12361236
CHECK-DAG: %[[IMIN:.*]] = arith.constant dense<-128>

0 commit comments

Comments
 (0)