Skip to content

Commit 5815ecd

Browse files
authored
Refine FFI to remove GemmConfig new & delete (#407)
1 parent 038c1b8 commit 5815ecd

File tree

6 files changed

+360
-309
lines changed

6 files changed

+360
-309
lines changed

third_party/openxla.patch

Lines changed: 65 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3012,7 +3012,7 @@ index 7403531a8..cc3b8aadd 100644
30123012
// Extract the memory value returned from atomicCAS and store it as
30133013
// cas_old_output.
30143014
diff --git a/xla/service/gpu/ir_emitter_unnested.cc b/xla/service/gpu/ir_emitter_unnested.cc
3015-
index 4ba739dba..475cd1217 100644
3015+
index 4ba739dba..a84be445e 100644
30163016
--- a/xla/service/gpu/ir_emitter_unnested.cc
30173017
+++ b/xla/service/gpu/ir_emitter_unnested.cc
30183018
@@ -1,4 +1,6 @@
@@ -3240,7 +3240,7 @@ index 4ba739dba..475cd1217 100644
32403240
absl::Status IrEmitterUnnested::EmitCholeskyThunk(const HloInstruction* instr) {
32413241
TF_ASSIGN_OR_RETURN(CholeskyOptions options,
32423242
instr->backend_config<CholeskyOptions>());
3243-
@@ -1301,7 +1333,108 @@ absl::Status IrEmitterUnnested::EmitCholeskyThunk(const HloInstruction* instr) {
3243+
@@ -1301,7 +1333,144 @@ absl::Status IrEmitterUnnested::EmitCholeskyThunk(const HloInstruction* instr) {
32443244

32453245
return absl::OkStatus();
32463246
}
@@ -3250,19 +3250,21 @@ index 4ba739dba..475cd1217 100644
32503250
+#ifdef TENSORFLOW_USE_SYCL
32513251
+absl::StatusOr<CustomCallThunk::AttributesMap> BuildAttributesMap(
32523252
+ const HloCustomCallInstruction* instr) {
3253-
+ // After 0.4.26, ffi support absl::span.
3254-
+ // Below attrs can be refine to absl::span for reducing key-value
32553253
+ CustomCallThunk::AttributesMap attrs;
32563254
+ if (IsCustomCallToDnnConvolution(*instr)) {
32573255
+ TF_ASSIGN_OR_RETURN(auto gpu_config,
32583256
+ instr->backend_config<GpuBackendConfig>());
32593257
+ const CudnnConvBackendConfig& backend_config =
32603258
+ gpu_config.cudnn_conv_backend_config();
32613259
+ TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(instr));
3262-
+ attrs["conv_result_scale"] = static_cast<float>(backend_config.conv_result_scale());
3263-
+ attrs["side_input_scale"] = static_cast<float>(backend_config.side_input_scale());
3264-
+ attrs["activation_mode"] = static_cast<int32_t>(backend_config.activation_mode());
3265-
+ attrs["leakyrelu_alpha"] = static_cast<float>(backend_config.leakyrelu_alpha());
3260+
+ attrs["conv_result_scale"] =
3261+
+ static_cast<float>(backend_config.conv_result_scale());
3262+
+ attrs["side_input_scale"] =
3263+
+ static_cast<float>(backend_config.side_input_scale());
3264+
+ attrs["activation_mode"] =
3265+
+ static_cast<int32_t>(backend_config.activation_mode());
3266+
+ attrs["leakyrelu_alpha"] =
3267+
+ static_cast<float>(backend_config.leakyrelu_alpha());
32663268
+
32673269
+ const Window& window = instr->window();
32683270
+ const ConvolutionDimensionNumbers& dnums =
@@ -3333,14 +3335,48 @@ index 4ba739dba..475cd1217 100644
33333335
+ instr->backend_config<xla::gpu::GpuBackendConfig>());
33343336
+ xla::gpu::GemmBackendConfig config = gpu_config.gemm_backend_config();
33353337
+ xla::gpu::GemmBackendConfig_Epilogue epilogue = config.epilogue();
3336-
+ TF_ASSIGN_OR_RETURN(
3337-
+ auto gemm_config,
3338-
+ GemmConfig::For(static_cast<const HloInstruction*>(instr)));
3339-
+ GemmConfig* gemm_config_ptr = new GemmConfig(gemm_config);
33403338
+ attrs["epilogue"] = static_cast<int32_t>(epilogue);
3341-
+ // SYCL TODO:
3342-
+ // gemm_config may be split into separate parameters and added to attrs later.
3343-
+ attrs["gemm_config_ptr"] = reinterpret_cast<int64_t>(gemm_config_ptr);
3339+
+
3340+
+ TF_ASSIGN_OR_RETURN(
3341+
+ auto gemm_config,
3342+
+ GemmConfig::For(static_cast<const HloInstruction*>(instr)));
3343+
+
3344+
+ attrs["lhs_layout_dtype"] =
3345+
+ static_cast<int32_t>(gemm_config.lhs_layout.dtype);
3346+
+ attrs["lhs_order"] = static_cast<int32_t>(gemm_config.lhs_layout.order);
3347+
+ attrs["lhs_num_cols"] = gemm_config.lhs_layout.num_cols;
3348+
+ attrs["lhs_num_rows"] = gemm_config.lhs_layout.num_rows;
3349+
+ attrs["lhs_batch_stride"] = gemm_config.lhs_layout.batch_stride;
3350+
+ attrs["lhs_leading_dim_stride"] = gemm_config.lhs_layout.leading_dim_stride;
3351+
+
3352+
+ attrs["rhs_layout_dtype"] =
3353+
+ static_cast<int32_t>(gemm_config.rhs_layout.dtype);
3354+
+ attrs["rhs_order"] = static_cast<int32_t>(gemm_config.rhs_layout.order);
3355+
+ attrs["rhs_num_cols"] = gemm_config.rhs_layout.num_cols;
3356+
+ attrs["rhs_num_rows"] = gemm_config.rhs_layout.num_rows;
3357+
+ attrs["rhs_batch_stride"] = gemm_config.rhs_layout.batch_stride;
3358+
+ attrs["rhs_leading_dim_stride"] = gemm_config.rhs_layout.leading_dim_stride;
3359+
+
3360+
+ attrs["output_layout_dtype"] =
3361+
+ static_cast<int32_t>(gemm_config.output_layout.dtype);
3362+
+ attrs["output_order"] =
3363+
+ static_cast<int32_t>(gemm_config.output_layout.order);
3364+
+ attrs["output_num_cols"] = gemm_config.output_layout.num_cols;
3365+
+ attrs["output_num_rows"] = gemm_config.output_layout.num_rows;
3366+
+ attrs["output_batch_stride"] = gemm_config.output_layout.batch_stride;
3367+
+ attrs["output_leading_dim_stride"] =
3368+
+ gemm_config.output_layout.leading_dim_stride;
3369+
+
3370+
+ attrs["batch_size"] =
3371+
+ static_cast<int64_t>(gemm_config.output_layout.batch_size);
3372+
+ attrs["alpha"] = static_cast<float>(gemm_config.alpha.real());
3373+
+ attrs["beta"] = static_cast<float>(gemm_config.beta);
3374+
+
3375+
+ // config.algorithm is less than 0, thus 0 means no algorithm
3376+
+ if (gemm_config.algorithm.has_value()) {
3377+
+ attrs["algorithm"] = static_cast<int64_t>(gemm_config.algorithm.value());
3378+
+ } else
3379+
+ attrs["algorithm"] = static_cast<int64_t>(0);
33443380
+ } else {
33453381
+ return absl::InternalError("Unknown CustomCall To SYCL FFI Call");
33463382
+ }
@@ -3350,7 +3386,7 @@ index 4ba739dba..475cd1217 100644
33503386

33513387
absl::Status IrEmitterUnnested::EmitCustomCallThunk(
33523388
const HloCustomCallInstruction* instr) {
3353-
@@ -1431,6 +1564,12 @@ absl::Status IrEmitterUnnested::EmitCustomCallThunk(
3389+
@@ -1431,6 +1600,12 @@ absl::Status IrEmitterUnnested::EmitCustomCallThunk(
33543390
break;
33553391

33563392
case CustomCallApiVersion::API_VERSION_TYPED_FFI:
@@ -3363,7 +3399,7 @@ index 4ba739dba..475cd1217 100644
33633399
if (!backend_config_str.empty()) {
33643400
mlir::Attribute attr = mlir::parseAttribute(
33653401
backend_config_str, ir_emitter_context_->mlir_context());
3366-
@@ -1443,7 +1582,7 @@ absl::Status IrEmitterUnnested::EmitCustomCallThunk(
3402+
@@ -1443,7 +1618,7 @@ absl::Status IrEmitterUnnested::EmitCustomCallThunk(
33673403
"dictionary attribute");
33683404
}
33693405
break;
@@ -3372,7 +3408,7 @@ index 4ba739dba..475cd1217 100644
33723408
default:
33733409
return Internal("Unknown custom-call API version enum value: %d",
33743410
instr->api_version());
3375-
@@ -1484,7 +1623,7 @@ absl::Status IrEmitterUnnested::EmitFftThunk(const HloFftInstruction* instr) {
3411+
@@ -1484,7 +1659,7 @@ absl::Status IrEmitterUnnested::EmitFftThunk(const HloFftInstruction* instr) {
33763412
return absl::OkStatus();
33773413
}
33783414

@@ -3381,7 +3417,7 @@ index 4ba739dba..475cd1217 100644
33813417

33823418
absl::Status IrEmitterUnnested::EmitTriangularSolveCustomCall(
33833419
const HloInstruction* instr) {
3384-
@@ -1564,7 +1703,7 @@ absl::Status IrEmitterUnnested::EmitTriangularSolveCustomCall(
3420+
@@ -1564,7 +1739,7 @@ absl::Status IrEmitterUnnested::EmitTriangularSolveCustomCall(
33853421
}
33863422
return absl::OkStatus();
33873423
}
@@ -3390,15 +3426,15 @@ index 4ba739dba..475cd1217 100644
33903426

33913427
absl::Status IrEmitterUnnested::EmitTopKCustomCall(
33923428
const HloCustomCallInstruction* instr) {
3393-
@@ -2617,6 +2756,7 @@ absl::Status IrEmitterUnnested::EmitCopyStartThunk(
3429+
@@ -2617,6 +2792,7 @@ absl::Status IrEmitterUnnested::EmitCopyStartThunk(
33943430
}
33953431

33963432
absl::Status IrEmitterUnnested::EmitSendThunk(const HloSendInstruction* instr) {
33973433
+#if 0
33983434
if (!instr->channel_id().has_value())
33993435
return absl::InternalError("Unknown send instruction channel id");
34003436

3401-
@@ -2669,12 +2809,14 @@ absl::Status IrEmitterUnnested::EmitSendThunk(const HloSendInstruction* instr) {
3437+
@@ -2669,12 +2845,14 @@ absl::Status IrEmitterUnnested::EmitSendThunk(const HloSendInstruction* instr) {
34023438
*instr->channel_id(), send_recv_events_,
34033439
ConvertFrontendAttributes(instr->frontend_attributes()),
34043440
DeviceConstraint(instr)));
@@ -3414,7 +3450,7 @@ index 4ba739dba..475cd1217 100644
34143450
if (!instr->channel_id().has_value())
34153451
return absl::InternalError("Unknown send done instruction channel id");
34163452

3417-
@@ -2685,11 +2827,13 @@ absl::Status IrEmitterUnnested::EmitSendDoneThunk(
3453+
@@ -2685,11 +2863,13 @@ absl::Status IrEmitterUnnested::EmitSendDoneThunk(
34183454
AddThunkToThunkSequence(std::make_unique<SendDoneThunk>(
34193455
Thunk::ThunkInfo::WithProfileAnnotation(instr), *instr->channel_id(),
34203456
send_recv_events_, DeviceConstraint(instr)));
@@ -3429,7 +3465,7 @@ index 4ba739dba..475cd1217 100644
34293465
if (!instr->channel_id().has_value())
34303466
return absl::InternalError("Unknown recv instruction channel id");
34313467
TF_RET_CHECK(instr->shape().IsTuple());
3432-
@@ -2744,11 +2888,13 @@ absl::Status IrEmitterUnnested::EmitRecvThunk(const HloRecvInstruction* instr) {
3468+
@@ -2744,11 +2924,13 @@ absl::Status IrEmitterUnnested::EmitRecvThunk(const HloRecvInstruction* instr) {
34333469
ConvertFrontendAttributes(instr->frontend_attributes()),
34343470
DeviceConstraint(instr)));
34353471

@@ -3444,7 +3480,7 @@ index 4ba739dba..475cd1217 100644
34443480
if (!instr->channel_id().has_value())
34453481
return absl::InternalError("Unknown recv done instruction channel id");
34463482

3447-
@@ -2759,8 +2905,9 @@ absl::Status IrEmitterUnnested::EmitRecvDoneThunk(
3483+
@@ -2759,8 +2941,9 @@ absl::Status IrEmitterUnnested::EmitRecvDoneThunk(
34483484
AddThunkToThunkSequence(std::make_unique<RecvDoneThunk>(
34493485
Thunk::ThunkInfo::WithProfileAnnotation(instr), *instr->channel_id(),
34503486
send_recv_events_, DeviceConstraint(instr)));
@@ -3455,7 +3491,7 @@ index 4ba739dba..475cd1217 100644
34553491
}
34563492

34573493
absl::Status IrEmitterUnnested::EmitHloInstruction(
3458-
@@ -2871,47 +3018,67 @@ absl::Status IrEmitterUnnested::EmitHloInstruction(
3494+
@@ -2871,47 +3054,67 @@ absl::Status IrEmitterUnnested::EmitHloInstruction(
34593495
case HloOpcode::kCustomCall: {
34603496
auto* custom_call = Cast<HloCustomCallInstruction>(instr);
34613497
if (IsLegacyCublasMatmul(*instr)) {
@@ -4816,7 +4852,7 @@ index 3f9db4ea2..2aa8d4030 100644
48164852
absl::flat_hash_map<const stream_executor::Stream*,
48174853
std::unique_ptr<GenericConvRunner>>
48184854
diff --git a/xla/service/gpu/runtime/custom_call_thunk.cc b/xla/service/gpu/runtime/custom_call_thunk.cc
4819-
index 0eaf0aaf4..172ed916b 100644
4855+
index 0eaf0aaf4..d9c4b733c 100644
48204856
--- a/xla/service/gpu/runtime/custom_call_thunk.cc
48214857
+++ b/xla/service/gpu/runtime/custom_call_thunk.cc
48224858
@@ -38,6 +38,7 @@ limitations under the License.
@@ -4836,28 +4872,7 @@ index 0eaf0aaf4..172ed916b 100644
48364872
#include "xla/stream_executor/gpu/gpu_stream.h"
48374873
#endif
48384874

4839-
@@ -79,6 +80,20 @@ CustomCallThunk::CustomCallThunk(ThunkInfo thunk_info, XLA_FFI_Handler* handler,
4840-
attributes_(std::move(attributes)),
4841-
called_computation_(called_computation) {}
4842-
4843-
+#ifdef TENSORFLOW_USE_SYCL
4844-
+ CustomCallThunk::~CustomCallThunk(){
4845-
+ if(attributes_.find("gemm_config_ptr") != attributes_.end()){
4846-
+ GemmConfig* gemm_config_ptr =
4847-
+ reinterpret_cast<GemmConfig*>(
4848-
+ std::get<int64_t>(
4849-
+ std::get<std::variant<int32_t, int64_t, float>>(attributes_["gemm_config_ptr"])));
4850-
+ if(gemm_config_ptr != nullptr){
4851-
+ delete gemm_config_ptr;
4852-
+ }
4853-
+ }
4854-
+ }
4855-
+#endif
4856-
+
4857-
absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) {
4858-
// gpu_stream is CUstream or e.g. the equivalent type in ROCm.
4859-
std::vector<void*> buffers;
4860-
@@ -98,7 +113,7 @@ absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) {
4875+
@@ -98,7 +99,7 @@ absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) {
48614876
}
48624877
}
48634878

@@ -4866,7 +4881,7 @@ index 0eaf0aaf4..172ed916b 100644
48664881
auto gpu_stream = se::gpu::AsGpuStreamValue(params.stream);
48674882
XlaCustomCallStatus custom_call_status;
48684883
call_target_(gpu_stream, buffers.data(), opaque_.data(), opaque_.size(),
4869-
@@ -109,11 +124,11 @@ absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) {
4884+
@@ -109,11 +110,11 @@ absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) {
48704885
} else {
48714886
return absl::OkStatus();
48724887
}
@@ -4881,7 +4896,7 @@ index 0eaf0aaf4..172ed916b 100644
48814896

48824897
absl::Status CustomCallThunk::ExecuteFfiHandler(const ExecuteParams& params) {
48834898
diff --git a/xla/service/gpu/runtime/custom_call_thunk.h b/xla/service/gpu/runtime/custom_call_thunk.h
4884-
index 02679d2e0..2bd5a73c6 100644
4899+
index 02679d2e0..1bcb07264 100644
48854900
--- a/xla/service/gpu/runtime/custom_call_thunk.h
48864901
+++ b/xla/service/gpu/runtime/custom_call_thunk.h
48874902
@@ -32,7 +32,7 @@ limitations under the License.
@@ -4908,17 +4923,6 @@ index 02679d2e0..2bd5a73c6 100644
49084923

49094924
using CustomCallTarget = std::function<void(Stream, void**, const char*,
49104925
size_t, XlaCustomCallStatus*)>;
4911-
@@ -91,6 +91,10 @@ class CustomCallThunk : public Thunk {
4912-
const std::vector<std::optional<Slice>>& results() const { return results_; }
4913-
absl::string_view opaque() const { return opaque_; }
4914-
4915-
+#ifdef TENSORFLOW_USE_SYCL
4916-
+ ~CustomCallThunk();
4917-
+#endif
4918-
+
4919-
private:
4920-
absl::Status ExecuteCustomCall(const ExecuteParams& params);
4921-
absl::Status ExecuteFfiHandler(const ExecuteParams& params);
49224926
diff --git a/xla/service/gpu/runtime/fft_thunk.cc b/xla/service/gpu/runtime/fft_thunk.cc
49234927
index 728c36752..fccde5793 100644
49244928
--- a/xla/service/gpu/runtime/fft_thunk.cc

xla/service/gpu/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ xetla_library(
4747
":scratch_allocator",
4848
"//xla/service:onednn_util",
4949
"//xla/service/gpu/xetla/gemm:gemm_kernel",
50+
"@xla//xla/ffi",
51+
"@xla//xla/ffi:ffi_api",
5052
"@xla//xla/service/gpu:matmul_utils",
5153
],
5254
)
@@ -58,6 +60,8 @@ cc_library(
5860
deps = [
5961
":sycl_onednn",
6062
"//xla/stream_executor/sycl:hw_info",
63+
"@xla//xla/ffi",
64+
"@xla//xla/ffi:ffi_api",
6165
"@com_google_absl//absl/algorithm:container",
6266
"@tsl//tsl/platform:errors",
6367
"@tsl//tsl/platform:logging",

0 commit comments

Comments
 (0)