Skip to content

Commit 4168481

Browse files
authored
fix ffi prefmance(main) (#382)
1 parent 3144e1f commit 4168481

File tree

6 files changed

+126
-199
lines changed

6 files changed

+126
-199
lines changed

third_party/openxla.patch

Lines changed: 85 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3232,7 +3232,7 @@ index c888aff30..0fd7254d9 100644
32323232
// Extract the memory value returned from atomicCAS and store it as
32333233
// cas_old_output.
32343234
diff --git a/xla/service/gpu/ir_emitter_unnested.cc b/xla/service/gpu/ir_emitter_unnested.cc
3235-
index 86426953a..620f092c7 100644
3235+
index 86426953a..afd144957 100644
32363236
--- a/xla/service/gpu/ir_emitter_unnested.cc
32373237
+++ b/xla/service/gpu/ir_emitter_unnested.cc
32383238
@@ -1,4 +1,6 @@
@@ -3852,7 +3852,7 @@ index 86426953a..620f092c7 100644
38523852
static absl::StatusOr<CustomCallThunk::AttributesMap> BuildAttributesMap(
38533853
mlir::DictionaryAttr dict) {
38543854
CustomCallThunk::AttributesMap attributes;
3855-
@@ -1314,6 +1337,103 @@ static absl::StatusOr<CustomCallThunk::AttributesMap> BuildAttributesMap(
3855+
@@ -1314,6 +1337,106 @@ static absl::StatusOr<CustomCallThunk::AttributesMap> BuildAttributesMap(
38563856
}
38573857
return attributes;
38583858
}
@@ -3862,9 +3862,17 @@ index 86426953a..620f092c7 100644
38623862
+ // After 0.4.26, ffi support absl::span.
38633863
+ // Below attrs can be refine to absl::span for reducing key-value
38643864
+ CustomCallThunk::AttributesMap attrs;
3865-
+ attrs["backend_config_str"] = instr->raw_backend_config_string();
38663865
+ if (IsCustomCallToDnnConvolution(*instr)) {
3866+
+ TF_ASSIGN_OR_RETURN(auto gpu_config,
3867+
+ instr->backend_config<GpuBackendConfig>());
3868+
+ const CudnnConvBackendConfig& backend_config =
3869+
+ gpu_config.cudnn_conv_backend_config();
38673870
+ TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(instr));
3871+
+ attrs["conv_result_scale"] = static_cast<float>(backend_config.conv_result_scale());
3872+
+ attrs["side_input_scale"] = static_cast<float>(backend_config.side_input_scale());
3873+
+ attrs["activation_mode"] = static_cast<int32_t>(backend_config.activation_mode());
3874+
+ attrs["leakyrelu_alpha"] = static_cast<float>(backend_config.leakyrelu_alpha());
3875+
+
38683876
+ const Window& window = instr->window();
38693877
+ const ConvolutionDimensionNumbers& dnums =
38703878
+ instr->convolution_dimension_numbers();
@@ -3930,23 +3938,18 @@ index 86426953a..620f092c7 100644
39303938
+ attrs["filter_dl"] = static_cast<int32_t>(filter_dl);
39313939
+ attrs["output_dl"] = static_cast<int32_t>(output_dl);
39323940
+ } else if (IsLegacyCublasMatmul(*instr) || IsCublasLtMatmul(*instr)) {
3933-
+ const Shape& lhs_shape = instr->operand(0)->shape();
3934-
+ const Shape& rhs_shape = instr->operand(1)->shape();
3935-
+ const Shape& output_shape = instr->shape().IsTuple()
3936-
+ ? instr->shape().tuple_shapes(0)
3937-
+ : instr->shape();
3938-
+ for (int i = 0; i < lhs_shape.layout().minor_to_major().size(); ++i) {
3939-
+ attrs["lhs_minor_to_major_" + std::to_string(i)] =
3940-
+ lhs_shape.layout().minor_to_major()[i];
3941-
+ }
3942-
+ for (int i = 0; i < rhs_shape.layout().minor_to_major().size(); ++i) {
3943-
+ attrs["rhs_minor_to_major_" + std::to_string(i)] =
3944-
+ rhs_shape.layout().minor_to_major()[i];
3945-
+ }
3946-
+ for (int i = 0; i < output_shape.layout().minor_to_major().size(); ++i) {
3947-
+ attrs["output_minor_to_major_" + std::to_string(i)] =
3948-
+ output_shape.layout().minor_to_major()[i];
3949-
+ }
3941+
+ TF_ASSIGN_OR_RETURN(const auto gpu_config,
3942+
+ instr->backend_config<xla::gpu::GpuBackendConfig>());
3943+
+ xla::gpu::GemmBackendConfig config = gpu_config.gemm_backend_config();
3944+
+ xla::gpu::GemmBackendConfig_Epilogue epilogue = config.epilogue();
3945+
+ TF_ASSIGN_OR_RETURN(
3946+
+ auto gemm_config,
3947+
+ GemmConfig::For(static_cast<const HloInstruction*>(instr)));
3948+
+ GemmConfig* gemm_config_ptr = new GemmConfig(gemm_config);
3949+
+ attrs["epilogue"] = static_cast<int32_t>(epilogue);
3950+
+ // SYCL TODO:
3951+
+ // gemm_config may be split into separate parameters and added to attrs later.
3952+
+ attrs["gemm_config_ptr"] = reinterpret_cast<int64_t>(gemm_config_ptr);
39503953
+ } else {
39513954
+ return absl::InternalError("Unknown CustomCall To SYCL FFI Call");
39523955
+ }
@@ -3956,15 +3959,15 @@ index 86426953a..620f092c7 100644
39563959

39573960
absl::Status IrEmitterUnnested::EmitCustomCallThunk(
39583961
const HloCustomCallInstruction* instr) {
3959-
@@ -1433,6 +1553,7 @@ absl::Status IrEmitterUnnested::EmitCustomCallThunk(
3962+
@@ -1433,6 +1556,7 @@ absl::Status IrEmitterUnnested::EmitCustomCallThunk(
39603963
}
39613964

39623965
auto& backend_config_str = instr->raw_backend_config_string();
39633966
+
39643967
switch (instr->api_version()) {
39653968
case CustomCallApiVersion::API_VERSION_ORIGINAL:
39663969
case CustomCallApiVersion::API_VERSION_STATUS_RETURNING:
3967-
@@ -1443,6 +1564,12 @@ absl::Status IrEmitterUnnested::EmitCustomCallThunk(
3970+
@@ -1443,6 +1567,12 @@ absl::Status IrEmitterUnnested::EmitCustomCallThunk(
39683971
break;
39693972

39703973
case CustomCallApiVersion::API_VERSION_TYPED_FFI:
@@ -3977,7 +3980,7 @@ index 86426953a..620f092c7 100644
39773980
if (!backend_config_str.empty()) {
39783981
mlir::Attribute attr = mlir::parseAttribute(
39793982
backend_config_str, ir_emitter_context_->mlir_context());
3980-
@@ -1455,7 +1582,7 @@ absl::Status IrEmitterUnnested::EmitCustomCallThunk(
3983+
@@ -1455,7 +1585,7 @@ absl::Status IrEmitterUnnested::EmitCustomCallThunk(
39813984
"dictionary attribute");
39823985
}
39833986
break;
@@ -3986,7 +3989,7 @@ index 86426953a..620f092c7 100644
39863989
default:
39873990
return Internal("Unknown custom-call API version enum value: %d",
39883991
instr->api_version());
3989-
@@ -1496,7 +1623,7 @@ absl::Status IrEmitterUnnested::EmitFftThunk(const HloFftInstruction* instr) {
3992+
@@ -1496,7 +1626,7 @@ absl::Status IrEmitterUnnested::EmitFftThunk(const HloFftInstruction* instr) {
39903993
return absl::OkStatus();
39913994
}
39923995

@@ -3995,7 +3998,7 @@ index 86426953a..620f092c7 100644
39953998

39963999
absl::Status IrEmitterUnnested::EmitTriangularSolveCustomCall(
39974000
const HloInstruction* instr) {
3998-
@@ -1576,7 +1703,7 @@ absl::Status IrEmitterUnnested::EmitTriangularSolveCustomCall(
4001+
@@ -1576,7 +1706,7 @@ absl::Status IrEmitterUnnested::EmitTriangularSolveCustomCall(
39994002
}
40004003
return absl::OkStatus();
40014004
}
@@ -4004,7 +4007,7 @@ index 86426953a..620f092c7 100644
40044007

40054008
absl::Status IrEmitterUnnested::EmitTopKCustomCall(
40064009
const HloCustomCallInstruction* instr) {
4007-
@@ -2602,33 +2729,33 @@ absl::Status IrEmitterUnnested::EmitCopyStartThunk(
4010+
@@ -2602,33 +2732,33 @@ absl::Status IrEmitterUnnested::EmitCopyStartThunk(
40084011
}
40094012

40104013
absl::Status IrEmitterUnnested::EmitSendThunk(const HloSendInstruction* instr) {
@@ -4064,7 +4067,7 @@ index 86426953a..620f092c7 100644
40644067

40654068
return absl::OkStatus();
40664069
}
4067-
@@ -2650,33 +2777,33 @@ absl::Status IrEmitterUnnested::EmitSendDoneThunk(
4070+
@@ -2650,33 +2780,33 @@ absl::Status IrEmitterUnnested::EmitSendDoneThunk(
40684071
}
40694072

40704073
absl::Status IrEmitterUnnested::EmitRecvThunk(const HloRecvInstruction* instr) {
@@ -4124,7 +4127,7 @@ index 86426953a..620f092c7 100644
41244127

41254128
return absl::OkStatus();
41264129
}
4127-
@@ -2798,13 +2925,31 @@ absl::Status IrEmitterUnnested::EmitHloInstruction(
4130+
@@ -2798,13 +2928,31 @@ absl::Status IrEmitterUnnested::EmitHloInstruction(
41284131
case HloOpcode::kCustomCall: {
41294132
auto* custom_call = Cast<HloCustomCallInstruction>(instr);
41304133
if (IsLegacyCublasMatmul(*instr)) {
@@ -4158,7 +4161,7 @@ index 86426953a..620f092c7 100644
41584161
#if GOOGLE_CUDA
41594162
if (IsCublasLtMatmulF8(*instr)) {
41604163
return EmitCublasLtMatmulThunkF8(custom_call);
4161-
@@ -2815,30 +2960,32 @@ absl::Status IrEmitterUnnested::EmitHloInstruction(
4164+
@@ -2815,30 +2963,32 @@ absl::Status IrEmitterUnnested::EmitHloInstruction(
41624165
if (IsCustomCallToDnnNorm(*instr)) {
41634166
return EmitNormThunk(custom_call);
41644167
}
@@ -4753,10 +4756,18 @@ index 15e6e6692..db8ebb0e1 100644
47534756
} // namespace xla::gpu
47544757

47554758
diff --git a/xla/service/gpu/runtime/BUILD b/xla/service/gpu/runtime/BUILD
4756-
index 4f298bbe0..39b81b2c7 100644
4759+
index 4f298bbe0..1a45c9a12 100644
47574760
--- a/xla/service/gpu/runtime/BUILD
47584761
+++ b/xla/service/gpu/runtime/BUILD
4759-
@@ -427,6 +427,7 @@ cc_library(
4762+
@@ -381,6 +381,7 @@ cc_library(
4763+
"//xla/service:custom_call_status",
4764+
"//xla/service:custom_call_status_internal",
4765+
"//xla/service:executable",
4766+
+ "//xla/service/gpu:matmul_utils",
4767+
"//xla/service/gpu:thunk",
4768+
"//xla/stream_executor:device_memory",
4769+
"//xla/stream_executor/gpu:gpu_stream_header",
4770+
@@ -427,6 +428,7 @@ cc_library(
47604771
"//xla/service/gpu:thunk",
47614772
"//xla/stream_executor",
47624773
"@com_google_absl//absl/container:flat_hash_map",
@@ -4922,10 +4933,17 @@ index 02aecd464..df9213bae 100644
49224933
absl::flat_hash_map<const stream_executor::Stream*,
49234934
std::unique_ptr<GenericConvRunner>>
49244935
diff --git a/xla/service/gpu/runtime/custom_call_thunk.cc b/xla/service/gpu/runtime/custom_call_thunk.cc
4925-
index 28a7dcebf..4c8727f6b 100644
4936+
index 28a7dcebf..faaa07689 100644
49264937
--- a/xla/service/gpu/runtime/custom_call_thunk.cc
49274938
+++ b/xla/service/gpu/runtime/custom_call_thunk.cc
4928-
@@ -36,7 +36,7 @@ limitations under the License.
4939+
@@ -30,13 +30,14 @@ limitations under the License.
4940+
#include "xla/service/buffer_assignment.h"
4941+
#include "xla/service/custom_call_status.h"
4942+
#include "xla/service/custom_call_status_internal.h"
4943+
+#include "xla/service/gpu/matmul_utils.h"
4944+
#include "xla/service/gpu/thunk.h"
4945+
#include "xla/service/service_executable_run_options.h"
4946+
#include "xla/status.h"
49294947
#include "xla/stream_executor/device_memory.h"
49304948
#include "xla/util.h"
49314949

@@ -4934,7 +4952,26 @@ index 28a7dcebf..4c8727f6b 100644
49344952
#include "xla/stream_executor/gpu/gpu_stream.h"
49354953
#endif
49364954

4937-
@@ -89,7 +89,7 @@ absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) {
4955+
@@ -70,6 +71,18 @@ CustomCallThunk::CustomCallThunk(ThunkInfo thunk_info, XLA_FFI_Handler* handler,
4956+
attributes_(std::move(attributes)),
4957+
called_computation_(called_computation) {}
4958+
4959+
+#ifdef TENSORFLOW_USE_SYCL
4960+
+ CustomCallThunk::~CustomCallThunk(){
4961+
+ if(attributes_.find("gemm_config_ptr") != attributes_.end()){
4962+
+ GemmConfig* gemm_config_ptr =
4963+
+ reinterpret_cast<GemmConfig*>(std::get<int64_t>(attributes_["gemm_config_ptr"]));
4964+
+ if(gemm_config_ptr != nullptr){
4965+
+ delete gemm_config_ptr;
4966+
+ }
4967+
+ }
4968+
+ }
4969+
+#endif
4970+
+
4971+
absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) {
4972+
// gpu_stream is CUstream or e.g. the equivalent type in ROCm.
4973+
std::vector<void*> buffers;
4974+
@@ -89,7 +102,7 @@ absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) {
49384975
}
49394976
}
49404977

@@ -4943,7 +4980,7 @@ index 28a7dcebf..4c8727f6b 100644
49434980
auto gpu_stream = se::gpu::AsGpuStreamValue(params.stream);
49444981
XlaCustomCallStatus custom_call_status;
49454982
call_target_(gpu_stream, buffers.data(), opaque_.data(), opaque_.size(),
4946-
@@ -100,11 +100,11 @@ absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) {
4983+
@@ -100,11 +113,11 @@ absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) {
49474984
} else {
49484985
return absl::OkStatus();
49494986
}
@@ -4957,7 +4994,7 @@ index 28a7dcebf..4c8727f6b 100644
49574994
}
49584995

49594996
absl::Status CustomCallThunk::ExecuteFfiHandler(const ExecuteParams& params) {
4960-
@@ -139,6 +139,10 @@ absl::Status CustomCallThunk::ExecuteFfiHandler(const ExecuteParams& params) {
4997+
@@ -139,6 +152,10 @@ absl::Status CustomCallThunk::ExecuteFfiHandler(const ExecuteParams& params) {
49614998
// execution context, as apparently it's not easily accessible from Thunk.
49624999
ExecutableRunOptions run_options;
49635000
run_options.set_stream(params.stream);
@@ -4969,7 +5006,7 @@ index 28a7dcebf..4c8727f6b 100644
49695006

49705007
CallOptions options = {&service_run_options, called_computation_};
49715008
diff --git a/xla/service/gpu/runtime/custom_call_thunk.h b/xla/service/gpu/runtime/custom_call_thunk.h
4972-
index 5fa1dce32..e75b61636 100644
5009+
index 5fa1dce32..03a61bd0c 100644
49735010
--- a/xla/service/gpu/runtime/custom_call_thunk.h
49745011
+++ b/xla/service/gpu/runtime/custom_call_thunk.h
49755012
@@ -35,7 +35,7 @@ limitations under the License.
@@ -4996,6 +5033,17 @@ index 5fa1dce32..e75b61636 100644
49965033

49975034
using CustomCallTarget = std::function<void(Stream, void**, const char*,
49985035
size_t, XlaCustomCallStatus*)>;
5036+
@@ -94,6 +94,10 @@ class CustomCallThunk : public Thunk {
5037+
const std::vector<std::optional<Slice>>& results() const { return results_; }
5038+
absl::string_view opaque() const { return opaque_; }
5039+
5040+
+#ifdef TENSORFLOW_USE_SYCL
5041+
+ ~CustomCallThunk();
5042+
+#endif
5043+
+
5044+
private:
5045+
absl::Status ExecuteCustomCall(const ExecuteParams& params);
5046+
absl::Status ExecuteFfiHandler(const ExecuteParams& params);
49995047
diff --git a/xla/service/gpu/runtime/fft_thunk.cc b/xla/service/gpu/runtime/fft_thunk.cc
50005048
index 728c36752..fccde5793 100644
50015049
--- a/xla/service/gpu/runtime/fft_thunk.cc

xla/service/gpu/onednn_gpu_conv_runner.cc

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ int64_t GetVectCSize(FilterLayout layout) {
6666
absl::Status CreateOneDnnPrimitive(
6767
OneDnnConvPrimitive* onednn_primitive, // NOLINT
6868
const ffi::Dictionary& dict,
69-
absl::flat_hash_map<std::string, std::string>& backend_dict,
7069
absl::Span<const ffi::BufferBase> operand_buffers,
7170
ffi::BufferBase result_buffer, se::Stream* stream,
7271
se::ScratchAllocator* scratch_allocator, CudnnConvKind conv_kind) {
@@ -89,7 +88,7 @@ absl::Status CreateOneDnnPrimitive(
8988
void* bias_data = nullptr;
9089
void* side_input_data = nullptr;
9190

92-
float conv_result_scale = std::stof(backend_dict["conv_result_scale"]);
91+
float conv_result_scale = *dict.get<float>("conv_result_scale");
9392
bool conv_result_scale_one = (fabs(conv_result_scale - 1.0f) < 1e-6);
9493

9594
switch (conv_kind) {
@@ -139,7 +138,7 @@ absl::Status CreateOneDnnPrimitive(
139138
bias_data = const_cast<void*>(operand_buffers[2].data.opaque());
140139
if (operand_buffers.size() >= 4) {
141140
side_input_data = const_cast<void*>(operand_buffers[3].data.opaque());
142-
side_input_scale = std::stof(backend_dict["side_input_scale"]);
141+
side_input_scale = *dict.get<float>("side_input_scale");
143142
side_input_scale_zero = (fabs(side_input_scale - 0.0f) < 1e-6);
144143
}
145144
}
@@ -457,22 +456,30 @@ absl::Status CreateOneDnnPrimitive(
457456
onednn_primitive->bias_memory});
458457
}
459458
if (conv_kind == CudnnConvKind::kForwardActivation) {
460-
if (backend_dict["activation_mode"] == "kNone") {
461-
} else if (backend_dict["activation_mode"] == "kSigmoid") {
462-
po.append_eltwise(dnnl::algorithm::eltwise_logistic, 1, 0);
463-
} else if (backend_dict["activation_mode"] == "kRelu") {
464-
po.append_eltwise(dnnl::algorithm::eltwise_relu, 0, 0);
465-
} else if (backend_dict["activation_mode"] == "kRelu6") {
466-
po.append_eltwise(dnnl::algorithm::eltwise_clip_v2, 0, 6);
467-
} else if (backend_dict["activation_mode"] == "kTanh") {
468-
po.append_eltwise(dnnl::algorithm::eltwise_tanh, 0, 0);
469-
} else if (backend_dict["activation_mode"] == "kElu") {
470-
po.append_eltwise(dnnl::algorithm::eltwise_elu, 1, 0);
471-
} else if (backend_dict["activation_mode"] == "kLeakyRelu") {
472-
float leakyrelu_alpha = std::stof(backend_dict["leakyrelu_alpha"]);
473-
po.append_eltwise(dnnl::algorithm::eltwise_relu, leakyrelu_alpha, 0);
474-
} else {
475-
return Internal("Unsupported Activation mode");
459+
auto activation_mode = static_cast<stream_executor::dnn::ActivationMode>(*dict.get<int32_t>("activation_mode"));
460+
switch (activation_mode) {
461+
case stream_executor::dnn::kSigmoid:
462+
po.append_eltwise(dnnl::algorithm::eltwise_logistic, 1, 0);
463+
break;
464+
case stream_executor::dnn::kRelu:
465+
po.append_eltwise(dnnl::algorithm::eltwise_relu, 0, 0);
466+
break;
467+
case stream_executor::dnn::kRelu6:
468+
po.append_eltwise(dnnl::algorithm::eltwise_clip_v2, 0, 6);
469+
break;
470+
case stream_executor::dnn::kTanh:
471+
po.append_eltwise(dnnl::algorithm::eltwise_tanh, 0, 0);
472+
break;
473+
case stream_executor::dnn::kElu:
474+
po.append_eltwise(dnnl::algorithm::eltwise_elu, 1, 0);
475+
break;
476+
case stream_executor::dnn::kLeakyRelu:
477+
po.append_eltwise(dnnl::algorithm::eltwise_relu, *dict.get<float>("leakyrelu_alpha"), 0);
478+
break;
479+
case stream_executor::dnn::kNone:
480+
break;
481+
default:
482+
return Internal("Unsupported Activation mode");
476483
}
477484
}
478485
post_ops_attr.set_post_ops(po);
@@ -673,12 +680,11 @@ absl::Status CreateOneDnnPrimitive(
673680

674681
absl::StatusOr<OneDnnConvPrimitive> GetOrCreateOneDnnConvPrimitive(
675682
se::Stream* stream, const ffi::Dictionary& dict,
676-
absl::flat_hash_map<std::string, std::string>& backend_dict,
677683
const std::vector<ffi::BufferBase>& operand_se_buffers,
678684
const ffi::BufferBase& result_buffer,
679685
se::ScratchAllocator* scratch_allocator, CudnnConvKind conv_kind) {
680686
OneDnnConvPrimitive primitive;
681-
auto status = CreateOneDnnPrimitive(&primitive, dict, backend_dict,
687+
auto status = CreateOneDnnPrimitive(&primitive, dict,
682688
absl::MakeSpan(operand_se_buffers),
683689
result_buffer, stream, scratch_allocator,
684690
conv_kind);

xla/service/gpu/onednn_gpu_conv_runner.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ typedef struct OneDnnConvPrimitive {
5353

5454
absl::StatusOr<OneDnnConvPrimitive> GetOrCreateOneDnnConvPrimitive(
5555
se::Stream*, const ffi::Dictionary& dict,
56-
absl::flat_hash_map<std::string, std::string>& backend_dict,
5756
const std::vector<ffi::BufferBase>& operand_se_buffers,
5857
const ffi::BufferBase& result_buffer,
5958
se::ScratchAllocator* scratch_allocator, CudnnConvKind conv_kind);

0 commit comments

Comments
 (0)